music-assistant-server

35.2 KBPY
gateway.py
35.2 KB844 lines • python
1"""Music Assistant WebRTC Gateway.
2
3This module provides WebRTC-based remote access to Music Assistant instances.
4It connects to a signaling server and handles incoming WebRTC connections,
5bridging them to the local WebSocket API.
6"""
7
8from __future__ import annotations
9
10import asyncio
11import contextlib
12import json
13import logging
14from collections.abc import Awaitable, Callable
15from dataclasses import dataclass, field
16from typing import TYPE_CHECKING, Any
17from urllib.parse import urlparse
18
19import aiohttp
20from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
21from aiortc.sdp import candidate_from_sdp
22
23from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
24from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
25
26if TYPE_CHECKING:
27    from aiortc.rtcdtlstransport import RTCCertificate
28
29LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access")
30
31# Reduce verbose logging from aiortc/aioice
32logging.getLogger("aioice").setLevel(logging.WARNING)
33logging.getLogger("aiortc").setLevel(logging.WARNING)
34
35
36@dataclass
37class WebRTCSession:
38    """Represents an active WebRTC session with a remote client."""
39
40    session_id: str
41    peer_connection: RTCPeerConnection
42    # Main API channel (ma-api) - bridges to local MA WebSocket API
43    data_channel: Any = None
44    local_ws: Any = None
45    message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
46    forward_to_local_task: asyncio.Task[None] | None = None
47    forward_from_local_task: asyncio.Task[None] | None = None
48    # Sendspin channel - bridges to internal sendspin server
49    sendspin_channel: Any = None
50    sendspin_ws: Any = None
51    sendspin_queue: asyncio.Queue[str | bytes] = field(default_factory=asyncio.Queue)
52    sendspin_player_id: str | None = None  # Extracted from first sendspin auth message
53    sendspin_to_local_task: asyncio.Task[None] | None = None
54    sendspin_from_local_task: asyncio.Task[None] | None = None
55
56
57class WebRTCGateway:
58    """WebRTC Gateway for Music Assistant Remote Access.
59
60    This gateway:
61    1. Connects to a signaling server
62    2. Registers with a unique Remote ID
63    3. Handles incoming WebRTC connections from remote PWA clients
64    4. Bridges WebRTC DataChannel messages to the local WebSocket API
65    """
66
67    # Close code 4000 means this connection was replaced by a new one from the same server
68    # In that case, we should not reconnect as another connection is now active
69    CLOSE_CODE_REPLACED = 4000
70
71    # Default ICE servers (public STUN only - used as fallback)
72    DEFAULT_ICE_SERVERS: list[dict[str, Any]] = [
73        {"urls": "stun:stun.home-assistant.io:3478"},
74        {"urls": "stun:stun.l.google.com:19302"},
75        {"urls": "stun:stun1.l.google.com:19302"},
76        {"urls": "stun:stun.cloudflare.com:3478"},
77    ]
78
79    def __init__(
80        self,
81        http_session: aiohttp.ClientSession,
82        remote_id: str,
83        certificate: RTCCertificate,
84        signaling_url: str = "wss://signaling.music-assistant.io/ws",
85        local_ws_url: str = "ws://localhost:8095/ws",
86        sendspin_url: str = "ws://localhost:8927/sendspin",
87        ice_servers: list[dict[str, Any]] | None = None,
88        ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None,
89        set_sendspin_player_callback: Callable[[str, str], None] | None = None,
90    ) -> None:
91        """
92        Initialize the WebRTC Gateway.
93
94        :param http_session: Shared aiohttp ClientSession for HTTP/WebSocket connections.
95        :param remote_id: Remote ID for this server instance.
96        :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning.
97        :param signaling_url: WebSocket URL of the signaling server.
98        :param local_ws_url: Local WebSocket URL to bridge to.
99        :param sendspin_url: Internal Sendspin WebSocket URL to bridge to.
100        :param ice_servers: List of ICE server configurations (used at registration time).
101        :param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session.
102        :param set_sendspin_player_callback: Callback to set sendspin player for a session.
103        """
104        self.http_session = http_session
105        self.signaling_url = signaling_url
106        self.local_ws_url = local_ws_url
107        self.sendspin_url = sendspin_url
108        self._remote_id = remote_id
109        self._certificate = certificate
110        self.logger = LOGGER
111        self._ice_servers_callback = ice_servers_callback
112        self._set_sendspin_player_callback = set_sendspin_player_callback
113
114        # Static ICE servers used at registration time (relayed to clients via signaling server)
115        self.ice_servers = ice_servers or self.DEFAULT_ICE_SERVERS
116
117        self.sessions: dict[str, WebRTCSession] = {}
118        self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None
119        self._running = False
120        self._reconnect_delay = 10  # Wait 10 seconds before reconnecting
121        self._max_reconnect_delay = 300  # Max 5 minutes between reconnects
122        self._current_reconnect_delay = 10
123        self._run_task: asyncio.Task[None] | None = None
124        self._is_connected = False
125        self._connecting = False
126
127    @property
128    def is_running(self) -> bool:
129        """Return whether the gateway is running."""
130        return self._running
131
132    @property
133    def is_connected(self) -> bool:
134        """Return whether the gateway is connected to the signaling server."""
135        return self._is_connected
136
137    async def _get_fresh_ice_servers(self) -> list[dict[str, Any]]:
138        """Get fresh ICE servers for a new WebRTC session.
139
140        If an ice_servers_callback was provided, it will be called to get fresh
141        TURN credentials. Otherwise, returns the static ice_servers.
142
143        :return: List of ICE server configurations with fresh credentials.
144        """
145        if self._ice_servers_callback:
146            try:
147                fresh_servers = await self._ice_servers_callback()
148                if fresh_servers:
149                    return fresh_servers
150            except Exception:
151                self.logger.exception("Failed to fetch fresh ICE servers, using cached servers")
152        return self.ice_servers
153
154    async def start(self) -> None:
155        """Start the WebRTC Gateway."""
156        if self._running:
157            self.logger.warning("WebRTC Gateway already running, skipping start")
158            return
159        self.logger.info("Starting WebRTC Gateway")
160        self.logger.debug("Signaling URL: %s", self.signaling_url)
161        self.logger.debug("Local WS URL: %s", self.local_ws_url)
162        self._running = True
163        self._run_task = asyncio.create_task(self._run())
164        self.logger.debug("WebRTC Gateway start task created")
165
166    async def stop(self) -> None:
167        """Stop the WebRTC Gateway."""
168        self.logger.info("Stopping WebRTC Gateway")
169        self._running = False
170
171        # Close all sessions
172        for session_id in list(self.sessions.keys()):
173            await self._close_session(session_id)
174
175        # Close signaling connection gracefully
176        if self._signaling_ws and not self._signaling_ws.closed:
177            try:
178                await self._signaling_ws.close()
179            except Exception:
180                self.logger.debug("Error closing signaling WebSocket", exc_info=True)
181
182        # Cancel run task and wait for it to finish
183        if self._run_task and not self._run_task.done():
184            self._run_task.cancel()
185            with contextlib.suppress(asyncio.CancelledError):
186                await self._run_task
187
188        # Wait briefly for any in-progress connection to notice _running=False
189        if self._connecting:
190            await asyncio.sleep(0.1)
191
192        self._signaling_ws = None
193        self._connecting = False
194
195    async def _run(self) -> None:
196        """Run the main loop with reconnection logic."""
197        self.logger.debug("WebRTC Gateway _run() loop starting")
198        while self._running:
199            should_reconnect = True
200            try:
201                should_reconnect = await self._connect_to_signaling()
202                # Connection closed gracefully or with error
203                self._is_connected = False
204                if self._running and should_reconnect:
205                    self.logger.warning(
206                        "Signaling server connection lost. Reconnecting in %ss...",
207                        self._current_reconnect_delay,
208                    )
209            except Exception:
210                self._is_connected = False
211                self.logger.exception("Signaling connection error")
212                if self._running:
213                    self.logger.info(
214                        "Reconnecting to signaling server in %ss",
215                        self._current_reconnect_delay,
216                    )
217
218            if self._running and should_reconnect:
219                await asyncio.sleep(self._current_reconnect_delay)
220                # Exponential backoff with max limit
221                self._current_reconnect_delay = min(
222                    self._current_reconnect_delay * 2, self._max_reconnect_delay
223                )
224            elif not should_reconnect:
225                # Connection was replaced by another instance, stop the run loop
226                self.logger.info("Connection replaced, stopping reconnection attempts")
227                self._running = False
228                break
229
230    async def _connect_to_signaling(self) -> bool:
231        """Connect to the signaling server.
232
233        :return: True if reconnection should be attempted, False if connection was replaced.
234        """
235        if self._connecting:
236            self.logger.warning("Already connecting to signaling server, skipping")
237            return False  # Don't trigger another reconnect cycle
238        self._connecting = True
239        close_code: int | None = None
240        self.logger.info("Connecting to signaling server: %s", self.signaling_url)
241        try:
242            self._signaling_ws = await self.http_session.ws_connect(
243                self.signaling_url,
244                heartbeat=35.0,  # Send ping every 35s (slightly above server's 30s interval)
245            )
246            # Check if we were stopped while connecting
247            if not self._running:
248                self.logger.debug("Gateway stopped during connection, closing WebSocket")
249                await self._signaling_ws.close()
250                self._signaling_ws = None
251                self._connecting = False
252                return False
253            self.logger.debug("WebSocket connection established, id=%s", id(self._signaling_ws))
254            self.logger.debug("Sending registration")
255            await self._register()
256            self._current_reconnect_delay = self._reconnect_delay
257            self.logger.debug("Registration sent, waiting for confirmation...")
258
259            # Run message loop and get close code
260            close_code = await self._signaling_message_loop(self._signaling_ws)
261
262            # Get close code from WebSocket if not already set from CLOSE message
263            if close_code is None:
264                close_code = self._signaling_ws.close_code
265            ws_exception = self._signaling_ws.exception()
266            self.logger.debug(
267                "Message loop exited - WebSocket closed: %s, close_code: %s, exception: %s",
268                self._signaling_ws.closed,
269                close_code,
270                ws_exception,
271            )
272        except TimeoutError:
273            self.logger.error("Timeout connecting to signaling server")
274        except aiohttp.ClientError as err:
275            self.logger.error("Failed to connect to signaling server: %s", err)
276        except Exception:
277            self.logger.exception("Unexpected error in signaling connection")
278        finally:
279            self._is_connected = False
280            self._connecting = False
281            self._signaling_ws = None
282
283        # Check if this connection was replaced by another one
284        if close_code == self.CLOSE_CODE_REPLACED:
285            self.logger.info("Connection was replaced by another instance - not reconnecting")
286            return False
287
288        return True
289
290    async def _signaling_message_loop(self, ws: aiohttp.ClientWebSocketResponse) -> int | None:
291        """Process messages from the signaling WebSocket.
292
293        :param ws: The WebSocket connection to process messages from.
294        :return: Close code if connection was closed with a code, None otherwise.
295        """
296        close_code: int | None = None
297        self.logger.debug("Entering message loop")
298        async for msg in ws:
299            if msg.type == aiohttp.WSMsgType.TEXT:
300                try:
301                    await self._handle_signaling_message(json.loads(msg.data))
302                except Exception:
303                    self.logger.exception("Error handling signaling message")
304            elif msg.type == aiohttp.WSMsgType.PING:
305                self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PING")
306            elif msg.type == aiohttp.WSMsgType.PONG:
307                self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PONG")
308            elif msg.type == aiohttp.WSMsgType.CLOSE:
309                close_code = msg.data
310                self.logger.warning(
311                    "Signaling server sent close frame: code=%s, reason=%s",
312                    msg.data,
313                    msg.extra,
314                )
315                break
316            elif msg.type == aiohttp.WSMsgType.CLOSED:
317                self.logger.warning("Signaling server closed connection")
318                break
319            elif msg.type == aiohttp.WSMsgType.ERROR:
320                self.logger.error("WebSocket error: %s", ws.exception())
321                break
322            else:
323                self.logger.warning("Unexpected WebSocket message type: %s", msg.type)
324        return close_code
325
326    async def _register(self) -> None:
327        """Register with the signaling server."""
328        if self._signaling_ws:
329            await self._signaling_ws.send_json(
330                {
331                    "type": "register-server",
332                    "remoteId": self._remote_id,
333                    "iceServers": self.ice_servers,
334                }
335            )
336
337    async def _handle_signaling_message(self, message: dict[str, Any]) -> None:
338        """Handle incoming signaling messages.
339
340        :param message: The signaling message.
341        """
342        msg_type = message.get("type")
343
344        if msg_type in ("ping", "pong"):
345            # Ignore JSON-level ping/pong messages - we use WebSocket protocol-level heartbeat
346            # The signaling server still sends these for backward compatibility with older clients
347            pass
348        elif msg_type == "registered":
349            self._is_connected = True
350            self.logger.info("Registered with signaling server")
351        elif msg_type == "error":
352            error_msg = message.get("error") or message.get("message", "Unknown error")
353            self.logger.error("Signaling server error: %s", error_msg)
354        elif msg_type == "client-connected":
355            session_id = message.get("sessionId")
356            if session_id:
357                await self._create_session(session_id)
358                # Send session-ready with fresh ICE servers for the client
359                fresh_ice_servers = await self._get_fresh_ice_servers()
360                if self._signaling_ws:
361                    await self._signaling_ws.send_json(
362                        {
363                            "type": "session-ready",
364                            "sessionId": session_id,
365                            "iceServers": fresh_ice_servers,
366                        }
367                    )
368        elif msg_type == "client-disconnected":
369            session_id = message.get("sessionId")
370            if session_id:
371                await self._close_session(session_id)
372        elif msg_type == "offer":
373            session_id = message.get("sessionId")
374            offer_data = message.get("data")
375            if session_id and offer_data:
376                await self._handle_offer(session_id, offer_data)
377        elif msg_type == "ice-candidate":
378            session_id = message.get("sessionId")
379            candidate_data = message.get("data")
380            if session_id and candidate_data:
381                await self._handle_ice_candidate(session_id, candidate_data)
382
383    async def _create_session(self, session_id: str) -> None:
384        """Create a new WebRTC session.
385
386        :param session_id: The session ID.
387        """
388        session_ice_servers = await self._get_fresh_ice_servers()
389        config = RTCConfiguration(
390            iceServers=[RTCIceServer(**server) for server in session_ice_servers]
391        )
392        pc = create_peer_connection_with_certificate(self._certificate, configuration=config)
393        session = WebRTCSession(session_id=session_id, peer_connection=pc)
394        self.sessions[session_id] = session
395
396        @pc.on("datachannel")
397        def on_datachannel(channel: Any) -> None:
398            if channel.label == "sendspin":
399                session.sendspin_channel = channel
400                asyncio.create_task(self._setup_sendspin_channel(session))
401            else:
402                session.data_channel = channel
403                asyncio.create_task(self._setup_data_channel(session))
404
405        @pc.on("icecandidate")
406        async def on_icecandidate(candidate: Any) -> None:
407            if candidate and self._signaling_ws:
408                await self._signaling_ws.send_json(
409                    {
410                        "type": "ice-candidate",
411                        "sessionId": session_id,
412                        "data": {
413                            "candidate": candidate.candidate,
414                            "sdpMid": candidate.sdpMid,
415                            "sdpMLineIndex": candidate.sdpMLineIndex,
416                        },
417                    }
418                )
419
420        @pc.on("connectionstatechange")
421        async def on_connectionstatechange() -> None:
422            if pc.connectionState == "failed":
423                await self._close_session(session_id)
424
425    async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None:
426        """Handle incoming WebRTC offer.
427
428        :param session_id: The session ID.
429        :param offer: The offer data.
430        """
431        session = self.sessions.get(session_id)
432        if not session:
433            return
434        pc = session.peer_connection
435
436        if pc.connectionState in ("closed", "failed"):
437            return
438
439        sdp = offer.get("sdp")
440        sdp_type = offer.get("type")
441        if not sdp or not sdp_type:
442            self.logger.error("Invalid offer data: missing sdp or type")
443            return
444
445        try:
446            await pc.setRemoteDescription(
447                RTCSessionDescription(
448                    sdp=str(sdp),
449                    type=str(sdp_type),
450                )
451            )
452
453            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
454                return
455
456            answer = await pc.createAnswer()
457
458            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
459                return
460
461            await pc.setLocalDescription(answer)
462
463            # Wait for ICE gathering to complete before sending the answer
464            # aiortc doesn't support trickle ICE, candidates are embedded in SDP after gathering
465            gather_timeout = 30
466            gather_start = asyncio.get_event_loop().time()
467            while pc.iceGatheringState != "complete":
468                if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
469                    return
470                if asyncio.get_event_loop().time() - gather_start > gather_timeout:
471                    self.logger.warning("Session %s ICE gathering timeout", session_id)
472                    break
473                await asyncio.sleep(0.1)
474
475            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
476                return
477
478            if self._signaling_ws:
479                await self._signaling_ws.send_json(
480                    {
481                        "type": "answer",
482                        "sessionId": session_id,
483                        "data": {
484                            "sdp": pc.localDescription.sdp,
485                            "type": pc.localDescription.type,
486                        },
487                    }
488                )
489        except Exception:
490            self.logger.exception("Error handling offer for session %s", session_id)
491            # Clean up the session on error
492            await self._close_session(session_id)
493
494    async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None:
495        """Handle incoming ICE candidate.
496
497        :param session_id: The session ID.
498        :param candidate: The ICE candidate data.
499        """
500        session = self.sessions.get(session_id)
501        if not session or not candidate:
502            return
503
504        pc = session.peer_connection
505        if pc.connectionState in ("closed", "failed"):
506            return
507
508        candidate_str = candidate.get("candidate")
509        sdp_mid = candidate.get("sdpMid")
510        sdp_mline_index = candidate.get("sdpMLineIndex")
511
512        if not candidate_str:
513            return
514
515        try:
516            # Parse ICE candidate - browser sends "candidate:..." format
517            if candidate_str.startswith("candidate:"):
518                sdp_candidate_str = candidate_str[len("candidate:") :]
519            else:
520                sdp_candidate_str = candidate_str
521
522            ice_candidate = candidate_from_sdp(sdp_candidate_str)
523            ice_candidate.sdpMid = str(sdp_mid) if sdp_mid else None
524            ice_candidate.sdpMLineIndex = (
525                int(sdp_mline_index) if sdp_mline_index is not None else None
526            )
527
528            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
529                return
530
531            await session.peer_connection.addIceCandidate(ice_candidate)
532        except Exception:
533            self.logger.exception("Failed to add ICE candidate for session %s", session_id)
534
535    async def _setup_data_channel(self, session: WebRTCSession) -> None:
536        """Set up data channel and bridge to local WebSocket.
537
538        :param session: The WebRTC session.
539        """
540        channel = session.data_channel
541        if not channel:
542            return
543        try:
544            # Include session_id in URL so server can track WebRTC sessions
545            ws_url = f"{self.local_ws_url}?webrtc_session_id={session.session_id}"
546            session.local_ws = await self.http_session.ws_connect(ws_url)
547            loop = asyncio.get_event_loop()
548
549            # Store task references for proper cleanup
550            session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
551            session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session))
552
553            @channel.on("message")  # type: ignore[untyped-decorator]
554            def on_message(message: str) -> None:
555                # Called from aiortc thread, use call_soon_threadsafe
556                # Only queue message if session is still active
557                if session.forward_to_local_task and not session.forward_to_local_task.done():
558                    loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
559
560            @channel.on("close")  # type: ignore[untyped-decorator]
561            def on_close() -> None:
562                # Called from aiortc thread, use call_soon_threadsafe to schedule task
563                asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop)
564
565        except Exception:
566            self.logger.exception("Failed to connect to local WebSocket")
567
568    async def _forward_to_local(self, session: WebRTCSession) -> None:
569        """Forward messages from WebRTC DataChannel to local WebSocket.
570
571        :param session: The WebRTC session.
572        """
573        try:
574            while session.local_ws and not session.local_ws.closed:
575                message = await session.message_queue.get()
576
577                # Check if this is an HTTP proxy request
578                try:
579                    msg_data = json.loads(message)
580                    if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request":
581                        # Handle HTTP proxy request
582                        await self._handle_http_proxy_request(session, msg_data)
583                        continue
584                except (json.JSONDecodeError, ValueError):
585                    pass
586
587                # Regular WebSocket message
588                if session.local_ws and not session.local_ws.closed:
589                    await session.local_ws.send_str(message)
590        except asyncio.CancelledError:
591            # Task was cancelled during cleanup, this is expected
592            self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
593            raise
594        except Exception:
595            self.logger.exception("Error forwarding to local WebSocket")
596
597    async def _forward_from_local(self, session: WebRTCSession) -> None:
598        """Forward messages from local WebSocket to WebRTC DataChannel.
599
600        :param session: The WebRTC session.
601        """
602        try:
603            async for msg in session.local_ws:
604                if msg.type == aiohttp.WSMsgType.TEXT:
605                    if session.data_channel and session.data_channel.readyState == "open":
606                        session.data_channel.send(msg.data)
607                elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
608                    break
609        except asyncio.CancelledError:
610            # Task was cancelled during cleanup, this is expected
611            self.logger.debug(
612                "Forward from local task cancelled for session %s", session.session_id
613            )
614            raise
615        except Exception:
616            self.logger.exception("Error forwarding from local WebSocket")
617
618    async def _handle_http_proxy_request(
619        self, session: WebRTCSession, request_data: dict[str, Any]
620    ) -> None:
621        """Handle HTTP proxy request from remote client.
622
623        :param session: The WebRTC session.
624        :param request_data: The HTTP proxy request data.
625        """
626        request_id = request_data.get("id")
627        method = request_data.get("method", "GET")
628        path = request_data.get("path", "/")
629        headers = request_data.get("headers", {})
630
631        # Build local HTTP URL from the WebSocket URL.
632        # Handle both ws:// and wss:// schemes.
633        parsed = urlparse(self.local_ws_url)
634        http_scheme = "https" if parsed.scheme == "wss" else "http"
635        local_http_url = f"{http_scheme}://{parsed.netloc}{path}"
636
637        self.logger.debug("HTTP proxy request: %s %s", method, local_http_url)
638
639        try:
640            # Use shared HTTP session for this request
641            async with self.http_session.request(
642                method, local_http_url, headers=headers
643            ) as response:
644                # Read response body
645                body = await response.read()
646
647                # Prepare response data
648                response_data = {
649                    "type": "http-proxy-response",
650                    "id": request_id,
651                    "status": response.status,
652                    "headers": dict(response.headers),
653                    "body": body.hex(),  # Send as hex string to avoid encoding issues
654                }
655
656                # Send response back through data channel
657                if session.data_channel and session.data_channel.readyState == "open":
658                    session.data_channel.send(json.dumps(response_data))
659
660        except Exception as err:
661            self.logger.exception("Error handling HTTP proxy request")
662            # Send error response
663            error_response = {
664                "type": "http-proxy-response",
665                "id": request_id,
666                "status": 500,
667                "headers": {"Content-Type": "text/plain"},
668                "body": str(err).encode().hex(),
669            }
670            if session.data_channel and session.data_channel.readyState == "open":
671                session.data_channel.send(json.dumps(error_response))
672
673    async def _close_session(self, session_id: str) -> None:
674        """Close a WebRTC session.
675
676        :param session_id: The session ID.
677        """
678        session = self.sessions.pop(session_id, None)
679        if not session:
680            return
681
682        # Cancel forwarding tasks first to prevent race conditions
683        if session.forward_to_local_task and not session.forward_to_local_task.done():
684            session.forward_to_local_task.cancel()
685            with contextlib.suppress(asyncio.CancelledError):
686                await session.forward_to_local_task
687
688        if session.forward_from_local_task and not session.forward_from_local_task.done():
689            session.forward_from_local_task.cancel()
690            with contextlib.suppress(asyncio.CancelledError):
691                await session.forward_from_local_task
692
693        # Cancel sendspin forwarding tasks
694        if session.sendspin_to_local_task and not session.sendspin_to_local_task.done():
695            session.sendspin_to_local_task.cancel()
696            with contextlib.suppress(asyncio.CancelledError):
697                await session.sendspin_to_local_task
698
699        if session.sendspin_from_local_task and not session.sendspin_from_local_task.done():
700            session.sendspin_from_local_task.cancel()
701            with contextlib.suppress(asyncio.CancelledError):
702                await session.sendspin_from_local_task
703
704        # Close connections
705        if session.local_ws and not session.local_ws.closed:
706            await session.local_ws.close()
707        if session.sendspin_ws and not session.sendspin_ws.closed:
708            await session.sendspin_ws.close()
709        if session.data_channel:
710            session.data_channel.close()
711        if session.sendspin_channel:
712            session.sendspin_channel.close()
713        await session.peer_connection.close()
714
715    async def _setup_sendspin_channel(self, session: WebRTCSession) -> None:
716        """Set up sendspin data channel and bridge to internal sendspin server.
717
718        :param session: The WebRTC session.
719        """
720        channel = session.sendspin_channel
721        if not channel:
722            return
723
724        try:
725            loop = asyncio.get_event_loop()
726
727            @channel.on("message")  # type: ignore[untyped-decorator]
728            def on_message(message: str | bytes) -> None:
729                # Queue if task not yet created (None) or still running.
730                # Only drop when task exists and is done (shutdown).
731                if (
732                    session.sendspin_to_local_task is None
733                    or not session.sendspin_to_local_task.done()
734                ):
735                    loop.call_soon_threadsafe(session.sendspin_queue.put_nowait, message)
736
737            @channel.on("close")  # type: ignore[untyped-decorator]
738            def on_close() -> None:
739                if session.sendspin_ws and not session.sendspin_ws.closed:
740                    asyncio.run_coroutine_threadsafe(session.sendspin_ws.close(), loop)
741
742            session.sendspin_ws = await self.http_session.ws_connect(self.sendspin_url)
743            self.logger.debug("Sendspin channel connected for session %s", session.session_id)
744
745            # Start forwarding tasks - queued messages will be processed
746            session.sendspin_to_local_task = asyncio.create_task(
747                self._forward_sendspin_to_local(session)
748            )
749            session.sendspin_from_local_task = asyncio.create_task(
750                self._forward_sendspin_from_local(session)
751            )
752
753        except Exception:
754            self.logger.exception(
755                "Failed to connect sendspin channel to internal server for session %s",
756                session.session_id,
757            )
758            # Clean up partial state on failure
759            if session.sendspin_to_local_task:
760                session.sendspin_to_local_task.cancel()
761            if session.sendspin_from_local_task:
762                session.sendspin_from_local_task.cancel()
763            if session.sendspin_ws and not session.sendspin_ws.closed:
764                await session.sendspin_ws.close()
765
766    async def _forward_sendspin_to_local(self, session: WebRTCSession) -> None:
767        """Forward messages from sendspin DataChannel to internal sendspin server.
768
769        :param session: The WebRTC session.
770        """
771        first_message = True
772        try:
773            while session.sendspin_ws and not session.sendspin_ws.closed:
774                message = await session.sendspin_queue.get()
775
776                # Check only the first message for client_id extraction
777                if first_message:
778                    first_message = False
779                    if isinstance(message, str):
780                        self._try_extract_sendspin_client_id(session, message)
781
782                if session.sendspin_ws and not session.sendspin_ws.closed:
783                    if isinstance(message, bytes):
784                        await session.sendspin_ws.send_bytes(message)
785                    else:
786                        await session.sendspin_ws.send_str(message)
787        except asyncio.CancelledError:
788            self.logger.debug(
789                "Sendspin forward to local task cancelled for session %s",
790                session.session_id,
791            )
792            raise
793        except Exception:
794            self.logger.exception("Error forwarding sendspin to local")
795
796    def _try_extract_sendspin_client_id(self, session: WebRTCSession, message: str) -> None:
797        """Try to extract client_id from sendspin auth message and set on websocket client.
798
799        :param session: The WebRTC session.
800        :param message: The first sendspin message (expected to be auth).
801        """
802        try:
803            data = json.loads(message)
804            if data.get("type") != "auth":
805                return  # Not an auth message
806
807            # This is an auth message - extract client_id if present
808            if client_id := data.get("client_id"):
809                session.sendspin_player_id = client_id
810                self.logger.debug(
811                    "Extracted sendspin player %s for session %s",
812                    client_id,
813                    session.session_id,
814                )
815                # Use callback to set sendspin player on the websocket client
816                if self._set_sendspin_player_callback:
817                    self._set_sendspin_player_callback(session.session_id, client_id)
818        except (json.JSONDecodeError, TypeError):
819            pass  # Not valid JSON, ignore
820
821    async def _forward_sendspin_from_local(self, session: WebRTCSession) -> None:
822        """Forward messages from internal sendspin server to sendspin DataChannel.
823
824        :param session: The WebRTC session.
825        """
826        if not session.sendspin_ws or session.sendspin_ws.closed:
827            return
828
829        try:
830            async for msg in session.sendspin_ws:
831                if msg.type in {aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY}:
832                    if session.sendspin_channel and session.sendspin_channel.readyState == "open":
833                        session.sendspin_channel.send(msg.data)
834                elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
835                    break
836        except asyncio.CancelledError:
837            self.logger.debug(
838                "Sendspin forward from local task cancelled for session %s",
839                session.session_id,
840            )
841            raise
842        except Exception:
843            self.logger.exception("Error forwarding sendspin from local")
844