/
/
/
1"""Unix socket server for Snapcast control script communication.
2
3This module provides a secure communication channel between the Snapcast control script
4and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
5"""
6
7from __future__ import annotations
8
9import asyncio
10import inspect
11import json
12import logging
13from contextlib import suppress
14from pathlib import Path
15from typing import TYPE_CHECKING, Any
16
17from music_assistant_models.enums import EventType
18
19if TYPE_CHECKING:
20 from music_assistant.mass import MusicAssistant
21
22LOGGER = logging.getLogger(__name__)
23
24LOOP_STATUS_MAP = {
25 "all": "playlist",
26 "one": "track",
27 "off": "none",
28}
29LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
30
31
32class SnapcastSocketServer:
33 """Unix socket server for a single Snapcast control script connection.
34
35 Each stream gets its own socket server instance to handle control script communication.
36 The socket provides a secure IPC channel that doesn't require authentication since
37 only local processes can connect.
38 """
39
40 def __init__(
41 self,
42 mass: MusicAssistant,
43 queue_id: str,
44 socket_path: str,
45 streamserver_ip: str,
46 streamserver_port: int,
47 ) -> None:
48 """Initialize the socket server.
49
50 :param mass: The MusicAssistant instance.
51 :param queue_id: The queue ID this socket serves.
52 :param socket_path: Path to the Unix socket file.
53 :param streamserver_ip: IP address of the stream server (for image proxy).
54 :param streamserver_port: Port of the stream server (for image proxy).
55 """
56 self.mass = mass
57 self.queue_id = queue_id
58 self.socket_path = socket_path
59 self.streamserver_ip = streamserver_ip
60 self.streamserver_port = streamserver_port
61 self._server: asyncio.AbstractServer | None = None
62 self._client_writer: asyncio.StreamWriter | None = None
63 self._unsub_callback: Any = None
64 self._logger = LOGGER.getChild(queue_id)
65
66 async def start(self) -> None:
67 """Start the Unix socket server."""
68 # Ensure the socket file doesn't exist
69 socket_path = Path(self.socket_path)
70 socket_path.unlink(missing_ok=True)
71
72 # Create the socket server
73 self._server = await asyncio.start_unix_server(
74 self._handle_client,
75 path=self.socket_path,
76 )
77 # Set permissions so only the current user can access
78 Path(self.socket_path).chmod(0o600)
79 self._logger.debug("Started Unix socket server at %s", self.socket_path)
80
81 # Subscribe to queue events
82 self._unsub_callback = self.mass.subscribe(
83 self._handle_mass_event,
84 (EventType.QUEUE_UPDATED,),
85 self.queue_id,
86 )
87
88 async def stop(self) -> None:
89 """Stop the Unix socket server."""
90 if self._unsub_callback:
91 self._unsub_callback()
92 self._unsub_callback = None
93
94 if self._client_writer:
95 self._client_writer.close()
96 with suppress(Exception):
97 await self._client_writer.wait_closed()
98 self._client_writer = None
99
100 if self._server:
101 self._server.close()
102 await self._server.wait_closed()
103 self._server = None
104
105 # Clean up socket file
106 Path(self.socket_path).unlink(missing_ok=True)
107 self._logger.debug("Stopped Unix socket server")
108
109 async def _handle_client(
110 self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
111 ) -> None:
112 """Handle a control script connection."""
113 self._logger.debug("Control script connected")
114 self._client_writer = writer
115
116 try:
117 while True:
118 line = await reader.readline()
119 if not line:
120 break
121
122 try:
123 message = json.loads(line.decode().strip())
124 await self._handle_message(message)
125 except json.JSONDecodeError as err:
126 self._logger.warning("Invalid JSON from control script: %s", err)
127 except Exception as err:
128 self._logger.exception("Error handling control script message: %s", err)
129 except asyncio.CancelledError:
130 pass
131 except ConnectionResetError:
132 self._logger.debug("Control script connection reset")
133 finally:
134 self._client_writer = None
135 writer.close()
136 with suppress(Exception):
137 await writer.wait_closed()
138 self._logger.debug("Control script disconnected")
139
140 async def _handle_message(self, message: dict[str, Any]) -> None:
141 """Handle a message from the control script.
142
143 :param message: The JSON message from the control script.
144 """
145 msg_id = message.get("message_id")
146 command = message.get("command")
147 args = message.get("args", {})
148
149 if not command:
150 await self._send_error(msg_id, "Missing command")
151 return
152
153 try:
154 result = await self._execute_command(command, args)
155 await self._send_result(msg_id, result)
156 except Exception as err:
157 self._logger.exception("Error executing command %s: %s", command, err)
158 await self._send_error(msg_id, str(err))
159
160 async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
161 """Execute a Music Assistant API command.
162
163 :param command: The API command to execute.
164 :param args: The arguments for the command.
165 :return: The result of the command.
166 """
167 handler = self.mass.command_handlers.get(command)
168 if handler is None:
169 raise ValueError(f"Unknown command: {command}")
170
171 # Execute the handler
172 result = handler.target(**args)
173 if inspect.iscoroutine(result):
174 result = await result
175 return result
176
177 async def _send_result(self, msg_id: str | None, result: Any) -> None:
178 """Send a success result to the control script.
179
180 :param msg_id: The message ID from the request.
181 :param result: The result data.
182 """
183 if not self._client_writer:
184 return
185
186 response: dict[str, Any] = {"message_id": msg_id}
187 if result is not None:
188 # Convert result to dict if it has to_dict method
189 if hasattr(result, "to_dict"):
190 response["result"] = result.to_dict()
191 else:
192 response["result"] = result
193
194 await self._send_message(response)
195
196 async def _send_error(self, msg_id: str | None, error: str) -> None:
197 """Send an error result to the control script.
198
199 :param msg_id: The message ID from the request.
200 :param error: The error message.
201 """
202 if not self._client_writer:
203 return
204
205 response = {
206 "message_id": msg_id,
207 "error": error,
208 }
209 await self._send_message(response)
210
211 async def _send_message(self, message: dict[str, Any]) -> None:
212 """Send a message to the control script.
213
214 :param message: The message to send.
215 """
216 if not self._client_writer:
217 return
218
219 try:
220 data = json.dumps(message) + "\n"
221 self._client_writer.write(data.encode())
222 await self._client_writer.drain()
223 except (ConnectionResetError, BrokenPipeError):
224 self._logger.debug("Failed to send message - connection closed")
225 self._client_writer = None
226
227 def _handle_mass_event(self, event: Any) -> None:
228 """Handle Music Assistant events and forward to control script.
229
230 :param event: The Music Assistant event.
231 """
232 if not self._client_writer:
233 return
234
235 # Forward queue_updated events
236 if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
237 event_msg = {
238 "event": "queue_updated",
239 "object_id": event.object_id,
240 "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
241 }
242 # Schedule the send in the event loop
243 asyncio.create_task(self._send_message(event_msg))
244