/
/
/
1"""Unix socket server for Snapcast control script communication.
2
3This module provides a secure communication channel between the Snapcast control script
4and Music Assistant, avoiding the need to expose the WebSocket API to the control script.
5"""
6
7from __future__ import annotations
8
9import asyncio
10import json
11import logging
12from contextlib import suppress
13from pathlib import Path
14from typing import TYPE_CHECKING, Any
15
16from music_assistant_models.enums import EventType
17
18if TYPE_CHECKING:
19 from music_assistant.mass import MusicAssistant
20
21LOGGER = logging.getLogger(__name__)
22
23LOOP_STATUS_MAP = {
24 "all": "playlist",
25 "one": "track",
26 "off": "none",
27}
28LOOP_STATUS_MAP_REVERSE = {v: k for k, v in LOOP_STATUS_MAP.items()}
29
30
31class SnapcastSocketServer:
32 """Unix socket server for a single Snapcast control script connection.
33
34 Each stream gets its own socket server instance to handle control script communication.
35 The socket provides a secure IPC channel that doesn't require authentication since
36 only local processes can connect.
37 """
38
39 def __init__(
40 self,
41 mass: MusicAssistant,
42 queue_id: str,
43 socket_path: str,
44 streamserver_ip: str,
45 streamserver_port: int,
46 ) -> None:
47 """Initialize the socket server.
48
49 :param mass: The MusicAssistant instance.
50 :param queue_id: The queue ID this socket serves.
51 :param socket_path: Path to the Unix socket file.
52 :param streamserver_ip: IP address of the stream server (for image proxy).
53 :param streamserver_port: Port of the stream server (for image proxy).
54 """
55 self.mass = mass
56 self.queue_id = queue_id
57 self.socket_path = socket_path
58 self.streamserver_ip = streamserver_ip
59 self.streamserver_port = streamserver_port
60 self._server: asyncio.AbstractServer | None = None
61 self._client_writer: asyncio.StreamWriter | None = None
62 self._unsub_callback: Any = None
63 self._logger = LOGGER.getChild(queue_id)
64
65 async def start(self) -> None:
66 """Start the Unix socket server."""
67 # Ensure the socket file doesn't exist
68 socket_path = Path(self.socket_path)
69 socket_path.unlink(missing_ok=True)
70
71 # Create the socket server
72 self._server = await asyncio.start_unix_server(
73 self._handle_client,
74 path=self.socket_path,
75 )
76 # Set permissions so only the current user can access
77 Path(self.socket_path).chmod(0o600)
78 self._logger.debug("Started Unix socket server at %s", self.socket_path)
79
80 # Subscribe to queue events
81 self._unsub_callback = self.mass.subscribe(
82 self._handle_mass_event,
83 (EventType.QUEUE_UPDATED,),
84 self.queue_id,
85 )
86
87 async def stop(self) -> None:
88 """Stop the Unix socket server."""
89 if self._unsub_callback:
90 self._unsub_callback()
91 self._unsub_callback = None
92
93 if self._client_writer:
94 self._client_writer.close()
95 with suppress(Exception):
96 await self._client_writer.wait_closed()
97 self._client_writer = None
98
99 if self._server:
100 self._server.close()
101 await self._server.wait_closed()
102 self._server = None
103
104 # Clean up socket file
105 Path(self.socket_path).unlink(missing_ok=True)
106 self._logger.debug("Stopped Unix socket server")
107
108 async def _handle_client(
109 self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
110 ) -> None:
111 """Handle a control script connection."""
112 self._logger.debug("Control script connected")
113 self._client_writer = writer
114
115 try:
116 while True:
117 line = await reader.readline()
118 if not line:
119 break
120
121 try:
122 message = json.loads(line.decode().strip())
123 await self._handle_message(message)
124 except json.JSONDecodeError as err:
125 self._logger.warning("Invalid JSON from control script: %s", err)
126 except Exception as err:
127 self._logger.exception("Error handling control script message: %s", err)
128 except asyncio.CancelledError:
129 pass
130 except ConnectionResetError:
131 self._logger.debug("Control script connection reset")
132 finally:
133 self._client_writer = None
134 writer.close()
135 with suppress(Exception):
136 await writer.wait_closed()
137 self._logger.debug("Control script disconnected")
138
139 async def _handle_message(self, message: dict[str, Any]) -> None:
140 """Handle a message from the control script.
141
142 :param message: The JSON message from the control script.
143 """
144 msg_id = message.get("message_id")
145 command = message.get("command")
146 args = message.get("args", {})
147
148 if not command:
149 await self._send_error(msg_id, "Missing command")
150 return
151
152 try:
153 result = await self._execute_command(command, args)
154 await self._send_result(msg_id, result)
155 except Exception as err:
156 self._logger.exception("Error executing command %s: %s", command, err)
157 await self._send_error(msg_id, str(err))
158
159 async def _execute_command(self, command: str, args: dict[str, Any]) -> Any:
160 """Execute a Music Assistant API command.
161
162 :param command: The API command to execute.
163 :param args: The arguments for the command.
164 :return: The result of the command.
165 """
166 handler = self.mass.command_handlers.get(command)
167 if handler is None:
168 raise ValueError(f"Unknown command: {command}")
169
170 # Execute the handler
171 result = handler.target(**args)
172 if asyncio.iscoroutine(result):
173 result = await result
174 return result
175
176 async def _send_result(self, msg_id: str | None, result: Any) -> None:
177 """Send a success result to the control script.
178
179 :param msg_id: The message ID from the request.
180 :param result: The result data.
181 """
182 if not self._client_writer:
183 return
184
185 response: dict[str, Any] = {"message_id": msg_id}
186 if result is not None:
187 # Convert result to dict if it has to_dict method
188 if hasattr(result, "to_dict"):
189 response["result"] = result.to_dict()
190 else:
191 response["result"] = result
192
193 await self._send_message(response)
194
195 async def _send_error(self, msg_id: str | None, error: str) -> None:
196 """Send an error result to the control script.
197
198 :param msg_id: The message ID from the request.
199 :param error: The error message.
200 """
201 if not self._client_writer:
202 return
203
204 response = {
205 "message_id": msg_id,
206 "error": error,
207 }
208 await self._send_message(response)
209
210 async def _send_message(self, message: dict[str, Any]) -> None:
211 """Send a message to the control script.
212
213 :param message: The message to send.
214 """
215 if not self._client_writer:
216 return
217
218 try:
219 data = json.dumps(message) + "\n"
220 self._client_writer.write(data.encode())
221 await self._client_writer.drain()
222 except (ConnectionResetError, BrokenPipeError):
223 self._logger.debug("Failed to send message - connection closed")
224 self._client_writer = None
225
226 def _handle_mass_event(self, event: Any) -> None:
227 """Handle Music Assistant events and forward to control script.
228
229 :param event: The Music Assistant event.
230 """
231 if not self._client_writer:
232 return
233
234 # Forward queue_updated events
235 if event.event == EventType.QUEUE_UPDATED and event.object_id == self.queue_id:
236 event_msg = {
237 "event": "queue_updated",
238 "object_id": event.object_id,
239 "data": event.data.to_dict() if hasattr(event.data, "to_dict") else event.data,
240 }
241 # Schedule the send in the event loop
242 asyncio.create_task(self._send_message(event_msg))
243