music-assistant-server

8.5 KBPY
socket_server.py
8.5 KB255 lines • python
1"""Unix socket server for Snapcast control script communication.
2
3This module provides a secure communication channel between the Snapcast control script
4and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
5"""
6
7from __future__ import annotations
8
9import asyncio
10import inspect
11import json
12import logging
13from contextlib import suppress
14from pathlib import Path
15from typing import TYPE_CHECKING, Any
16
17from music_assistant_models.enums import EventType
18
19if TYPE_CHECKING:
20    from music_assistant.mass import MusicAssistant
21
22LOGGER = logging.getLogger(__name__)
23
24LOOP_STATUS_MAP = {
25    "all": "playlist",
26    "one": "track",
27    "off": "none",
28}
29LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
30
31
32class SnapcastSocketServer:
33    """Unix socket server for a single Snapcast control script connection.
34
35    Each stream gets its own socket server instance to handle control script communication.
36    The socket provides a secure IPC channel that doesn't require authentication since
37    only local processes can connect.
38    """
39
40    def __init__(
41        self,
42        mass: MusicAssistant,
43        queue_id: str,
44        socket_path: str,
45        streamserver_ip: str,
46        streamserver_port: int,
47    ) -> None:
48        """Initialize the socket server.
49
50        :param mass: The MusicAssistant instance.
51        :param queue_id: The queue ID this socket serves.
52        :param socket_path: Path to the Unix socket file.
53        :param streamserver_ip: IP address of the stream server (for image proxy).
54        :param streamserver_port: Port of the stream server (for image proxy).
55        """
56        self.mass = mass
57        self.queue_id = queue_id
58        self.socket_path = socket_path
59        self.streamserver_ip = streamserver_ip
60        self.streamserver_port = streamserver_port
61        self._server: asyncio.AbstractServer | None = None
62        self._client_writer: asyncio.StreamWriter | None = None
63        self._unsub_callback: Any = None
64        self._logger = LOGGER.getChild(queue_id)
65
66    async def start(self) -> None:
67        """Start the Unix socket server."""
68        # Ensure the socket file doesn't exist
69        socket_path = Path(self.socket_path)
70        socket_path.unlink(missing_ok=True)
71
72        # Create the socket server
73        self._server = await asyncio.start_unix_server(
74            self._handle_client,
75            path=self.socket_path,
76        )
77        # Set permissions so only the current user can access
78        Path(self.socket_path).chmod(0o600)
79        self._logger.debug("Started Unix socket server at %s", self.socket_path)
80
81        # Subscribe to queue events
82        self._unsub_callback = self.mass.subscribe(
83            self._handle_mass_event,
84            (EventType.QUEUE_UPDATED,),
85            self.queue_id,
86        )
87
88    async def stop(self) -> None:
89        """Stop the Unix socket server."""
90        if self._unsub_callback:
91            self._unsub_callback()
92            self._unsub_callback = None
93
94        if self._client_writer:
95            with suppress(Exception):
96                await self.notify_shutdown()
97            self._client_writer.close()
98            with suppress(Exception):
99                await self._client_writer.wait_closed()
100            self._client_writer = None
101
102        if self._server:
103            self._server.close()
104            await self._server.wait_closed()
105            self._server = None
106
107        # Clean up socket file
108        Path(self.socket_path).unlink(missing_ok=True)
109        self._logger.debug("Stopped Unix socket server")
110
111    async def notify_shutdown(self) -> None:
112        """Tell the control script to exit."""
113        await self._send_message(
114            {
115                "event": "shutdown",
116                "object_id": self.queue_id,
117            }
118        )
119
120    async def _handle_client(
121        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
122    ) -> None:
123        """Handle a control script connection."""
124        self._logger.debug("Control script connected")
125        self._client_writer = writer
126
127        try:
128            while True:
129                line = await reader.readline()
130                if not line:
131                    break
132
133                try:
134                    message = json.loads(line.decode().strip())
135                    await self._handle_message(message)
136                except json.JSONDecodeError as err:
137                    self._logger.warning("Invalid JSON from control script: %s", err)
138                except Exception as err:
139                    self._logger.exception("Error handling control script message: %s", err)
140        except asyncio.CancelledError:
141            pass
142        except ConnectionResetError:
143            self._logger.debug("Control script connection reset")
144        finally:
145            self._client_writer = None
146            writer.close()
147            with suppress(Exception):
148                await writer.wait_closed()
149            self._logger.debug("Control script disconnected")
150
151    async def _handle_message(self, message: dict[str, Any]) -> None:
152        """Handle a message from the control script.
153
154        :param message: The JSON message from the control script.
155        """
156        msg_id = message.get("message_id")
157        command = message.get("command")
158        args = message.get("args", {})
159
160        if not command:
161            await self._send_error(msg_id, "Missing command")
162            return
163
164        try:
165            result = await self._execute_command(command, args)
166            await self._send_result(msg_id, result)
167        except Exception as err:
168            self._logger.exception("Error executing command %s: %s", command, err)
169            await self._send_error(msg_id, str(err))
170
171    async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
172        """Execute a Music Assistant API command.
173
174        :param command: The API command to execute.
175        :param args: The arguments for the command.
176        :return: The result of the command.
177        """
178        handler = self.mass.command_handlers.get(command)
179        if handler is None:
180            raise ValueError(f"Unknown command: {command}")
181
182        # Execute the handler
183        result = handler.target(**args)
184        if inspect.iscoroutine(result):
185            result = await result
186        return result
187
188    async def _send_result(self, msg_id: str | None, result: Any) -> None:
189        """Send a success result to the control script.
190
191        :param msg_id: The message ID from the request.
192        :param result: The result data.
193        """
194        if not self._client_writer:
195            return
196
197        response: dict[str, Any] = {"message_id": msg_id}
198        if result is not None:
199            # Convert result to dict if it has to_dict method
200            if hasattr(result, "to_dict"):
201                response["result"] = result.to_dict()
202            else:
203                response["result"] = result
204
205        await self._send_message(response)
206
207    async def _send_error(self, msg_id: str | None, error: str) -> None:
208        """Send an error result to the control script.
209
210        :param msg_id: The message ID from the request.
211        :param error: The error message.
212        """
213        if not self._client_writer:
214            return
215
216        response = {
217            "message_id": msg_id,
218            "error": error,
219        }
220        await self._send_message(response)
221
222    async def _send_message(self, message: dict[str, Any]) -> None:
223        """Send a message to the control script.
224
225        :param message: The message to send.
226        """
227        if not self._client_writer:
228            return
229
230        try:
231            data = json.dumps(message) + "\n"
232            self._client_writer.write(data.encode())
233            await self._client_writer.drain()
234        except (ConnectionResetError, BrokenPipeError):
235            self._logger.debug("Failed to send message - connection closed")
236            self._client_writer = None
237
238    def _handle_mass_event(self, event: Any) -> None:
239        """Handle Music Assistant events and forward to control script.
240
241        :param event: The Music Assistant event.
242        """
243        if not self._client_writer:
244            return
245
246        # Forward queue_updated events
247        if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
248            event_msg = {
249                "event": "queue_updated",
250                "object_id": event.object_id,
251                "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
252            }
253            # Schedule the send in the event loop
254            asyncio.create_task(self._send_message(event_msg))
255