music-assistant-server

35.2 KBPY
gateway.py
35.2 KB843 lines • python
1"""Music Assistant WebRTC Gateway.
2
3This module provides WebRTC-based remote access to Music Assistant instances.
4It connects to a signaling server and handles incoming WebRTC connections,
5bridging them to the local WebSocket API.
6"""
7
8from __future__ import annotations
9
10import asyncio
11import contextlib
12import json
13import logging
14from collections.abc import Awaitable, Callable
15from dataclasses import dataclass, field
16from typing import TYPE_CHECKING, Any
17
18import aiohttp
19from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection, RTCSessionDescription
20from aiortc.sdp import candidate_from_sdp
21
22from music_assistant.constants import MASS_LOGGER_NAME, VERBOSE_LOG_LEVEL
23from music_assistant.helpers.webrtc_certificate import create_peer_connection_with_certificate
24
25if TYPE_CHECKING:
26    from aiortc.rtcdtlstransport import RTCCertificate
27
28LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.remote_access")
29
30# Reduce verbose logging from aiortc/aioice
31logging.getLogger("aioice").setLevel(logging.WARNING)
32logging.getLogger("aiortc").setLevel(logging.WARNING)
33
34
35@dataclass
36class WebRTCSession:
37    """Represents an active WebRTC session with a remote client."""
38
39    session_id: str
40    peer_connection: RTCPeerConnection
41    # Main API channel (ma-api) - bridges to local MA WebSocket API
42    data_channel: Any = None
43    local_ws: Any = None
44    message_queue: asyncio.Queue[str] = field(default_factory=asyncio.Queue)
45    forward_to_local_task: asyncio.Task[None] | None = None
46    forward_from_local_task: asyncio.Task[None] | None = None
47    # Sendspin channel - bridges to internal sendspin server
48    sendspin_channel: Any = None
49    sendspin_ws: Any = None
50    sendspin_queue: asyncio.Queue[str | bytes] = field(default_factory=asyncio.Queue)
51    sendspin_player_id: str | None = None  # Extracted from first sendspin auth message
52    sendspin_to_local_task: asyncio.Task[None] | None = None
53    sendspin_from_local_task: asyncio.Task[None] | None = None
54
55
56class WebRTCGateway:
57    """WebRTC Gateway for Music Assistant Remote Access.
58
59    This gateway:
60    1. Connects to a signaling server
61    2. Registers with a unique Remote ID
62    3. Handles incoming WebRTC connections from remote PWA clients
63    4. Bridges WebRTC DataChannel messages to the local WebSocket API
64    """
65
66    # Close code 4000 means this connection was replaced by a new one from the same server
67    # In that case, we should not reconnect as another connection is now active
68    CLOSE_CODE_REPLACED = 4000
69
70    # Default ICE servers (public STUN only - used as fallback)
71    DEFAULT_ICE_SERVERS: list[dict[str, Any]] = [
72        {"urls": "stun:stun.home-assistant.io:3478"},
73        {"urls": "stun:stun.l.google.com:19302"},
74        {"urls": "stun:stun1.l.google.com:19302"},
75        {"urls": "stun:stun.cloudflare.com:3478"},
76    ]
77
78    def __init__(
79        self,
80        http_session: aiohttp.ClientSession,
81        remote_id: str,
82        certificate: RTCCertificate,
83        signaling_url: str = "wss://signaling.music-assistant.io/ws",
84        local_ws_url: str = "ws://localhost:8095/ws",
85        sendspin_url: str = "ws://localhost:8927/sendspin",
86        ice_servers: list[dict[str, Any]] | None = None,
87        ice_servers_callback: Callable[[], Awaitable[list[dict[str, Any]]]] | None = None,
88        set_sendspin_player_callback: Callable[[str, str], None] | None = None,
89    ) -> None:
90        """
91        Initialize the WebRTC Gateway.
92
93        :param http_session: Shared aiohttp ClientSession for HTTP/WebSocket connections.
94        :param remote_id: Remote ID for this server instance.
95        :param certificate: Persistent RTCCertificate for DTLS, enabling client-side pinning.
96        :param signaling_url: WebSocket URL of the signaling server.
97        :param local_ws_url: Local WebSocket URL to bridge to.
98        :param sendspin_url: Internal Sendspin WebSocket URL to bridge to.
99        :param ice_servers: List of ICE server configurations (used at registration time).
100        :param ice_servers_callback: Optional callback to fetch fresh ICE servers for each session.
101        :param set_sendspin_player_callback: Callback to set sendspin player for a session.
102        """
103        self.http_session = http_session
104        self.signaling_url = signaling_url
105        self.local_ws_url = local_ws_url
106        self.sendspin_url = sendspin_url
107        self._remote_id = remote_id
108        self._certificate = certificate
109        self.logger = LOGGER
110        self._ice_servers_callback = ice_servers_callback
111        self._set_sendspin_player_callback = set_sendspin_player_callback
112
113        # Static ICE servers used at registration time (relayed to clients via signaling server)
114        self.ice_servers = ice_servers or self.DEFAULT_ICE_SERVERS
115
116        self.sessions: dict[str, WebRTCSession] = {}
117        self._signaling_ws: aiohttp.ClientWebSocketResponse | None = None
118        self._running = False
119        self._reconnect_delay = 10  # Wait 10 seconds before reconnecting
120        self._max_reconnect_delay = 300  # Max 5 minutes between reconnects
121        self._current_reconnect_delay = 10
122        self._run_task: asyncio.Task[None] | None = None
123        self._is_connected = False
124        self._connecting = False
125
126    @property
127    def is_running(self) -> bool:
128        """Return whether the gateway is running."""
129        return self._running
130
131    @property
132    def is_connected(self) -> bool:
133        """Return whether the gateway is connected to the signaling server."""
134        return self._is_connected
135
136    async def _get_fresh_ice_servers(self) -> list[dict[str, Any]]:
137        """Get fresh ICE servers for a new WebRTC session.
138
139        If an ice_servers_callback was provided, it will be called to get fresh
140        TURN credentials. Otherwise, returns the static ice_servers.
141
142        :return: List of ICE server configurations with fresh credentials.
143        """
144        if self._ice_servers_callback:
145            try:
146                fresh_servers = await self._ice_servers_callback()
147                if fresh_servers:
148                    return fresh_servers
149            except Exception:
150                self.logger.exception("Failed to fetch fresh ICE servers, using cached servers")
151        return self.ice_servers
152
153    async def start(self) -> None:
154        """Start the WebRTC Gateway."""
155        if self._running:
156            self.logger.warning("WebRTC Gateway already running, skipping start")
157            return
158        self.logger.info("Starting WebRTC Gateway")
159        self.logger.debug("Signaling URL: %s", self.signaling_url)
160        self.logger.debug("Local WS URL: %s", self.local_ws_url)
161        self._running = True
162        self._run_task = asyncio.create_task(self._run())
163        self.logger.debug("WebRTC Gateway start task created")
164
165    async def stop(self) -> None:
166        """Stop the WebRTC Gateway."""
167        self.logger.info("Stopping WebRTC Gateway")
168        self._running = False
169
170        # Close all sessions
171        for session_id in list(self.sessions.keys()):
172            await self._close_session(session_id)
173
174        # Close signaling connection gracefully
175        if self._signaling_ws and not self._signaling_ws.closed:
176            try:
177                await self._signaling_ws.close()
178            except Exception:
179                self.logger.debug("Error closing signaling WebSocket", exc_info=True)
180
181        # Cancel run task and wait for it to finish
182        if self._run_task and not self._run_task.done():
183            self._run_task.cancel()
184            with contextlib.suppress(asyncio.CancelledError):
185                await self._run_task
186
187        # Wait briefly for any in-progress connection to notice _running=False
188        if self._connecting:
189            await asyncio.sleep(0.1)
190
191        self._signaling_ws = None
192        self._connecting = False
193
194    async def _run(self) -> None:
195        """Run the main loop with reconnection logic."""
196        self.logger.debug("WebRTC Gateway _run() loop starting")
197        while self._running:
198            should_reconnect = True
199            try:
200                should_reconnect = await self._connect_to_signaling()
201                # Connection closed gracefully or with error
202                self._is_connected = False
203                if self._running and should_reconnect:
204                    self.logger.warning(
205                        "Signaling server connection lost. Reconnecting in %ss...",
206                        self._current_reconnect_delay,
207                    )
208            except Exception:
209                self._is_connected = False
210                self.logger.exception("Signaling connection error")
211                if self._running:
212                    self.logger.info(
213                        "Reconnecting to signaling server in %ss",
214                        self._current_reconnect_delay,
215                    )
216
217            if self._running and should_reconnect:
218                await asyncio.sleep(self._current_reconnect_delay)
219                # Exponential backoff with max limit
220                self._current_reconnect_delay = min(
221                    self._current_reconnect_delay * 2, self._max_reconnect_delay
222                )
223            elif not should_reconnect:
224                # Connection was replaced by another instance, stop the run loop
225                self.logger.info("Connection replaced, stopping reconnection attempts")
226                self._running = False
227                break
228
229    async def _connect_to_signaling(self) -> bool:
230        """Connect to the signaling server.
231
232        :return: True if reconnection should be attempted, False if connection was replaced.
233        """
234        if self._connecting:
235            self.logger.warning("Already connecting to signaling server, skipping")
236            return False  # Don't trigger another reconnect cycle
237        self._connecting = True
238        close_code: int | None = None
239        self.logger.info("Connecting to signaling server: %s", self.signaling_url)
240        try:
241            self._signaling_ws = await self.http_session.ws_connect(
242                self.signaling_url,
243                heartbeat=35.0,  # Send ping every 35s (slightly above server's 30s interval)
244            )
245            # Check if we were stopped while connecting
246            if not self._running:
247                self.logger.debug("Gateway stopped during connection, closing WebSocket")
248                await self._signaling_ws.close()
249                self._signaling_ws = None
250                self._connecting = False
251                return False
252            self.logger.debug("WebSocket connection established, id=%s", id(self._signaling_ws))
253            self.logger.debug("Sending registration")
254            await self._register()
255            self._current_reconnect_delay = self._reconnect_delay
256            self.logger.debug("Registration sent, waiting for confirmation...")
257
258            # Run message loop and get close code
259            close_code = await self._signaling_message_loop(self._signaling_ws)
260
261            # Get close code from WebSocket if not already set from CLOSE message
262            if close_code is None:
263                close_code = self._signaling_ws.close_code
264            ws_exception = self._signaling_ws.exception()
265            self.logger.debug(
266                "Message loop exited - WebSocket closed: %s, close_code: %s, exception: %s",
267                self._signaling_ws.closed,
268                close_code,
269                ws_exception,
270            )
271        except TimeoutError:
272            self.logger.error("Timeout connecting to signaling server")
273        except aiohttp.ClientError as err:
274            self.logger.error("Failed to connect to signaling server: %s", err)
275        except Exception:
276            self.logger.exception("Unexpected error in signaling connection")
277        finally:
278            self._is_connected = False
279            self._connecting = False
280            self._signaling_ws = None
281
282        # Check if this connection was replaced by another one
283        if close_code == self.CLOSE_CODE_REPLACED:
284            self.logger.info("Connection was replaced by another instance - not reconnecting")
285            return False
286
287        return True
288
289    async def _signaling_message_loop(self, ws: aiohttp.ClientWebSocketResponse) -> int | None:
290        """Process messages from the signaling WebSocket.
291
292        :param ws: The WebSocket connection to process messages from.
293        :return: Close code if connection was closed with a code, None otherwise.
294        """
295        close_code: int | None = None
296        self.logger.debug("Entering message loop")
297        async for msg in ws:
298            if msg.type == aiohttp.WSMsgType.TEXT:
299                try:
300                    await self._handle_signaling_message(json.loads(msg.data))
301                except Exception:
302                    self.logger.exception("Error handling signaling message")
303            elif msg.type == aiohttp.WSMsgType.PING:
304                self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PING")
305            elif msg.type == aiohttp.WSMsgType.PONG:
306                self.logger.log(VERBOSE_LOG_LEVEL, "Received WebSocket PONG")
307            elif msg.type == aiohttp.WSMsgType.CLOSE:
308                close_code = msg.data
309                self.logger.warning(
310                    "Signaling server sent close frame: code=%s, reason=%s",
311                    msg.data,
312                    msg.extra,
313                )
314                break
315            elif msg.type == aiohttp.WSMsgType.CLOSED:
316                self.logger.warning("Signaling server closed connection")
317                break
318            elif msg.type == aiohttp.WSMsgType.ERROR:
319                self.logger.error("WebSocket error: %s", ws.exception())
320                break
321            else:
322                self.logger.warning("Unexpected WebSocket message type: %s", msg.type)
323        return close_code
324
325    async def _register(self) -> None:
326        """Register with the signaling server."""
327        if self._signaling_ws:
328            await self._signaling_ws.send_json(
329                {
330                    "type": "register-server",
331                    "remoteId": self._remote_id,
332                    "iceServers": self.ice_servers,
333                }
334            )
335
336    async def _handle_signaling_message(self, message: dict[str, Any]) -> None:
337        """Handle incoming signaling messages.
338
339        :param message: The signaling message.
340        """
341        msg_type = message.get("type")
342
343        if msg_type in ("ping", "pong"):
344            # Ignore JSON-level ping/pong messages - we use WebSocket protocol-level heartbeat
345            # The signaling server still sends these for backward compatibility with older clients
346            pass
347        elif msg_type == "registered":
348            self._is_connected = True
349            self.logger.info("Registered with signaling server")
350        elif msg_type == "error":
351            error_msg = message.get("error") or message.get("message", "Unknown error")
352            self.logger.error("Signaling server error: %s", error_msg)
353        elif msg_type == "client-connected":
354            session_id = message.get("sessionId")
355            if session_id:
356                await self._create_session(session_id)
357                # Send session-ready with fresh ICE servers for the client
358                fresh_ice_servers = await self._get_fresh_ice_servers()
359                if self._signaling_ws:
360                    await self._signaling_ws.send_json(
361                        {
362                            "type": "session-ready",
363                            "sessionId": session_id,
364                            "iceServers": fresh_ice_servers,
365                        }
366                    )
367        elif msg_type == "client-disconnected":
368            session_id = message.get("sessionId")
369            if session_id:
370                await self._close_session(session_id)
371        elif msg_type == "offer":
372            session_id = message.get("sessionId")
373            offer_data = message.get("data")
374            if session_id and offer_data:
375                await self._handle_offer(session_id, offer_data)
376        elif msg_type == "ice-candidate":
377            session_id = message.get("sessionId")
378            candidate_data = message.get("data")
379            if session_id and candidate_data:
380                await self._handle_ice_candidate(session_id, candidate_data)
381
382    async def _create_session(self, session_id: str) -> None:
383        """Create a new WebRTC session.
384
385        :param session_id: The session ID.
386        """
387        session_ice_servers = await self._get_fresh_ice_servers()
388        config = RTCConfiguration(
389            iceServers=[RTCIceServer(**server) for server in session_ice_servers]
390        )
391        pc = create_peer_connection_with_certificate(self._certificate, configuration=config)
392        session = WebRTCSession(session_id=session_id, peer_connection=pc)
393        self.sessions[session_id] = session
394
395        @pc.on("datachannel")
396        def on_datachannel(channel: Any) -> None:
397            if channel.label == "sendspin":
398                session.sendspin_channel = channel
399                asyncio.create_task(self._setup_sendspin_channel(session))
400            else:
401                session.data_channel = channel
402                asyncio.create_task(self._setup_data_channel(session))
403
404        @pc.on("icecandidate")
405        async def on_icecandidate(candidate: Any) -> None:
406            if candidate and self._signaling_ws:
407                await self._signaling_ws.send_json(
408                    {
409                        "type": "ice-candidate",
410                        "sessionId": session_id,
411                        "data": {
412                            "candidate": candidate.candidate,
413                            "sdpMid": candidate.sdpMid,
414                            "sdpMLineIndex": candidate.sdpMLineIndex,
415                        },
416                    }
417                )
418
419        @pc.on("connectionstatechange")
420        async def on_connectionstatechange() -> None:
421            if pc.connectionState == "failed":
422                await self._close_session(session_id)
423
424    async def _handle_offer(self, session_id: str, offer: dict[str, Any]) -> None:
425        """Handle incoming WebRTC offer.
426
427        :param session_id: The session ID.
428        :param offer: The offer data.
429        """
430        session = self.sessions.get(session_id)
431        if not session:
432            return
433        pc = session.peer_connection
434
435        if pc.connectionState in ("closed", "failed"):
436            return
437
438        sdp = offer.get("sdp")
439        sdp_type = offer.get("type")
440        if not sdp or not sdp_type:
441            self.logger.error("Invalid offer data: missing sdp or type")
442            return
443
444        try:
445            await pc.setRemoteDescription(
446                RTCSessionDescription(
447                    sdp=str(sdp),
448                    type=str(sdp_type),
449                )
450            )
451
452            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
453                return
454
455            answer = await pc.createAnswer()
456
457            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
458                return
459
460            await pc.setLocalDescription(answer)
461
462            # Wait for ICE gathering to complete before sending the answer
463            # aiortc doesn't support trickle ICE, candidates are embedded in SDP after gathering
464            gather_timeout = 30
465            gather_start = asyncio.get_event_loop().time()
466            while pc.iceGatheringState != "complete":
467                if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
468                    return
469                if asyncio.get_event_loop().time() - gather_start > gather_timeout:
470                    self.logger.warning("Session %s ICE gathering timeout", session_id)
471                    break
472                await asyncio.sleep(0.1)
473
474            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
475                return
476
477            if self._signaling_ws:
478                await self._signaling_ws.send_json(
479                    {
480                        "type": "answer",
481                        "sessionId": session_id,
482                        "data": {
483                            "sdp": pc.localDescription.sdp,
484                            "type": pc.localDescription.type,
485                        },
486                    }
487                )
488        except Exception:
489            self.logger.exception("Error handling offer for session %s", session_id)
490            # Clean up the session on error
491            await self._close_session(session_id)
492
493    async def _handle_ice_candidate(self, session_id: str, candidate: dict[str, Any]) -> None:
494        """Handle incoming ICE candidate.
495
496        :param session_id: The session ID.
497        :param candidate: The ICE candidate data.
498        """
499        session = self.sessions.get(session_id)
500        if not session or not candidate:
501            return
502
503        pc = session.peer_connection
504        if pc.connectionState in ("closed", "failed"):
505            return
506
507        candidate_str = candidate.get("candidate")
508        sdp_mid = candidate.get("sdpMid")
509        sdp_mline_index = candidate.get("sdpMLineIndex")
510
511        if not candidate_str:
512            return
513
514        try:
515            # Parse ICE candidate - browser sends "candidate:..." format
516            if candidate_str.startswith("candidate:"):
517                sdp_candidate_str = candidate_str[len("candidate:") :]
518            else:
519                sdp_candidate_str = candidate_str
520
521            ice_candidate = candidate_from_sdp(sdp_candidate_str)
522            ice_candidate.sdpMid = str(sdp_mid) if sdp_mid else None
523            ice_candidate.sdpMLineIndex = (
524                int(sdp_mline_index) if sdp_mline_index is not None else None
525            )
526
527            if session_id not in self.sessions or pc.connectionState in ("closed", "failed"):
528                return
529
530            await session.peer_connection.addIceCandidate(ice_candidate)
531        except Exception:
532            self.logger.exception("Failed to add ICE candidate for session %s", session_id)
533
534    async def _setup_data_channel(self, session: WebRTCSession) -> None:
535        """Set up data channel and bridge to local WebSocket.
536
537        :param session: The WebRTC session.
538        """
539        channel = session.data_channel
540        if not channel:
541            return
542        try:
543            # Include session_id in URL so server can track WebRTC sessions
544            ws_url = f"{self.local_ws_url}?webrtc_session_id={session.session_id}"
545            session.local_ws = await self.http_session.ws_connect(ws_url)
546            loop = asyncio.get_event_loop()
547
548            # Store task references for proper cleanup
549            session.forward_to_local_task = asyncio.create_task(self._forward_to_local(session))
550            session.forward_from_local_task = asyncio.create_task(self._forward_from_local(session))
551
552            @channel.on("message")  # type: ignore[untyped-decorator]
553            def on_message(message: str) -> None:
554                # Called from aiortc thread, use call_soon_threadsafe
555                # Only queue message if session is still active
556                if session.forward_to_local_task and not session.forward_to_local_task.done():
557                    loop.call_soon_threadsafe(session.message_queue.put_nowait, message)
558
559            @channel.on("close")  # type: ignore[untyped-decorator]
560            def on_close() -> None:
561                # Called from aiortc thread, use call_soon_threadsafe to schedule task
562                asyncio.run_coroutine_threadsafe(self._close_session(session.session_id), loop)
563
564        except Exception:
565            self.logger.exception("Failed to connect to local WebSocket")
566
567    async def _forward_to_local(self, session: WebRTCSession) -> None:
568        """Forward messages from WebRTC DataChannel to local WebSocket.
569
570        :param session: The WebRTC session.
571        """
572        try:
573            while session.local_ws and not session.local_ws.closed:
574                message = await session.message_queue.get()
575
576                # Check if this is an HTTP proxy request
577                try:
578                    msg_data = json.loads(message)
579                    if isinstance(msg_data, dict) and msg_data.get("type") == "http-proxy-request":
580                        # Handle HTTP proxy request
581                        await self._handle_http_proxy_request(session, msg_data)
582                        continue
583                except (json.JSONDecodeError, ValueError):
584                    pass
585
586                # Regular WebSocket message
587                if session.local_ws and not session.local_ws.closed:
588                    await session.local_ws.send_str(message)
589        except asyncio.CancelledError:
590            # Task was cancelled during cleanup, this is expected
591            self.logger.debug("Forward to local task cancelled for session %s", session.session_id)
592            raise
593        except Exception:
594            self.logger.exception("Error forwarding to local WebSocket")
595
596    async def _forward_from_local(self, session: WebRTCSession) -> None:
597        """Forward messages from local WebSocket to WebRTC DataChannel.
598
599        :param session: The WebRTC session.
600        """
601        try:
602            async for msg in session.local_ws:
603                if msg.type == aiohttp.WSMsgType.TEXT:
604                    if session.data_channel and session.data_channel.readyState == "open":
605                        session.data_channel.send(msg.data)
606                elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
607                    break
608        except asyncio.CancelledError:
609            # Task was cancelled during cleanup, this is expected
610            self.logger.debug(
611                "Forward from local task cancelled for session %s", session.session_id
612            )
613            raise
614        except Exception:
615            self.logger.exception("Error forwarding from local WebSocket")
616
617    async def _handle_http_proxy_request(
618        self, session: WebRTCSession, request_data: dict[str, Any]
619    ) -> None:
620        """Handle HTTP proxy request from remote client.
621
622        :param session: The WebRTC session.
623        :param request_data: The HTTP proxy request data.
624        """
625        request_id = request_data.get("id")
626        method = request_data.get("method", "GET")
627        path = request_data.get("path", "/")
628        headers = request_data.get("headers", {})
629
630        # Build local HTTP URL
631        # Extract host and port from local_ws_url (ws://localhost:8095/ws)
632        ws_url_parts = self.local_ws_url.replace("ws://", "").split("/")
633        host_port = ws_url_parts[0]  # localhost:8095
634        local_http_url = f"http://{host_port}{path}"
635
636        self.logger.debug("HTTP proxy request: %s %s", method, local_http_url)
637
638        try:
639            # Use shared HTTP session for this request
640            async with self.http_session.request(
641                method, local_http_url, headers=headers
642            ) as response:
643                # Read response body
644                body = await response.read()
645
646                # Prepare response data
647                response_data = {
648                    "type": "http-proxy-response",
649                    "id": request_id,
650                    "status": response.status,
651                    "headers": dict(response.headers),
652                    "body": body.hex(),  # Send as hex string to avoid encoding issues
653                }
654
655                # Send response back through data channel
656                if session.data_channel and session.data_channel.readyState == "open":
657                    session.data_channel.send(json.dumps(response_data))
658
659        except Exception as err:
660            self.logger.exception("Error handling HTTP proxy request")
661            # Send error response
662            error_response = {
663                "type": "http-proxy-response",
664                "id": request_id,
665                "status": 500,
666                "headers": {"Content-Type": "text/plain"},
667                "body": str(err).encode().hex(),
668            }
669            if session.data_channel and session.data_channel.readyState == "open":
670                session.data_channel.send(json.dumps(error_response))
671
672    async def _close_session(self, session_id: str) -> None:
673        """Close a WebRTC session.
674
675        :param session_id: The session ID.
676        """
677        session = self.sessions.pop(session_id, None)
678        if not session:
679            return
680
681        # Cancel forwarding tasks first to prevent race conditions
682        if session.forward_to_local_task and not session.forward_to_local_task.done():
683            session.forward_to_local_task.cancel()
684            with contextlib.suppress(asyncio.CancelledError):
685                await session.forward_to_local_task
686
687        if session.forward_from_local_task and not session.forward_from_local_task.done():
688            session.forward_from_local_task.cancel()
689            with contextlib.suppress(asyncio.CancelledError):
690                await session.forward_from_local_task
691
692        # Cancel sendspin forwarding tasks
693        if session.sendspin_to_local_task and not session.sendspin_to_local_task.done():
694            session.sendspin_to_local_task.cancel()
695            with contextlib.suppress(asyncio.CancelledError):
696                await session.sendspin_to_local_task
697
698        if session.sendspin_from_local_task and not session.sendspin_from_local_task.done():
699            session.sendspin_from_local_task.cancel()
700            with contextlib.suppress(asyncio.CancelledError):
701                await session.sendspin_from_local_task
702
703        # Close connections
704        if session.local_ws and not session.local_ws.closed:
705            await session.local_ws.close()
706        if session.sendspin_ws and not session.sendspin_ws.closed:
707            await session.sendspin_ws.close()
708        if session.data_channel:
709            session.data_channel.close()
710        if session.sendspin_channel:
711            session.sendspin_channel.close()
712        await session.peer_connection.close()
713
714    async def _setup_sendspin_channel(self, session: WebRTCSession) -> None:
715        """Set up sendspin data channel and bridge to internal sendspin server.
716
717        :param session: The WebRTC session.
718        """
719        channel = session.sendspin_channel
720        if not channel:
721            return
722
723        try:
724            loop = asyncio.get_event_loop()
725
726            @channel.on("message")  # type: ignore[untyped-decorator]
727            def on_message(message: str | bytes) -> None:
728                # Queue if task not yet created (None) or still running.
729                # Only drop when task exists and is done (shutdown).
730                if (
731                    session.sendspin_to_local_task is None
732                    or not session.sendspin_to_local_task.done()
733                ):
734                    loop.call_soon_threadsafe(session.sendspin_queue.put_nowait, message)
735
736            @channel.on("close")  # type: ignore[untyped-decorator]
737            def on_close() -> None:
738                if session.sendspin_ws and not session.sendspin_ws.closed:
739                    asyncio.run_coroutine_threadsafe(session.sendspin_ws.close(), loop)
740
741            session.sendspin_ws = await self.http_session.ws_connect(self.sendspin_url)
742            self.logger.debug("Sendspin channel connected for session %s", session.session_id)
743
744            # Start forwarding tasks - queued messages will be processed
745            session.sendspin_to_local_task = asyncio.create_task(
746                self._forward_sendspin_to_local(session)
747            )
748            session.sendspin_from_local_task = asyncio.create_task(
749                self._forward_sendspin_from_local(session)
750            )
751
752        except Exception:
753            self.logger.exception(
754                "Failed to connect sendspin channel to internal server for session %s",
755                session.session_id,
756            )
757            # Clean up partial state on failure
758            if session.sendspin_to_local_task:
759                session.sendspin_to_local_task.cancel()
760            if session.sendspin_from_local_task:
761                session.sendspin_from_local_task.cancel()
762            if session.sendspin_ws and not session.sendspin_ws.closed:
763                await session.sendspin_ws.close()
764
765    async def _forward_sendspin_to_local(self, session: WebRTCSession) -> None:
766        """Forward messages from sendspin DataChannel to internal sendspin server.
767
768        :param session: The WebRTC session.
769        """
770        first_message = True
771        try:
772            while session.sendspin_ws and not session.sendspin_ws.closed:
773                message = await session.sendspin_queue.get()
774
775                # Check only the first message for client_id extraction
776                if first_message:
777                    first_message = False
778                    if isinstance(message, str):
779                        self._try_extract_sendspin_client_id(session, message)
780
781                if session.sendspin_ws and not session.sendspin_ws.closed:
782                    if isinstance(message, bytes):
783                        await session.sendspin_ws.send_bytes(message)
784                    else:
785                        await session.sendspin_ws.send_str(message)
786        except asyncio.CancelledError:
787            self.logger.debug(
788                "Sendspin forward to local task cancelled for session %s",
789                session.session_id,
790            )
791            raise
792        except Exception:
793            self.logger.exception("Error forwarding sendspin to local")
794
795    def _try_extract_sendspin_client_id(self, session: WebRTCSession, message: str) -> None:
796        """Try to extract client_id from sendspin auth message and set on websocket client.
797
798        :param session: The WebRTC session.
799        :param message: The first sendspin message (expected to be auth).
800        """
801        try:
802            data = json.loads(message)
803            if data.get("type") != "auth":
804                return  # Not an auth message
805
806            # This is an auth message - extract client_id if present
807            if client_id := data.get("client_id"):
808                session.sendspin_player_id = client_id
809                self.logger.debug(
810                    "Extracted sendspin player %s for session %s",
811                    client_id,
812                    session.session_id,
813                )
814                # Use callback to set sendspin player on the websocket client
815                if self._set_sendspin_player_callback:
816                    self._set_sendspin_player_callback(session.session_id, client_id)
817        except (json.JSONDecodeError, TypeError):
818            pass  # Not valid JSON, ignore
819
820    async def _forward_sendspin_from_local(self, session: WebRTCSession) -> None:
821        """Forward messages from internal sendspin server to sendspin DataChannel.
822
823        :param session: The WebRTC session.
824        """
825        if not session.sendspin_ws or session.sendspin_ws.closed:
826            return
827
828        try:
829            async for msg in session.sendspin_ws:
830                if msg.type in {aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY}:
831                    if session.sendspin_channel and session.sendspin_channel.readyState == "open":
832                        session.sendspin_channel.send(msg.data)
833                elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
834                    break
835        except asyncio.CancelledError:
836            self.logger.debug(
837                "Sendspin forward from local task cancelled for session %s",
838                session.session_id,
839            )
840            raise
841        except Exception:
842            self.logger.exception("Error forwarding sendspin from local")
843