/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from enum import Enum
13from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
14from uuid import uuid4
15
16import aiofiles
17from aiofiles.os import wrap
18from music_assistant_models.api import ServerInfoMessage
19from music_assistant_models.auth import UserRole
20from music_assistant_models.enums import EventType, ProviderType
21from music_assistant_models.errors import MusicAssistantError, SetupFailedError
22from music_assistant_models.event import MassEvent
23from music_assistant_models.helpers import set_global_cache_values
24from music_assistant_models.provider import ProviderManifest
25from zeroconf import (
26 InterfaceChoice,
27 IPVersion,
28 NonUniqueNameException,
29 ServiceStateChange,
30 Zeroconf,
31)
32from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
33
34from music_assistant.constants import (
35 API_SCHEMA_VERSION,
36 CONF_DEFAULT_PROVIDERS_SETUP,
37 CONF_PROVIDERS,
38 CONF_SERVER_ID,
39 CONF_ZEROCONF_INTERFACES,
40 CONFIGURABLE_CORE_CONTROLLERS,
41 DEFAULT_PROVIDERS,
42 MASS_LOGGER_NAME,
43 MIN_SCHEMA_VERSION,
44 VERBOSE_LOG_LEVEL,
45)
46from music_assistant.controllers.cache import CacheController
47from music_assistant.controllers.config import ConfigController
48from music_assistant.controllers.metadata import MetaDataController
49from music_assistant.controllers.music import MusicController
50from music_assistant.controllers.player_queues import PlayerQueuesController
51from music_assistant.controllers.players import PlayerController
52from music_assistant.controllers.streams import StreamsController
53from music_assistant.controllers.webserver import WebserverController
54from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
55from music_assistant.helpers.aiohttp_client import create_clientsession
56from music_assistant.helpers.api import APICommandHandler, api_command
57from music_assistant.helpers.images import get_icon_string
58from music_assistant.helpers.util import (
59 TaskManager,
60 get_ip_pton,
61 get_package_version,
62 is_hass_supervisor,
63 load_provider_module,
64)
65from music_assistant.models import ProviderInstanceType
66from music_assistant.models.music_provider import MusicProvider
67from music_assistant.models.player_provider import PlayerProvider
68
69if TYPE_CHECKING:
70 from types import TracebackType
71
72 from aiohttp import ClientSession
73 from music_assistant_models.config_entries import ProviderConfig
74
75 from music_assistant.models.core_controller import CoreController
76
77isdir = wrap(os.path.isdir)
78isfile = wrap(os.path.isfile)
79mkdirs = wrap(os.makedirs)
80rmfile = wrap(os.remove)
81listdir = wrap(os.listdir)
82rename = wrap(os.rename)
83
84EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
85EventSubscriptionType = tuple[
86 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
87]
88
89LOGGER = logging.getLogger(MASS_LOGGER_NAME)
90
91BASE_DIR = os.path.dirname(os.path.abspath(__file__))
92PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
93
94_R = TypeVar("_R")
95_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
96
97
98def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
99 """Type guard that returns true if a provider is a music provider."""
100 return provider.type == ProviderType.MUSIC
101
102
103def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
104 """Type guard that returns true if a provider is a player provider."""
105 return provider.type == ProviderType.PLAYER
106
107
108class CoreState(Enum):
109 """Enum representing the core state of the Music Assistant server."""
110
111 STARTING = "starting"
112 RUNNING = "running"
113 STOPPING = "stopping"
114 STOPPED = "stopped"
115
116
117class MusicAssistant:
118 """Main MusicAssistant (Server) object."""
119
120 loop: asyncio.AbstractEventLoop
121 aiozc: AsyncZeroconf
122 config: ConfigController
123 webserver: WebserverController
124 cache: CacheController
125 metadata: MetaDataController
126 music: MusicController
127 players: PlayerController
128 player_queues: PlayerQueuesController
129 streams: StreamsController
130 _aiobrowser: AsyncServiceBrowser
131
132 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
133 """Initialize the MusicAssistant Server."""
134 self._state = CoreState.STARTING
135 self.storage_path = storage_path
136 self.cache_path = cache_path
137 self.safe_mode = safe_mode
138 # we dynamically register command handlers which can be consumed by the apis
139 self.command_handlers: dict[str, APICommandHandler] = {}
140 self._subscribers: set[EventSubscriptionType] = set()
141 self._provider_manifests: dict[str, ProviderManifest] = {}
142 self._providers: dict[str, ProviderInstanceType] = {}
143 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
144 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
145 self.running_as_hass_addon: bool = False
146 self.version: str = "0.0.0"
147 self.logger = LOGGER
148 self.dev_mode = (
149 os.environ.get("PYTHONDEVMODE") == "1"
150 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
151 )
152 self._http_session: ClientSession | None = None
153 self._http_session_no_ssl: ClientSession | None = None
154 self._mdns_locks: dict[str, asyncio.Lock] = {}
155
156 async def start(self) -> None:
157 """Start running the Music Assistant server."""
158 self.loop = asyncio.get_running_loop()
159 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
160 self.running_as_hass_addon = await is_hass_supervisor()
161 self.version = await get_package_version("music_assistant") or "0.0.0"
162 # setup config controller first and fetch important config values
163 self.config = ConfigController(self)
164 await self.config.setup()
165 # create shared zeroconf instance
166 zeroconf_interfaces = self.config.get_raw_core_config_value(
167 "players", CONF_ZEROCONF_INTERFACES, "default"
168 )
169 # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
170 use_all_interfaces = zeroconf_interfaces == "all"
171 self.aiozc = AsyncZeroconf(
172 ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
173 interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
174 )
175 # load all available providers from manifest files
176 await self.__load_provider_manifests()
177 # setup/migrate storage
178 await self._setup_storage()
179 LOGGER.info(
180 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
181 self.server_id,
182 self.version,
183 self.running_as_hass_addon,
184 self.safe_mode,
185 )
186 # setup other core controllers
187 self.cache = CacheController(self)
188 self.webserver = WebserverController(self)
189 self.metadata = MetaDataController(self)
190 self.music = MusicController(self)
191 self.players = PlayerController(self)
192 self.player_queues = PlayerQueuesController(self)
193 self.streams = StreamsController(self)
194 # add manifests for core controllers
195 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
196 controller: CoreController = getattr(self, controller_name)
197 self._provider_manifests[controller.domain] = controller.manifest
198
199 # setup all core controllers in parallel
200 async def setup_controller(controller: CoreController) -> None:
201 await controller.setup(await self.config.get_core_config(controller.domain))
202 controller.initialized.set()
203
204 async with asyncio.TaskGroup() as tg:
205 tg.create_task(setup_controller(self.cache))
206 tg.create_task(setup_controller(self.streams))
207 tg.create_task(setup_controller(self.music))
208 tg.create_task(setup_controller(self.metadata))
209 tg.create_task(setup_controller(self.players))
210 tg.create_task(setup_controller(self.player_queues))
211
212 # load webserver/api now that the core controllers are setup and ready to be used
213 self._register_api_commands()
214 await self.webserver.setup(await self.config.get_core_config("webserver"))
215 # load builtin providers (always needed, also in safe mode)
216 await self._load_builtin_providers()
217 # setup discovery
218 await self._setup_discovery()
219 # load regular providers (skip when in safe mode)
220 # providers are loaded in background tasks so they won't block
221 # the startup if they fail or take a long time to load
222 if not self.safe_mode:
223 await self._load_providers()
224 # at this point we are fully up and running,
225 # set state to running to signal we're ready
226 self._state = CoreState.RUNNING
227
228 async def stop(self) -> None:
229 """Stop running the music assistant server."""
230 LOGGER.info("Stop called, cleaning up...")
231 self.signal_event(EventType.SHUTDOWN)
232 self._state = CoreState.STOPPING
233 # cancel all running tasks
234 for task in list(self._tracked_tasks.values()):
235 task.cancel()
236 # cleanup all providers
237 await asyncio.gather(
238 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
239 return_exceptions=True,
240 )
241 # stop core controllers
242 await self.streams.close()
243 await self.webserver.close()
244 await self.metadata.close()
245 await self.music.close()
246 await self.player_queues.close()
247 await self.players.close()
248 # cleanup cache and config
249 await self.config.close()
250 await self.cache.close()
251 # close/cleanup shared http session
252 if self._http_session:
253 self._http_session.detach()
254 if self._http_session.connector:
255 await self._http_session.connector.close()
256 if self._http_session_no_ssl:
257 self._http_session_no_ssl.detach()
258 if self._http_session_no_ssl.connector:
259 await self._http_session_no_ssl.connector.close()
260 self._state = CoreState.STOPPED
261
262 @property
263 def state(self) -> CoreState:
264 """Return current state of the core."""
265 return self._state
266
267 @property
268 def closing(self) -> bool:
269 """Return true if the server is (in the process of) closing."""
270 return self._state in (CoreState.STOPPING, CoreState.STOPPED)
271
272 @property
273 def server_id(self) -> str:
274 """Return unique ID of this server."""
275 if not self.config.initialized:
276 return ""
277 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
278
279 @property
280 def http_session(self) -> ClientSession:
281 """
282 Return the shared HTTP Client session (with SSL).
283
284 NOTE: May only be called from the event loop.
285 """
286 if self._http_session is None:
287 self._http_session = create_clientsession(self, verify_ssl=True)
288 return self._http_session
289
290 @property
291 def http_session_no_ssl(self) -> ClientSession:
292 """
293 Return the shared HTTP Client session (without SSL).
294
295 NOTE: May only be called from the event loop thread.
296 """
297 if self._http_session_no_ssl is None:
298 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
299 return self._http_session_no_ssl
300
301 @api_command("info")
302 def get_server_info(self) -> ServerInfoMessage:
303 """Return Info of this server."""
304 return ServerInfoMessage(
305 server_id=self.server_id,
306 server_version=self.version,
307 schema_version=API_SCHEMA_VERSION,
308 min_supported_schema_version=MIN_SCHEMA_VERSION,
309 base_url=self.webserver.base_url,
310 homeassistant_addon=self.running_as_hass_addon,
311 onboard_done=self.config.onboard_done,
312 )
313
314 @api_command("providers/manifests")
315 def get_provider_manifests(self) -> list[ProviderManifest]:
316 """Return all Provider manifests."""
317 return list(self._provider_manifests.values())
318
319 @api_command("providers/manifests/get")
320 def get_provider_manifest(self, instance_id_or_domain: str) -> ProviderManifest:
321 """Return Provider manifests of single provider(domain)."""
322 if instance_id_or_domain in self._provider_manifests:
323 return self._provider_manifests[instance_id_or_domain]
324 if provider := self.get_provider(instance_id_or_domain, return_unavailable=True):
325 return provider.manifest
326 raise KeyError(f"Provider manifest not found for {instance_id_or_domain}")
327
328 @api_command("providers")
329 def get_providers(
330 self, provider_type: ProviderType | None = None
331 ) -> list[ProviderInstanceType]:
332 """
333 Return all loaded/running Providers (instances).
334
335 Optionally filtered by ProviderType.
336 Note that this applies user filters for music providers (for non admin users).
337 """
338 user = get_current_user()
339 user_provider_filter = (
340 user.provider_filter if user and user.role != UserRole.ADMIN else None
341 )
342 return [
343 x
344 for x in list(self._providers.values())
345 if (provider_type is None or provider_type == x.type)
346 # apply user provider filter
347 and (
348 not user_provider_filter
349 or x.instance_id in user_provider_filter
350 or x.type != ProviderType.MUSIC
351 )
352 ]
353
354 @api_command("logging/get", required_role=UserRole.ADMIN)
355 async def get_application_log(self) -> str:
356 """Return the application log from file."""
357 logfile = os.path.join(self.storage_path, "musicassistant.log")
358 async with aiofiles.open(logfile) as _file:
359 return str(await _file.read())
360
361 @property
362 def providers(self) -> list[ProviderInstanceType]:
363 """
364 Return all loaded/running Providers (instances).
365
366 Note that this skips user filters so may only be called from internal code.
367 """
368 return list(self._providers.values())
369
370 @overload
371 def get_provider(
372 self,
373 provider_instance_or_domain: str,
374 return_unavailable: bool = False,
375 provider_type: None = None,
376 ) -> ProviderInstanceType | None: ...
377
378 @overload
379 def get_provider(
380 self,
381 provider_instance_or_domain: str,
382 return_unavailable: bool = False,
383 *,
384 provider_type: type[_ProviderT],
385 ) -> _ProviderT | None: ...
386
387 def get_provider(
388 self,
389 provider_instance_or_domain: str,
390 return_unavailable: bool = False,
391 provider_type: type[_ProviderT] | None = None,
392 ) -> ProviderInstanceType | _ProviderT | None:
393 """Return provider by instance id or domain.
394
395 :param provider_instance_or_domain: Instance ID or domain of the provider.
396 :param return_unavailable: Also return unavailable providers.
397 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
398 """
399 # lookup by instance_id first
400 if prov := self._providers.get(provider_instance_or_domain):
401 if return_unavailable or prov.available:
402 return prov
403 if not getattr(prov, "is_streaming_provider", None):
404 # no need to lookup other instances because this provider has unique data
405 return None
406 provider_instance_or_domain = prov.domain
407 # fallback to match on domain
408 for prov in list(self._providers.values()):
409 if prov.domain != provider_instance_or_domain:
410 continue
411 if return_unavailable or prov.available:
412 return prov
413 return None
414
415 def get_provider_instances(
416 self,
417 domain: str,
418 return_unavailable: bool = False,
419 provider_type: ProviderType | None = None,
420 ) -> list[ProviderInstanceType]:
421 """
422 Return all provider instances for a given domain.
423
424 Note that this skips user filters so may only be called from internal code.
425 """
426 return [
427 prov
428 for prov in list(self._providers.values())
429 if (provider_type is None or provider_type == prov.type)
430 and prov.domain == domain
431 and (return_unavailable or prov.available)
432 ]
433
434 def signal_event(
435 self,
436 event: EventType,
437 object_id: str | None = None,
438 data: Any = None,
439 ) -> None:
440 """Signal event to subscribers."""
441 if self.closing:
442 return
443
444 self.verify_event_loop_thread("signal_event")
445
446 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
447 # do not log queue time updated events because that is too chatty
448 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
449
450 event_obj = MassEvent(event=event, object_id=object_id, data=data)
451 for cb_func, event_filter, id_filter in list(self._subscribers):
452 if not (event_filter is None or event in event_filter):
453 continue
454 if not (id_filter is None or object_id in id_filter):
455 continue
456 if inspect.iscoroutinefunction(cb_func):
457 if TYPE_CHECKING:
458 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
459 self.create_task(cb_func, event_obj)
460 else:
461 if TYPE_CHECKING:
462 cb_func = cast("Callable[[MassEvent], None]", cb_func)
463 self.loop.call_soon_threadsafe(cb_func, event_obj)
464
465 def subscribe(
466 self,
467 cb_func: EventCallBackType,
468 event_filter: EventType | tuple[EventType, ...] | None = None,
469 id_filter: str | tuple[str, ...] | None = None,
470 ) -> Callable[[], None]:
471 """Add callback to event listeners.
472
473 Returns function to remove the listener.
474 :param cb_func: callback function or coroutine
475 :param event_filter: Optionally only listen for these events
476 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
477 """
478 if isinstance(event_filter, EventType):
479 event_filter = (event_filter,)
480 if isinstance(id_filter, str):
481 id_filter = (id_filter,)
482 listener = (cb_func, event_filter, id_filter)
483 self._subscribers.add(listener)
484
485 def remove_listener() -> None:
486 self._subscribers.remove(listener)
487
488 return remove_listener
489
490 def create_task(
491 self,
492 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
493 *args: Any,
494 task_id: str | None = None,
495 abort_existing: bool = False,
496 eager_start: bool = True,
497 **kwargs: Any,
498 ) -> asyncio.Task[_R]:
499 """Create Task on (main) event loop from Coroutine(function).
500
501 Tasks created by this helper will be properly cancelled on stop.
502
503 :param target: Coroutine function or awaitable to run as a task.
504 :param args: Arguments to pass to the coroutine function.
505 :param task_id: Optional ID to track and deduplicate tasks.
506 :param abort_existing: If True, cancel existing task with same task_id.
507 :param eager_start: If True (default), start task immediately without waiting
508 for next event loop iteration. This ensures proper ordering
509 when creating multiple tasks in sequence.
510 :param kwargs: Keyword arguments to pass to the coroutine function.
511 """
512 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
513 # prevent duplicate tasks if task_id is given and already present
514 if abort_existing:
515 existing.cancel()
516 else:
517 return existing
518 self.verify_event_loop_thread("create_task")
519
520 if inspect.iscoroutinefunction(target):
521 # coroutine function
522 coro = target(*args, **kwargs)
523 elif inspect.iscoroutine(target):
524 # coroutine
525 coro = target
526 elif callable(target):
527 raise RuntimeError("Function is not a coroutine or coroutine function")
528 else:
529 raise RuntimeError("Target is missing")
530
531 # Use asyncio.Task directly with eager_start for immediate execution
532 task: asyncio.Task[_R] = asyncio.Task(coro, loop=self.loop, eager_start=eager_start)
533
534 if task_id is None:
535 task_id = uuid4().hex
536
537 def task_done_callback(_task: asyncio.Task[Any]) -> None:
538 self._tracked_tasks.pop(task_id, None)
539 # log unhandled exceptions
540 if (
541 LOGGER.isEnabledFor(logging.DEBUG)
542 and not _task.cancelled()
543 and (err := _task.exception())
544 ):
545 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
546 LOGGER.warning(
547 "Exception in task %s - target: %s: %s",
548 task_name,
549 str(target),
550 str(err),
551 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
552 )
553
554 self._tracked_tasks[task_id] = task
555 task.add_done_callback(task_done_callback)
556 return task
557
558 def call_later(
559 self,
560 delay: float,
561 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
562 *args: Any,
563 task_id: str | None = None,
564 **kwargs: Any,
565 ) -> asyncio.TimerHandle:
566 """
567 Run callable/awaitable after given delay.
568
569 Use task_id for debouncing.
570 """
571 self.verify_event_loop_thread("call_later")
572
573 if not task_id:
574 task_id = uuid4().hex
575
576 if existing := self._tracked_timers.get(task_id):
577 existing.cancel()
578
579 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
580 self._tracked_timers.pop(task_id)
581 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
582
583 def _call_sync(_target: Callable[..., _R]) -> None:
584 self._tracked_timers.pop(task_id)
585 _target(*args, **kwargs)
586
587 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
588 # coroutine function
589 if TYPE_CHECKING:
590 target = cast("Coroutine[Any, Any, _R]", target)
591 handle = self.loop.call_later(delay, _create_task, target)
592 else:
593 # regular sync callable
594 if TYPE_CHECKING:
595 target = cast("Callable[..., _R]", target)
596 handle = self.loop.call_later(delay, _call_sync, target)
597 self._tracked_timers[task_id] = handle
598 return handle
599
600 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
601 """Get existing scheduled task."""
602 if existing := self._tracked_tasks.get(task_id):
603 # prevent duplicate tasks if task_id is given and already present
604 return existing
605 return None
606
607 def cancel_task(self, task_id: str) -> None:
608 """Cancel existing scheduled task."""
609 if existing := self._tracked_tasks.pop(task_id, None):
610 existing.cancel()
611
612 def cancel_timer(self, task_id: str) -> None:
613 """Cancel existing scheduled timer."""
614 if existing := self._tracked_timers.pop(task_id, None):
615 existing.cancel()
616
617 def register_api_command(
618 self,
619 command: str,
620 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
621 authenticated: bool = True,
622 required_role: str | None = None,
623 alias: bool = False,
624 ) -> Callable[[], None]:
625 """Dynamically register a command on the API.
626
627 :param command: The command name/path.
628 :param handler: The function to handle the command.
629 :param authenticated: Whether authentication is required (default: True).
630 :param required_role: Required user role ("admin" or "user")
631 None means any authenticated user.
632 :param alias: Whether this is an alias for backward compatibility (default: False).
633 Aliases are not shown in API documentation but remain functional.
634
635 Returns handle to unregister.
636 """
637 if command in self.command_handlers:
638 msg = f"Command {command} is already registered"
639 raise RuntimeError(msg)
640 self.command_handlers[command] = APICommandHandler.parse(
641 command, handler, authenticated, required_role, alias
642 )
643
644 def unregister() -> None:
645 self.command_handlers.pop(command, None)
646
647 return unregister
648
649 async def load_provider_config(
650 self,
651 prov_conf: ProviderConfig,
652 ) -> None:
653 """Try to load a provider and catch errors."""
654 # cancel existing (re)load timer if needed
655 task_id = f"load_provider_{prov_conf.instance_id}"
656 if existing := self._tracked_timers.pop(task_id, None):
657 existing.cancel()
658
659 await self._load_provider(prov_conf)
660
661 # (re)load any dependants
662 prov_configs = await self.config.get_provider_configs(include_values=True)
663 for dep_prov_conf in prov_configs:
664 if not dep_prov_conf.enabled:
665 continue
666 manifest = self.get_provider_manifest(dep_prov_conf.domain)
667 if not manifest.depends_on:
668 continue
669 if manifest.depends_on == prov_conf.domain:
670 await self._load_provider(dep_prov_conf)
671
672 async def load_provider(
673 self,
674 instance_id: str,
675 allow_retry: bool = False,
676 ) -> None:
677 """Try to load a provider and catch errors."""
678 try:
679 prov_conf = await self.config.get_provider_config(instance_id)
680 except KeyError:
681 # Was deleted before we could run
682 return
683
684 if not prov_conf.enabled:
685 # Was disabled before we could run
686 return
687
688 # cancel existing (re)load timer if needed
689 task_id = f"load_provider_{instance_id}"
690 if existing := self._tracked_timers.pop(task_id, None):
691 existing.cancel()
692
693 try:
694 await self.load_provider_config(prov_conf)
695 except Exception as exc:
696 # if loading failed, we store the error in the config object
697 # so we can show something useful to the user
698 prov_conf.last_error = str(exc)
699 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
700
701 # auto schedule a retry if the (re)load failed with a handled exception
702 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
703 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
704 if will_retry:
705 self.call_later(
706 120,
707 self.load_provider,
708 instance_id,
709 allow_retry,
710 task_id=task_id,
711 )
712 LOGGER.warning(
713 "Error loading provider(instance) %s: %s%s",
714 prov_conf.name or prov_conf.instance_id,
715 str(exc) or exc.__class__.__name__,
716 " (will be retried later)" if will_retry else "",
717 # log full stack trace if verbose logging is enabled
718 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
719 )
720 return
721
722 # (re)load any dependents if needed
723 for dep_prov in self.providers:
724 if dep_prov.available:
725 continue
726 if dep_prov.manifest.depends_on == prov_conf.domain:
727 await self.unload_provider(dep_prov.instance_id)
728
729 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
730 """Unload a provider."""
731 self.music.unschedule_provider_sync(instance_id)
732 if provider := self._providers.get(instance_id):
733 # remove mdns discovery if needed
734 if provider.manifest.mdns_discovery:
735 for mdns_type in provider.manifest.mdns_discovery:
736 self._aiobrowser.types.discard(mdns_type)
737 if isinstance(provider, PlayerProvider):
738 await self.players.on_provider_unload(provider)
739 if isinstance(provider, MusicProvider):
740 await self.music.on_provider_unload(provider)
741 # check if there are no other providers dependent of this provider
742 for dep_prov in self.providers:
743 if dep_prov.manifest.depends_on == provider.domain:
744 await self.unload_provider(dep_prov.instance_id)
745 if is_player_provider(provider):
746 # unregister all players of this provider
747 for player in provider.players:
748 await self.players.unregister(player.player_id, permanent=is_removed)
749 try:
750 await provider.unload(is_removed)
751 except Exception as err:
752 LOGGER.warning(
753 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
754 )
755 finally:
756 self._providers.pop(instance_id, None)
757 await self._update_available_providers_cache()
758 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
759
760 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
761 """Unload a provider when it got into trouble which needs user interaction."""
762 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
763 await self.unload_provider(instance_id)
764
765 async def run_provider_discovery(self, instance_id: str) -> None:
766 """
767 Run mDNS discovery for a given provider.
768
769 In case of a PlayerProvider, will also call its own discovery method.
770 """
771 provider = self.get_provider(instance_id, return_unavailable=False)
772 if not provider:
773 raise KeyError(f"Provider with instance ID {instance_id} not found")
774 if provider.manifest.mdns_discovery:
775 if provider.instance_id not in self._mdns_locks:
776 self._mdns_locks[provider.instance_id] = asyncio.Lock()
777 async with self._mdns_locks[provider.instance_id]:
778 for mdns_type in provider.manifest.mdns_discovery or []:
779 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
780 if mdns_type not in mdns_name or mdns_type == mdns_name:
781 continue
782 info = AsyncServiceInfo(mdns_type, mdns_name)
783 if await info.async_request(self.aiozc.zeroconf, 3000):
784 await provider.on_mdns_service_state_change(
785 mdns_name, ServiceStateChange.Added, info
786 )
787 if isinstance(provider, PlayerProvider):
788 await provider.discover_players()
789
790 def verify_event_loop_thread(self, what: str) -> None:
791 """Report and raise if we are not running in the event loop thread."""
792 if self.loop_thread_id != threading.get_ident():
793 raise RuntimeError(
794 f"Non-Async operation detected: {what} may only be called from the eventloop."
795 )
796
797 def _register_api_commands(self) -> None:
798 """Register all methods decorated as api_command within a class(instance)."""
799 for cls in (
800 self,
801 self.config,
802 self.metadata,
803 self.music,
804 self.players,
805 self.player_queues,
806 self.webserver,
807 self.webserver.auth,
808 ):
809 for attr_name in dir(cls):
810 if attr_name.startswith("__"):
811 continue
812 try:
813 obj = getattr(cls, attr_name)
814 except (AttributeError, RuntimeError):
815 # Skip properties that fail during initialization
816 continue
817 if hasattr(obj, "api_cmd"):
818 # method is decorated with our api decorator
819 authenticated = getattr(obj, "api_authenticated", True)
820 required_role = getattr(obj, "api_required_role", None)
821 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
822
823 async def _load_builtin_providers(self) -> None:
824 """
825 Load all builtin providers.
826
827 Builtin providers are always needed (also in safe mode) and are fully awaited.
828 On error, setup will fail.
829 """
830 # create default config for any 'builtin' providers
831 for prov_manifest in self._provider_manifests.values():
832 if prov_manifest.type == ProviderType.CORE:
833 # core controllers are not real providers
834 continue
835 if not prov_manifest.builtin:
836 continue
837 await self.config.create_builtin_provider_config(prov_manifest.domain)
838
839 # load all configured (and enabled) builtin providers
840 prov_configs = await self.config.get_provider_configs(include_values=True)
841 builtin_configs: list[ProviderConfig] = [
842 prov_conf
843 for prov_conf in prov_configs
844 if (manifest := self._provider_manifests.get(prov_conf.domain))
845 and manifest.builtin
846 and (prov_conf.enabled or manifest.allow_disable is False)
847 ]
848
849 # load builtin providers and wait for them to complete
850 async with asyncio.TaskGroup() as tg:
851 for conf in builtin_configs:
852 tg.create_task(self.load_provider(conf.instance_id, allow_retry=True))
853
854 async def _load_providers(self) -> None:
855 """
856 Load regular (non-builtin) providers from config.
857
858 Regular providers are loaded in background tasks
859 and can fail without affecting core setup.
860 """
861 # handle default providers setup
862 self.config.set_default(CONF_DEFAULT_PROVIDERS_SETUP, set())
863 default_providers_setup = cast("set[str]", self.config.get(CONF_DEFAULT_PROVIDERS_SETUP))
864 changes_made = False
865 for default_provider, require_mdns in DEFAULT_PROVIDERS:
866 if default_provider in default_providers_setup:
867 # already processed/setup before, skip
868 continue
869 if not (manifest := self._provider_manifests.get(default_provider)):
870 continue
871 if require_mdns:
872 # if mdns discovery is required, check if we have seen any mdns entries
873 # for this provider before setting it up
874 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
875 if manifest.mdns_discovery and any(
876 mdns_type in mdns_name for mdns_type in manifest.mdns_discovery
877 ):
878 break
879 else:
880 continue
881 await self.config.create_builtin_provider_config(manifest.domain)
882 changes_made = True
883 # TEMP: migration - to be removed after 2.8 release
884 # enable all existing players of the default providers if they are not already enabled
885 # due to the linked protocol feature we introduced
886 for player_config in await self.config.get_player_configs(
887 provider=default_provider, include_disabled=True
888 ):
889 if player_config.enabled:
890 continue
891 await self.config.save_player_config(player_config.player_id, {"enabled": True})
892 default_providers_setup.add(default_provider)
893 if changes_made:
894 self.config.set(CONF_DEFAULT_PROVIDERS_SETUP, default_providers_setup)
895 self.config.save(True)
896 # load all configured (and enabled) regular (non-builtin) providers
897 prov_configs = await self.config.get_provider_configs(include_values=True)
898 other_configs: list[ProviderConfig] = [
899 prov_conf
900 for prov_conf in prov_configs
901 if prov_conf.enabled
902 and (
903 not (manifest := self._provider_manifests.get(prov_conf.domain))
904 or not manifest.builtin
905 )
906 ]
907 # load providers concurrently via tasks
908 for prov_conf in other_configs:
909 # Use a task so we can load multiple providers at once.
910 # If a provider fails, that will not block the loading of other providers.
911 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
912
913 async def _load_provider(self, conf: ProviderConfig) -> None:
914 """Load (or reload) a provider."""
915 # if provider is already loaded, stop and unload it first
916 await self.unload_provider(conf.instance_id)
917 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
918 if not conf.enabled:
919 msg = "Provider is disabled"
920 raise SetupFailedError(msg)
921
922 # validate config
923 try:
924 conf.validate()
925 except (KeyError, ValueError, AttributeError, TypeError) as err:
926 msg = "Configuration is invalid"
927 raise SetupFailedError(msg) from err
928
929 domain = conf.domain
930 prov_manifest = self._provider_manifests.get(domain)
931 # check for other instances of this provider
932 existing = next((x for x in self.providers if x.domain == domain), None)
933 if existing and prov_manifest and not prov_manifest.multi_instance:
934 msg = f"Provider {domain} already loaded and only one instance allowed."
935 raise SetupFailedError(msg)
936 # check valid manifest (just in case)
937 if not prov_manifest:
938 msg = f"Provider {domain} manifest not found"
939 raise SetupFailedError(msg)
940
941 # handle dependency on other provider
942 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
943 # we can safely ignore this completely as the setup will be retried later
944 # automatically when the dependency is loaded
945 return
946
947 # try to setup the module
948 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
949 try:
950 async with asyncio.timeout(30):
951 provider = await prov_mod.setup(self, prov_manifest, conf)
952 except TimeoutError as err:
953 msg = f"Provider {domain} did not load within 30 seconds"
954 raise SetupFailedError(msg) from err
955
956 # run async setup
957 await provider.handle_async_init()
958
959 # if we reach this point, the provider loaded successfully
960 self._providers[provider.instance_id] = provider
961 LOGGER.info(
962 "Loaded %s provider %s",
963 provider.type.value,
964 provider.name,
965 )
966 provider.available = True
967
968 # adapt logging name if needed
969 provider._set_log_level_from_config(provider.config)
970
971 # execute post load actions
972 async def _on_provider_loaded() -> None:
973 await provider.loaded_in_mass()
974 await self.run_provider_discovery(provider.instance_id)
975 # push instance name to config (to persist it if it was autogenerated)
976 if provider.default_name != conf.default_name:
977 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
978
979 self.create_task(_on_provider_loaded())
980
981 # clear any previous error in config and signal update
982 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
983 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
984 await self._update_available_providers_cache()
985 if isinstance(provider, MusicProvider):
986 await self.music.on_provider_loaded(provider)
987 if isinstance(provider, PlayerProvider):
988 await self.players.on_provider_loaded(provider)
989
990 async def __load_provider_manifests(self) -> None:
991 """Preload all available provider manifest files."""
992
993 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
994 """Preload all available provider manifest files."""
995 # get files in subdirectory
996 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
997 file_path = os.path.join(provider_path, file_str)
998 if not await isfile(file_path):
999 continue
1000 if file_str != "manifest.json":
1001 continue
1002 try:
1003 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
1004 # check for icon.svg file
1005 if not provider_manifest.icon_svg:
1006 icon_path = os.path.join(provider_path, "icon.svg")
1007 if await isfile(icon_path):
1008 provider_manifest.icon_svg = await get_icon_string(icon_path)
1009 # check for dark_icon file
1010 if not provider_manifest.icon_svg_dark:
1011 icon_path = os.path.join(provider_path, "icon_dark.svg")
1012 if await isfile(icon_path):
1013 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
1014 # check for icon_monochrome file
1015 if not provider_manifest.icon_svg_monochrome:
1016 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
1017 if await isfile(icon_path):
1018 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
1019 # override Home Assistant provider if we're running as add-on
1020 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
1021 provider_manifest.builtin = True
1022 provider_manifest.allow_disable = False
1023
1024 self._provider_manifests[provider_manifest.domain] = provider_manifest
1025 LOGGER.log(
1026 VERBOSE_LOG_LEVEL, "Loaded manifest for provider %s", provider_manifest.name
1027 )
1028 except Exception as exc:
1029 LOGGER.exception(
1030 "Error while loading manifest for provider %s",
1031 provider_domain,
1032 exc_info=exc,
1033 )
1034
1035 async with TaskManager(self) as tg:
1036 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
1037 if dir_str.startswith("."):
1038 # skip hidden directories
1039 continue
1040 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
1041 if dir_str.startswith("_") and not self.dev_mode:
1042 # only load demo/test providers if debug mode is enabled (e.g. for development)
1043 continue
1044 if not await isdir(dir_path):
1045 continue
1046 tg.create_task(load_provider_manifest(dir_str, dir_path))
1047 self.logger.debug("Loaded %s provider manifests", len(self._provider_manifests))
1048
1049 async def _setup_discovery(self) -> None:
1050 """Handle setup of MDNS discovery."""
1051 # create a global mdns browser
1052 all_types: set[str] = set()
1053 for prov_manifest in self._provider_manifests.values():
1054 if prov_manifest.mdns_discovery:
1055 all_types.update(prov_manifest.mdns_discovery)
1056 self._aiobrowser = AsyncServiceBrowser(
1057 self.aiozc.zeroconf,
1058 list(all_types),
1059 handlers=[self._on_mdns_service_state_change],
1060 )
1061 # register MA itself on mdns to be discovered
1062 zeroconf_type = "_mass._tcp.local."
1063 server_id = self.server_id
1064 LOGGER.debug("Starting Zeroconf broadcast...")
1065 info = AsyncServiceInfo(
1066 zeroconf_type,
1067 name=f"{server_id}.{zeroconf_type}",
1068 addresses=[await get_ip_pton(self.webserver.publish_ip)],
1069 port=self.webserver.publish_port,
1070 properties=self.get_server_info().to_dict(),
1071 server="mass.local.",
1072 )
1073 try:
1074 existing = getattr(self, "mass_zc_service_set", None)
1075 if existing:
1076 await self.aiozc.async_update_service(info)
1077 else:
1078 await self.aiozc.async_register_service(info)
1079 self.mass_zc_service_set = True
1080 except NonUniqueNameException:
1081 LOGGER.error(
1082 "Music Assistant instance with identical name present in the local network!"
1083 )
1084
1085 def _on_mdns_service_state_change(
1086 self,
1087 zeroconf: Zeroconf,
1088 service_type: str,
1089 name: str,
1090 state_change: ServiceStateChange,
1091 ) -> None:
1092 """Handle MDNS service state callback."""
1093
1094 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
1095 if prov.instance_id not in self._mdns_locks:
1096 self._mdns_locks[prov.instance_id] = asyncio.Lock()
1097 if state_change == ServiceStateChange.Removed:
1098 info = None
1099 else:
1100 info = AsyncServiceInfo(service_type, name)
1101 await info.async_request(zeroconf, 3000)
1102 # use a lock per provider instance to avoid
1103 # race conditions in processing mdns events
1104 async with self._mdns_locks[prov.instance_id]:
1105 await prov.on_mdns_service_state_change(name, state_change, info)
1106
1107 LOGGER.log(
1108 VERBOSE_LOG_LEVEL,
1109 "Service %s of type %s state changed: %s",
1110 name,
1111 service_type,
1112 state_change,
1113 )
1114 for prov in list(self._providers.values()):
1115 if not prov.manifest.mdns_discovery:
1116 continue
1117 if not prov.available:
1118 continue
1119 if service_type in prov.manifest.mdns_discovery:
1120 self.create_task(process_mdns_state_change(prov))
1121
1122 async def __aenter__(self) -> Self:
1123 """Return Context manager."""
1124 await self.start()
1125 return self
1126
1127 async def __aexit__(
1128 self,
1129 exc_type: type[BaseException] | None,
1130 exc_val: BaseException | None,
1131 exc_tb: TracebackType | None,
1132 ) -> bool | None:
1133 """Exit context manager."""
1134 await self.stop()
1135 return None
1136
1137 async def _update_available_providers_cache(self) -> None:
1138 """Update the global cache variable of loaded/available providers."""
1139 await set_global_cache_values(
1140 {
1141 "provider_domains": {x.domain for x in self.providers},
1142 "provider_instance_ids": {x.instance_id for x in self.providers},
1143 "available_providers": {
1144 *{x.domain for x in self.providers},
1145 *{x.instance_id for x in self.providers},
1146 },
1147 "unique_providers": self.music.get_unique_providers(),
1148 "streaming_providers": {
1149 x.domain
1150 for x in self.providers
1151 if is_music_provider(x) and x.is_streaming_provider
1152 },
1153 "non_streaming_providers": {
1154 x.instance_id
1155 for x in self.providers
1156 if not (is_music_provider(x) and x.is_streaming_provider)
1157 },
1158 }
1159 )
1160
1161 async def _setup_storage(self) -> None:
1162 """Handle Setup of storage/cache folder(s)."""
1163 if not await isdir(self.storage_path):
1164 await mkdirs(self.storage_path)
1165 if not await isdir(self.cache_path):
1166 await mkdirs(self.cache_path)
1167