/
/
/
1"""Music Assistant WebRTC Gateway.
2
3This module provides WebRTC-based remote access to Music Assistant instances.
4It connects to a signaling server and handles incoming WebRTC connections,
5bridging them to the local WebSocket API.
6"""
7
8from __future__ import annotations
9
10import asyncio
11import contextlib
12import json
13import logging
14from collections.abc import Awaitable, Callable
15from dataclasses import dataclass, field
16from typing import TYPE_CHECKING, Any
17
18import aiohttp
19from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
20from aiortc.sdp import candidate_from_sdp
21
22from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
23from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
24
25if TYPE_CHECKING:
26 from aiortc.rtcdtlstransport import RTCCertificate
27
28LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access")
29
30# Reduce verbose logging from aiortc/aioice
31logging.getLogger("aioice").setLevel(logging.WARNING)
32logging.getLogger("aiortc").setLevel(logging.WARNING)
33
34
35@dataclass
36class WebRTCSession:
37 """Represents an active WebRTC session with a remote client."""
38
39 session_id: str
40 peer_connection: RTCPeerConnection
41 # Main API channel (ma-api) - bridges to local MA WebSocket API
42 data_channel: Any = None
43 local_ws: Any = None
44 message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
45 forward_to_local_task: asyncio.Task[None] | None = None
46 forward_from_local_task: asyncio.Task[None] | None = None
47 # Sendspin channel - bridges to internal sendspin server
48 sendspin_channel: Any = None
49 sendspin_ws: Any = None
50 sendspin_queue: asyncio.Queue[str | bytes] = field(default_factory=asyncio.Queue)
51 sendspin_player_id: str | None = None # Extracted from first sendspin auth message
52 sendspin_to_local_task: asyncio.Task[None] | None = None
53 sendspin_from_local_task: asyncio.Task[None] | None = None
54
55
56class WebRTCGateway:
57 """WebRTC Gateway for Music Assistant Remote Access.
58
59 This gateway:
60 1. Connects to a signaling server
61 2. Registers with a unique Remote ID
62 3. Handles incoming WebRTC connections from remote PWA clients
63 4. Bridges WebRTC DataChannel messages to the local WebSocket API
64 """
65
66 # Close code 4000 means this connection was replaced by a new one from the same server
67 # In that case, we should not reconnect as another connection is now active
68 CLOSE_CODE_REPLACED = 4000
69
70 # Default ICE servers (public STUN only - used as fallback)
71 DEFAULT_ICE_SERVERS: list[dict[str, Any]] = [
72 {"urls": "stun:stun.home-assistant.io:3478"},
73 {"urls": "stun:stun.l.google.com:19302"},
74 {"urls": "stun:stun1.l.google.com:19302"},
75 {"urls": "stun:stun.cloudflare.com:3478"},
76 ]
77
78 def __init__(
79 self,
80 http_session: aiohttp.ClientSession,
81 remote_id: str,
82 certificate: RTCCertificate,
83 signaling_url: str = "wss://signaling.music-assistant.io/ws",
84 local_ws_url: str = "ws://localhost:8095/ws",
85 sendspin_url: str = "ws://localhost:8927/sendspin",
86 ice_servers: list[dict[str, Any]] | None = None,
87 ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None,
88 set_sendspin_player_callback: Callable[[str, str], None] | None = None,
89 ) -> None:
90 """
91 Initialize the WebRTC Gateway.
92
93 :param http_session: Shared aiohttp ClientSession for HTTP/WebSocket connections.
94 :param remote_id: Remote ID for this server instance.
95 :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning.
96 :param signaling_url: WebSocket URL of the signaling server.
97 :param local_ws_url: Local WebSocket URL to bridge to.
98 :param sendspin_url: Internal Sendspin WebSocket URL to bridge to.
99 :param ice_servers: List of ICE server configurations (used at registration time).
100 :param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session.
101 :param set_sendspin_player_callback: Callback to set sendspin player for a session.
102 """
103 self.http_session = http_session
104 self.signaling_url = signaling_url
105 self.local_ws_url = local_ws_url
106 self.sendspin_url = sendspin_url
107 self._remote_id = remote_id
108 self._certificate = certificate
109 self.logger = LOGGER
110 self._ice_servers_callback = ice_servers_callback
111 self._set_sendspin_player_callback = set_sendspin_player_callback
112
113 # Static ICE servers used at registration time (relayed to clients via signaling server)
114 self.ice_servers = ice_servers or self.DEFAULT_ICE_SERVERS
115
116 self.sessions: dict[str, WebRTCSession] = {}
117 self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None
118 self._running = False
119 self._reconnect_delay = 10 # Wait 10 seconds before reconnecting
120 self._max_reconnect_delay = 300 # Max 5 minutes between reconnects
121 self._current_reconnect_delay = 10
122 self._run_task: asyncio.Task[None] | None = None
123 self._is_connected = False
124 self._connecting = False
125
126 @property
127 def is_running(self) -> bool:
128 """Return whether the gateway is running."""
129 return self._running
130
131 @property
132 def is_connected(self) -> bool:
133 """Return whether the gateway is connected to the signaling server."""
134 return self._is_connected
135
136 async def _get_fresh_ice_servers(self) -> list[dict[str, Any]]:
137 """Get fresh ICE servers for a new WebRTC session.
138
139 If an ice_servers_callback was provided, it will be called to get fresh
140 TURN credentials. Otherwise, returns the static ice_servers.
141
142 :return: List of ICE server configurations with fresh credentials.
143 """
144 if self._ice_servers_callback:
145 try:
146 fresh_servers = await self._ice_servers_callback()
147 if fresh_servers:
148 return fresh_servers
149 except Exception:
150 self.logger.exception("Failed to fetch fresh ICE servers, using cached servers")
151 return self.ice_servers
152
153 async def start(self) -> None:
154 """Start the WebRTC Gateway."""
155 if self._running:
156 self.logger.warning("WebRTC Gateway already running, skipping start")
157 return
158 self.logger.info("Starting WebRTC Gateway")
159 self.logger.debug("Signaling URL: %s", self.signaling_url)
160 self.logger.debug("Local WS URL: %s", self.local_ws_url)
161 self._running = True
162 self._run_task = asyncio.create_task(self._run())
163 self.logger.debug("WebRTC Gateway start task created")
164
165 async def stop(self) -> None:
166 """Stop the WebRTC Gateway."""
167 self.logger.info("Stopping WebRTC Gateway")
168 self._running = False
169
170 # Close all sessions
171 for session_id in list(self.sessions.keys()):
172 await self._close_session(session_id)
173
174 # Close signaling connection gracefully
175 if self._signaling_ws and not self._signaling_ws.closed:
176 try:
177 await self._signaling_ws.close()
178 except Exception:
179 self.logger.debug("Error closing signaling WebSocket", exc_info=True)
180
181 # Cancel run task and wait for it to finish
182 if self._run_task and not self._run_task.done():
183 self._run_task.cancel()
184 with contextlib.suppress(asyncio.CancelledError):
185 await self._run_task
186
187 # Wait briefly for any in-progress connection to notice _running=False
188 if self._connecting:
189 await asyncio.sleep(0.1)
190
191 self._signaling_ws = None
192 self._connecting = False
193
194 async def _run(self) -> None:
195 """Run the main loop with reconnection logic."""
196 self.logger.debug("WebRTC Gateway _run() loop starting")
197 while self._running:
198 should_reconnect = True
199 try:
200 should_reconnect = await self._connect_to_signaling()
201 # Connection closed gracefully or with error
202 self._is_connected = False
203 if self._running and should_reconnect:
204 self.logger.warning(
205 "Signaling server connection lost. Reconnecting in %ss...",
206 self._current_reconnect_delay,
207 )
208 except Exception:
209 self._is_connected = False
210 self.logger.exception("Signaling connection error")
211 if self._running:
212 self.logger.info(
213 "Reconnecting to signaling server in %ss",
214 self._current_reconnect_delay,
215 )
216
217 if self._running and should_reconnect:
218 await asyncio.sleep(self._current_reconnect_delay)
219 # Exponential backoff with max limit
220 self._current_reconnect_delay = min(
221 self._current_reconnect_delay * 2, self._max_reconnect_delay
222 )
223 elif not should_reconnect:
224 # Connection was replaced by another instance, stop the run loop
225 self.logger.info("Connection replaced, stopping reconnection attempts")
226 self._running = False
227 break
228
229 async def _connect_to_signaling(self) -> bool:
230 """Connect to the signaling server.
231
232 :return: True if reconnection should be attempted, False if connection was replaced.
233 """
234 if self._connecting:
235 self.logger.warning("Already connecting to signaling server, skipping")
236 return False # Don't trigger another reconnect cycle
237 self._connecting = True
238 close_code: int | None = None
239 self.logger.info("Connecting to signaling server: %s", self.signaling_url)
240 try:
241 self._signaling_ws = await self.http_session.ws_connect(
242 self.signaling_url,
243 heartbeat=35.0, # Send ping every 35s (slightly above server's 30s interval)
244 )
245 # Check if we were stopped while connecting
246 if not self._running:
247 self.logger.debug("Gateway stopped during connection, closing WebSocket")
248 await self._signaling_ws.close()
249 self._signaling_ws = None
250 self._connecting = False
251 return False
252 self.logger.debug("WebSocket connection established, id=%s", id(self._signaling_ws))
253 self.logger.debug("Sending registration")
254 await self._register()
255 self._current_reconnect_delay = self._reconnect_delay
256 self.logger.debug("Registration sent, waiting for confirmation...")
257
258 # Run message loop and get close code
259 close_code = await self._signaling_message_loop(self._signaling_ws)
260
261 # Get close code from WebSocket if not already set from CLOSE message
262 if close_code is None:
263 close_code = self._signaling_ws.close_code
264 ws_exception = self._signaling_ws.exception()
265 self.logger.debug(
266 "Message loop exited - WebSocket closed: %s, close_code: %s, exception: %s",
267 self._signaling_ws.closed,
268 close_code,
269 ws_exception,
270 )
271 except TimeoutError:
272 self.logger.error("Timeout connecting to signaling server")
273 except aiohttp.ClientError as err:
274 self.logger.error("Failed to connect to signaling server: %s", err)
275 except Exception:
276 self.logger.exception("Unexpected error in signaling connection")
277 finally:
278 self._is_connected = False
279 self._connecting = False
280 self._signaling_ws = None
281
282 # Check if this connection was replaced by another one
283 if close_code == self.CLOSE_CODE_REPLACED:
284 self.logger.info("Connection was replaced by another instance - not reconnecting")
285 return False
286
287 return True
288
289 async def _signaling_message_loop(self, ws: aiohttp.ClientWebSocketResponse) -> int | None:
290 """Process messages from the signaling WebSocket.
291
292 :param ws: The WebSocket connection to process messages from.
293 :return: Close code if connection was closed with a code, None otherwise.
294 """
295 close_code: int | None = None
296 self.logger.debug("Entering message loop")
297 async for msg in ws:
298 if msg.type == aiohttp.WSMsgType.TEXT:
299 try:
300 await self._handle_signaling_message(json.loads(msg.data))
301 except Exception:
302 self.logger.exception("Error handling signaling message")
303 elif msg.type == aiohttp.WSMsgType.PING:
304 self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PING")
305 elif msg.type == aiohttp.WSMsgType.PONG:
306 self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PONG")
307 elif msg.type == aiohttp.WSMsgType.CLOSE:
308 close_code = msg.data
309 self.logger.warning(
310 "Signaling server sent close frame: code=%s, reason=%s",
311 msg.data,
312 msg.extra,
313 )
314 break
315 elif msg.type == aiohttp.WSMsgType.CLOSED:
316 self.logger.warning("Signaling server closed connection")
317 break
318 elif msg.type == aiohttp.WSMsgType.ERROR:
319 self.logger.error("WebSocket error: %s", ws.exception())
320 break
321 else:
322 self.logger.warning("Unexpected WebSocket message type: %s", msg.type)
323 return close_code
324
325 async def _register(self) -> None:
326 """Register with the signaling server."""
327 if self._signaling_ws:
328 await self._signaling_ws.send_json(
329 {
330 "type": "register-server",
331 "remoteId": self._remote_id,
332 "iceServers": self.ice_servers,
333 }
334 )
335
336 async def _handle_signaling_message(self, message: dict[str, Any]) -> None:
337 """Handle incoming signaling messages.
338
339 :param message: The signaling message.
340 """
341 msg_type = message.get("type")
342
343 if msg_type in ("ping", "pong"):
344 # Ignore JSON-level ping/pong messages - we use WebSocket protocol-level heartbeat
345 # The signaling server still sends these for backward compatibility with older clients
346 pass
347 elif msg_type == "registered":
348 self._is_connected = True
349 self.logger.info("Registered with signaling server")
350 elif msg_type == "error":
351 error_msg = message.get("error") or message.get("message", "Unknown error")
352 self.logger.error("Signaling server error: %s", error_msg)
353 elif msg_type == "client-connected":
354 session_id = message.get("sessionId")
355 if session_id:
356 await self._create_session(session_id)
357 # Send session-ready with fresh ICE servers for the client
358 fresh_ice_servers = await self._get_fresh_ice_servers()
359 if self._signaling_ws:
360 await self._signaling_ws.send_json(
361 {
362 "type": "session-ready",
363 "sessionId": session_id,
364 "iceServers": fresh_ice_servers,
365 }
366 )
367 elif msg_type == "client-disconnected":
368 session_id = message.get("sessionId")
369 if session_id:
370 await self._close_session(session_id)
371 elif msg_type == "offer":
372 session_id = message.get("sessionId")
373 offer_data = message.get("data")
374 if session_id and offer_data:
375 await self._handle_offer(session_id, offer_data)
376 elif msg_type == "ice-candidate":
377 session_id = message.get("sessionId")
378 candidate_data = message.get("data")
379 if session_id and candidate_data:
380 await self._handle_ice_candidate(session_id, candidate_data)
381
382 async def _create_session(self, session_id: str) -> None:
383 """Create a new WebRTC session.
384
385 :param session_id: The session ID.
386 """
387 session_ice_servers = await self._get_fresh_ice_servers()
388 config = RTCConfiguration(
389 iceServers=[RTCIceServer(**server) for server in session_ice_servers]
390 )
391 pc = create_peer_connection_with_certificate(self._certificate, configuration=config)
392 session = WebRTCSession(session_id=session_id, peer_connection=pc)
393 self.sessions[session_id] = session
394
395 @pc.on("datachannel")
396 def on_datachannel(channel: Any) -> None:
397 if channel.label == "sendspin":
398 session.sendspin_channel = channel
399 asyncio.create_task(self._setup_sendspin_channel(session))
400 else:
401 session.data_channel = channel
402 asyncio.create_task(self._setup_data_channel(session))
403
404 @pc.on("icecandidate")
405 async def on_icecandidate(candidate: Any) -> None:
406 if candidate and self._signaling_ws:
407 await self._signaling_ws.send_json(
408 {
409 "type": "ice-candidate",
410 "sessionId": session_id,
411 "data": {
412 "candidate": candidate.candidate,
413 "sdpMid": candidate.sdpMid,
414 "sdpMLineIndex": candidate.sdpMLineIndex,
415 },
416 }
417 )
418
419 @pc.on("connectionstatechange")
420 async def on_connectionstatechange() -> None:
421 if pc.connectionState == "failed":
422 await self._close_session(session_id)
423
424 async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None:
425 """Handle incoming WebRTC offer.
426
427 :param session_id: The session ID.
428 :param offer: The offer data.
429 """
430 session = self.sessions.get(session_id)
431 if not session:
432 return
433 pc = session.peer_connection
434
435 if pc.connectionState in ("closed", "failed"):
436 return
437
438 sdp = offer.get("sdp")
439 sdp_type = offer.get("type")
440 if not sdp or not sdp_type:
441 self.logger.error("Invalid offer data: missing sdp or type")
442 return
443
444 try:
445 await pc.setRemoteDescription(
446 RTCSessionDescription(
447 sdp=str(sdp),
448 type=str(sdp_type),
449 )
450 )
451
452 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
453 return
454
455 answer = await pc.createAnswer()
456
457 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
458 return
459
460 await pc.setLocalDescription(answer)
461
462 # Wait for ICE gathering to complete before sending the answer
463 # aiortc doesn't support trickle ICE, candidates are embedded in SDP after gathering
464 gather_timeout = 30
465 gather_start = asyncio.get_event_loop().time()
466 while pc.iceGatheringState != "complete":
467 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
468 return
469 if asyncio.get_event_loop().time() - gather_start > gather_timeout:
470 self.logger.warning("Session %s ICE gathering timeout", session_id)
471 break
472 await asyncio.sleep(0.1)
473
474 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
475 return
476
477 if self._signaling_ws:
478 await self._signaling_ws.send_json(
479 {
480 "type": "answer",
481 "sessionId": session_id,
482 "data": {
483 "sdp": pc.localDescription.sdp,
484 "type": pc.localDescription.type,
485 },
486 }
487 )
488 except Exception:
489 self.logger.exception("Error handling offer for session %s", session_id)
490 # Clean up the session on error
491 await self._close_session(session_id)
492
493 async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None:
494 """Handle incoming ICE candidate.
495
496 :param session_id: The session ID.
497 :param candidate: The ICE candidate data.
498 """
499 session = self.sessions.get(session_id)
500 if not session or not candidate:
501 return
502
503 pc = session.peer_connection
504 if pc.connectionState in ("closed", "failed"):
505 return
506
507 candidate_str = candidate.get("candidate")
508 sdp_mid = candidate.get("sdpMid")
509 sdp_mline_index = candidate.get("sdpMLineIndex")
510
511 if not candidate_str:
512 return
513
514 try:
515 # Parse ICE candidate - browser sends "candidate:..." format
516 if candidate_str.startswith("candidate:"):
517 sdp_candidate_str = candidate_str[len("candidate:") :]
518 else:
519 sdp_candidate_str = candidate_str
520
521 ice_candidate = candidate_from_sdp(sdp_candidate_str)
522 ice_candidate.sdpMid = str(sdp_mid) if sdp_mid else None
523 ice_candidate.sdpMLineIndex = (
524 int(sdp_mline_index) if sdp_mline_index is not None else None
525 )
526
527 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
528 return
529
530 await session.peer_connection.addIceCandidate(ice_candidate)
531 except Exception:
532 self.logger.exception("Failed to add ICE candidate for session %s", session_id)
533
534 async def _setup_data_channel(self, session: WebRTCSession) -> None:
535 """Set up data channel and bridge to local WebSocket.
536
537 :param session: The WebRTC session.
538 """
539 channel = session.data_channel
540 if not channel:
541 return
542 try:
543 # Include session_id in URL so server can track WebRTC sessions
544 ws_url = f"{self.local_ws_url}?webrtc_session_id={session.session_id}"
545 session.local_ws = await self.http_session.ws_connect(ws_url)
546 loop = asyncio.get_event_loop()
547
548 # Store task references for proper cleanup
549 session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
550 session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session))
551
552 @channel.on("message") # type: ignore[untyped-decorator]
553 def on_message(message: str) -> None:
554 # Called from aiortc thread, use call_soon_threadsafe
555 # Only queue message if session is still active
556 if session.forward_to_local_task and not session.forward_to_local_task.done():
557 loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
558
559 @channel.on("close") # type: ignore[untyped-decorator]
560 def on_close() -> None:
561 # Called from aiortc thread, use call_soon_threadsafe to schedule task
562 asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop)
563
564 except Exception:
565 self.logger.exception("Failed to connect to local WebSocket")
566
567 async def _forward_to_local(self, session: WebRTCSession) -> None:
568 """Forward messages from WebRTC DataChannel to local WebSocket.
569
570 :param session: The WebRTC session.
571 """
572 try:
573 while session.local_ws and not session.local_ws.closed:
574 message = await session.message_queue.get()
575
576 # Check if this is an HTTP proxy request
577 try:
578 msg_data = json.loads(message)
579 if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request":
580 # Handle HTTP proxy request
581 await self._handle_http_proxy_request(session, msg_data)
582 continue
583 except (json.JSONDecodeError, ValueError):
584 pass
585
586 # Regular WebSocket message
587 if session.local_ws and not session.local_ws.closed:
588 await session.local_ws.send_str(message)
589 except asyncio.CancelledError:
590 # Task was cancelled during cleanup, this is expected
591 self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
592 raise
593 except Exception:
594 self.logger.exception("Error forwarding to local WebSocket")
595
596 async def _forward_from_local(self, session: WebRTCSession) -> None:
597 """Forward messages from local WebSocket to WebRTC DataChannel.
598
599 :param session: The WebRTC session.
600 """
601 try:
602 async for msg in session.local_ws:
603 if msg.type == aiohttp.WSMsgType.TEXT:
604 if session.data_channel and session.data_channel.readyState == "open":
605 session.data_channel.send(msg.data)
606 elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
607 break
608 except asyncio.CancelledError:
609 # Task was cancelled during cleanup, this is expected
610 self.logger.debug(
611 "Forward from local task cancelled for session %s", session.session_id
612 )
613 raise
614 except Exception:
615 self.logger.exception("Error forwarding from local WebSocket")
616
617 async def _handle_http_proxy_request(
618 self, session: WebRTCSession, request_data: dict[str, Any]
619 ) -> None:
620 """Handle HTTP proxy request from remote client.
621
622 :param session: The WebRTC session.
623 :param request_data: The HTTP proxy request data.
624 """
625 request_id = request_data.get("id")
626 method = request_data.get("method", "GET")
627 path = request_data.get("path", "/")
628 headers = request_data.get("headers", {})
629
630 # Build local HTTP URL
631 # Extract host and port from local_ws_url (ws://localhost:8095/ws)
632 ws_url_parts = self.local_ws_url.replace("ws://", "").split("/")
633 host_port = ws_url_parts[0] # localhost:8095
634 local_http_url = f"http://{host_port}{path}"
635
636 self.logger.debug("HTTP proxy request: %s %s", method, local_http_url)
637
638 try:
639 # Use shared HTTP session for this request
640 async with self.http_session.request(
641 method, local_http_url, headers=headers
642 ) as response:
643 # Read response body
644 body = await response.read()
645
646 # Prepare response data
647 response_data = {
648 "type": "http-proxy-response",
649 "id": request_id,
650 "status": response.status,
651 "headers": dict(response.headers),
652 "body": body.hex(), # Send as hex string to avoid encoding issues
653 }
654
655 # Send response back through data channel
656 if session.data_channel and session.data_channel.readyState == "open":
657 session.data_channel.send(json.dumps(response_data))
658
659 except Exception as err:
660 self.logger.exception("Error handling HTTP proxy request")
661 # Send error response
662 error_response = {
663 "type": "http-proxy-response",
664 "id": request_id,
665 "status": 500,
666 "headers": {"Content-Type": "text/plain"},
667 "body": str(err).encode().hex(),
668 }
669 if session.data_channel and session.data_channel.readyState == "open":
670 session.data_channel.send(json.dumps(error_response))
671
672 async def _close_session(self, session_id: str) -> None:
673 """Close a WebRTC session.
674
675 :param session_id: The session ID.
676 """
677 session = self.sessions.pop(session_id, None)
678 if not session:
679 return
680
681 # Cancel forwarding tasks first to prevent race conditions
682 if session.forward_to_local_task and not session.forward_to_local_task.done():
683 session.forward_to_local_task.cancel()
684 with contextlib.suppress(asyncio.CancelledError):
685 await session.forward_to_local_task
686
687 if session.forward_from_local_task and not session.forward_from_local_task.done():
688 session.forward_from_local_task.cancel()
689 with contextlib.suppress(asyncio.CancelledError):
690 await session.forward_from_local_task
691
692 # Cancel sendspin forwarding tasks
693 if session.sendspin_to_local_task and not session.sendspin_to_local_task.done():
694 session.sendspin_to_local_task.cancel()
695 with contextlib.suppress(asyncio.CancelledError):
696 await session.sendspin_to_local_task
697
698 if session.sendspin_from_local_task and not session.sendspin_from_local_task.done():
699 session.sendspin_from_local_task.cancel()
700 with contextlib.suppress(asyncio.CancelledError):
701 await session.sendspin_from_local_task
702
703 # Close connections
704 if session.local_ws and not session.local_ws.closed:
705 await session.local_ws.close()
706 if session.sendspin_ws and not session.sendspin_ws.closed:
707 await session.sendspin_ws.close()
708 if session.data_channel:
709 session.data_channel.close()
710 if session.sendspin_channel:
711 session.sendspin_channel.close()
712 await session.peer_connection.close()
713
714 async def _setup_sendspin_channel(self, session: WebRTCSession) -> None:
715 """Set up sendspin data channel and bridge to internal sendspin server.
716
717 :param session: The WebRTC session.
718 """
719 channel = session.sendspin_channel
720 if not channel:
721 return
722
723 try:
724 loop = asyncio.get_event_loop()
725
726 @channel.on("message") # type: ignore[untyped-decorator]
727 def on_message(message: str | bytes) -> None:
728 # Queue if task not yet created (None) or still running.
729 # Only drop when task exists and is done (shutdown).
730 if (
731 session.sendspin_to_local_task is None
732 or not session.sendspin_to_local_task.done()
733 ):
734 loop.call_soon_threadsafe(session.sendspin_queue.put_nowait, message)
735
736 @channel.on("close") # type: ignore[untyped-decorator]
737 def on_close() -> None:
738 if session.sendspin_ws and not session.sendspin_ws.closed:
739 asyncio.run_coroutine_threadsafe(session.sendspin_ws.close(), loop)
740
741 session.sendspin_ws = await self.http_session.ws_connect(self.sendspin_url)
742 self.logger.debug("Sendspin channel connected for session %s", session.session_id)
743
744 # Start forwarding tasks - queued messages will be processed
745 session.sendspin_to_local_task = asyncio.create_task(
746 self._forward_sendspin_to_local(session)
747 )
748 session.sendspin_from_local_task = asyncio.create_task(
749 self._forward_sendspin_from_local(session)
750 )
751
752 except Exception:
753 self.logger.exception(
754 "Failed to connect sendspin channel to internal server for session %s",
755 session.session_id,
756 )
757 # Clean up partial state on failure
758 if session.sendspin_to_local_task:
759 session.sendspin_to_local_task.cancel()
760 if session.sendspin_from_local_task:
761 session.sendspin_from_local_task.cancel()
762 if session.sendspin_ws and not session.sendspin_ws.closed:
763 await session.sendspin_ws.close()
764
765 async def _forward_sendspin_to_local(self, session: WebRTCSession) -> None:
766 """Forward messages from sendspin DataChannel to internal sendspin server.
767
768 :param session: The WebRTC session.
769 """
770 first_message = True
771 try:
772 while session.sendspin_ws and not session.sendspin_ws.closed:
773 message = await session.sendspin_queue.get()
774
775 # Check only the first message for client_id extraction
776 if first_message:
777 first_message = False
778 if isinstance(message, str):
779 self._try_extract_sendspin_client_id(session, message)
780
781 if session.sendspin_ws and not session.sendspin_ws.closed:
782 if isinstance(message, bytes):
783 await session.sendspin_ws.send_bytes(message)
784 else:
785 await session.sendspin_ws.send_str(message)
786 except asyncio.CancelledError:
787 self.logger.debug(
788 "Sendspin forward to local task cancelled for session %s",
789 session.session_id,
790 )
791 raise
792 except Exception:
793 self.logger.exception("Error forwarding sendspin to local")
794
795 def _try_extract_sendspin_client_id(self, session: WebRTCSession, message: str) -> None:
796 """Try to extract client_id from sendspin auth message and set on websocket client.
797
798 :param session: The WebRTC session.
799 :param message: The first sendspin message (expected to be auth).
800 """
801 try:
802 data = json.loads(message)
803 if data.get("type") != "auth":
804 return # Not an auth message
805
806 # This is an auth message - extract client_id if present
807 if client_id := data.get("client_id"):
808 session.sendspin_player_id = client_id
809 self.logger.debug(
810 "Extracted sendspin player %s for session %s",
811 client_id,
812 session.session_id,
813 )
814 # Use callback to set sendspin player on the websocket client
815 if self._set_sendspin_player_callback:
816 self._set_sendspin_player_callback(session.session_id, client_id)
817 except (json.JSONDecodeError, TypeError):
818 pass # Not valid JSON, ignore
819
820 async def _forward_sendspin_from_local(self, session: WebRTCSession) -> None:
821 """Forward messages from internal sendspin server to sendspin DataChannel.
822
823 :param session: The WebRTC session.
824 """
825 if not session.sendspin_ws or session.sendspin_ws.closed:
826 return
827
828 try:
829 async for msg in session.sendspin_ws:
830 if msg.type in {aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY}:
831 if session.sendspin_channel and session.sendspin_channel.readyState == "open":
832 session.sendspin_channel.send(msg.data)
833 elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
834 break
835 except asyncio.CancelledError:
836 self.logger.debug(
837 "Sendspin forward from local task cancelled for session %s",
838 session.session_id,
839 )
840 raise
841 except Exception:
842 self.logger.exception("Error forwarding sendspin from local")
843