/
/
/
1"""Music Assistant WebRTC Gateway.
2
3This module provides WebRTC-based remote access to Music Assistant instances.
4It connects to a signaling server and handles incoming WebRTC connections,
5bridging them to the local WebSocket API.
6"""
7
8from __future__ import annotations
9
10import asyncio
11import contextlib
12import json
13import logging
14from collections.abc import Awaitable, Callable
15from dataclasses import dataclass, field
16from typing import TYPE_CHECKING, Any
17from urllib.parse import urlparse
18
19import aiohttp
20from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
21from aiortc.sdp import candidate_from_sdp
22
23from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
24from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
25
26if TYPE_CHECKING:
27 from aiortc.rtcdtlstransport import RTCCertificate
28
29LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access")
30
31# Reduce verbose logging from aiortc/aioice
32logging.getLogger("aioice").setLevel(logging.WARNING)
33logging.getLogger("aiortc").setLevel(logging.WARNING)
34
35
36@dataclass
37class WebRTCSession:
38 """Represents an active WebRTC session with a remote client."""
39
40 session_id: str
41 peer_connection: RTCPeerConnection
42 # Main API channel (ma-api) - bridges to local MA WebSocket API
43 data_channel: Any = None
44 local_ws: Any = None
45 message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
46 forward_to_local_task: asyncio.Task[None] | None = None
47 forward_from_local_task: asyncio.Task[None] | None = None
48 # Sendspin channel - bridges to internal sendspin server
49 sendspin_channel: Any = None
50 sendspin_ws: Any = None
51 sendspin_queue: asyncio.Queue[str | bytes] = field(default_factory=asyncio.Queue)
52 sendspin_player_id: str | None = None # Extracted from first sendspin auth message
53 sendspin_to_local_task: asyncio.Task[None] | None = None
54 sendspin_from_local_task: asyncio.Task[None] | None = None
55
56
57class WebRTCGateway:
58 """WebRTC Gateway for Music Assistant Remote Access.
59
60 This gateway:
61 1. Connects to a signaling server
62 2. Registers with a unique Remote ID
63 3. Handles incoming WebRTC connections from remote PWA clients
64 4. Bridges WebRTC DataChannel messages to the local WebSocket API
65 """
66
67 # Close code 4000 means this connection was replaced by a new one from the same server
68 # In that case, we should not reconnect as another connection is now active
69 CLOSE_CODE_REPLACED = 4000
70
71 # Default ICE servers (public STUN only - used as fallback)
72 DEFAULT_ICE_SERVERS: list[dict[str, Any]] = [
73 {"urls": "stun:stun.home-assistant.io:3478"},
74 {"urls": "stun:stun.l.google.com:19302"},
75 {"urls": "stun:stun1.l.google.com:19302"},
76 {"urls": "stun:stun.cloudflare.com:3478"},
77 ]
78
79 def __init__(
80 self,
81 http_session: aiohttp.ClientSession,
82 remote_id: str,
83 certificate: RTCCertificate,
84 signaling_url: str = "wss://signaling.music-assistant.io/ws",
85 local_ws_url: str = "ws://localhost:8095/ws",
86 sendspin_url: str = "ws://localhost:8927/sendspin",
87 ice_servers: list[dict[str, Any]] | None = None,
88 ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None,
89 set_sendspin_player_callback: Callable[[str, str], None] | None = None,
90 ) -> None:
91 """
92 Initialize the WebRTC Gateway.
93
94 :param http_session: Shared aiohttp ClientSession for HTTP/WebSocket connections.
95 :param remote_id: Remote ID for this server instance.
96 :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning.
97 :param signaling_url: WebSocket URL of the signaling server.
98 :param local_ws_url: Local WebSocket URL to bridge to.
99 :param sendspin_url: Internal Sendspin WebSocket URL to bridge to.
100 :param ice_servers: List of ICE server configurations (used at registration time).
101 :param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session.
102 :param set_sendspin_player_callback: Callback to set sendspin player for a session.
103 """
104 self.http_session = http_session
105 self.signaling_url = signaling_url
106 self.local_ws_url = local_ws_url
107 self.sendspin_url = sendspin_url
108 self._remote_id = remote_id
109 self._certificate = certificate
110 self.logger = LOGGER
111 self._ice_servers_callback = ice_servers_callback
112 self._set_sendspin_player_callback = set_sendspin_player_callback
113
114 # Static ICE servers used at registration time (relayed to clients via signaling server)
115 self.ice_servers = ice_servers or self.DEFAULT_ICE_SERVERS
116
117 self.sessions: dict[str, WebRTCSession] = {}
118 self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None
119 self._running = False
120 self._reconnect_delay = 10 # Wait 10 seconds before reconnecting
121 self._max_reconnect_delay = 300 # Max 5 minutes between reconnects
122 self._current_reconnect_delay = 10
123 self._run_task: asyncio.Task[None] | None = None
124 self._is_connected = False
125 self._connecting = False
126
127 @property
128 def is_running(self) -> bool:
129 """Return whether the gateway is running."""
130 return self._running
131
132 @property
133 def is_connected(self) -> bool:
134 """Return whether the gateway is connected to the signaling server."""
135 return self._is_connected
136
137 async def _get_fresh_ice_servers(self) -> list[dict[str, Any]]:
138 """Get fresh ICE servers for a new WebRTC session.
139
140 If an ice_servers_callback was provided, it will be called to get fresh
141 TURN credentials. Otherwise, returns the static ice_servers.
142
143 :return: List of ICE server configurations with fresh credentials.
144 """
145 if self._ice_servers_callback:
146 try:
147 fresh_servers = await self._ice_servers_callback()
148 if fresh_servers:
149 return fresh_servers
150 except Exception:
151 self.logger.exception("Failed to fetch fresh ICE servers, using cached servers")
152 return self.ice_servers
153
154 async def start(self) -> None:
155 """Start the WebRTC Gateway."""
156 if self._running:
157 self.logger.warning("WebRTC Gateway already running, skipping start")
158 return
159 self.logger.info("Starting WebRTC Gateway")
160 self.logger.debug("Signaling URL: %s", self.signaling_url)
161 self.logger.debug("Local WS URL: %s", self.local_ws_url)
162 self._running = True
163 self._run_task = asyncio.create_task(self._run())
164 self.logger.debug("WebRTC Gateway start task created")
165
166 async def stop(self) -> None:
167 """Stop the WebRTC Gateway."""
168 self.logger.info("Stopping WebRTC Gateway")
169 self._running = False
170
171 # Close all sessions
172 for session_id in list(self.sessions.keys()):
173 await self._close_session(session_id)
174
175 # Close signaling connection gracefully
176 if self._signaling_ws and not self._signaling_ws.closed:
177 try:
178 await self._signaling_ws.close()
179 except Exception:
180 self.logger.debug("Error closing signaling WebSocket", exc_info=True)
181
182 # Cancel run task and wait for it to finish
183 if self._run_task and not self._run_task.done():
184 self._run_task.cancel()
185 with contextlib.suppress(asyncio.CancelledError):
186 await self._run_task
187
188 # Wait briefly for any in-progress connection to notice _running=False
189 if self._connecting:
190 await asyncio.sleep(0.1)
191
192 self._signaling_ws = None
193 self._connecting = False
194
195 async def _run(self) -> None:
196 """Run the main loop with reconnection logic."""
197 self.logger.debug("WebRTC Gateway _run() loop starting")
198 while self._running:
199 should_reconnect = True
200 try:
201 should_reconnect = await self._connect_to_signaling()
202 # Connection closed gracefully or with error
203 self._is_connected = False
204 if self._running and should_reconnect:
205 self.logger.warning(
206 "Signaling server connection lost. Reconnecting in %ss...",
207 self._current_reconnect_delay,
208 )
209 except Exception:
210 self._is_connected = False
211 self.logger.exception("Signaling connection error")
212 if self._running:
213 self.logger.info(
214 "Reconnecting to signaling server in %ss",
215 self._current_reconnect_delay,
216 )
217
218 if self._running and should_reconnect:
219 await asyncio.sleep(self._current_reconnect_delay)
220 # Exponential backoff with max limit
221 self._current_reconnect_delay = min(
222 self._current_reconnect_delay * 2, self._max_reconnect_delay
223 )
224 elif not should_reconnect:
225 # Connection was replaced by another instance, stop the run loop
226 self.logger.info("Connection replaced, stopping reconnection attempts")
227 self._running = False
228 break
229
230 async def _connect_to_signaling(self) -> bool:
231 """Connect to the signaling server.
232
233 :return: True if reconnection should be attempted, False if connection was replaced.
234 """
235 if self._connecting:
236 self.logger.warning("Already connecting to signaling server, skipping")
237 return False # Don't trigger another reconnect cycle
238 self._connecting = True
239 close_code: int | None = None
240 self.logger.info("Connecting to signaling server: %s", self.signaling_url)
241 try:
242 self._signaling_ws = await self.http_session.ws_connect(
243 self.signaling_url,
244 heartbeat=35.0, # Send ping every 35s (slightly above server's 30s interval)
245 )
246 # Check if we were stopped while connecting
247 if not self._running:
248 self.logger.debug("Gateway stopped during connection, closing WebSocket")
249 await self._signaling_ws.close()
250 self._signaling_ws = None
251 self._connecting = False
252 return False
253 self.logger.debug("WebSocket connection established, id=%s", id(self._signaling_ws))
254 self.logger.debug("Sending registration")
255 await self._register()
256 self._current_reconnect_delay = self._reconnect_delay
257 self.logger.debug("Registration sent, waiting for confirmation...")
258
259 # Run message loop and get close code
260 close_code = await self._signaling_message_loop(self._signaling_ws)
261
262 # Get close code from WebSocket if not already set from CLOSE message
263 if close_code is None:
264 close_code = self._signaling_ws.close_code
265 ws_exception = self._signaling_ws.exception()
266 self.logger.debug(
267 "Message loop exited - WebSocket closed: %s, close_code: %s, exception: %s",
268 self._signaling_ws.closed,
269 close_code,
270 ws_exception,
271 )
272 except TimeoutError:
273 self.logger.error("Timeout connecting to signaling server")
274 except aiohttp.ClientError as err:
275 self.logger.error("Failed to connect to signaling server: %s", err)
276 except Exception:
277 self.logger.exception("Unexpected error in signaling connection")
278 finally:
279 self._is_connected = False
280 self._connecting = False
281 self._signaling_ws = None
282
283 # Check if this connection was replaced by another one
284 if close_code == self.CLOSE_CODE_REPLACED:
285 self.logger.info("Connection was replaced by another instance - not reconnecting")
286 return False
287
288 return True
289
290 async def _signaling_message_loop(self, ws: aiohttp.ClientWebSocketResponse) -> int | None:
291 """Process messages from the signaling WebSocket.
292
293 :param ws: The WebSocket connection to process messages from.
294 :return: Close code if connection was closed with a code, None otherwise.
295 """
296 close_code: int | None = None
297 self.logger.debug("Entering message loop")
298 async for msg in ws:
299 if msg.type == aiohttp.WSMsgType.TEXT:
300 try:
301 await self._handle_signaling_message(json.loads(msg.data))
302 except Exception:
303 self.logger.exception("Error handling signaling message")
304 elif msg.type == aiohttp.WSMsgType.PING:
305 self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PING")
306 elif msg.type == aiohttp.WSMsgType.PONG:
307 self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PONG")
308 elif msg.type == aiohttp.WSMsgType.CLOSE:
309 close_code = msg.data
310 self.logger.warning(
311 "Signaling server sent close frame: code=%s, reason=%s",
312 msg.data,
313 msg.extra,
314 )
315 break
316 elif msg.type == aiohttp.WSMsgType.CLOSED:
317 self.logger.warning("Signaling server closed connection")
318 break
319 elif msg.type == aiohttp.WSMsgType.ERROR:
320 self.logger.error("WebSocket error: %s", ws.exception())
321 break
322 else:
323 self.logger.warning("Unexpected WebSocket message type: %s", msg.type)
324 return close_code
325
326 async def _register(self) -> None:
327 """Register with the signaling server."""
328 if self._signaling_ws:
329 await self._signaling_ws.send_json(
330 {
331 "type": "register-server",
332 "remoteId": self._remote_id,
333 "iceServers": self.ice_servers,
334 }
335 )
336
337 async def _handle_signaling_message(self, message: dict[str, Any]) -> None:
338 """Handle incoming signaling messages.
339
340 :param message: The signaling message.
341 """
342 msg_type = message.get("type")
343
344 if msg_type in ("ping", "pong"):
345 # Ignore JSON-level ping/pong messages - we use WebSocket protocol-level heartbeat
346 # The signaling server still sends these for backward compatibility with older clients
347 pass
348 elif msg_type == "registered":
349 self._is_connected = True
350 self.logger.info("Registered with signaling server")
351 elif msg_type == "error":
352 error_msg = message.get("error") or message.get("message", "Unknown error")
353 self.logger.error("Signaling server error: %s", error_msg)
354 elif msg_type == "client-connected":
355 session_id = message.get("sessionId")
356 if session_id:
357 await self._create_session(session_id)
358 # Send session-ready with fresh ICE servers for the client
359 fresh_ice_servers = await self._get_fresh_ice_servers()
360 if self._signaling_ws:
361 await self._signaling_ws.send_json(
362 {
363 "type": "session-ready",
364 "sessionId": session_id,
365 "iceServers": fresh_ice_servers,
366 }
367 )
368 elif msg_type == "client-disconnected":
369 session_id = message.get("sessionId")
370 if session_id:
371 await self._close_session(session_id)
372 elif msg_type == "offer":
373 session_id = message.get("sessionId")
374 offer_data = message.get("data")
375 if session_id and offer_data:
376 await self._handle_offer(session_id, offer_data)
377 elif msg_type == "ice-candidate":
378 session_id = message.get("sessionId")
379 candidate_data = message.get("data")
380 if session_id and candidate_data:
381 await self._handle_ice_candidate(session_id, candidate_data)
382
383 async def _create_session(self, session_id: str) -> None:
384 """Create a new WebRTC session.
385
386 :param session_id: The session ID.
387 """
388 session_ice_servers = await self._get_fresh_ice_servers()
389 config = RTCConfiguration(
390 iceServers=[RTCIceServer(**server) for server in session_ice_servers]
391 )
392 pc = create_peer_connection_with_certificate(self._certificate, configuration=config)
393 session = WebRTCSession(session_id=session_id, peer_connection=pc)
394 self.sessions[session_id] = session
395
396 @pc.on("datachannel")
397 def on_datachannel(channel: Any) -> None:
398 if channel.label == "sendspin":
399 session.sendspin_channel = channel
400 asyncio.create_task(self._setup_sendspin_channel(session))
401 else:
402 session.data_channel = channel
403 asyncio.create_task(self._setup_data_channel(session))
404
405 @pc.on("icecandidate")
406 async def on_icecandidate(candidate: Any) -> None:
407 if candidate and self._signaling_ws:
408 await self._signaling_ws.send_json(
409 {
410 "type": "ice-candidate",
411 "sessionId": session_id,
412 "data": {
413 "candidate": candidate.candidate,
414 "sdpMid": candidate.sdpMid,
415 "sdpMLineIndex": candidate.sdpMLineIndex,
416 },
417 }
418 )
419
420 @pc.on("connectionstatechange")
421 async def on_connectionstatechange() -> None:
422 if pc.connectionState == "failed":
423 await self._close_session(session_id)
424
425 async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None:
426 """Handle incoming WebRTC offer.
427
428 :param session_id: The session ID.
429 :param offer: The offer data.
430 """
431 session = self.sessions.get(session_id)
432 if not session:
433 return
434 pc = session.peer_connection
435
436 if pc.connectionState in ("closed", "failed"):
437 return
438
439 sdp = offer.get("sdp")
440 sdp_type = offer.get("type")
441 if not sdp or not sdp_type:
442 self.logger.error("Invalid offer data: missing sdp or type")
443 return
444
445 try:
446 await pc.setRemoteDescription(
447 RTCSessionDescription(
448 sdp=str(sdp),
449 type=str(sdp_type),
450 )
451 )
452
453 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
454 return
455
456 answer = await pc.createAnswer()
457
458 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
459 return
460
461 await pc.setLocalDescription(answer)
462
463 # Wait for ICE gathering to complete before sending the answer
464 # aiortc doesn't support trickle ICE, candidates are embedded in SDP after gathering
465 gather_timeout = 30
466 gather_start = asyncio.get_event_loop().time()
467 while pc.iceGatheringState != "complete":
468 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
469 return
470 if asyncio.get_event_loop().time() - gather_start > gather_timeout:
471 self.logger.warning("Session %s ICE gathering timeout", session_id)
472 break
473 await asyncio.sleep(0.1)
474
475 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
476 return
477
478 if self._signaling_ws:
479 await self._signaling_ws.send_json(
480 {
481 "type": "answer",
482 "sessionId": session_id,
483 "data": {
484 "sdp": pc.localDescription.sdp,
485 "type": pc.localDescription.type,
486 },
487 }
488 )
489 except Exception:
490 self.logger.exception("Error handling offer for session %s", session_id)
491 # Clean up the session on error
492 await self._close_session(session_id)
493
494 async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None:
495 """Handle incoming ICE candidate.
496
497 :param session_id: The session ID.
498 :param candidate: The ICE candidate data.
499 """
500 session = self.sessions.get(session_id)
501 if not session or not candidate:
502 return
503
504 pc = session.peer_connection
505 if pc.connectionState in ("closed", "failed"):
506 return
507
508 candidate_str = candidate.get("candidate")
509 sdp_mid = candidate.get("sdpMid")
510 sdp_mline_index = candidate.get("sdpMLineIndex")
511
512 if not candidate_str:
513 return
514
515 try:
516 # Parse ICE candidate - browser sends "candidate:..." format
517 if candidate_str.startswith("candidate:"):
518 sdp_candidate_str = candidate_str[len("candidate:") :]
519 else:
520 sdp_candidate_str = candidate_str
521
522 ice_candidate = candidate_from_sdp(sdp_candidate_str)
523 ice_candidate.sdpMid = str(sdp_mid) if sdp_mid else None
524 ice_candidate.sdpMLineIndex = (
525 int(sdp_mline_index) if sdp_mline_index is not None else None
526 )
527
528 if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
529 return
530
531 await session.peer_connection.addIceCandidate(ice_candidate)
532 except Exception:
533 self.logger.exception("Failed to add ICE candidate for session %s", session_id)
534
535 async def _setup_data_channel(self, session: WebRTCSession) -> None:
536 """Set up data channel and bridge to local WebSocket.
537
538 :param session: The WebRTC session.
539 """
540 channel = session.data_channel
541 if not channel:
542 return
543 try:
544 # Include session_id in URL so server can track WebRTC sessions
545 ws_url = f"{self.local_ws_url}?webrtc_session_id={session.session_id}"
546 session.local_ws = await self.http_session.ws_connect(ws_url)
547 loop = asyncio.get_event_loop()
548
549 # Store task references for proper cleanup
550 session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
551 session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session))
552
553 @channel.on("message") # type: ignore[untyped-decorator]
554 def on_message(message: str) -> None:
555 # Called from aiortc thread, use call_soon_threadsafe
556 # Only queue message if session is still active
557 if session.forward_to_local_task and not session.forward_to_local_task.done():
558 loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
559
560 @channel.on("close") # type: ignore[untyped-decorator]
561 def on_close() -> None:
562 # Called from aiortc thread, use call_soon_threadsafe to schedule task
563 asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop)
564
565 except Exception:
566 self.logger.exception("Failed to connect to local WebSocket")
567
568 async def _forward_to_local(self, session: WebRTCSession) -> None:
569 """Forward messages from WebRTC DataChannel to local WebSocket.
570
571 :param session: The WebRTC session.
572 """
573 try:
574 while session.local_ws and not session.local_ws.closed:
575 message = await session.message_queue.get()
576
577 # Check if this is an HTTP proxy request
578 try:
579 msg_data = json.loads(message)
580 if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request":
581 # Handle HTTP proxy request
582 await self._handle_http_proxy_request(session, msg_data)
583 continue
584 except (json.JSONDecodeError, ValueError):
585 pass
586
587 # Regular WebSocket message
588 if session.local_ws and not session.local_ws.closed:
589 await session.local_ws.send_str(message)
590 except asyncio.CancelledError:
591 # Task was cancelled during cleanup, this is expected
592 self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
593 raise
594 except Exception:
595 self.logger.exception("Error forwarding to local WebSocket")
596
597 async def _forward_from_local(self, session: WebRTCSession) -> None:
598 """Forward messages from local WebSocket to WebRTC DataChannel.
599
600 :param session: The WebRTC session.
601 """
602 try:
603 async for msg in session.local_ws:
604 if msg.type == aiohttp.WSMsgType.TEXT:
605 if session.data_channel and session.data_channel.readyState == "open":
606 session.data_channel.send(msg.data)
607 elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
608 break
609 except asyncio.CancelledError:
610 # Task was cancelled during cleanup, this is expected
611 self.logger.debug(
612 "Forward from local task cancelled for session %s", session.session_id
613 )
614 raise
615 except Exception:
616 self.logger.exception("Error forwarding from local WebSocket")
617
618 async def _handle_http_proxy_request(
619 self, session: WebRTCSession, request_data: dict[str, Any]
620 ) -> None:
621 """Handle HTTP proxy request from remote client.
622
623 :param session: The WebRTC session.
624 :param request_data: The HTTP proxy request data.
625 """
626 request_id = request_data.get("id")
627 method = request_data.get("method", "GET")
628 path = request_data.get("path", "/")
629 headers = request_data.get("headers", {})
630
631 # Build local HTTP URL from the WebSocket URL.
632 # Handle both ws:// and wss:// schemes.
633 parsed = urlparse(self.local_ws_url)
634 http_scheme = "https" if parsed.scheme == "wss" else "http"
635 local_http_url = f"{http_scheme}://{parsed.netloc}{path}"
636
637 self.logger.debug("HTTP proxy request: %s %s", method, local_http_url)
638
639 try:
640 # Use shared HTTP session for this request
641 async with self.http_session.request(
642 method, local_http_url, headers=headers
643 ) as response:
644 # Read response body
645 body = await response.read()
646
647 # Prepare response data
648 response_data = {
649 "type": "http-proxy-response",
650 "id": request_id,
651 "status": response.status,
652 "headers": dict(response.headers),
653 "body": body.hex(), # Send as hex string to avoid encoding issues
654 }
655
656 # Send response back through data channel
657 if session.data_channel and session.data_channel.readyState == "open":
658 session.data_channel.send(json.dumps(response_data))
659
660 except Exception as err:
661 self.logger.exception("Error handling HTTP proxy request")
662 # Send error response
663 error_response = {
664 "type": "http-proxy-response",
665 "id": request_id,
666 "status": 500,
667 "headers": {"Content-Type": "text/plain"},
668 "body": str(err).encode().hex(),
669 }
670 if session.data_channel and session.data_channel.readyState == "open":
671 session.data_channel.send(json.dumps(error_response))
672
673 async def _close_session(self, session_id: str) -> None:
674 """Close a WebRTC session.
675
676 :param session_id: The session ID.
677 """
678 session = self.sessions.pop(session_id, None)
679 if not session:
680 return
681
682 # Cancel forwarding tasks first to prevent race conditions
683 if session.forward_to_local_task and not session.forward_to_local_task.done():
684 session.forward_to_local_task.cancel()
685 with contextlib.suppress(asyncio.CancelledError):
686 await session.forward_to_local_task
687
688 if session.forward_from_local_task and not session.forward_from_local_task.done():
689 session.forward_from_local_task.cancel()
690 with contextlib.suppress(asyncio.CancelledError):
691 await session.forward_from_local_task
692
693 # Cancel sendspin forwarding tasks
694 if session.sendspin_to_local_task and not session.sendspin_to_local_task.done():
695 session.sendspin_to_local_task.cancel()
696 with contextlib.suppress(asyncio.CancelledError):
697 await session.sendspin_to_local_task
698
699 if session.sendspin_from_local_task and not session.sendspin_from_local_task.done():
700 session.sendspin_from_local_task.cancel()
701 with contextlib.suppress(asyncio.CancelledError):
702 await session.sendspin_from_local_task
703
704 # Close connections
705 if session.local_ws and not session.local_ws.closed:
706 await session.local_ws.close()
707 if session.sendspin_ws and not session.sendspin_ws.closed:
708 await session.sendspin_ws.close()
709 if session.data_channel:
710 session.data_channel.close()
711 if session.sendspin_channel:
712 session.sendspin_channel.close()
713 await session.peer_connection.close()
714
715 async def _setup_sendspin_channel(self, session: WebRTCSession) -> None:
716 """Set up sendspin data channel and bridge to internal sendspin server.
717
718 :param session: The WebRTC session.
719 """
720 channel = session.sendspin_channel
721 if not channel:
722 return
723
724 try:
725 loop = asyncio.get_event_loop()
726
727 @channel.on("message") # type: ignore[untyped-decorator]
728 def on_message(message: str | bytes) -> None:
729 # Queue if task not yet created (None) or still running.
730 # Only drop when task exists and is done (shutdown).
731 if (
732 session.sendspin_to_local_task is None
733 or not session.sendspin_to_local_task.done()
734 ):
735 loop.call_soon_threadsafe(session.sendspin_queue.put_nowait, message)
736
737 @channel.on("close") # type: ignore[untyped-decorator]
738 def on_close() -> None:
739 if session.sendspin_ws and not session.sendspin_ws.closed:
740 asyncio.run_coroutine_threadsafe(session.sendspin_ws.close(), loop)
741
742 session.sendspin_ws = await self.http_session.ws_connect(self.sendspin_url)
743 self.logger.debug("Sendspin channel connected for session %s", session.session_id)
744
745 # Start forwarding tasks - queued messages will be processed
746 session.sendspin_to_local_task = asyncio.create_task(
747 self._forward_sendspin_to_local(session)
748 )
749 session.sendspin_from_local_task = asyncio.create_task(
750 self._forward_sendspin_from_local(session)
751 )
752
753 except Exception:
754 self.logger.exception(
755 "Failed to connect sendspin channel to internal server for session %s",
756 session.session_id,
757 )
758 # Clean up partial state on failure
759 if session.sendspin_to_local_task:
760 session.sendspin_to_local_task.cancel()
761 if session.sendspin_from_local_task:
762 session.sendspin_from_local_task.cancel()
763 if session.sendspin_ws and not session.sendspin_ws.closed:
764 await session.sendspin_ws.close()
765
766 async def _forward_sendspin_to_local(self, session: WebRTCSession) -> None:
767 """Forward messages from sendspin DataChannel to internal sendspin server.
768
769 :param session: The WebRTC session.
770 """
771 first_message = True
772 try:
773 while session.sendspin_ws and not session.sendspin_ws.closed:
774 message = await session.sendspin_queue.get()
775
776 # Check only the first message for client_id extraction
777 if first_message:
778 first_message = False
779 if isinstance(message, str):
780 self._try_extract_sendspin_client_id(session, message)
781
782 if session.sendspin_ws and not session.sendspin_ws.closed:
783 if isinstance(message, bytes):
784 await session.sendspin_ws.send_bytes(message)
785 else:
786 await session.sendspin_ws.send_str(message)
787 except asyncio.CancelledError:
788 self.logger.debug(
789 "Sendspin forward to local task cancelled for session %s",
790 session.session_id,
791 )
792 raise
793 except Exception:
794 self.logger.exception("Error forwarding sendspin to local")
795
796 def _try_extract_sendspin_client_id(self, session: WebRTCSession, message: str) -> None:
797 """Try to extract client_id from sendspin auth message and set on websocket client.
798
799 :param session: The WebRTC session.
800 :param message: The first sendspin message (expected to be auth).
801 """
802 try:
803 data = json.loads(message)
804 if data.get("type") != "auth":
805 return # Not an auth message
806
807 # This is an auth message - extract client_id if present
808 if client_id := data.get("client_id"):
809 session.sendspin_player_id = client_id
810 self.logger.debug(
811 "Extracted sendspin player %s for session %s",
812 client_id,
813 session.session_id,
814 )
815 # Use callback to set sendspin player on the websocket client
816 if self._set_sendspin_player_callback:
817 self._set_sendspin_player_callback(session.session_id, client_id)
818 except (json.JSONDecodeError, TypeError):
819 pass # Not valid JSON, ignore
820
821 async def _forward_sendspin_from_local(self, session: WebRTCSession) -> None:
822 """Forward messages from internal sendspin server to sendspin DataChannel.
823
824 :param session: The WebRTC session.
825 """
826 if not session.sendspin_ws or session.sendspin_ws.closed:
827 return
828
829 try:
830 async for msg in session.sendspin_ws:
831 if msg.type in {aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY}:
832 if session.sendspin_channel and session.sendspin_channel.readyState == "open":
833 session.sendspin_channel.send(msg.data)
834 elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
835 break
836 except asyncio.CancelledError:
837 self.logger.debug(
838 "Sendspin forward from local task cancelled for session %s",
839 session.session_id,
840 )
841 raise
842 except Exception:
843 self.logger.exception("Error forwarding sendspin from local")
844