/
/
/
1"""Unix socket server for Snapcast control script communication.
2
3This module provides a secure communication channel between the Snapcast control script
4and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
5"""
6
7from __future__ import annotations
8
9import asyncio
10import inspect
11import json
12import logging
13from contextlib import suppress
14from pathlib import Path
15from typing import TYPE_CHECKING, Any
16
17from music_assistant_models.enums import EventType
18
19if TYPE_CHECKING:
20 from music_assistant.mass import MusicAssistant
21
22LOGGER = logging.getLogger(__name__)
23
24LOOP_STATUS_MAP = {
25 "all": "playlist",
26 "one": "track",
27 "off": "none",
28}
29LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
30
31
32class SnapcastSocketServer:
33 """Unix socket server for a single Snapcast control script connection.
34
35 Each stream gets its own socket server instance to handle control script communication.
36 The socket provides a secure IPC channel that doesn't require authentication since
37 only local processes can connect.
38 """
39
40 def __init__(
41 self,
42 mass: MusicAssistant,
43 queue_id: str,
44 socket_path: str,
45 streamserver_ip: str,
46 streamserver_port: int,
47 ) -> None:
48 """Initialize the socket server.
49
50 :param mass: The MusicAssistant instance.
51 :param queue_id: The queue ID this socket serves.
52 :param socket_path: Path to the Unix socket file.
53 :param streamserver_ip: IP address of the stream server (for image proxy).
54 :param streamserver_port: Port of the stream server (for image proxy).
55 """
56 self.mass = mass
57 self.queue_id = queue_id
58 self.socket_path = socket_path
59 self.streamserver_ip = streamserver_ip
60 self.streamserver_port = streamserver_port
61 self._server: asyncio.AbstractServer | None = None
62 self._client_writer: asyncio.StreamWriter | None = None
63 self._unsub_callback: Any = None
64 self._logger = LOGGER.getChild(queue_id)
65
66 async def start(self) -> None:
67 """Start the Unix socket server."""
68 # Ensure the socket file doesn't exist
69 socket_path = Path(self.socket_path)
70 socket_path.unlink(missing_ok=True)
71
72 # Create the socket server
73 self._server = await asyncio.start_unix_server(
74 self._handle_client,
75 path=self.socket_path,
76 )
77 # Set permissions so only the current user can access
78 Path(self.socket_path).chmod(0o600)
79 self._logger.debug("Started Unix socket server at %s", self.socket_path)
80
81 # Subscribe to queue events
82 self._unsub_callback = self.mass.subscribe(
83 self._handle_mass_event,
84 (EventType.QUEUE_UPDATED,),
85 self.queue_id,
86 )
87
88 async def stop(self) -> None:
89 """Stop the Unix socket server."""
90 if self._unsub_callback:
91 self._unsub_callback()
92 self._unsub_callback = None
93
94 if self._client_writer:
95 with suppress(Exception):
96 await self.notify_shutdown()
97 self._client_writer.close()
98 with suppress(Exception):
99 await self._client_writer.wait_closed()
100 self._client_writer = None
101
102 if self._server:
103 self._server.close()
104 await self._server.wait_closed()
105 self._server = None
106
107 # Clean up socket file
108 Path(self.socket_path).unlink(missing_ok=True)
109 self._logger.debug("Stopped Unix socket server")
110
111 async def notify_shutdown(self) -> None:
112 """Tell the control script to exit."""
113 await self._send_message(
114 {
115 "event": "shutdown",
116 "object_id": self.queue_id,
117 }
118 )
119
120 async def _handle_client(
121 self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
122 ) -> None:
123 """Handle a control script connection."""
124 self._logger.debug("Control script connected")
125 self._client_writer = writer
126
127 try:
128 while True:
129 line = await reader.readline()
130 if not line:
131 break
132
133 try:
134 message = json.loads(line.decode().strip())
135 await self._handle_message(message)
136 except json.JSONDecodeError as err:
137 self._logger.warning("Invalid JSON from control script: %s", err)
138 except Exception as err:
139 self._logger.exception("Error handling control script message: %s", err)
140 except asyncio.CancelledError:
141 pass
142 except ConnectionResetError:
143 self._logger.debug("Control script connection reset")
144 finally:
145 self._client_writer = None
146 writer.close()
147 with suppress(Exception):
148 await writer.wait_closed()
149 self._logger.debug("Control script disconnected")
150
151 async def _handle_message(self, message: dict[str, Any]) -> None:
152 """Handle a message from the control script.
153
154 :param message: The JSON message from the control script.
155 """
156 msg_id = message.get("message_id")
157 command = message.get("command")
158 args = message.get("args", {})
159
160 if not command:
161 await self._send_error(msg_id, "Missing command")
162 return
163
164 try:
165 result = await self._execute_command(command, args)
166 await self._send_result(msg_id, result)
167 except Exception as err:
168 self._logger.exception("Error executing command %s: %s", command, err)
169 await self._send_error(msg_id, str(err))
170
171 async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
172 """Execute a Music Assistant API command.
173
174 :param command: The API command to execute.
175 :param args: The arguments for the command.
176 :return: The result of the command.
177 """
178 handler = self.mass.command_handlers.get(command)
179 if handler is None:
180 raise ValueError(f"Unknown command: {command}")
181
182 # Execute the handler
183 result = handler.target(**args)
184 if inspect.iscoroutine(result):
185 result = await result
186 return result
187
188 async def _send_result(self, msg_id: str | None, result: Any) -> None:
189 """Send a success result to the control script.
190
191 :param msg_id: The message ID from the request.
192 :param result: The result data.
193 """
194 if not self._client_writer:
195 return
196
197 response: dict[str, Any] = {"message_id": msg_id}
198 if result is not None:
199 # Convert result to dict if it has to_dict method
200 if hasattr(result, "to_dict"):
201 response["result"] = result.to_dict()
202 else:
203 response["result"] = result
204
205 await self._send_message(response)
206
207 async def _send_error(self, msg_id: str | None, error: str) -> None:
208 """Send an error result to the control script.
209
210 :param msg_id: The message ID from the request.
211 :param error: The error message.
212 """
213 if not self._client_writer:
214 return
215
216 response = {
217 "message_id": msg_id,
218 "error": error,
219 }
220 await self._send_message(response)
221
222 async def _send_message(self, message: dict[str, Any]) -> None:
223 """Send a message to the control script.
224
225 :param message: The message to send.
226 """
227 if not self._client_writer:
228 return
229
230 try:
231 data = json.dumps(message) + "\n"
232 self._client_writer.write(data.encode())
233 await self._client_writer.drain()
234 except (ConnectionResetError, BrokenPipeError):
235 self._logger.debug("Failed to send message - connection closed")
236 self._client_writer = None
237
238 def _handle_mass_event(self, event: Any) -> None:
239 """Handle Music Assistant events and forward to control script.
240
241 :param event: The Music Assistant event.
242 """
243 if not self._client_writer:
244 return
245
246 # Forward queue_updated events
247 if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
248 event_msg = {
249 "event": "queue_updated",
250 "object_id": event.object_id,
251 "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
252 }
253 # Schedule the send in the event loop
254 asyncio.create_task(self._send_message(event_msg))
255