/
/
/
1"""SnapCastProvider."""
2
3import asyncio
4import logging
5import re
6import shutil
7import socket
8from pathlib import Path
9from typing import cast
10
11from bidict import bidict
12from music_assistant_models.enums import PlaybackState
13from music_assistant_models.errors import SetupFailedError
14from snapcast.control import create_server
15from snapcast.control.client import Snapclient
16from snapcast.control.group import Snapgroup
17from snapcast.control.server import Snapserver
18from zeroconf import NonUniqueNameException
19from zeroconf.asyncio import AsyncServiceInfo
20
21from music_assistant.helpers.process import AsyncProcess
22from music_assistant.helpers.util import get_ip_pton
23from music_assistant.models.player_provider import PlayerProvider
24from music_assistant.providers.snapcast.constants import (
25 CONF_SERVER_BUFFER_SIZE,
26 CONF_SERVER_CHUNK_MS,
27 CONF_SERVER_CONTROL_PORT,
28 CONF_SERVER_HOST,
29 CONF_SERVER_INITIAL_VOLUME,
30 CONF_SERVER_SEND_AUDIO_TO_MUTED,
31 CONF_SERVER_TRANSPORT_CODEC,
32 CONF_STREAM_IDLE_THRESHOLD,
33 CONF_USE_EXTERNAL_SERVER,
34 CONTROL_SCRIPT,
35 CONTROL_SOCKET_PATH_TEMPLATE,
36 DEFAULT_SNAPSERVER_PORT,
37 SNAPWEB_DIR,
38)
39from music_assistant.providers.snapcast.player import SnapCastPlayer
40from music_assistant.providers.snapcast.socket_server import SnapcastSocketServer
41
42
43class SnapCastProvider(PlayerProvider):
44 """SnapCastProvider."""
45
46 _snapserver: Snapserver
47 _snapserver_runner: asyncio.Task[None] | None
48 _snapserver_started: asyncio.Event | None
49 _snapcast_server_host: str
50 _snapcast_server_control_port: int
51 _ids_map: bidict[str, str] # ma_id / snapclient_id
52 _use_builtin_server: bool
53 _stop_called: bool
54 _controlscript_available: bool
55 _socket_servers: dict[str, SnapcastSocketServer] # queue_id -> socket server
56
57 async def handle_async_init(self) -> None:
58 """Handle async initialization of the provider."""
59 # set snapcast logging
60 logging.getLogger("snapcast").setLevel(self.logger.level)
61 self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
62 self._stop_called = False
63 self._controlscript_available = False
64 self._socket_servers = {}
65 if self._use_builtin_server:
66 self._snapcast_server_host = "127.0.0.1"
67 self._snapcast_server_control_port = DEFAULT_SNAPSERVER_PORT
68 self._snapcast_server_buffer_size = self.config.get_value(CONF_SERVER_BUFFER_SIZE)
69 self._snapcast_server_chunk_ms = self.config.get_value(CONF_SERVER_CHUNK_MS)
70 self._snapcast_server_initial_volume = self.config.get_value(CONF_SERVER_INITIAL_VOLUME)
71 self._snapcast_server_send_to_muted = self.config.get_value(
72 CONF_SERVER_SEND_AUDIO_TO_MUTED
73 )
74 self._snapcast_server_transport_codec = self.config.get_value(
75 CONF_SERVER_TRANSPORT_CODEC
76 )
77 else:
78 self._snapcast_server_host = str(self.config.get_value(CONF_SERVER_HOST))
79 self._snapcast_server_control_port = int(
80 str(self.config.get_value(CONF_SERVER_CONTROL_PORT))
81 )
82 self._snapcast_stream_idle_threshold = self.config.get_value(CONF_STREAM_IDLE_THRESHOLD)
83 self._ids_map = bidict({})
84
85 if self._use_builtin_server:
86 await self._start_builtin_server()
87 else:
88 self._snapserver_runner = None
89 self._snapserver_started = None
90 try:
91 self._snapserver = await create_server(
92 self.mass.loop,
93 self._snapcast_server_host,
94 port=self._snapcast_server_control_port,
95 reconnect=True,
96 )
97 self._snapserver.set_on_update_callback(self._handle_update)
98 self.logger.info(
99 "Started connection to Snapserver %s",
100 f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
101 )
102 # register callback for when the connection gets lost to the snapserver
103 self._snapserver.set_on_disconnect_callback(self._handle_disconnect)
104
105 except OSError as err:
106 msg = "Unable to start the Snapserver connection ?"
107 raise SetupFailedError(msg) from err
108
109 async def loaded_in_mass(self) -> None:
110 """Call after the provider has been loaded."""
111 await super().loaded_in_mass()
112 # initial load of players
113 self._handle_update()
114
115 async def unload(self, is_removed: bool = False) -> None:
116 """Handle close/cleanup of the provider."""
117 self._stop_called = True
118
119 for snap_client in self._snapserver.clients:
120 player_id = self._get_ma_id(snap_client.identifier)
121 if not (player := self.mass.players.get(player_id, raise_unavailable=False)):
122 continue
123 if player.playback_state != PlaybackState.PLAYING:
124 continue
125 await player.stop()
126 self._snapserver.stop()
127 await self._stop_builtin_server()
128
129 async def _start_builtin_server(self) -> None:
130 """Start the built-in Snapserver."""
131 if self._use_builtin_server:
132 self._snapserver_started = asyncio.Event()
133 self._snapserver_runner = self.mass.create_task(self._builtin_server_runner())
134 await asyncio.wait_for(self._snapserver_started.wait(), 10)
135
136 async def _stop_builtin_server(self) -> None:
137 """Stop the built-in Snapserver."""
138 self.logger.info("Stopping, built-in Snapserver")
139 if self._snapserver_runner and not self._snapserver_runner.done():
140 self._snapserver_runner.cancel()
141 if self._snapserver_started is not None:
142 self._snapserver_started.clear()
143
144 def _setup_controlscript(self) -> bool:
145 """Copy control script to plugin directory (blocking I/O).
146
147 :return: True if successful, False otherwise.
148 """
149 plugin_dir = Path("/usr/share/snapserver/plug-ins")
150 control_dest = plugin_dir / "control.py"
151 logger = self.logger.getChild("snapserver")
152 try:
153 plugin_dir.mkdir(parents=True, exist_ok=True)
154 # Clean up existing file
155 control_dest.unlink(missing_ok=True)
156 if not CONTROL_SCRIPT.exists():
157 logger.warning("Control script does not exist: %s", CONTROL_SCRIPT)
158 return False
159 # Copy the control script to the plugin directory
160 shutil.copy2(CONTROL_SCRIPT, control_dest)
161 # Ensure it's executable
162 control_dest.chmod(0o755)
163 logger.debug("Copied controlscript to: %s", control_dest)
164 return True
165 except (OSError, PermissionError) as err:
166 logger.warning(
167 "Could not copy controlscript (metadata/control disabled): %s",
168 err,
169 )
170 return False
171
172 async def _builtin_server_runner(self) -> None:
173 """Start running the builtin snapserver."""
174 assert self._snapserver_started is not None # for type checking
175 if self._snapserver_started.is_set():
176 raise RuntimeError("Snapserver is already started!")
177 logger = self.logger.getChild("snapserver")
178 logger.info("Starting builtin Snapserver...")
179 # register the snapcast mdns services
180 for name, port in (
181 ("-http", 1780),
182 ("-jsonrpc", 1705),
183 ("-stream", 1704),
184 ("-tcp", 1705),
185 ("", 1704),
186 ):
187 zeroconf_type = f"_snapcast{name}._tcp.local."
188 try:
189 info = AsyncServiceInfo(
190 zeroconf_type,
191 name=f"Snapcast.{zeroconf_type}",
192 properties={"is_mass": "true"},
193 addresses=[await get_ip_pton(str(self.mass.streams.publish_ip))],
194 port=port,
195 server=f"{socket.gethostname()}.local",
196 )
197 attr_name = f"zc_service_set{name}"
198 if getattr(self, attr_name, None):
199 await self.mass.aiozc.async_update_service(info)
200 else:
201 await self.mass.aiozc.async_register_service(info, strict=False)
202 setattr(self, attr_name, True)
203 except NonUniqueNameException:
204 self.logger.debug(
205 "Could not register mdns record for %s as its already in use",
206 zeroconf_type,
207 )
208 except Exception as err:
209 self.logger.exception(
210 "Could not register mdns record for %s: %s", zeroconf_type, str(err)
211 )
212
213 args = [
214 "snapserver",
215 # config settings taken from
216 # https://raw.githubusercontent.com/badaix/snapcast/86cd4b2b63e750a72e0dfe6a46d47caf01426c8d/server/etc/snapserver.conf
217 f"--server.datadir={self.mass.storage_path}",
218 "--http.enabled=true",
219 "--http.port=1780",
220 f"--http.doc_root={SNAPWEB_DIR}",
221 "--tcp.enabled=true",
222 f"--tcp.port={self._snapcast_server_control_port}",
223 "--stream.sampleformat=48000:16:2",
224 f"--stream.buffer={self._snapcast_server_buffer_size}",
225 f"--stream.chunk_ms={self._snapcast_server_chunk_ms}",
226 f"--stream.codec={self._snapcast_server_transport_codec}",
227 f"--stream.send_to_muted={str(self._snapcast_server_send_to_muted).lower()}",
228 f"--streaming_client.initial_volume={self._snapcast_server_initial_volume}",
229 ]
230 async with AsyncProcess(args, stdout=True, name="snapserver") as snapserver_proc:
231 try:
232 # keep reading from stdout until exit
233 async for raw_data in snapserver_proc.iter_any():
234 text = raw_data.decode().strip()
235 for line in text.split("\n"):
236 logger.debug(line)
237 if "(Snapserver) Version 0." in line:
238 # delay init a small bit to prevent race conditions
239 # where we try to connect too soon
240 self.mass.loop.call_later(2, self._snapserver_started.set)
241 # Copy control script after snapserver starts
242 # (run in executor to avoid blocking)
243 loop = asyncio.get_running_loop()
244 self._controlscript_available = await loop.run_in_executor(
245 None, self._setup_controlscript
246 )
247 except asyncio.CancelledError:
248 # Currently, MA doesn't guarantee a defined shutdown order;
249 # Make sure to close socket servers before
250 # shutting down the snapcast server.
251 #
252 # The snapserver doesn't always cleanup the control script processes
253 # properly. We do it explicitly when closing a socket server.
254 # Should be fixed on the server side, though.
255 for socket_server in list(self._socket_servers.values()):
256 await socket_server.stop()
257 self._socket_servers.clear()
258 raise
259
260 def _get_ma_id(self, snap_client_id: str) -> str:
261 search_dict = self._ids_map.inverse
262 ma_id = search_dict.get(snap_client_id)
263 assert ma_id is not None # for type checking
264 return ma_id
265
266 def _get_snapclient_id(self, player_id: str) -> str:
267 search_dict = self._ids_map
268 snap_id = search_dict.get(player_id)
269 assert snap_id is not None # for type checking
270 return snap_id
271
272 def _generate_and_register_id(self, snap_client_id: str) -> str:
273 search_dict = self._ids_map.inverse
274 if snap_client_id not in search_dict:
275 new_id = "ma_" + str(re.sub(r"\W+", "", snap_client_id))
276 self._ids_map[new_id] = snap_client_id
277 return new_id
278 return self._get_ma_id(snap_client_id)
279
280 def _handle_player_init(self, snap_client: Snapclient) -> SnapCastPlayer:
281 """Process Snapcast add to Player controller."""
282 player_id = self._generate_and_register_id(snap_client.identifier)
283 player = self.mass.players.get(player_id, raise_unavailable=False)
284 if not player:
285 snap_client = cast(
286 "Snapclient", self._snapserver.client(self._get_snapclient_id(player_id))
287 )
288 player = SnapCastPlayer(
289 provider=self,
290 player_id=player_id,
291 snap_client=snap_client,
292 snap_client_id=self._get_snapclient_id(player_id),
293 )
294 player.setup()
295 else:
296 player = cast("SnapCastPlayer", player) # for type checking
297 asyncio.run_coroutine_threadsafe(
298 self.mass.players.register_or_update(player), loop=self.mass.loop
299 )
300 return player
301
302 def _handle_update(self) -> None:
303 """Process Snapcast init Player/Group and set callback ."""
304 for snap_client in self._snapserver.clients:
305 if not snap_client.identifier:
306 self.logger.warning(
307 "Detected Snapclient %s without identifier, skipping", snap_client.friendly_name
308 )
309 continue
310 if ma_player := self._handle_player_init(snap_client):
311 snap_client.set_callback(ma_player._handle_player_update)
312 for snap_client in self._snapserver.clients:
313 if player := self.mass.players.get(self._get_ma_id(snap_client.identifier)):
314 ma_player = cast("SnapCastPlayer", player)
315 snap_client.set_callback(ma_player._handle_player_update)
316 for snap_group in self._snapserver.groups:
317 snap_group.set_callback(self._handle_group_update)
318
319 def _handle_group_update(self, snap_group: Snapgroup) -> None:
320 """Process Snapcast group callback."""
321 for snap_client in self._snapserver.clients:
322 if ma_player := self.mass.players.get(self._get_ma_id(snap_client.identifier)):
323 assert isinstance(ma_player, SnapCastPlayer) # for type checking
324 ma_player._handle_player_update(snap_client)
325
326 def _handle_disconnect(self, exc: Exception) -> None:
327 """Handle disconnect callback from snapserver."""
328 if self._stop_called or self.mass.closing:
329 # we're instructed to stop/exit, so no need to restart the connection
330 return
331 self.logger.info(
332 "Connection to SnapServer lost, reason: %s. Reloading provider in 5 seconds.",
333 str(exc),
334 )
335 # schedule a reload of the provider
336 self.mass.call_later(5, self.mass.load_provider, self.instance_id, allow_retry=True)
337
338 async def remove_player(self, player_id: str) -> None:
339 """Remove the client from the snapserver when it is deleted."""
340 success, error_msg = await self._snapserver.delete_client(
341 self._get_snapclient_id(player_id)
342 )
343 if success:
344 self.logger.debug("Snapclient removed %s", player_id)
345 else:
346 self.logger.warning("Unable to remove snapclient %s: %s", player_id, error_msg)
347
348 async def get_or_create_socket_server(self, queue_id: str) -> str:
349 """Get or create a socket server for the given queue.
350
351 :param queue_id: The queue ID to create a socket server for.
352 :return: The path to the Unix socket.
353 """
354 if queue_id in self._socket_servers:
355 return self._socket_servers[queue_id].socket_path
356
357 socket_path = CONTROL_SOCKET_PATH_TEMPLATE.format(queue_id=queue_id)
358 socket_server = SnapcastSocketServer(
359 mass=self.mass,
360 queue_id=queue_id,
361 socket_path=socket_path,
362 streamserver_ip=str(self.mass.streams.publish_ip),
363 streamserver_port=cast("int", self.mass.streams.publish_port),
364 )
365 await socket_server.start()
366 self._socket_servers[queue_id] = socket_server
367 self.logger.debug("Created socket server for queue %s at %s", queue_id, socket_path)
368 return socket_path
369
370 async def stop_socket_server(self, queue_id: str) -> None:
371 """Stop and remove the socket server for the given queue.
372
373 :param queue_id: The queue ID to stop the socket server for.
374 """
375 if queue_id in self._socket_servers:
376 await self._socket_servers[queue_id].stop()
377 del self._socket_servers[queue_id]
378 self.logger.debug("Stopped socket server for queue %s", queue_id)
379