/
/
/
1"""SnapCastProvider."""
2
3import asyncio
4import logging
5import re
6import shutil
7import socket
8from pathlib import Path
9from typing import cast
10
11from bidict import bidict
12from music_assistant_models.enums import PlaybackState
13from music_assistant_models.errors import SetupFailedError
14from snapcast.control import create_server
15from snapcast.control.client import Snapclient
16from snapcast.control.group import Snapgroup
17from snapcast.control.server import Snapserver
18from zeroconf import NonUniqueNameException
19from zeroconf.asyncio import AsyncServiceInfo
20
21from music_assistant.helpers.process import AsyncProcess
22from music_assistant.helpers.util import get_ip_pton
23from music_assistant.models.player_provider import PlayerProvider
24from music_assistant.providers.snapcast.constants import (
25 CONF_SERVER_BUFFER_SIZE,
26 CONF_SERVER_CHUNK_MS,
27 CONF_SERVER_CONTROL_PORT,
28 CONF_SERVER_HOST,
29 CONF_SERVER_INITIAL_VOLUME,
30 CONF_SERVER_SEND_AUDIO_TO_MUTED,
31 CONF_SERVER_TRANSPORT_CODEC,
32 CONF_STREAM_IDLE_THRESHOLD,
33 CONF_USE_EXTERNAL_SERVER,
34 CONTROL_SCRIPT,
35 CONTROL_SOCKET_PATH_TEMPLATE,
36 DEFAULT_SNAPSERVER_PORT,
37 SNAPWEB_DIR,
38)
39from music_assistant.providers.snapcast.player import SnapCastPlayer
40from music_assistant.providers.snapcast.socket_server import SnapcastSocketServer
41
42
43class SnapCastProvider(PlayerProvider):
44 """SnapCastProvider."""
45
46 _snapserver: Snapserver
47 _snapserver_runner: asyncio.Task[None] | None
48 _snapserver_started: asyncio.Event | None
49 _snapcast_server_host: str
50 _snapcast_server_control_port: int
51 _ids_map: bidict[str, str] # ma_id / snapclient_id
52 _use_builtin_server: bool
53 _stop_called: bool
54 _controlscript_available: bool
55 _socket_servers: dict[str, SnapcastSocketServer] # queue_id -> socket server
56
57 async def handle_async_init(self) -> None:
58 """Handle async initialization of the provider."""
59 # set snapcast logging
60 logging.getLogger("snapcast").setLevel(self.logger.level)
61 self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
62 self._stop_called = False
63 self._controlscript_available = False
64 self._socket_servers = {}
65 if self._use_builtin_server:
66 self._snapcast_server_host = "127.0.0.1"
67 self._snapcast_server_control_port = DEFAULT_SNAPSERVER_PORT
68 self._snapcast_server_buffer_size = self.config.get_value(CONF_SERVER_BUFFER_SIZE)
69 self._snapcast_server_chunk_ms = self.config.get_value(CONF_SERVER_CHUNK_MS)
70 self._snapcast_server_initial_volume = self.config.get_value(CONF_SERVER_INITIAL_VOLUME)
71 self._snapcast_server_send_to_muted = self.config.get_value(
72 CONF_SERVER_SEND_AUDIO_TO_MUTED
73 )
74 self._snapcast_server_transport_codec = self.config.get_value(
75 CONF_SERVER_TRANSPORT_CODEC
76 )
77 else:
78 self._snapcast_server_host = str(self.config.get_value(CONF_SERVER_HOST))
79 self._snapcast_server_control_port = int(
80 str(self.config.get_value(CONF_SERVER_CONTROL_PORT))
81 )
82 self._snapcast_stream_idle_threshold = self.config.get_value(CONF_STREAM_IDLE_THRESHOLD)
83 self._ids_map = bidict({})
84
85 if self._use_builtin_server:
86 await self._start_builtin_server()
87 else:
88 self._snapserver_runner = None
89 self._snapserver_started = None
90 try:
91 self._snapserver = await create_server(
92 self.mass.loop,
93 self._snapcast_server_host,
94 port=self._snapcast_server_control_port,
95 reconnect=True,
96 )
97 self._snapserver.set_on_update_callback(self._handle_update)
98 self.logger.info(
99 "Started connection to Snapserver %s",
100 f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
101 )
102 # register callback for when the connection gets lost to the snapserver
103 self._snapserver.set_on_disconnect_callback(self._handle_disconnect)
104
105 except OSError as err:
106 msg = "Unable to start the Snapserver connection ?"
107 raise SetupFailedError(msg) from err
108
109 async def loaded_in_mass(self) -> None:
110 """Call after the provider has been loaded."""
111 await super().loaded_in_mass()
112 # initial load of players
113 self._handle_update()
114
115 async def unload(self, is_removed: bool = False) -> None:
116 """Handle close/cleanup of the provider."""
117 self._stop_called = True
118 # Stop all socket servers
119 for socket_server in list(self._socket_servers.values()):
120 await socket_server.stop()
121 self._socket_servers.clear()
122 for snap_client in self._snapserver.clients:
123 player_id = self._get_ma_id(snap_client.identifier)
124 if not (player := self.mass.players.get(player_id, raise_unavailable=False)):
125 continue
126 if player.playback_state != PlaybackState.PLAYING:
127 continue
128 await player.stop()
129 self._snapserver.stop()
130 await self._stop_builtin_server()
131
132 async def _start_builtin_server(self) -> None:
133 """Start the built-in Snapserver."""
134 if self._use_builtin_server:
135 self._snapserver_started = asyncio.Event()
136 self._snapserver_runner = self.mass.create_task(self._builtin_server_runner())
137 await asyncio.wait_for(self._snapserver_started.wait(), 10)
138
139 async def _stop_builtin_server(self) -> None:
140 """Stop the built-in Snapserver."""
141 self.logger.info("Stopping, built-in Snapserver")
142 if self._snapserver_runner and not self._snapserver_runner.done():
143 self._snapserver_runner.cancel()
144 if self._snapserver_started is not None:
145 self._snapserver_started.clear()
146
147 def _setup_controlscript(self) -> bool:
148 """Copy control script to plugin directory (blocking I/O).
149
150 :return: True if successful, False otherwise.
151 """
152 plugin_dir = Path("/usr/share/snapserver/plug-ins")
153 control_dest = plugin_dir / "control.py"
154 logger = self.logger.getChild("snapserver")
155 try:
156 plugin_dir.mkdir(parents=True, exist_ok=True)
157 # Clean up existing file
158 control_dest.unlink(missing_ok=True)
159 if not CONTROL_SCRIPT.exists():
160 logger.warning("Control script does not exist: %s", CONTROL_SCRIPT)
161 return False
162 # Copy the control script to the plugin directory
163 shutil.copy2(CONTROL_SCRIPT, control_dest)
164 # Ensure it's executable
165 control_dest.chmod(0o755)
166 logger.debug("Copied controlscript to: %s", control_dest)
167 return True
168 except (OSError, PermissionError) as err:
169 logger.warning(
170 "Could not copy controlscript (metadata/control disabled): %s",
171 err,
172 )
173 return False
174
175 async def _builtin_server_runner(self) -> None:
176 """Start running the builtin snapserver."""
177 assert self._snapserver_started is not None # for type checking
178 if self._snapserver_started.is_set():
179 raise RuntimeError("Snapserver is already started!")
180 logger = self.logger.getChild("snapserver")
181 logger.info("Starting builtin Snapserver...")
182 # register the snapcast mdns services
183 for name, port in (
184 ("-http", 1780),
185 ("-jsonrpc", 1705),
186 ("-stream", 1704),
187 ("-tcp", 1705),
188 ("", 1704),
189 ):
190 zeroconf_type = f"_snapcast{name}._tcp.local."
191 try:
192 info = AsyncServiceInfo(
193 zeroconf_type,
194 name=f"Snapcast.{zeroconf_type}",
195 properties={"is_mass": "true"},
196 addresses=[await get_ip_pton(str(self.mass.streams.publish_ip))],
197 port=port,
198 server=f"{socket.gethostname()}.local",
199 )
200 attr_name = f"zc_service_set{name}"
201 if getattr(self, attr_name, None):
202 await self.mass.aiozc.async_update_service(info)
203 else:
204 await self.mass.aiozc.async_register_service(info, strict=False)
205 setattr(self, attr_name, True)
206 except NonUniqueNameException:
207 self.logger.debug(
208 "Could not register mdns record for %s as its already in use",
209 zeroconf_type,
210 )
211 except Exception as err:
212 self.logger.exception(
213 "Could not register mdns record for %s: %s", zeroconf_type, str(err)
214 )
215
216 args = [
217 "snapserver",
218 # config settings taken from
219 # https://raw.githubusercontent.com/badaix/snapcast/86cd4b2b63e750a72e0dfe6a46d47caf01426c8d/server/etc/snapserver.conf
220 f"--server.datadir={self.mass.storage_path}",
221 "--http.enabled=true",
222 "--http.port=1780",
223 f"--http.doc_root={SNAPWEB_DIR}",
224 "--tcp.enabled=true",
225 f"--tcp.port={self._snapcast_server_control_port}",
226 "--stream.sampleformat=48000:16:2",
227 f"--stream.buffer={self._snapcast_server_buffer_size}",
228 f"--stream.chunk_ms={self._snapcast_server_chunk_ms}",
229 f"--stream.codec={self._snapcast_server_transport_codec}",
230 f"--stream.send_to_muted={str(self._snapcast_server_send_to_muted).lower()}",
231 f"--streaming_client.initial_volume={self._snapcast_server_initial_volume}",
232 ]
233 async with AsyncProcess(args, stdout=True, name="snapserver") as snapserver_proc:
234 # keep reading from stdout until exit
235 async for raw_data in snapserver_proc.iter_any():
236 text = raw_data.decode().strip()
237 for line in text.split("\n"):
238 logger.debug(line)
239 if "(Snapserver) Version 0." in line:
240 # delay init a small bit to prevent race conditions
241 # where we try to connect too soon
242 self.mass.loop.call_later(2, self._snapserver_started.set)
243 # Copy control script after snapserver starts
244 # (run in executor to avoid blocking)
245 loop = asyncio.get_running_loop()
246 self._controlscript_available = await loop.run_in_executor(
247 None, self._setup_controlscript
248 )
249
250 def _get_ma_id(self, snap_client_id: str) -> str:
251 search_dict = self._ids_map.inverse
252 ma_id = search_dict.get(snap_client_id)
253 assert ma_id is not None # for type checking
254 return ma_id
255
256 def _get_snapclient_id(self, player_id: str) -> str:
257 search_dict = self._ids_map
258 snap_id = search_dict.get(player_id)
259 assert snap_id is not None # for type checking
260 return snap_id
261
262 def _generate_and_register_id(self, snap_client_id: str) -> str:
263 search_dict = self._ids_map.inverse
264 if snap_client_id not in search_dict:
265 new_id = "ma_" + str(re.sub(r"\W+", "", snap_client_id))
266 self._ids_map[new_id] = snap_client_id
267 return new_id
268 return self._get_ma_id(snap_client_id)
269
270 def _handle_player_init(self, snap_client: Snapclient) -> SnapCastPlayer:
271 """Process Snapcast add to Player controller."""
272 player_id = self._generate_and_register_id(snap_client.identifier)
273 player = self.mass.players.get(player_id, raise_unavailable=False)
274 if not player:
275 snap_client = cast(
276 "Snapclient", self._snapserver.client(self._get_snapclient_id(player_id))
277 )
278 player = SnapCastPlayer(
279 provider=self,
280 player_id=player_id,
281 snap_client=snap_client,
282 snap_client_id=self._get_snapclient_id(player_id),
283 )
284 player.setup()
285 else:
286 player = cast("SnapCastPlayer", player) # for type checking
287 asyncio.run_coroutine_threadsafe(
288 self.mass.players.register_or_update(player), loop=self.mass.loop
289 )
290 return player
291
292 def _handle_update(self) -> None:
293 """Process Snapcast init Player/Group and set callback ."""
294 for snap_client in self._snapserver.clients:
295 if not snap_client.identifier:
296 self.logger.warning(
297 "Detected Snapclient %s without identifier, skipping", snap_client.friendly_name
298 )
299 continue
300 if ma_player := self._handle_player_init(snap_client):
301 snap_client.set_callback(ma_player._handle_player_update)
302 for snap_client in self._snapserver.clients:
303 if player := self.mass.players.get(self._get_ma_id(snap_client.identifier)):
304 ma_player = cast("SnapCastPlayer", player)
305 snap_client.set_callback(ma_player._handle_player_update)
306 for snap_group in self._snapserver.groups:
307 snap_group.set_callback(self._handle_group_update)
308
309 def _handle_group_update(self, snap_group: Snapgroup) -> None:
310 """Process Snapcast group callback."""
311 for snap_client in self._snapserver.clients:
312 if ma_player := self.mass.players.get(self._get_ma_id(snap_client.identifier)):
313 assert isinstance(ma_player, SnapCastPlayer) # for type checking
314 ma_player._handle_player_update(snap_client)
315
316 def _handle_disconnect(self, exc: Exception) -> None:
317 """Handle disconnect callback from snapserver."""
318 if self._stop_called or self.mass.closing:
319 # we're instructed to stop/exit, so no need to restart the connection
320 return
321 self.logger.info(
322 "Connection to SnapServer lost, reason: %s. Reloading provider in 5 seconds.",
323 str(exc),
324 )
325 # schedule a reload of the provider
326 self.mass.call_later(5, self.mass.load_provider, self.instance_id, allow_retry=True)
327
328 async def remove_player(self, player_id: str) -> None:
329 """Remove the client from the snapserver when it is deleted."""
330 success, error_msg = await self._snapserver.delete_client(
331 self._get_snapclient_id(player_id)
332 )
333 if success:
334 self.logger.debug("Snapclient removed %s", player_id)
335 else:
336 self.logger.warning("Unable to remove snapclient %s: %s", player_id, error_msg)
337
338 async def get_or_create_socket_server(self, queue_id: str) -> str:
339 """Get or create a socket server for the given queue.
340
341 :param queue_id: The queue ID to create a socket server for.
342 :return: The path to the Unix socket.
343 """
344 if queue_id in self._socket_servers:
345 return self._socket_servers[queue_id].socket_path
346
347 socket_path = CONTROL_SOCKET_PATH_TEMPLATE.format(queue_id=queue_id)
348 socket_server = SnapcastSocketServer(
349 mass=self.mass,
350 queue_id=queue_id,
351 socket_path=socket_path,
352 streamserver_ip=str(self.mass.streams.publish_ip),
353 streamserver_port=cast("int", self.mass.streams.publish_port),
354 )
355 await socket_server.start()
356 self._socket_servers[queue_id] = socket_server
357 self.logger.debug("Created socket server for queue %s at %s", queue_id, socket_path)
358 return socket_path
359
360 async def stop_socket_server(self, queue_id: str) -> None:
361 """Stop and remove the socket server for the given queue.
362
363 :param queue_id: The queue ID to stop the socket server for.
364 """
365 if queue_id in self._socket_servers:
366 await self._socket_servers[queue_id].stop()
367 del self._socket_servers[queue_id]
368 self.logger.debug("Stopped socket server for queue %s", queue_id)
369