/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
13from uuid import uuid4
14
15import aiofiles
16from aiofiles.os import wrap
17from music_assistant_models.api import ServerInfoMessage
18from music_assistant_models.auth import UserRole
19from music_assistant_models.enums import EventType, ProviderType
20from music_assistant_models.errors import MusicAssistantError, SetupFailedError
21from music_assistant_models.event import MassEvent
22from music_assistant_models.helpers import set_global_cache_values
23from music_assistant_models.provider import ProviderManifest
24from zeroconf import (
25 InterfaceChoice,
26 IPVersion,
27 NonUniqueNameException,
28 ServiceStateChange,
29 Zeroconf,
30)
31from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
32
33from music_assistant.constants import (
34 API_SCHEMA_VERSION,
35 CONF_PROVIDERS,
36 CONF_SERVER_ID,
37 CONF_ZEROCONF_INTERFACES,
38 CONFIGURABLE_CORE_CONTROLLERS,
39 MASS_LOGGER_NAME,
40 MIN_SCHEMA_VERSION,
41 VERBOSE_LOG_LEVEL,
42)
43from music_assistant.controllers.cache import CacheController
44from music_assistant.controllers.config import ConfigController
45from music_assistant.controllers.metadata import MetaDataController
46from music_assistant.controllers.music import MusicController
47from music_assistant.controllers.player_queues import PlayerQueuesController
48from music_assistant.controllers.players.player_controller import PlayerController
49from music_assistant.controllers.streams import StreamsController
50from music_assistant.controllers.webserver import WebserverController
51from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
52from music_assistant.helpers.aiohttp_client import create_clientsession
53from music_assistant.helpers.api import APICommandHandler, api_command
54from music_assistant.helpers.images import get_icon_string
55from music_assistant.helpers.util import (
56 TaskManager,
57 get_ip_pton,
58 get_package_version,
59 is_hass_supervisor,
60 load_provider_module,
61)
62from music_assistant.models import ProviderInstanceType
63from music_assistant.models.music_provider import MusicProvider
64from music_assistant.models.player_provider import PlayerProvider
65
66if TYPE_CHECKING:
67 from types import TracebackType
68
69 from aiohttp import ClientSession
70 from music_assistant_models.config_entries import ProviderConfig
71
72 from music_assistant.models.core_controller import CoreController
73
74isdir = wrap(os.path.isdir)
75isfile = wrap(os.path.isfile)
76mkdirs = wrap(os.makedirs)
77rmfile = wrap(os.remove)
78listdir = wrap(os.listdir)
79rename = wrap(os.rename)
80
81EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
82EventSubscriptionType = tuple[
83 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
84]
85
86LOGGER = logging.getLogger(MASS_LOGGER_NAME)
87
88BASE_DIR = os.path.dirname(os.path.abspath(__file__))
89PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
90
91_R = TypeVar("_R")
92_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
93
94
95def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
96 """Type guard that returns true if a provider is a music provider."""
97 return provider.type == ProviderType.MUSIC
98
99
100def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
101 """Type guard that returns true if a provider is a player provider."""
102 return provider.type == ProviderType.PLAYER
103
104
105class MusicAssistant:
106 """Main MusicAssistant (Server) object."""
107
108 loop: asyncio.AbstractEventLoop
109 aiozc: AsyncZeroconf
110 config: ConfigController
111 webserver: WebserverController
112 cache: CacheController
113 metadata: MetaDataController
114 music: MusicController
115 players: PlayerController
116 player_queues: PlayerQueuesController
117 streams: StreamsController
118 _aiobrowser: AsyncServiceBrowser
119
120 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
121 """Initialize the MusicAssistant Server."""
122 self.storage_path = storage_path
123 self.cache_path = cache_path
124 self.safe_mode = safe_mode
125 # we dynamically register command handlers which can be consumed by the apis
126 self.command_handlers: dict[str, APICommandHandler] = {}
127 self._subscribers: set[EventSubscriptionType] = set()
128 self._provider_manifests: dict[str, ProviderManifest] = {}
129 self._providers: dict[str, ProviderInstanceType] = {}
130 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
131 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
132 self.closing = False
133 self.running_as_hass_addon: bool = False
134 self.version: str = "0.0.0"
135 self.dev_mode = (
136 os.environ.get("PYTHONDEVMODE") == "1"
137 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
138 )
139 self._http_session: ClientSession | None = None
140 self._http_session_no_ssl: ClientSession | None = None
141 self._mdns_locks: dict[str, asyncio.Lock] = {}
142
143 async def start(self) -> None:
144 """Start running the Music Assistant server."""
145 self.loop = asyncio.get_running_loop()
146 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
147 self.running_as_hass_addon = await is_hass_supervisor()
148 self.version = await get_package_version("music_assistant") or "0.0.0"
149 # setup config controller first and fetch important config values
150 self.config = ConfigController(self)
151 await self.config.setup()
152 # create shared zeroconf instance
153 zeroconf_interfaces = self.config.get_raw_core_config_value(
154 "players", CONF_ZEROCONF_INTERFACES, "default"
155 )
156 # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
157 use_all_interfaces = zeroconf_interfaces == "all"
158 self.aiozc = AsyncZeroconf(
159 ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
160 interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
161 )
162 # load all available providers from manifest files
163 await self.__load_provider_manifests()
164 # setup/migrate storage
165 await self._setup_storage()
166 LOGGER.info(
167 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
168 self.server_id,
169 self.version,
170 self.running_as_hass_addon,
171 self.safe_mode,
172 )
173 # setup other core controllers
174 self.cache = CacheController(self)
175 self.webserver = WebserverController(self)
176 self.metadata = MetaDataController(self)
177 self.music = MusicController(self)
178 self.players = PlayerController(self)
179 self.player_queues = PlayerQueuesController(self)
180 self.streams = StreamsController(self)
181 # add manifests for core controllers
182 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
183 controller: CoreController = getattr(self, controller_name)
184 self._provider_manifests[controller.domain] = controller.manifest
185 await self.cache.setup(await self.config.get_core_config("cache"))
186 # load streams controller early so we can abort if we can't load it
187 await self.streams.setup(await self.config.get_core_config("streams"))
188 await self.music.setup(await self.config.get_core_config("music"))
189 await self.metadata.setup(await self.config.get_core_config("metadata"))
190 await self.players.setup(await self.config.get_core_config("players"))
191 await self.player_queues.setup(await self.config.get_core_config("player_queues"))
192 # load webserver/api last so the api/frontend is
193 # not yet available while we're starting (or performing migrations)
194 self._register_api_commands()
195 await self.webserver.setup(await self.config.get_core_config("webserver"))
196 # setup discovery
197 await self._setup_discovery()
198 # load providers
199 if not self.safe_mode:
200 await self._load_providers()
201
202 async def stop(self) -> None:
203 """Stop running the music assistant server."""
204 LOGGER.info("Stop called, cleaning up...")
205 self.signal_event(EventType.SHUTDOWN)
206 self.closing = True
207 # cancel all running tasks
208 for task in self._tracked_tasks.values():
209 task.cancel()
210 # cleanup all providers
211 await asyncio.gather(
212 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
213 return_exceptions=True,
214 )
215 # stop core controllers
216 await self.streams.close()
217 await self.webserver.close()
218 await self.metadata.close()
219 await self.music.close()
220 await self.player_queues.close()
221 await self.players.close()
222 # cleanup cache and config
223 await self.config.close()
224 await self.cache.close()
225 # close/cleanup shared http session
226 if self._http_session:
227 self._http_session.detach()
228 if self._http_session.connector:
229 await self._http_session.connector.close()
230 if self._http_session_no_ssl:
231 self._http_session_no_ssl.detach()
232 if self._http_session_no_ssl.connector:
233 await self._http_session_no_ssl.connector.close()
234
235 @property
236 def server_id(self) -> str:
237 """Return unique ID of this server."""
238 if not self.config.initialized:
239 return ""
240 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
241
242 @property
243 def http_session(self) -> ClientSession:
244 """
245 Return the shared HTTP Client session (with SSL).
246
247 NOTE: May only be called from the event loop.
248 """
249 if self._http_session is None:
250 self._http_session = create_clientsession(self, verify_ssl=True)
251 return self._http_session
252
253 @property
254 def http_session_no_ssl(self) -> ClientSession:
255 """
256 Return the shared HTTP Client session (without SSL).
257
258 NOTE: May only be called from the event loop thread.
259 """
260 if self._http_session_no_ssl is None:
261 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
262 return self._http_session_no_ssl
263
264 @api_command("info")
265 def get_server_info(self) -> ServerInfoMessage:
266 """Return Info of this server."""
267 return ServerInfoMessage(
268 server_id=self.server_id,
269 server_version=self.version,
270 schema_version=API_SCHEMA_VERSION,
271 min_supported_schema_version=MIN_SCHEMA_VERSION,
272 base_url=self.webserver.base_url,
273 homeassistant_addon=self.running_as_hass_addon,
274 onboard_done=self.config.onboard_done,
275 )
276
277 @api_command("providers/manifests")
278 def get_provider_manifests(self) -> list[ProviderManifest]:
279 """Return all Provider manifests."""
280 return list(self._provider_manifests.values())
281
282 @api_command("providers/manifests/get")
283 def get_provider_manifest(self, domain: str) -> ProviderManifest:
284 """Return Provider manifests of single provider(domain)."""
285 return self._provider_manifests[domain]
286
287 @api_command("providers")
288 def get_providers(
289 self, provider_type: ProviderType | None = None
290 ) -> list[ProviderInstanceType]:
291 """
292 Return all loaded/running Providers (instances).
293
294 Optionally filtered by ProviderType.
295 Note that this applies user filters for music providers (for non admin users).
296 """
297 user = get_current_user()
298 user_provider_filter = (
299 user.provider_filter if user and user.role != UserRole.ADMIN else None
300 )
301 return [
302 x
303 for x in self._providers.values()
304 if (provider_type is None or provider_type == x.type)
305 # apply user provider filter
306 and (
307 not user_provider_filter
308 or x.instance_id in user_provider_filter
309 or x.type != ProviderType.MUSIC
310 )
311 ]
312
313 @api_command("logging/get", required_role=UserRole.ADMIN)
314 async def get_application_log(self) -> str:
315 """Return the application log from file."""
316 logfile = os.path.join(self.storage_path, "musicassistant.log")
317 async with aiofiles.open(logfile) as _file:
318 return str(await _file.read())
319
320 @property
321 def providers(self) -> list[ProviderInstanceType]:
322 """
323 Return all loaded/running Providers (instances).
324
325 Note that this skips user filters so may only be called from internal code.
326 """
327 return list(self._providers.values())
328
329 @overload
330 def get_provider(
331 self,
332 provider_instance_or_domain: str,
333 return_unavailable: bool = False,
334 provider_type: None = None,
335 ) -> ProviderInstanceType | None: ...
336
337 @overload
338 def get_provider(
339 self,
340 provider_instance_or_domain: str,
341 return_unavailable: bool = False,
342 *,
343 provider_type: type[_ProviderT],
344 ) -> _ProviderT | None: ...
345
346 def get_provider(
347 self,
348 provider_instance_or_domain: str,
349 return_unavailable: bool = False,
350 provider_type: type[_ProviderT] | None = None,
351 ) -> ProviderInstanceType | _ProviderT | None:
352 """Return provider by instance id or domain.
353
354 :param provider_instance_or_domain: Instance ID or domain of the provider.
355 :param return_unavailable: Also return unavailable providers.
356 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
357 """
358 # lookup by instance_id first
359 if prov := self._providers.get(provider_instance_or_domain):
360 if return_unavailable or prov.available:
361 return prov
362 if not getattr(prov, "is_streaming_provider", None):
363 # no need to lookup other instances because this provider has unique data
364 return None
365 provider_instance_or_domain = prov.domain
366 # fallback to match on domain
367 for prov in self._providers.values():
368 if prov.domain != provider_instance_or_domain:
369 continue
370 if return_unavailable or prov.available:
371 return prov
372 return None
373
374 def get_provider_instances(
375 self,
376 domain: str,
377 return_unavailable: bool = False,
378 provider_type: ProviderType | None = None,
379 ) -> list[ProviderInstanceType]:
380 """
381 Return all provider instances for a given domain.
382
383 Note that this skips user filters so may only be called from internal code.
384 """
385 return [
386 prov
387 for prov in self._providers.values()
388 if (provider_type is None or provider_type == prov.type)
389 and prov.domain == domain
390 and (return_unavailable or prov.available)
391 ]
392
393 def signal_event(
394 self,
395 event: EventType,
396 object_id: str | None = None,
397 data: Any = None,
398 ) -> None:
399 """Signal event to subscribers."""
400 if self.closing:
401 return
402
403 self.verify_event_loop_thread("signal_event")
404
405 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
406 # do not log queue time updated events because that is too chatty
407 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
408
409 event_obj = MassEvent(event=event, object_id=object_id, data=data)
410 for cb_func, event_filter, id_filter in self._subscribers:
411 if not (event_filter is None or event in event_filter):
412 continue
413 if not (id_filter is None or object_id in id_filter):
414 continue
415 if inspect.iscoroutinefunction(cb_func):
416 if TYPE_CHECKING:
417 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
418 self.create_task(cb_func, event_obj)
419 else:
420 if TYPE_CHECKING:
421 cb_func = cast("Callable[[MassEvent], None]", cb_func)
422 self.loop.call_soon_threadsafe(cb_func, event_obj)
423
424 def subscribe(
425 self,
426 cb_func: EventCallBackType,
427 event_filter: EventType | tuple[EventType, ...] | None = None,
428 id_filter: str | tuple[str, ...] | None = None,
429 ) -> Callable[[], None]:
430 """Add callback to event listeners.
431
432 Returns function to remove the listener.
433 :param cb_func: callback function or coroutine
434 :param event_filter: Optionally only listen for these events
435 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
436 """
437 if isinstance(event_filter, EventType):
438 event_filter = (event_filter,)
439 if isinstance(id_filter, str):
440 id_filter = (id_filter,)
441 listener = (cb_func, event_filter, id_filter)
442 self._subscribers.add(listener)
443
444 def remove_listener() -> None:
445 self._subscribers.remove(listener)
446
447 return remove_listener
448
449 def create_task(
450 self,
451 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
452 *args: Any,
453 task_id: str | None = None,
454 abort_existing: bool = False,
455 **kwargs: Any,
456 ) -> asyncio.Task[_R]:
457 """Create Task on (main) event loop from Coroutine(function).
458
459 Tasks created by this helper will be properly cancelled on stop.
460 """
461 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
462 # prevent duplicate tasks if task_id is given and already present
463 if abort_existing:
464 existing.cancel()
465 else:
466 return existing
467 self.verify_event_loop_thread("create_task")
468
469 if inspect.iscoroutinefunction(target):
470 # coroutine function
471 task = self.loop.create_task(target(*args, **kwargs))
472 elif inspect.iscoroutine(target):
473 # coroutine
474 task = self.loop.create_task(target)
475 elif callable(target):
476 raise RuntimeError("Function is not a coroutine or coroutine function")
477 else:
478 raise RuntimeError("Target is missing")
479
480 if task_id is None:
481 task_id = uuid4().hex
482
483 def task_done_callback(_task: asyncio.Task[Any]) -> None:
484 self._tracked_tasks.pop(task_id, None)
485 # log unhandled exceptions
486 if (
487 LOGGER.isEnabledFor(logging.DEBUG)
488 and not _task.cancelled()
489 and (err := _task.exception())
490 ):
491 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
492 LOGGER.warning(
493 "Exception in task %s - target: %s: %s",
494 task_name,
495 str(target),
496 str(err),
497 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
498 )
499
500 self._tracked_tasks[task_id] = task
501 task.add_done_callback(task_done_callback)
502 return task
503
504 def call_later(
505 self,
506 delay: float,
507 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
508 *args: Any,
509 task_id: str | None = None,
510 **kwargs: Any,
511 ) -> asyncio.TimerHandle:
512 """
513 Run callable/awaitable after given delay.
514
515 Use task_id for debouncing.
516 """
517 self.verify_event_loop_thread("call_later")
518
519 if not task_id:
520 task_id = uuid4().hex
521
522 if existing := self._tracked_timers.get(task_id):
523 existing.cancel()
524
525 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
526 self._tracked_timers.pop(task_id)
527 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
528
529 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
530 # coroutine function
531 if TYPE_CHECKING:
532 target = cast("Coroutine[Any, Any, _R]", target)
533 handle = self.loop.call_later(delay, _create_task, target)
534 else:
535 # regular callable
536 if TYPE_CHECKING:
537 target = cast("Callable[..., _R]", target)
538 handle = self.loop.call_later(delay, target, *args)
539 self._tracked_timers[task_id] = handle
540 return handle
541
542 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
543 """Get existing scheduled task."""
544 if existing := self._tracked_tasks.get(task_id):
545 # prevent duplicate tasks if task_id is given and already present
546 return existing
547 return None
548
549 def cancel_task(self, task_id: str) -> None:
550 """Cancel existing scheduled task."""
551 if existing := self._tracked_tasks.pop(task_id, None):
552 existing.cancel()
553
554 def cancel_timer(self, task_id: str) -> None:
555 """Cancel existing scheduled timer."""
556 if existing := self._tracked_timers.pop(task_id, None):
557 existing.cancel()
558
559 def register_api_command(
560 self,
561 command: str,
562 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
563 authenticated: bool = True,
564 required_role: str | None = None,
565 alias: bool = False,
566 ) -> Callable[[], None]:
567 """Dynamically register a command on the API.
568
569 :param command: The command name/path.
570 :param handler: The function to handle the command.
571 :param authenticated: Whether authentication is required (default: True).
572 :param required_role: Required user role ("admin" or "user")
573 None means any authenticated user.
574 :param alias: Whether this is an alias for backward compatibility (default: False).
575 Aliases are not shown in API documentation but remain functional.
576
577 Returns handle to unregister.
578 """
579 if command in self.command_handlers:
580 msg = f"Command {command} is already registered"
581 raise RuntimeError(msg)
582 self.command_handlers[command] = APICommandHandler.parse(
583 command, handler, authenticated, required_role, alias
584 )
585
586 def unregister() -> None:
587 self.command_handlers.pop(command, None)
588
589 return unregister
590
591 async def load_provider_config(
592 self,
593 prov_conf: ProviderConfig,
594 ) -> None:
595 """Try to load a provider and catch errors."""
596 # cancel existing (re)load timer if needed
597 task_id = f"load_provider_{prov_conf.instance_id}"
598 if existing := self._tracked_timers.pop(task_id, None):
599 existing.cancel()
600
601 await self._load_provider(prov_conf)
602
603 # (re)load any dependants
604 prov_configs = await self.config.get_provider_configs(include_values=True)
605 for dep_prov_conf in prov_configs:
606 if not dep_prov_conf.enabled:
607 continue
608 manifest = self.get_provider_manifest(dep_prov_conf.domain)
609 if not manifest.depends_on:
610 continue
611 if manifest.depends_on == prov_conf.domain:
612 await self._load_provider(dep_prov_conf)
613
614 async def load_provider(
615 self,
616 instance_id: str,
617 allow_retry: bool = False,
618 ) -> None:
619 """Try to load a provider and catch errors."""
620 try:
621 prov_conf = await self.config.get_provider_config(instance_id)
622 except KeyError:
623 # Was deleted before we could run
624 return
625
626 if not prov_conf.enabled:
627 # Was disabled before we could run
628 return
629
630 # cancel existing (re)load timer if needed
631 task_id = f"load_provider_{instance_id}"
632 if existing := self._tracked_timers.pop(task_id, None):
633 existing.cancel()
634
635 try:
636 await self.load_provider_config(prov_conf)
637 except Exception as exc:
638 # if loading failed, we store the error in the config object
639 # so we can show something useful to the user
640 prov_conf.last_error = str(exc)
641 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
642
643 # auto schedule a retry if the (re)load failed with a handled exception
644 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
645 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
646 if will_retry:
647 self.call_later(
648 120,
649 self.load_provider,
650 instance_id,
651 allow_retry,
652 task_id=task_id,
653 )
654 LOGGER.warning(
655 "Error loading provider(instance) %s: %s%s",
656 prov_conf.name or prov_conf.instance_id,
657 str(exc) or exc.__class__.__name__,
658 " (will be retried later)" if will_retry else "",
659 # log full stack trace if verbose logging is enabled
660 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
661 )
662 return
663
664 # (re)load any dependents if needed
665 for dep_prov in self.providers:
666 if dep_prov.available:
667 continue
668 if dep_prov.manifest.depends_on == prov_conf.domain:
669 await self.unload_provider(dep_prov.instance_id)
670
671 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
672 """Unload a provider."""
673 self.music.unschedule_provider_sync(instance_id)
674 if provider := self._providers.get(instance_id):
675 # remove mdns discovery if needed
676 if provider.manifest.mdns_discovery:
677 for mdns_type in provider.manifest.mdns_discovery:
678 self._aiobrowser.types.discard(mdns_type)
679 if isinstance(provider, PlayerProvider):
680 await self.players.on_provider_unload(provider)
681 if isinstance(provider, MusicProvider):
682 await self.music.on_provider_unload(provider)
683 # check if there are no other providers dependent of this provider
684 for dep_prov in self.providers:
685 if dep_prov.manifest.depends_on == provider.domain:
686 await self.unload_provider(dep_prov.instance_id)
687 if is_player_provider(provider):
688 # unregister all players of this provider
689 for player in provider.players:
690 await self.players.unregister(player.player_id, permanent=is_removed)
691 try:
692 await provider.unload(is_removed)
693 except Exception as err:
694 LOGGER.warning(
695 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
696 )
697 finally:
698 self._providers.pop(instance_id, None)
699 await self._update_available_providers_cache()
700 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
701
702 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
703 """Unload a provider when it got into trouble which needs user interaction."""
704 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
705 await self.unload_provider(instance_id)
706
707 async def run_provider_discovery(self, instance_id: str) -> None:
708 """
709 Run mDNS discovery for a given provider.
710
711 In case of a PlayerProvider, will also call its own discovery method.
712 """
713 provider = self.get_provider(instance_id, return_unavailable=False)
714 if not provider:
715 raise KeyError(f"Provider with instance ID {instance_id} not found")
716 if provider.manifest.mdns_discovery:
717 if provider.instance_id not in self._mdns_locks:
718 self._mdns_locks[provider.instance_id] = asyncio.Lock()
719 async with self._mdns_locks[provider.instance_id]:
720 for mdns_type in provider.manifest.mdns_discovery or []:
721 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
722 if mdns_type not in mdns_name or mdns_type == mdns_name:
723 continue
724 info = AsyncServiceInfo(mdns_type, mdns_name)
725 if await info.async_request(self.aiozc.zeroconf, 3000):
726 await provider.on_mdns_service_state_change(
727 mdns_name, ServiceStateChange.Added, info
728 )
729 if isinstance(provider, PlayerProvider):
730 await provider.discover_players()
731
732 def verify_event_loop_thread(self, what: str) -> None:
733 """Report and raise if we are not running in the event loop thread."""
734 if self.loop_thread_id != threading.get_ident():
735 raise RuntimeError(
736 f"Non-Async operation detected: {what} may only be called from the eventloop."
737 )
738
739 def _register_api_commands(self) -> None:
740 """Register all methods decorated as api_command within a class(instance)."""
741 for cls in (
742 self,
743 self.config,
744 self.metadata,
745 self.music,
746 self.players,
747 self.player_queues,
748 self.webserver,
749 self.webserver.auth,
750 ):
751 for attr_name in dir(cls):
752 if attr_name.startswith("__"):
753 continue
754 try:
755 obj = getattr(cls, attr_name)
756 except (AttributeError, RuntimeError):
757 # Skip properties that fail during initialization
758 continue
759 if hasattr(obj, "api_cmd"):
760 # method is decorated with our api decorator
761 authenticated = getattr(obj, "api_authenticated", True)
762 required_role = getattr(obj, "api_required_role", None)
763 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
764
765 async def _load_providers(self) -> None:
766 """Load providers from config."""
767 # create default config for any 'builtin' providers (e.g. URL provider)
768 for prov_manifest in self._provider_manifests.values():
769 if prov_manifest.type == ProviderType.CORE:
770 # core controllers are not real providers
771 continue
772 if not prov_manifest.builtin:
773 continue
774 await self.config.create_builtin_provider_config(prov_manifest.domain)
775
776 # load all configured (and enabled) providers
777 prov_configs = await self.config.get_provider_configs(include_values=True)
778 for prov_conf in prov_configs:
779 if not prov_conf.enabled:
780 continue
781 # Use a task so we can load multiple providers at once.
782 # If a provider fails, that will not block the loading of other providers.
783 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
784
785 async def _load_provider(self, conf: ProviderConfig) -> None:
786 """Load (or reload) a provider."""
787 # if provider is already loaded, stop and unload it first
788 await self.unload_provider(conf.instance_id)
789 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
790 if not conf.enabled:
791 msg = "Provider is disabled"
792 raise SetupFailedError(msg)
793
794 # validate config
795 try:
796 conf.validate()
797 except (KeyError, ValueError, AttributeError, TypeError) as err:
798 msg = "Configuration is invalid"
799 raise SetupFailedError(msg) from err
800
801 domain = conf.domain
802 prov_manifest = self._provider_manifests.get(domain)
803 # check for other instances of this provider
804 existing = next((x for x in self.providers if x.domain == domain), None)
805 if existing and prov_manifest and not prov_manifest.multi_instance:
806 msg = f"Provider {domain} already loaded and only one instance allowed."
807 raise SetupFailedError(msg)
808 # check valid manifest (just in case)
809 if not prov_manifest:
810 msg = f"Provider {domain} manifest not found"
811 raise SetupFailedError(msg)
812
813 # handle dependency on other provider
814 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
815 # we can safely ignore this completely as the setup will be retried later
816 # automatically when the dependency is loaded
817 return
818
819 # try to setup the module
820 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
821 try:
822 async with asyncio.timeout(30):
823 provider = await prov_mod.setup(self, prov_manifest, conf)
824 except TimeoutError as err:
825 msg = f"Provider {domain} did not load within 30 seconds"
826 raise SetupFailedError(msg) from err
827
828 # run async setup
829 await provider.handle_async_init()
830
831 # if we reach this point, the provider loaded successfully
832 self._providers[provider.instance_id] = provider
833 LOGGER.info(
834 "Loaded %s provider %s",
835 provider.type.value,
836 provider.name,
837 )
838 provider.available = True
839
840 # adapt logging name if needed
841 provider._set_log_level_from_config(provider.config)
842
843 # execute post load actions
844 async def _on_provider_loaded() -> None:
845 await provider.loaded_in_mass()
846 await self.run_provider_discovery(provider.instance_id)
847 # push instance name to config (to persist it if it was autogenerated)
848 if provider.default_name != conf.default_name:
849 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
850
851 self.create_task(_on_provider_loaded())
852
853 # clear any previous error in config and signal update
854 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
855 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
856 await self._update_available_providers_cache()
857 if isinstance(provider, MusicProvider):
858 await self.music.on_provider_loaded(provider)
859 if isinstance(provider, PlayerProvider):
860 await self.players.on_provider_loaded(provider)
861
862 async def __load_provider_manifests(self) -> None:
863 """Preload all available provider manifest files."""
864
865 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
866 """Preload all available provider manifest files."""
867 # get files in subdirectory
868 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
869 file_path = os.path.join(provider_path, file_str)
870 if not await isfile(file_path):
871 continue
872 if file_str != "manifest.json":
873 continue
874 try:
875 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
876 # check for icon.svg file
877 if not provider_manifest.icon_svg:
878 icon_path = os.path.join(provider_path, "icon.svg")
879 if await isfile(icon_path):
880 provider_manifest.icon_svg = await get_icon_string(icon_path)
881 # check for dark_icon file
882 if not provider_manifest.icon_svg_dark:
883 icon_path = os.path.join(provider_path, "icon_dark.svg")
884 if await isfile(icon_path):
885 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
886 # check for icon_monochrome file
887 if not provider_manifest.icon_svg_monochrome:
888 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
889 if await isfile(icon_path):
890 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
891 # override Home Assistant provider if we're running as add-on
892 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
893 provider_manifest.builtin = True
894 provider_manifest.allow_disable = False
895
896 self._provider_manifests[provider_manifest.domain] = provider_manifest
897 LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
898 except Exception as exc:
899 LOGGER.exception(
900 "Error while loading manifest for provider %s",
901 provider_domain,
902 exc_info=exc,
903 )
904
905 async with TaskManager(self) as tg:
906 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
907 if dir_str.startswith("."):
908 # skip hidden directories
909 continue
910 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
911 if dir_str.startswith("_") and not self.dev_mode:
912 # only load demo/test providers if debug mode is enabled (e.g. for development)
913 continue
914 if not await isdir(dir_path):
915 continue
916 tg.create_task(load_provider_manifest(dir_str, dir_path))
917
918 async def _setup_discovery(self) -> None:
919 """Handle setup of MDNS discovery."""
920 # create a global mdns browser
921 all_types: set[str] = set()
922 for prov_manifest in self._provider_manifests.values():
923 if prov_manifest.mdns_discovery:
924 all_types.update(prov_manifest.mdns_discovery)
925 self._aiobrowser = AsyncServiceBrowser(
926 self.aiozc.zeroconf,
927 list(all_types),
928 handlers=[self._on_mdns_service_state_change],
929 )
930 # register MA itself on mdns to be discovered
931 zeroconf_type = "_mass._tcp.local."
932 server_id = self.server_id
933 LOGGER.debug("Starting Zeroconf broadcast...")
934 info = AsyncServiceInfo(
935 zeroconf_type,
936 name=f"{server_id}.{zeroconf_type}",
937 addresses=[await get_ip_pton(self.webserver.publish_ip)],
938 port=self.webserver.publish_port,
939 properties=self.get_server_info().to_dict(),
940 server="mass.local.",
941 )
942 try:
943 existing = getattr(self, "mass_zc_service_set", None)
944 if existing:
945 await self.aiozc.async_update_service(info)
946 else:
947 await self.aiozc.async_register_service(info)
948 self.mass_zc_service_set = True
949 except NonUniqueNameException:
950 LOGGER.error(
951 "Music Assistant instance with identical name present in the local network!"
952 )
953
954 def _on_mdns_service_state_change(
955 self,
956 zeroconf: Zeroconf,
957 service_type: str,
958 name: str,
959 state_change: ServiceStateChange,
960 ) -> None:
961 """Handle MDNS service state callback."""
962
963 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
964 if prov.instance_id not in self._mdns_locks:
965 self._mdns_locks[prov.instance_id] = asyncio.Lock()
966 if state_change == ServiceStateChange.Removed:
967 info = None
968 else:
969 info = AsyncServiceInfo(service_type, name)
970 await info.async_request(zeroconf, 3000)
971 # use a lock per provider instance to avoid
972 # race conditions in processing mdns events
973 async with self._mdns_locks[prov.instance_id]:
974 await prov.on_mdns_service_state_change(name, state_change, info)
975
976 LOGGER.log(
977 VERBOSE_LOG_LEVEL,
978 "Service %s of type %s state changed: %s",
979 name,
980 service_type,
981 state_change,
982 )
983 for prov in self._providers.values():
984 if not prov.manifest.mdns_discovery:
985 continue
986 if not prov.available:
987 continue
988 if service_type in prov.manifest.mdns_discovery:
989 self.create_task(process_mdns_state_change(prov))
990
991 async def __aenter__(self) -> Self:
992 """Return Context manager."""
993 await self.start()
994 return self
995
996 async def __aexit__(
997 self,
998 exc_type: type[BaseException] | None,
999 exc_val: BaseException | None,
1000 exc_tb: TracebackType | None,
1001 ) -> bool | None:
1002 """Exit context manager."""
1003 await self.stop()
1004 return None
1005
1006 async def _update_available_providers_cache(self) -> None:
1007 """Update the global cache variable of loaded/available providers."""
1008 await set_global_cache_values(
1009 {
1010 "provider_domains": {x.domain for x in self.providers},
1011 "provider_instance_ids": {x.instance_id for x in self.providers},
1012 "available_providers": {
1013 *{x.domain for x in self.providers},
1014 *{x.instance_id for x in self.providers},
1015 },
1016 "unique_providers": self.music.get_unique_providers(),
1017 "streaming_providers": {
1018 x.domain
1019 for x in self.providers
1020 if is_music_provider(x) and x.is_streaming_provider
1021 },
1022 "non_streaming_providers": {
1023 x.instance_id
1024 for x in self.providers
1025 if not (is_music_provider(x) and x.is_streaming_provider)
1026 },
1027 }
1028 )
1029
1030 async def _setup_storage(self) -> None:
1031 """Handle Setup of storage/cache folder(s)."""
1032 if not await isdir(self.storage_path):
1033 await mkdirs(self.storage_path)
1034 if not await isdir(self.cache_path):
1035 await mkdirs(self.cache_path)
1036