/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
13from uuid import uuid4
14
15import aiofiles
16from aiofiles.os import wrap
17from music_assistant_models.api import ServerInfoMessage
18from music_assistant_models.auth import UserRole
19from music_assistant_models.enums import EventType, ProviderType
20from music_assistant_models.errors import MusicAssistantError, SetupFailedError
21from music_assistant_models.event import MassEvent
22from music_assistant_models.helpers import set_global_cache_values
23from music_assistant_models.provider import ProviderManifest
24from zeroconf import (
25 InterfaceChoice,
26 IPVersion,
27 NonUniqueNameException,
28 ServiceStateChange,
29 Zeroconf,
30)
31from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
32
33from music_assistant.constants import (
34 API_SCHEMA_VERSION,
35 CONF_PROVIDERS,
36 CONF_SERVER_ID,
37 CONF_ZEROCONF_INTERFACES,
38 CONFIGURABLE_CORE_CONTROLLERS,
39 MASS_LOGGER_NAME,
40 MIN_SCHEMA_VERSION,
41 VERBOSE_LOG_LEVEL,
42)
43from music_assistant.controllers.cache import CacheController
44from music_assistant.controllers.config import ConfigController
45from music_assistant.controllers.metadata import MetaDataController
46from music_assistant.controllers.music import MusicController
47from music_assistant.controllers.player_queues import PlayerQueuesController
48from music_assistant.controllers.players.player_controller import PlayerController
49from music_assistant.controllers.streams import StreamsController
50from music_assistant.controllers.webserver import WebserverController
51from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
52from music_assistant.helpers.aiohttp_client import create_clientsession
53from music_assistant.helpers.api import APICommandHandler, api_command
54from music_assistant.helpers.images import get_icon_string
55from music_assistant.helpers.util import (
56 TaskManager,
57 get_ip_pton,
58 get_package_version,
59 is_hass_supervisor,
60 load_provider_module,
61)
62from music_assistant.models import ProviderInstanceType
63from music_assistant.models.music_provider import MusicProvider
64from music_assistant.models.player_provider import PlayerProvider
65
66if TYPE_CHECKING:
67 from types import TracebackType
68
69 from aiohttp import ClientSession
70 from music_assistant_models.config_entries import ProviderConfig
71
72 from music_assistant.models.core_controller import CoreController
73
74isdir = wrap(os.path.isdir)
75isfile = wrap(os.path.isfile)
76mkdirs = wrap(os.makedirs)
77rmfile = wrap(os.remove)
78listdir = wrap(os.listdir)
79rename = wrap(os.rename)
80
81EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
82EventSubscriptionType = tuple[
83 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
84]
85
86LOGGER = logging.getLogger(MASS_LOGGER_NAME)
87
88BASE_DIR = os.path.dirname(os.path.abspath(__file__))
89PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
90
91_R = TypeVar("_R")
92_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
93
94
95def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
96 """Type guard that returns true if a provider is a music provider."""
97 return provider.type == ProviderType.MUSIC
98
99
100def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
101 """Type guard that returns true if a provider is a player provider."""
102 return provider.type == ProviderType.PLAYER
103
104
105class MusicAssistant:
106 """Main MusicAssistant (Server) object."""
107
108 loop: asyncio.AbstractEventLoop
109 aiozc: AsyncZeroconf
110 config: ConfigController
111 webserver: WebserverController
112 cache: CacheController
113 metadata: MetaDataController
114 music: MusicController
115 players: PlayerController
116 player_queues: PlayerQueuesController
117 streams: StreamsController
118 _aiobrowser: AsyncServiceBrowser
119
120 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
121 """Initialize the MusicAssistant Server."""
122 self.storage_path = storage_path
123 self.cache_path = cache_path
124 self.safe_mode = safe_mode
125 # we dynamically register command handlers which can be consumed by the apis
126 self.command_handlers: dict[str, APICommandHandler] = {}
127 self._subscribers: set[EventSubscriptionType] = set()
128 self._provider_manifests: dict[str, ProviderManifest] = {}
129 self._providers: dict[str, ProviderInstanceType] = {}
130 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
131 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
132 self.closing = False
133 self.running_as_hass_addon: bool = False
134 self.version: str = "0.0.0"
135 self.dev_mode = (
136 os.environ.get("PYTHONDEVMODE") == "1"
137 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
138 )
139 self._http_session: ClientSession | None = None
140 self._http_session_no_ssl: ClientSession | None = None
141 self._mdns_locks: dict[str, asyncio.Lock] = {}
142
143 async def start(self) -> None:
144 """Start running the Music Assistant server."""
145 self.loop = asyncio.get_running_loop()
146 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
147 self.running_as_hass_addon = await is_hass_supervisor()
148 self.version = await get_package_version("music_assistant") or "0.0.0"
149 # setup config controller first and fetch important config values
150 self.config = ConfigController(self)
151 await self.config.setup()
152 # create shared zeroconf instance
153 # TODO: enumerate interfaces and enable IPv6 support
154 zeroconf_interfaces = self.config.get_raw_core_config_value(
155 "streams", CONF_ZEROCONF_INTERFACES, "default"
156 )
157 self.aiozc = AsyncZeroconf(
158 ip_version=IPVersion.V4Only,
159 interfaces=InterfaceChoice.All
160 if zeroconf_interfaces == "all"
161 else InterfaceChoice.Default,
162 )
163 # load all available providers from manifest files
164 await self.__load_provider_manifests()
165 # setup/migrate storage
166 await self._setup_storage()
167 LOGGER.info(
168 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
169 self.server_id,
170 self.version,
171 self.running_as_hass_addon,
172 self.safe_mode,
173 )
174 # setup other core controllers
175 self.cache = CacheController(self)
176 self.webserver = WebserverController(self)
177 self.metadata = MetaDataController(self)
178 self.music = MusicController(self)
179 self.players = PlayerController(self)
180 self.player_queues = PlayerQueuesController(self)
181 self.streams = StreamsController(self)
182 # add manifests for core controllers
183 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
184 controller: CoreController = getattr(self, controller_name)
185 self._provider_manifests[controller.domain] = controller.manifest
186 await self.cache.setup(await self.config.get_core_config("cache"))
187 # load streams controller early so we can abort if we can't load it
188 await self.streams.setup(await self.config.get_core_config("streams"))
189 await self.music.setup(await self.config.get_core_config("music"))
190 await self.metadata.setup(await self.config.get_core_config("metadata"))
191 await self.players.setup(await self.config.get_core_config("players"))
192 await self.player_queues.setup(await self.config.get_core_config("player_queues"))
193 # load webserver/api last so the api/frontend is
194 # not yet available while we're starting (or performing migrations)
195 self._register_api_commands()
196 await self.webserver.setup(await self.config.get_core_config("webserver"))
197 # setup discovery
198 await self._setup_discovery()
199 # load providers
200 if not self.safe_mode:
201 await self._load_providers()
202
203 async def stop(self) -> None:
204 """Stop running the music assistant server."""
205 LOGGER.info("Stop called, cleaning up...")
206 self.signal_event(EventType.SHUTDOWN)
207 self.closing = True
208 # cancel all running tasks
209 for task in self._tracked_tasks.values():
210 task.cancel()
211 # cleanup all providers
212 await asyncio.gather(
213 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
214 return_exceptions=True,
215 )
216 # stop core controllers
217 await self.streams.close()
218 await self.webserver.close()
219 await self.metadata.close()
220 await self.music.close()
221 await self.player_queues.close()
222 await self.players.close()
223 # cleanup cache and config
224 await self.config.close()
225 await self.cache.close()
226 # close/cleanup shared http session
227 if self._http_session:
228 self._http_session.detach()
229 if self._http_session.connector:
230 await self._http_session.connector.close()
231 if self._http_session_no_ssl:
232 self._http_session_no_ssl.detach()
233 if self._http_session_no_ssl.connector:
234 await self._http_session_no_ssl.connector.close()
235
236 @property
237 def server_id(self) -> str:
238 """Return unique ID of this server."""
239 if not self.config.initialized:
240 return ""
241 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
242
243 @property
244 def http_session(self) -> ClientSession:
245 """
246 Return the shared HTTP Client session (with SSL).
247
248 NOTE: May only be called from the event loop.
249 """
250 if self._http_session is None:
251 self._http_session = create_clientsession(self, verify_ssl=True)
252 return self._http_session
253
254 @property
255 def http_session_no_ssl(self) -> ClientSession:
256 """
257 Return the shared HTTP Client session (without SSL).
258
259 NOTE: May only be called from the event loop thread.
260 """
261 if self._http_session_no_ssl is None:
262 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
263 return self._http_session_no_ssl
264
265 @api_command("info")
266 def get_server_info(self) -> ServerInfoMessage:
267 """Return Info of this server."""
268 return ServerInfoMessage(
269 server_id=self.server_id,
270 server_version=self.version,
271 schema_version=API_SCHEMA_VERSION,
272 min_supported_schema_version=MIN_SCHEMA_VERSION,
273 base_url=self.webserver.base_url,
274 homeassistant_addon=self.running_as_hass_addon,
275 onboard_done=self.config.onboard_done,
276 )
277
278 @api_command("providers/manifests")
279 def get_provider_manifests(self) -> list[ProviderManifest]:
280 """Return all Provider manifests."""
281 return list(self._provider_manifests.values())
282
283 @api_command("providers/manifests/get")
284 def get_provider_manifest(self, domain: str) -> ProviderManifest:
285 """Return Provider manifests of single provider(domain)."""
286 return self._provider_manifests[domain]
287
288 @api_command("providers")
289 def get_providers(
290 self, provider_type: ProviderType | None = None
291 ) -> list[ProviderInstanceType]:
292 """
293 Return all loaded/running Providers (instances).
294
295 Optionally filtered by ProviderType.
296 Note that this applies user filters for music providers (for non admin users).
297 """
298 user = get_current_user()
299 user_provider_filter = (
300 user.provider_filter if user and user.role != UserRole.ADMIN else None
301 )
302 return [
303 x
304 for x in self._providers.values()
305 if (provider_type is None or provider_type == x.type)
306 # apply user provider filter
307 and (
308 not user_provider_filter
309 or x.instance_id in user_provider_filter
310 or x.type != ProviderType.MUSIC
311 )
312 ]
313
314 @api_command("logging/get", required_role=UserRole.ADMIN)
315 async def get_application_log(self) -> str:
316 """Return the application log from file."""
317 logfile = os.path.join(self.storage_path, "musicassistant.log")
318 async with aiofiles.open(logfile) as _file:
319 return str(await _file.read())
320
321 @property
322 def providers(self) -> list[ProviderInstanceType]:
323 """
324 Return all loaded/running Providers (instances).
325
326 Note that this skips user filters so may only be called from internal code.
327 """
328 return list(self._providers.values())
329
330 @overload
331 def get_provider(
332 self,
333 provider_instance_or_domain: str,
334 return_unavailable: bool = False,
335 provider_type: None = None,
336 ) -> ProviderInstanceType | None: ...
337
338 @overload
339 def get_provider(
340 self,
341 provider_instance_or_domain: str,
342 return_unavailable: bool = False,
343 *,
344 provider_type: type[_ProviderT],
345 ) -> _ProviderT | None: ...
346
347 def get_provider(
348 self,
349 provider_instance_or_domain: str,
350 return_unavailable: bool = False,
351 provider_type: type[_ProviderT] | None = None,
352 ) -> ProviderInstanceType | _ProviderT | None:
353 """Return provider by instance id or domain.
354
355 :param provider_instance_or_domain: Instance ID or domain of the provider.
356 :param return_unavailable: Also return unavailable providers.
357 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
358 """
359 # lookup by instance_id first
360 if prov := self._providers.get(provider_instance_or_domain):
361 if return_unavailable or prov.available:
362 return prov
363 if not getattr(prov, "is_streaming_provider", None):
364 # no need to lookup other instances because this provider has unique data
365 return None
366 provider_instance_or_domain = prov.domain
367 # fallback to match on domain
368 for prov in self._providers.values():
369 if prov.domain != provider_instance_or_domain:
370 continue
371 if return_unavailable or prov.available:
372 return prov
373 return None
374
375 def get_provider_instances(
376 self,
377 domain: str,
378 return_unavailable: bool = False,
379 provider_type: ProviderType | None = None,
380 ) -> list[ProviderInstanceType]:
381 """
382 Return all provider instances for a given domain.
383
384 Note that this skips user filters so may only be called from internal code.
385 """
386 return [
387 prov
388 for prov in self._providers.values()
389 if (provider_type is None or provider_type == prov.type)
390 and prov.domain == domain
391 and (return_unavailable or prov.available)
392 ]
393
394 def signal_event(
395 self,
396 event: EventType,
397 object_id: str | None = None,
398 data: Any = None,
399 ) -> None:
400 """Signal event to subscribers."""
401 if self.closing:
402 return
403
404 self.verify_event_loop_thread("signal_event")
405
406 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
407 # do not log queue time updated events because that is too chatty
408 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
409
410 event_obj = MassEvent(event=event, object_id=object_id, data=data)
411 for cb_func, event_filter, id_filter in self._subscribers:
412 if not (event_filter is None or event in event_filter):
413 continue
414 if not (id_filter is None or object_id in id_filter):
415 continue
416 if inspect.iscoroutinefunction(cb_func):
417 if TYPE_CHECKING:
418 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
419 self.create_task(cb_func, event_obj)
420 else:
421 if TYPE_CHECKING:
422 cb_func = cast("Callable[[MassEvent], None]", cb_func)
423 self.loop.call_soon_threadsafe(cb_func, event_obj)
424
425 def subscribe(
426 self,
427 cb_func: EventCallBackType,
428 event_filter: EventType | tuple[EventType, ...] | None = None,
429 id_filter: str | tuple[str, ...] | None = None,
430 ) -> Callable[[], None]:
431 """Add callback to event listeners.
432
433 Returns function to remove the listener.
434 :param cb_func: callback function or coroutine
435 :param event_filter: Optionally only listen for these events
436 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
437 """
438 if isinstance(event_filter, EventType):
439 event_filter = (event_filter,)
440 if isinstance(id_filter, str):
441 id_filter = (id_filter,)
442 listener = (cb_func, event_filter, id_filter)
443 self._subscribers.add(listener)
444
445 def remove_listener() -> None:
446 self._subscribers.remove(listener)
447
448 return remove_listener
449
450 def create_task(
451 self,
452 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
453 *args: Any,
454 task_id: str | None = None,
455 abort_existing: bool = False,
456 **kwargs: Any,
457 ) -> asyncio.Task[_R]:
458 """Create Task on (main) event loop from Coroutine(function).
459
460 Tasks created by this helper will be properly cancelled on stop.
461 """
462 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
463 # prevent duplicate tasks if task_id is given and already present
464 if abort_existing:
465 existing.cancel()
466 else:
467 return existing
468 self.verify_event_loop_thread("create_task")
469
470 if inspect.iscoroutinefunction(target):
471 # coroutine function
472 task = self.loop.create_task(target(*args, **kwargs))
473 elif inspect.iscoroutine(target):
474 # coroutine
475 task = self.loop.create_task(target)
476 elif callable(target):
477 raise RuntimeError("Function is not a coroutine or coroutine function")
478 else:
479 raise RuntimeError("Target is missing")
480
481 if task_id is None:
482 task_id = uuid4().hex
483
484 def task_done_callback(_task: asyncio.Task[Any]) -> None:
485 self._tracked_tasks.pop(task_id, None)
486 # log unhandled exceptions
487 if (
488 LOGGER.isEnabledFor(logging.DEBUG)
489 and not _task.cancelled()
490 and (err := _task.exception())
491 ):
492 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
493 LOGGER.warning(
494 "Exception in task %s - target: %s: %s",
495 task_name,
496 str(target),
497 str(err),
498 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
499 )
500
501 self._tracked_tasks[task_id] = task
502 task.add_done_callback(task_done_callback)
503 return task
504
505 def call_later(
506 self,
507 delay: float,
508 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
509 *args: Any,
510 task_id: str | None = None,
511 **kwargs: Any,
512 ) -> asyncio.TimerHandle:
513 """
514 Run callable/awaitable after given delay.
515
516 Use task_id for debouncing.
517 """
518 self.verify_event_loop_thread("call_later")
519
520 if not task_id:
521 task_id = uuid4().hex
522
523 if existing := self._tracked_timers.get(task_id):
524 existing.cancel()
525
526 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
527 self._tracked_timers.pop(task_id)
528 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
529
530 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
531 # coroutine function
532 if TYPE_CHECKING:
533 target = cast("Coroutine[Any, Any, _R]", target)
534 handle = self.loop.call_later(delay, _create_task, target)
535 else:
536 # regular callable
537 if TYPE_CHECKING:
538 target = cast("Callable[..., _R]", target)
539 handle = self.loop.call_later(delay, target, *args)
540 self._tracked_timers[task_id] = handle
541 return handle
542
543 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
544 """Get existing scheduled task."""
545 if existing := self._tracked_tasks.get(task_id):
546 # prevent duplicate tasks if task_id is given and already present
547 return existing
548 return None
549
550 def cancel_task(self, task_id: str) -> None:
551 """Cancel existing scheduled task."""
552 if existing := self._tracked_tasks.pop(task_id, None):
553 existing.cancel()
554
555 def cancel_timer(self, task_id: str) -> None:
556 """Cancel existing scheduled timer."""
557 if existing := self._tracked_timers.pop(task_id, None):
558 existing.cancel()
559
560 def register_api_command(
561 self,
562 command: str,
563 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
564 authenticated: bool = True,
565 required_role: str | None = None,
566 alias: bool = False,
567 ) -> Callable[[], None]:
568 """Dynamically register a command on the API.
569
570 :param command: The command name/path.
571 :param handler: The function to handle the command.
572 :param authenticated: Whether authentication is required (default: True).
573 :param required_role: Required user role ("admin" or "user")
574 None means any authenticated user.
575 :param alias: Whether this is an alias for backward compatibility (default: False).
576 Aliases are not shown in API documentation but remain functional.
577
578 Returns handle to unregister.
579 """
580 if command in self.command_handlers:
581 msg = f"Command {command} is already registered"
582 raise RuntimeError(msg)
583 self.command_handlers[command] = APICommandHandler.parse(
584 command, handler, authenticated, required_role, alias
585 )
586
587 def unregister() -> None:
588 self.command_handlers.pop(command, None)
589
590 return unregister
591
592 async def load_provider_config(
593 self,
594 prov_conf: ProviderConfig,
595 ) -> None:
596 """Try to load a provider and catch errors."""
597 # cancel existing (re)load timer if needed
598 task_id = f"load_provider_{prov_conf.instance_id}"
599 if existing := self._tracked_timers.pop(task_id, None):
600 existing.cancel()
601
602 await self._load_provider(prov_conf)
603
604 # (re)load any dependants
605 prov_configs = await self.config.get_provider_configs(include_values=True)
606 for dep_prov_conf in prov_configs:
607 if not dep_prov_conf.enabled:
608 continue
609 manifest = self.get_provider_manifest(dep_prov_conf.domain)
610 if not manifest.depends_on:
611 continue
612 if manifest.depends_on == prov_conf.domain:
613 await self._load_provider(dep_prov_conf)
614
615 async def load_provider(
616 self,
617 instance_id: str,
618 allow_retry: bool = False,
619 ) -> None:
620 """Try to load a provider and catch errors."""
621 try:
622 prov_conf = await self.config.get_provider_config(instance_id)
623 except KeyError:
624 # Was deleted before we could run
625 return
626
627 if not prov_conf.enabled:
628 # Was disabled before we could run
629 return
630
631 # cancel existing (re)load timer if needed
632 task_id = f"load_provider_{instance_id}"
633 if existing := self._tracked_timers.pop(task_id, None):
634 existing.cancel()
635
636 try:
637 await self.load_provider_config(prov_conf)
638 except Exception as exc:
639 # if loading failed, we store the error in the config object
640 # so we can show something useful to the user
641 prov_conf.last_error = str(exc)
642 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
643
644 # auto schedule a retry if the (re)load failed with a handled exception
645 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
646 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
647 if will_retry:
648 self.call_later(
649 120,
650 self.load_provider,
651 instance_id,
652 allow_retry,
653 task_id=task_id,
654 )
655 LOGGER.warning(
656 "Error loading provider(instance) %s: %s%s",
657 prov_conf.name or prov_conf.instance_id,
658 str(exc) or exc.__class__.__name__,
659 " (will be retried later)" if will_retry else "",
660 # log full stack trace if verbose logging is enabled
661 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
662 )
663 return
664
665 # (re)load any dependents if needed
666 for dep_prov in self.providers:
667 if dep_prov.available:
668 continue
669 if dep_prov.manifest.depends_on == prov_conf.domain:
670 await self.unload_provider(dep_prov.instance_id)
671
672 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
673 """Unload a provider."""
674 self.music.unschedule_provider_sync(instance_id)
675 if provider := self._providers.get(instance_id):
676 # remove mdns discovery if needed
677 if provider.manifest.mdns_discovery:
678 for mdns_type in provider.manifest.mdns_discovery:
679 self._aiobrowser.types.discard(mdns_type)
680 if isinstance(provider, PlayerProvider):
681 await self.players.on_provider_unload(provider)
682 if isinstance(provider, MusicProvider):
683 await self.music.on_provider_unload(provider)
684 # check if there are no other providers dependent of this provider
685 for dep_prov in self.providers:
686 if dep_prov.manifest.depends_on == provider.domain:
687 await self.unload_provider(dep_prov.instance_id)
688 if is_player_provider(provider):
689 # unregister all players of this provider
690 for player in provider.players:
691 await self.players.unregister(player.player_id, permanent=is_removed)
692 try:
693 await provider.unload(is_removed)
694 except Exception as err:
695 LOGGER.warning(
696 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
697 )
698 finally:
699 self._providers.pop(instance_id, None)
700 await self._update_available_providers_cache()
701 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
702
703 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
704 """Unload a provider when it got into trouble which needs user interaction."""
705 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
706 await self.unload_provider(instance_id)
707
708 async def run_provider_discovery(self, instance_id: str) -> None:
709 """
710 Run mDNS discovery for a given provider.
711
712 In case of a PlayerProvider, will also call its own discovery method.
713 """
714 provider = self.get_provider(instance_id, return_unavailable=False)
715 if not provider:
716 raise KeyError(f"Provider with instance ID {instance_id} not found")
717 if provider.manifest.mdns_discovery:
718 if provider.instance_id not in self._mdns_locks:
719 self._mdns_locks[provider.instance_id] = asyncio.Lock()
720 async with self._mdns_locks[provider.instance_id]:
721 for mdns_type in provider.manifest.mdns_discovery or []:
722 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
723 if mdns_type not in mdns_name or mdns_type == mdns_name:
724 continue
725 info = AsyncServiceInfo(mdns_type, mdns_name)
726 if await info.async_request(self.aiozc.zeroconf, 3000):
727 await provider.on_mdns_service_state_change(
728 mdns_name, ServiceStateChange.Added, info
729 )
730 if isinstance(provider, PlayerProvider):
731 await provider.discover_players()
732
733 def verify_event_loop_thread(self, what: str) -> None:
734 """Report and raise if we are not running in the event loop thread."""
735 if self.loop_thread_id != threading.get_ident():
736 raise RuntimeError(
737 f"Non-Async operation detected: {what} may only be called from the eventloop."
738 )
739
740 def _register_api_commands(self) -> None:
741 """Register all methods decorated as api_command within a class(instance)."""
742 for cls in (
743 self,
744 self.config,
745 self.metadata,
746 self.music,
747 self.players,
748 self.player_queues,
749 self.webserver,
750 self.webserver.auth,
751 ):
752 for attr_name in dir(cls):
753 if attr_name.startswith("__"):
754 continue
755 try:
756 obj = getattr(cls, attr_name)
757 except (AttributeError, RuntimeError):
758 # Skip properties that fail during initialization
759 continue
760 if hasattr(obj, "api_cmd"):
761 # method is decorated with our api decorator
762 authenticated = getattr(obj, "api_authenticated", True)
763 required_role = getattr(obj, "api_required_role", None)
764 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
765
766 async def _load_providers(self) -> None:
767 """Load providers from config."""
768 # create default config for any 'builtin' providers (e.g. URL provider)
769 for prov_manifest in self._provider_manifests.values():
770 if prov_manifest.type == ProviderType.CORE:
771 # core controllers are not real providers
772 continue
773 if not prov_manifest.builtin:
774 continue
775 await self.config.create_builtin_provider_config(prov_manifest.domain)
776
777 # load all configured (and enabled) providers
778 prov_configs = await self.config.get_provider_configs(include_values=True)
779 for prov_conf in prov_configs:
780 if not prov_conf.enabled:
781 continue
782 # Use a task so we can load multiple providers at once.
783 # If a provider fails, that will not block the loading of other providers.
784 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
785
786 async def _load_provider(self, conf: ProviderConfig) -> None:
787 """Load (or reload) a provider."""
788 # if provider is already loaded, stop and unload it first
789 await self.unload_provider(conf.instance_id)
790 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
791 if not conf.enabled:
792 msg = "Provider is disabled"
793 raise SetupFailedError(msg)
794
795 # validate config
796 try:
797 conf.validate()
798 except (KeyError, ValueError, AttributeError, TypeError) as err:
799 msg = "Configuration is invalid"
800 raise SetupFailedError(msg) from err
801
802 domain = conf.domain
803 prov_manifest = self._provider_manifests.get(domain)
804 # check for other instances of this provider
805 existing = next((x for x in self.providers if x.domain == domain), None)
806 if existing and prov_manifest and not prov_manifest.multi_instance:
807 msg = f"Provider {domain} already loaded and only one instance allowed."
808 raise SetupFailedError(msg)
809 # check valid manifest (just in case)
810 if not prov_manifest:
811 msg = f"Provider {domain} manifest not found"
812 raise SetupFailedError(msg)
813
814 # handle dependency on other provider
815 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
816 # we can safely ignore this completely as the setup will be retried later
817 # automatically when the dependency is loaded
818 return
819
820 # try to setup the module
821 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
822 try:
823 async with asyncio.timeout(30):
824 provider = await prov_mod.setup(self, prov_manifest, conf)
825 except TimeoutError as err:
826 msg = f"Provider {domain} did not load within 30 seconds"
827 raise SetupFailedError(msg) from err
828
829 # run async setup
830 await provider.handle_async_init()
831
832 # if we reach this point, the provider loaded successfully
833 self._providers[provider.instance_id] = provider
834 LOGGER.info(
835 "Loaded %s provider %s",
836 provider.type.value,
837 provider.name,
838 )
839 provider.available = True
840
841 # execute post load actions
842 async def _on_provider_loaded() -> None:
843 await provider.loaded_in_mass()
844 await self.run_provider_discovery(provider.instance_id)
845 # push instance name to config (to persist it if it was autogenerated)
846 if provider.default_name != conf.default_name:
847 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
848
849 self.create_task(_on_provider_loaded())
850
851 # clear any previous error in config and signal update
852 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
853 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
854 await self._update_available_providers_cache()
855 if isinstance(provider, MusicProvider):
856 await self.music.on_provider_loaded(provider)
857 if isinstance(provider, PlayerProvider):
858 await self.players.on_provider_loaded(provider)
859
860 async def __load_provider_manifests(self) -> None:
861 """Preload all available provider manifest files."""
862
863 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
864 """Preload all available provider manifest files."""
865 # get files in subdirectory
866 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
867 file_path = os.path.join(provider_path, file_str)
868 if not await isfile(file_path):
869 continue
870 if file_str != "manifest.json":
871 continue
872 try:
873 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
874 # check for icon.svg file
875 if not provider_manifest.icon_svg:
876 icon_path = os.path.join(provider_path, "icon.svg")
877 if await isfile(icon_path):
878 provider_manifest.icon_svg = await get_icon_string(icon_path)
879 # check for dark_icon file
880 if not provider_manifest.icon_svg_dark:
881 icon_path = os.path.join(provider_path, "icon_dark.svg")
882 if await isfile(icon_path):
883 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
884 # check for icon_monochrome file
885 if not provider_manifest.icon_svg_monochrome:
886 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
887 if await isfile(icon_path):
888 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
889 # override Home Assistant provider if we're running as add-on
890 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
891 provider_manifest.builtin = True
892 provider_manifest.allow_disable = False
893
894 self._provider_manifests[provider_manifest.domain] = provider_manifest
895 LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
896 except Exception as exc:
897 LOGGER.exception(
898 "Error while loading manifest for provider %s",
899 provider_domain,
900 exc_info=exc,
901 )
902
903 async with TaskManager(self) as tg:
904 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
905 if dir_str.startswith("."):
906 # skip hidden directories
907 continue
908 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
909 if dir_str.startswith("_") and not self.dev_mode:
910 # only load demo/test providers if debug mode is enabled (e.g. for development)
911 continue
912 if not await isdir(dir_path):
913 continue
914 tg.create_task(load_provider_manifest(dir_str, dir_path))
915
916 async def _setup_discovery(self) -> None:
917 """Handle setup of MDNS discovery."""
918 # create a global mdns browser
919 all_types: set[str] = set()
920 for prov_manifest in self._provider_manifests.values():
921 if prov_manifest.mdns_discovery:
922 all_types.update(prov_manifest.mdns_discovery)
923 self._aiobrowser = AsyncServiceBrowser(
924 self.aiozc.zeroconf,
925 list(all_types),
926 handlers=[self._on_mdns_service_state_change],
927 )
928 # register MA itself on mdns to be discovered
929 zeroconf_type = "_mass._tcp.local."
930 server_id = self.server_id
931 LOGGER.debug("Starting Zeroconf broadcast...")
932 info = AsyncServiceInfo(
933 zeroconf_type,
934 name=f"{server_id}.{zeroconf_type}",
935 addresses=[await get_ip_pton(self.webserver.publish_ip)],
936 port=self.webserver.publish_port,
937 properties=self.get_server_info().to_dict(),
938 server="mass.local.",
939 )
940 try:
941 existing = getattr(self, "mass_zc_service_set", None)
942 if existing:
943 await self.aiozc.async_update_service(info)
944 else:
945 await self.aiozc.async_register_service(info)
946 self.mass_zc_service_set = True
947 except NonUniqueNameException:
948 LOGGER.error(
949 "Music Assistant instance with identical name present in the local network!"
950 )
951
952 def _on_mdns_service_state_change(
953 self,
954 zeroconf: Zeroconf,
955 service_type: str,
956 name: str,
957 state_change: ServiceStateChange,
958 ) -> None:
959 """Handle MDNS service state callback."""
960
961 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
962 if prov.instance_id not in self._mdns_locks:
963 self._mdns_locks[prov.instance_id] = asyncio.Lock()
964 if state_change == ServiceStateChange.Removed:
965 info = None
966 else:
967 info = AsyncServiceInfo(service_type, name)
968 await info.async_request(zeroconf, 3000)
969 # use a lock per provider instance to avoid
970 # race conditions in processing mdns events
971 async with self._mdns_locks[prov.instance_id]:
972 await prov.on_mdns_service_state_change(name, state_change, info)
973
974 LOGGER.log(
975 VERBOSE_LOG_LEVEL,
976 "Service %s of type %s state changed: %s",
977 name,
978 service_type,
979 state_change,
980 )
981 for prov in self._providers.values():
982 if not prov.manifest.mdns_discovery:
983 continue
984 if not prov.available:
985 continue
986 if service_type in prov.manifest.mdns_discovery:
987 self.create_task(process_mdns_state_change(prov))
988
989 async def __aenter__(self) -> Self:
990 """Return Context manager."""
991 await self.start()
992 return self
993
994 async def __aexit__(
995 self,
996 exc_type: type[BaseException] | None,
997 exc_val: BaseException | None,
998 exc_tb: TracebackType | None,
999 ) -> bool | None:
1000 """Exit context manager."""
1001 await self.stop()
1002 return None
1003
1004 async def _update_available_providers_cache(self) -> None:
1005 """Update the global cache variable of loaded/available providers."""
1006 await set_global_cache_values(
1007 {
1008 "provider_domains": {x.domain for x in self.providers},
1009 "provider_instance_ids": {x.instance_id for x in self.providers},
1010 "available_providers": {
1011 *{x.domain for x in self.providers},
1012 *{x.instance_id for x in self.providers},
1013 },
1014 "unique_providers": self.music.get_unique_providers(),
1015 "streaming_providers": {
1016 x.domain
1017 for x in self.providers
1018 if is_music_provider(x) and x.is_streaming_provider
1019 },
1020 "non_streaming_providers": {
1021 x.instance_id
1022 for x in self.providers
1023 if not (is_music_provider(x) and x.is_streaming_provider)
1024 },
1025 }
1026 )
1027
1028 async def _setup_storage(self) -> None:
1029 """Handle Setup of storage/cache folder(s)."""
1030 if not await isdir(self.storage_path):
1031 await mkdirs(self.storage_path)
1032 if not await isdir(self.cache_path):
1033 await mkdirs(self.cache_path)
1034