/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
13from uuid import uuid4
14
15import aiofiles
16from aiofiles.os import wrap
17from music_assistant_models.api import ServerInfoMessage
18from music_assistant_models.auth import UserRole
19from music_assistant_models.enums import EventType, ProviderType
20from music_assistant_models.errors import MusicAssistantError, SetupFailedError
21from music_assistant_models.event import MassEvent
22from music_assistant_models.helpers import set_global_cache_values
23from music_assistant_models.provider import ProviderManifest
24from zeroconf import (
25 InterfaceChoice,
26 IPVersion,
27 NonUniqueNameException,
28 ServiceStateChange,
29 Zeroconf,
30)
31from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
32
33from music_assistant.constants import (
34 API_SCHEMA_VERSION,
35 CONF_DEFAULT_PROVIDERS_SETUP,
36 CONF_PROVIDERS,
37 CONF_SERVER_ID,
38 CONF_ZEROCONF_INTERFACES,
39 CONFIGURABLE_CORE_CONTROLLERS,
40 DEFAULT_PROVIDERS,
41 MASS_LOGGER_NAME,
42 MIN_SCHEMA_VERSION,
43 VERBOSE_LOG_LEVEL,
44)
45from music_assistant.controllers.cache import CacheController
46from music_assistant.controllers.config import ConfigController
47from music_assistant.controllers.metadata import MetaDataController
48from music_assistant.controllers.music import MusicController
49from music_assistant.controllers.player_queues import PlayerQueuesController
50from music_assistant.controllers.players import PlayerController
51from music_assistant.controllers.streams import StreamsController
52from music_assistant.controllers.webserver import WebserverController
53from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
54from music_assistant.helpers.aiohttp_client import create_clientsession
55from music_assistant.helpers.api import APICommandHandler, api_command
56from music_assistant.helpers.images import get_icon_string
57from music_assistant.helpers.util import (
58 TaskManager,
59 get_ip_pton,
60 get_package_version,
61 is_hass_supervisor,
62 load_provider_module,
63)
64from music_assistant.models import ProviderInstanceType
65from music_assistant.models.music_provider import MusicProvider
66from music_assistant.models.player_provider import PlayerProvider
67
68if TYPE_CHECKING:
69 from types import TracebackType
70
71 from aiohttp import ClientSession
72 from music_assistant_models.config_entries import ProviderConfig
73
74 from music_assistant.models.core_controller import CoreController
75
76isdir = wrap(os.path.isdir)
77isfile = wrap(os.path.isfile)
78mkdirs = wrap(os.makedirs)
79rmfile = wrap(os.remove)
80listdir = wrap(os.listdir)
81rename = wrap(os.rename)
82
83EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
84EventSubscriptionType = tuple[
85 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
86]
87
88LOGGER = logging.getLogger(MASS_LOGGER_NAME)
89
90BASE_DIR = os.path.dirname(os.path.abspath(__file__))
91PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
92
93_R = TypeVar("_R")
94_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
95
96
97def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
98 """Type guard that returns true if a provider is a music provider."""
99 return provider.type == ProviderType.MUSIC
100
101
102def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
103 """Type guard that returns true if a provider is a player provider."""
104 return provider.type == ProviderType.PLAYER
105
106
107class MusicAssistant:
108 """Main MusicAssistant (Server) object."""
109
110 loop: asyncio.AbstractEventLoop
111 aiozc: AsyncZeroconf
112 config: ConfigController
113 webserver: WebserverController
114 cache: CacheController
115 metadata: MetaDataController
116 music: MusicController
117 players: PlayerController
118 player_queues: PlayerQueuesController
119 streams: StreamsController
120 _aiobrowser: AsyncServiceBrowser
121
122 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
123 """Initialize the MusicAssistant Server."""
124 self.storage_path = storage_path
125 self.cache_path = cache_path
126 self.safe_mode = safe_mode
127 # we dynamically register command handlers which can be consumed by the apis
128 self.command_handlers: dict[str, APICommandHandler] = {}
129 self._subscribers: set[EventSubscriptionType] = set()
130 self._provider_manifests: dict[str, ProviderManifest] = {}
131 self._providers: dict[str, ProviderInstanceType] = {}
132 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
133 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
134 self.closing = False
135 self.running_as_hass_addon: bool = False
136 self.version: str = "0.0.0"
137 self.dev_mode = (
138 os.environ.get("PYTHONDEVMODE") == "1"
139 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
140 )
141 self._http_session: ClientSession | None = None
142 self._http_session_no_ssl: ClientSession | None = None
143 self._mdns_locks: dict[str, asyncio.Lock] = {}
144
145 async def start(self) -> None:
146 """Start running the Music Assistant server."""
147 self.loop = asyncio.get_running_loop()
148 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
149 self.running_as_hass_addon = await is_hass_supervisor()
150 self.version = await get_package_version("music_assistant") or "0.0.0"
151 # setup config controller first and fetch important config values
152 self.config = ConfigController(self)
153 await self.config.setup()
154 # create shared zeroconf instance
155 zeroconf_interfaces = self.config.get_raw_core_config_value(
156 "players", CONF_ZEROCONF_INTERFACES, "default"
157 )
158 # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
159 use_all_interfaces = zeroconf_interfaces == "all"
160 self.aiozc = AsyncZeroconf(
161 ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
162 interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
163 )
164 # load all available providers from manifest files
165 await self.__load_provider_manifests()
166 # setup/migrate storage
167 await self._setup_storage()
168 LOGGER.info(
169 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
170 self.server_id,
171 self.version,
172 self.running_as_hass_addon,
173 self.safe_mode,
174 )
175 # setup other core controllers
176 self.cache = CacheController(self)
177 self.webserver = WebserverController(self)
178 self.metadata = MetaDataController(self)
179 self.music = MusicController(self)
180 self.players = PlayerController(self)
181 self.player_queues = PlayerQueuesController(self)
182 self.streams = StreamsController(self)
183 # add manifests for core controllers
184 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
185 controller: CoreController = getattr(self, controller_name)
186 self._provider_manifests[controller.domain] = controller.manifest
187 await self.cache.setup(await self.config.get_core_config("cache"))
188 # load streams controller early so we can abort if we can't load it
189 await self.streams.setup(await self.config.get_core_config("streams"))
190 await self.music.setup(await self.config.get_core_config("music"))
191 await self.metadata.setup(await self.config.get_core_config("metadata"))
192 await self.players.setup(await self.config.get_core_config("players"))
193 await self.player_queues.setup(await self.config.get_core_config("player_queues"))
194 # load webserver/api last so the api/frontend is
195 # not yet available while we're starting (or performing migrations)
196 self._register_api_commands()
197 await self.webserver.setup(await self.config.get_core_config("webserver"))
198 # setup discovery
199 await self._setup_discovery()
200 # load providers
201 if not self.safe_mode:
202 await self._load_providers()
203
204 async def stop(self) -> None:
205 """Stop running the music assistant server."""
206 LOGGER.info("Stop called, cleaning up...")
207 self.signal_event(EventType.SHUTDOWN)
208 self.closing = True
209 # cancel all running tasks
210 for task in self._tracked_tasks.values():
211 task.cancel()
212 # cleanup all providers
213 await asyncio.gather(
214 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
215 return_exceptions=True,
216 )
217 # stop core controllers
218 await self.streams.close()
219 await self.webserver.close()
220 await self.metadata.close()
221 await self.music.close()
222 await self.player_queues.close()
223 await self.players.close()
224 # cleanup cache and config
225 await self.config.close()
226 await self.cache.close()
227 # close/cleanup shared http session
228 if self._http_session:
229 self._http_session.detach()
230 if self._http_session.connector:
231 await self._http_session.connector.close()
232 if self._http_session_no_ssl:
233 self._http_session_no_ssl.detach()
234 if self._http_session_no_ssl.connector:
235 await self._http_session_no_ssl.connector.close()
236
237 @property
238 def server_id(self) -> str:
239 """Return unique ID of this server."""
240 if not self.config.initialized:
241 return ""
242 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
243
244 @property
245 def http_session(self) -> ClientSession:
246 """
247 Return the shared HTTP Client session (with SSL).
248
249 NOTE: May only be called from the event loop.
250 """
251 if self._http_session is None:
252 self._http_session = create_clientsession(self, verify_ssl=True)
253 return self._http_session
254
255 @property
256 def http_session_no_ssl(self) -> ClientSession:
257 """
258 Return the shared HTTP Client session (without SSL).
259
260 NOTE: May only be called from the event loop thread.
261 """
262 if self._http_session_no_ssl is None:
263 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
264 return self._http_session_no_ssl
265
266 @api_command("info")
267 def get_server_info(self) -> ServerInfoMessage:
268 """Return Info of this server."""
269 return ServerInfoMessage(
270 server_id=self.server_id,
271 server_version=self.version,
272 schema_version=API_SCHEMA_VERSION,
273 min_supported_schema_version=MIN_SCHEMA_VERSION,
274 base_url=self.webserver.base_url,
275 homeassistant_addon=self.running_as_hass_addon,
276 onboard_done=self.config.onboard_done,
277 )
278
279 @api_command("providers/manifests")
280 def get_provider_manifests(self) -> list[ProviderManifest]:
281 """Return all Provider manifests."""
282 return list(self._provider_manifests.values())
283
284 @api_command("providers/manifests/get")
285 def get_provider_manifest(self, domain: str) -> ProviderManifest:
286 """Return Provider manifests of single provider(domain)."""
287 return self._provider_manifests[domain]
288
289 @api_command("providers")
290 def get_providers(
291 self, provider_type: ProviderType | None = None
292 ) -> list[ProviderInstanceType]:
293 """
294 Return all loaded/running Providers (instances).
295
296 Optionally filtered by ProviderType.
297 Note that this applies user filters for music providers (for non admin users).
298 """
299 user = get_current_user()
300 user_provider_filter = (
301 user.provider_filter if user and user.role != UserRole.ADMIN else None
302 )
303 return [
304 x
305 for x in self._providers.values()
306 if (provider_type is None or provider_type == x.type)
307 # apply user provider filter
308 and (
309 not user_provider_filter
310 or x.instance_id in user_provider_filter
311 or x.type != ProviderType.MUSIC
312 )
313 ]
314
315 @api_command("logging/get", required_role=UserRole.ADMIN)
316 async def get_application_log(self) -> str:
317 """Return the application log from file."""
318 logfile = os.path.join(self.storage_path, "musicassistant.log")
319 async with aiofiles.open(logfile) as _file:
320 return str(await _file.read())
321
322 @property
323 def providers(self) -> list[ProviderInstanceType]:
324 """
325 Return all loaded/running Providers (instances).
326
327 Note that this skips user filters so may only be called from internal code.
328 """
329 return list(self._providers.values())
330
331 @overload
332 def get_provider(
333 self,
334 provider_instance_or_domain: str,
335 return_unavailable: bool = False,
336 provider_type: None = None,
337 ) -> ProviderInstanceType | None: ...
338
339 @overload
340 def get_provider(
341 self,
342 provider_instance_or_domain: str,
343 return_unavailable: bool = False,
344 *,
345 provider_type: type[_ProviderT],
346 ) -> _ProviderT | None: ...
347
348 def get_provider(
349 self,
350 provider_instance_or_domain: str,
351 return_unavailable: bool = False,
352 provider_type: type[_ProviderT] | None = None,
353 ) -> ProviderInstanceType | _ProviderT | None:
354 """Return provider by instance id or domain.
355
356 :param provider_instance_or_domain: Instance ID or domain of the provider.
357 :param return_unavailable: Also return unavailable providers.
358 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
359 """
360 # lookup by instance_id first
361 if prov := self._providers.get(provider_instance_or_domain):
362 if return_unavailable or prov.available:
363 return prov
364 if not getattr(prov, "is_streaming_provider", None):
365 # no need to lookup other instances because this provider has unique data
366 return None
367 provider_instance_or_domain = prov.domain
368 # fallback to match on domain
369 for prov in self._providers.values():
370 if prov.domain != provider_instance_or_domain:
371 continue
372 if return_unavailable or prov.available:
373 return prov
374 return None
375
376 def get_provider_instances(
377 self,
378 domain: str,
379 return_unavailable: bool = False,
380 provider_type: ProviderType | None = None,
381 ) -> list[ProviderInstanceType]:
382 """
383 Return all provider instances for a given domain.
384
385 Note that this skips user filters so may only be called from internal code.
386 """
387 return [
388 prov
389 for prov in self._providers.values()
390 if (provider_type is None or provider_type == prov.type)
391 and prov.domain == domain
392 and (return_unavailable or prov.available)
393 ]
394
395 def signal_event(
396 self,
397 event: EventType,
398 object_id: str | None = None,
399 data: Any = None,
400 ) -> None:
401 """Signal event to subscribers."""
402 if self.closing:
403 return
404
405 self.verify_event_loop_thread("signal_event")
406
407 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
408 # do not log queue time updated events because that is too chatty
409 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
410
411 event_obj = MassEvent(event=event, object_id=object_id, data=data)
412 for cb_func, event_filter, id_filter in self._subscribers:
413 if not (event_filter is None or event in event_filter):
414 continue
415 if not (id_filter is None or object_id in id_filter):
416 continue
417 if inspect.iscoroutinefunction(cb_func):
418 if TYPE_CHECKING:
419 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
420 self.create_task(cb_func, event_obj)
421 else:
422 if TYPE_CHECKING:
423 cb_func = cast("Callable[[MassEvent], None]", cb_func)
424 self.loop.call_soon_threadsafe(cb_func, event_obj)
425
426 def subscribe(
427 self,
428 cb_func: EventCallBackType,
429 event_filter: EventType | tuple[EventType, ...] | None = None,
430 id_filter: str | tuple[str, ...] | None = None,
431 ) -> Callable[[], None]:
432 """Add callback to event listeners.
433
434 Returns function to remove the listener.
435 :param cb_func: callback function or coroutine
436 :param event_filter: Optionally only listen for these events
437 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
438 """
439 if isinstance(event_filter, EventType):
440 event_filter = (event_filter,)
441 if isinstance(id_filter, str):
442 id_filter = (id_filter,)
443 listener = (cb_func, event_filter, id_filter)
444 self._subscribers.add(listener)
445
446 def remove_listener() -> None:
447 self._subscribers.remove(listener)
448
449 return remove_listener
450
451 def create_task(
452 self,
453 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
454 *args: Any,
455 task_id: str | None = None,
456 abort_existing: bool = False,
457 **kwargs: Any,
458 ) -> asyncio.Task[_R]:
459 """Create Task on (main) event loop from Coroutine(function).
460
461 Tasks created by this helper will be properly cancelled on stop.
462 """
463 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
464 # prevent duplicate tasks if task_id is given and already present
465 if abort_existing:
466 existing.cancel()
467 else:
468 return existing
469 self.verify_event_loop_thread("create_task")
470
471 if inspect.iscoroutinefunction(target):
472 # coroutine function
473 task = self.loop.create_task(target(*args, **kwargs))
474 elif inspect.iscoroutine(target):
475 # coroutine
476 task = self.loop.create_task(target)
477 elif callable(target):
478 raise RuntimeError("Function is not a coroutine or coroutine function")
479 else:
480 raise RuntimeError("Target is missing")
481
482 if task_id is None:
483 task_id = uuid4().hex
484
485 def task_done_callback(_task: asyncio.Task[Any]) -> None:
486 self._tracked_tasks.pop(task_id, None)
487 # log unhandled exceptions
488 if (
489 LOGGER.isEnabledFor(logging.DEBUG)
490 and not _task.cancelled()
491 and (err := _task.exception())
492 ):
493 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
494 LOGGER.warning(
495 "Exception in task %s - target: %s: %s",
496 task_name,
497 str(target),
498 str(err),
499 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
500 )
501
502 self._tracked_tasks[task_id] = task
503 task.add_done_callback(task_done_callback)
504 return task
505
506 def call_later(
507 self,
508 delay: float,
509 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
510 *args: Any,
511 task_id: str | None = None,
512 **kwargs: Any,
513 ) -> asyncio.TimerHandle:
514 """
515 Run callable/awaitable after given delay.
516
517 Use task_id for debouncing.
518 """
519 self.verify_event_loop_thread("call_later")
520
521 if not task_id:
522 task_id = uuid4().hex
523
524 if existing := self._tracked_timers.get(task_id):
525 existing.cancel()
526
527 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
528 self._tracked_timers.pop(task_id)
529 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
530
531 def _call_sync(_target: Callable[..., _R]) -> None:
532 self._tracked_timers.pop(task_id)
533 _target(*args, **kwargs)
534
535 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
536 # coroutine function
537 if TYPE_CHECKING:
538 target = cast("Coroutine[Any, Any, _R]", target)
539 handle = self.loop.call_later(delay, _create_task, target)
540 else:
541 # regular sync callable
542 if TYPE_CHECKING:
543 target = cast("Callable[..., _R]", target)
544 handle = self.loop.call_later(delay, _call_sync, target)
545 self._tracked_timers[task_id] = handle
546 return handle
547
548 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
549 """Get existing scheduled task."""
550 if existing := self._tracked_tasks.get(task_id):
551 # prevent duplicate tasks if task_id is given and already present
552 return existing
553 return None
554
555 def cancel_task(self, task_id: str) -> None:
556 """Cancel existing scheduled task."""
557 if existing := self._tracked_tasks.pop(task_id, None):
558 existing.cancel()
559
560 def cancel_timer(self, task_id: str) -> None:
561 """Cancel existing scheduled timer."""
562 if existing := self._tracked_timers.pop(task_id, None):
563 existing.cancel()
564
565 def register_api_command(
566 self,
567 command: str,
568 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
569 authenticated: bool = True,
570 required_role: str | None = None,
571 alias: bool = False,
572 ) -> Callable[[], None]:
573 """Dynamically register a command on the API.
574
575 :param command: The command name/path.
576 :param handler: The function to handle the command.
577 :param authenticated: Whether authentication is required (default: True).
578 :param required_role: Required user role ("admin" or "user")
579 None means any authenticated user.
580 :param alias: Whether this is an alias for backward compatibility (default: False).
581 Aliases are not shown in API documentation but remain functional.
582
583 Returns handle to unregister.
584 """
585 if command in self.command_handlers:
586 msg = f"Command {command} is already registered"
587 raise RuntimeError(msg)
588 self.command_handlers[command] = APICommandHandler.parse(
589 command, handler, authenticated, required_role, alias
590 )
591
592 def unregister() -> None:
593 self.command_handlers.pop(command, None)
594
595 return unregister
596
597 async def load_provider_config(
598 self,
599 prov_conf: ProviderConfig,
600 ) -> None:
601 """Try to load a provider and catch errors."""
602 # cancel existing (re)load timer if needed
603 task_id = f"load_provider_{prov_conf.instance_id}"
604 if existing := self._tracked_timers.pop(task_id, None):
605 existing.cancel()
606
607 await self._load_provider(prov_conf)
608
609 # (re)load any dependants
610 prov_configs = await self.config.get_provider_configs(include_values=True)
611 for dep_prov_conf in prov_configs:
612 if not dep_prov_conf.enabled:
613 continue
614 manifest = self.get_provider_manifest(dep_prov_conf.domain)
615 if not manifest.depends_on:
616 continue
617 if manifest.depends_on == prov_conf.domain:
618 await self._load_provider(dep_prov_conf)
619
620 async def load_provider(
621 self,
622 instance_id: str,
623 allow_retry: bool = False,
624 ) -> None:
625 """Try to load a provider and catch errors."""
626 try:
627 prov_conf = await self.config.get_provider_config(instance_id)
628 except KeyError:
629 # Was deleted before we could run
630 return
631
632 if not prov_conf.enabled:
633 # Was disabled before we could run
634 return
635
636 # cancel existing (re)load timer if needed
637 task_id = f"load_provider_{instance_id}"
638 if existing := self._tracked_timers.pop(task_id, None):
639 existing.cancel()
640
641 try:
642 await self.load_provider_config(prov_conf)
643 except Exception as exc:
644 # if loading failed, we store the error in the config object
645 # so we can show something useful to the user
646 prov_conf.last_error = str(exc)
647 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
648
649 # auto schedule a retry if the (re)load failed with a handled exception
650 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
651 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
652 if will_retry:
653 self.call_later(
654 120,
655 self.load_provider,
656 instance_id,
657 allow_retry,
658 task_id=task_id,
659 )
660 LOGGER.warning(
661 "Error loading provider(instance) %s: %s%s",
662 prov_conf.name or prov_conf.instance_id,
663 str(exc) or exc.__class__.__name__,
664 " (will be retried later)" if will_retry else "",
665 # log full stack trace if verbose logging is enabled
666 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
667 )
668 return
669
670 # (re)load any dependents if needed
671 for dep_prov in self.providers:
672 if dep_prov.available:
673 continue
674 if dep_prov.manifest.depends_on == prov_conf.domain:
675 await self.unload_provider(dep_prov.instance_id)
676
677 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
678 """Unload a provider."""
679 self.music.unschedule_provider_sync(instance_id)
680 if provider := self._providers.get(instance_id):
681 # remove mdns discovery if needed
682 if provider.manifest.mdns_discovery:
683 for mdns_type in provider.manifest.mdns_discovery:
684 self._aiobrowser.types.discard(mdns_type)
685 if isinstance(provider, PlayerProvider):
686 await self.players.on_provider_unload(provider)
687 if isinstance(provider, MusicProvider):
688 await self.music.on_provider_unload(provider)
689 # check if there are no other providers dependent of this provider
690 for dep_prov in self.providers:
691 if dep_prov.manifest.depends_on == provider.domain:
692 await self.unload_provider(dep_prov.instance_id)
693 if is_player_provider(provider):
694 # unregister all players of this provider
695 for player in provider.players:
696 await self.players.unregister(player.player_id, permanent=is_removed)
697 try:
698 await provider.unload(is_removed)
699 except Exception as err:
700 LOGGER.warning(
701 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
702 )
703 finally:
704 self._providers.pop(instance_id, None)
705 await self._update_available_providers_cache()
706 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
707
708 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
709 """Unload a provider when it got into trouble which needs user interaction."""
710 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
711 await self.unload_provider(instance_id)
712
713 async def run_provider_discovery(self, instance_id: str) -> None:
714 """
715 Run mDNS discovery for a given provider.
716
717 In case of a PlayerProvider, will also call its own discovery method.
718 """
719 provider = self.get_provider(instance_id, return_unavailable=False)
720 if not provider:
721 raise KeyError(f"Provider with instance ID {instance_id} not found")
722 if provider.manifest.mdns_discovery:
723 if provider.instance_id not in self._mdns_locks:
724 self._mdns_locks[provider.instance_id] = asyncio.Lock()
725 async with self._mdns_locks[provider.instance_id]:
726 for mdns_type in provider.manifest.mdns_discovery or []:
727 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
728 if mdns_type not in mdns_name or mdns_type == mdns_name:
729 continue
730 info = AsyncServiceInfo(mdns_type, mdns_name)
731 if await info.async_request(self.aiozc.zeroconf, 3000):
732 await provider.on_mdns_service_state_change(
733 mdns_name, ServiceStateChange.Added, info
734 )
735 if isinstance(provider, PlayerProvider):
736 await provider.discover_players()
737
738 def verify_event_loop_thread(self, what: str) -> None:
739 """Report and raise if we are not running in the event loop thread."""
740 if self.loop_thread_id != threading.get_ident():
741 raise RuntimeError(
742 f"Non-Async operation detected: {what} may only be called from the eventloop."
743 )
744
745 def _register_api_commands(self) -> None:
746 """Register all methods decorated as api_command within a class(instance)."""
747 for cls in (
748 self,
749 self.config,
750 self.metadata,
751 self.music,
752 self.players,
753 self.player_queues,
754 self.webserver,
755 self.webserver.auth,
756 ):
757 for attr_name in dir(cls):
758 if attr_name.startswith("__"):
759 continue
760 try:
761 obj = getattr(cls, attr_name)
762 except (AttributeError, RuntimeError):
763 # Skip properties that fail during initialization
764 continue
765 if hasattr(obj, "api_cmd"):
766 # method is decorated with our api decorator
767 authenticated = getattr(obj, "api_authenticated", True)
768 required_role = getattr(obj, "api_required_role", None)
769 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
770
771 async def _load_providers(self) -> None:
772 """Load providers from config."""
773 # create default config for any 'builtin' providers (e.g. URL provider)
774 for prov_manifest in self._provider_manifests.values():
775 if prov_manifest.type == ProviderType.CORE:
776 # core controllers are not real providers
777 continue
778 if not prov_manifest.builtin:
779 continue
780 await self.config.create_builtin_provider_config(prov_manifest.domain)
781 # handle default providers setup
782 self.config.set_default(CONF_DEFAULT_PROVIDERS_SETUP, set())
783 default_providers_setup = cast("set[str]", self.config.get(CONF_DEFAULT_PROVIDERS_SETUP))
784 changes_made = False
785 for default_provider, require_mdns in DEFAULT_PROVIDERS:
786 if default_provider in default_providers_setup:
787 # already processed/setup before, skip
788 continue
789 if not (manifest := self._provider_manifests.get(default_provider)):
790 continue
791 if require_mdns:
792 # if mdns discovery is required, check if we have seen any mdns entries
793 # for this provider before setting it up
794 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
795 if manifest.mdns_discovery and any(
796 mdns_type in mdns_name for mdns_type in manifest.mdns_discovery
797 ):
798 break
799 else:
800 continue
801 await self.config.create_builtin_provider_config(manifest.domain)
802 changes_made = True
803 # TEMP: migration - to be removed after 2.8 release
804 # enable all existing players of the default providers if they are not already enabled
805 # due to the linked protocol feature we introduced
806 for player_config in await self.config.get_player_configs(
807 provider=default_provider, include_disabled=True
808 ):
809 if player_config.enabled:
810 continue
811 await self.config.save_player_config(player_config.player_id, {"enabled": True})
812 default_providers_setup.add(default_provider)
813 if changes_made:
814 self.config.set(CONF_DEFAULT_PROVIDERS_SETUP, default_providers_setup)
815 self.config.save(True)
816
817 # load all configured (and enabled) providers
818 # builtin providers are loaded first (and awaited) before loading the rest
819 prov_configs = await self.config.get_provider_configs(include_values=True)
820 builtin_configs: list[ProviderConfig] = []
821 other_configs: list[ProviderConfig] = []
822 for prov_conf in prov_configs:
823 if not prov_conf.enabled:
824 continue
825 manifest = self._provider_manifests.get(prov_conf.domain)
826 if manifest and manifest.builtin:
827 builtin_configs.append(prov_conf)
828 else:
829 other_configs.append(prov_conf)
830
831 # load builtin providers first and wait for them to complete
832 await asyncio.gather(
833 *[self.load_provider(conf.instance_id, allow_retry=True) for conf in builtin_configs]
834 )
835
836 # load remaining providers concurrently via tasks
837 for prov_conf in other_configs:
838 # Use a task so we can load multiple providers at once.
839 # If a provider fails, that will not block the loading of other providers.
840 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
841
842 async def _load_provider(self, conf: ProviderConfig) -> None:
843 """Load (or reload) a provider."""
844 # if provider is already loaded, stop and unload it first
845 await self.unload_provider(conf.instance_id)
846 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
847 if not conf.enabled:
848 msg = "Provider is disabled"
849 raise SetupFailedError(msg)
850
851 # validate config
852 try:
853 conf.validate()
854 except (KeyError, ValueError, AttributeError, TypeError) as err:
855 msg = "Configuration is invalid"
856 raise SetupFailedError(msg) from err
857
858 domain = conf.domain
859 prov_manifest = self._provider_manifests.get(domain)
860 # check for other instances of this provider
861 existing = next((x for x in self.providers if x.domain == domain), None)
862 if existing and prov_manifest and not prov_manifest.multi_instance:
863 msg = f"Provider {domain} already loaded and only one instance allowed."
864 raise SetupFailedError(msg)
865 # check valid manifest (just in case)
866 if not prov_manifest:
867 msg = f"Provider {domain} manifest not found"
868 raise SetupFailedError(msg)
869
870 # handle dependency on other provider
871 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
872 # we can safely ignore this completely as the setup will be retried later
873 # automatically when the dependency is loaded
874 return
875
876 # try to setup the module
877 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
878 try:
879 async with asyncio.timeout(30):
880 provider = await prov_mod.setup(self, prov_manifest, conf)
881 except TimeoutError as err:
882 msg = f"Provider {domain} did not load within 30 seconds"
883 raise SetupFailedError(msg) from err
884
885 # run async setup
886 await provider.handle_async_init()
887
888 # if we reach this point, the provider loaded successfully
889 self._providers[provider.instance_id] = provider
890 LOGGER.info(
891 "Loaded %s provider %s",
892 provider.type.value,
893 provider.name,
894 )
895 provider.available = True
896
897 # adapt logging name if needed
898 provider._set_log_level_from_config(provider.config)
899
900 # execute post load actions
901 async def _on_provider_loaded() -> None:
902 await provider.loaded_in_mass()
903 await self.run_provider_discovery(provider.instance_id)
904 # push instance name to config (to persist it if it was autogenerated)
905 if provider.default_name != conf.default_name:
906 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
907
908 self.create_task(_on_provider_loaded())
909
910 # clear any previous error in config and signal update
911 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
912 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
913 await self._update_available_providers_cache()
914 if isinstance(provider, MusicProvider):
915 await self.music.on_provider_loaded(provider)
916 if isinstance(provider, PlayerProvider):
917 await self.players.on_provider_loaded(provider)
918
919 async def __load_provider_manifests(self) -> None:
920 """Preload all available provider manifest files."""
921
922 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
923 """Preload all available provider manifest files."""
924 # get files in subdirectory
925 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
926 file_path = os.path.join(provider_path, file_str)
927 if not await isfile(file_path):
928 continue
929 if file_str != "manifest.json":
930 continue
931 try:
932 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
933 # check for icon.svg file
934 if not provider_manifest.icon_svg:
935 icon_path = os.path.join(provider_path, "icon.svg")
936 if await isfile(icon_path):
937 provider_manifest.icon_svg = await get_icon_string(icon_path)
938 # check for dark_icon file
939 if not provider_manifest.icon_svg_dark:
940 icon_path = os.path.join(provider_path, "icon_dark.svg")
941 if await isfile(icon_path):
942 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
943 # check for icon_monochrome file
944 if not provider_manifest.icon_svg_monochrome:
945 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
946 if await isfile(icon_path):
947 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
948 # override Home Assistant provider if we're running as add-on
949 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
950 provider_manifest.builtin = True
951 provider_manifest.allow_disable = False
952
953 self._provider_manifests[provider_manifest.domain] = provider_manifest
954 LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
955 except Exception as exc:
956 LOGGER.exception(
957 "Error while loading manifest for provider %s",
958 provider_domain,
959 exc_info=exc,
960 )
961
962 async with TaskManager(self) as tg:
963 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
964 if dir_str.startswith("."):
965 # skip hidden directories
966 continue
967 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
968 if dir_str.startswith("_") and not self.dev_mode:
969 # only load demo/test providers if debug mode is enabled (e.g. for development)
970 continue
971 if not await isdir(dir_path):
972 continue
973 tg.create_task(load_provider_manifest(dir_str, dir_path))
974
975 async def _setup_discovery(self) -> None:
976 """Handle setup of MDNS discovery."""
977 # create a global mdns browser
978 all_types: set[str] = set()
979 for prov_manifest in self._provider_manifests.values():
980 if prov_manifest.mdns_discovery:
981 all_types.update(prov_manifest.mdns_discovery)
982 self._aiobrowser = AsyncServiceBrowser(
983 self.aiozc.zeroconf,
984 list(all_types),
985 handlers=[self._on_mdns_service_state_change],
986 )
987 # register MA itself on mdns to be discovered
988 zeroconf_type = "_mass._tcp.local."
989 server_id = self.server_id
990 LOGGER.debug("Starting Zeroconf broadcast...")
991 info = AsyncServiceInfo(
992 zeroconf_type,
993 name=f"{server_id}.{zeroconf_type}",
994 addresses=[await get_ip_pton(self.webserver.publish_ip)],
995 port=self.webserver.publish_port,
996 properties=self.get_server_info().to_dict(),
997 server="mass.local.",
998 )
999 try:
1000 existing = getattr(self, "mass_zc_service_set", None)
1001 if existing:
1002 await self.aiozc.async_update_service(info)
1003 else:
1004 await self.aiozc.async_register_service(info)
1005 self.mass_zc_service_set = True
1006 except NonUniqueNameException:
1007 LOGGER.error(
1008 "Music Assistant instance with identical name present in the local network!"
1009 )
1010
1011 def _on_mdns_service_state_change(
1012 self,
1013 zeroconf: Zeroconf,
1014 service_type: str,
1015 name: str,
1016 state_change: ServiceStateChange,
1017 ) -> None:
1018 """Handle MDNS service state callback."""
1019
1020 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
1021 if prov.instance_id not in self._mdns_locks:
1022 self._mdns_locks[prov.instance_id] = asyncio.Lock()
1023 if state_change == ServiceStateChange.Removed:
1024 info = None
1025 else:
1026 info = AsyncServiceInfo(service_type, name)
1027 await info.async_request(zeroconf, 3000)
1028 # use a lock per provider instance to avoid
1029 # race conditions in processing mdns events
1030 async with self._mdns_locks[prov.instance_id]:
1031 await prov.on_mdns_service_state_change(name, state_change, info)
1032
1033 LOGGER.log(
1034 VERBOSE_LOG_LEVEL,
1035 "Service %s of type %s state changed: %s",
1036 name,
1037 service_type,
1038 state_change,
1039 )
1040 for prov in self._providers.values():
1041 if not prov.manifest.mdns_discovery:
1042 continue
1043 if not prov.available:
1044 continue
1045 if service_type in prov.manifest.mdns_discovery:
1046 self.create_task(process_mdns_state_change(prov))
1047
1048 async def __aenter__(self) -> Self:
1049 """Return Context manager."""
1050 await self.start()
1051 return self
1052
1053 async def __aexit__(
1054 self,
1055 exc_type: type[BaseException] | None,
1056 exc_val: BaseException | None,
1057 exc_tb: TracebackType | None,
1058 ) -> bool | None:
1059 """Exit context manager."""
1060 await self.stop()
1061 return None
1062
1063 async def _update_available_providers_cache(self) -> None:
1064 """Update the global cache variable of loaded/available providers."""
1065 await set_global_cache_values(
1066 {
1067 "provider_domains": {x.domain for x in self.providers},
1068 "provider_instance_ids": {x.instance_id for x in self.providers},
1069 "available_providers": {
1070 *{x.domain for x in self.providers},
1071 *{x.instance_id for x in self.providers},
1072 },
1073 "unique_providers": self.music.get_unique_providers(),
1074 "streaming_providers": {
1075 x.domain
1076 for x in self.providers
1077 if is_music_provider(x) and x.is_streaming_provider
1078 },
1079 "non_streaming_providers": {
1080 x.instance_id
1081 for x in self.providers
1082 if not (is_music_provider(x) and x.is_streaming_provider)
1083 },
1084 }
1085 )
1086
1087 async def _setup_storage(self) -> None:
1088 """Handle Setup of storage/cache folder(s)."""
1089 if not await isdir(self.storage_path):
1090 await mkdirs(self.storage_path)
1091 if not await isdir(self.cache_path):
1092 await mkdirs(self.cache_path)
1093