/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from enum import Enum
13from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
14from uuid import uuid4
15
16import aiofiles
17from aiofiles.os import wrap
18from music_assistant_models.api import ServerInfoMessage
19from music_assistant_models.auth import UserRole
20from music_assistant_models.enums import EventType, ProviderType
21from music_assistant_models.errors import MusicAssistantError, SetupFailedError
22from music_assistant_models.event import MassEvent
23from music_assistant_models.helpers import set_global_cache_values
24from music_assistant_models.provider import ProviderManifest
25from zeroconf import (
26 InterfaceChoice,
27 IPVersion,
28 NonUniqueNameException,
29 ServiceStateChange,
30 Zeroconf,
31)
32from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
33
34from music_assistant.constants import (
35 API_SCHEMA_VERSION,
36 CONF_DEFAULT_PROVIDERS_SETUP,
37 CONF_PROVIDERS,
38 CONF_SERVER_ID,
39 CONF_ZEROCONF_INTERFACES,
40 CONFIGURABLE_CORE_CONTROLLERS,
41 DEFAULT_PROVIDERS,
42 MASS_LOGGER_NAME,
43 MIN_SCHEMA_VERSION,
44 VERBOSE_LOG_LEVEL,
45)
46from music_assistant.controllers.cache import CacheController
47from music_assistant.controllers.config import ConfigController
48from music_assistant.controllers.metadata import MetaDataController
49from music_assistant.controllers.music import MusicController
50from music_assistant.controllers.player_queues import PlayerQueuesController
51from music_assistant.controllers.players import PlayerController
52from music_assistant.controllers.streams import StreamsController
53from music_assistant.controllers.webserver import WebserverController
54from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
55from music_assistant.helpers.aiohttp_client import create_clientsession
56from music_assistant.helpers.api import APICommandHandler, api_command
57from music_assistant.helpers.images import get_icon_string
58from music_assistant.helpers.util import (
59 TaskManager,
60 get_ip_pton,
61 get_package_version,
62 is_hass_supervisor,
63 load_provider_module,
64)
65from music_assistant.models import ProviderInstanceType
66from music_assistant.models.music_provider import MusicProvider
67from music_assistant.models.player_provider import PlayerProvider
68
69if TYPE_CHECKING:
70 from types import TracebackType
71
72 from aiohttp import ClientSession
73 from music_assistant_models.config_entries import ProviderConfig
74
75 from music_assistant.models.core_controller import CoreController
76
77isdir = wrap(os.path.isdir)
78isfile = wrap(os.path.isfile)
79mkdirs = wrap(os.makedirs)
80rmfile = wrap(os.remove)
81listdir = wrap(os.listdir)
82rename = wrap(os.rename)
83
84EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
85EventSubscriptionType = tuple[
86 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
87]
88
89LOGGER = logging.getLogger(MASS_LOGGER_NAME)
90
91BASE_DIR = os.path.dirname(os.path.abspath(__file__))
92PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
93
94_R = TypeVar("_R")
95_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
96
97
98def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
99 """Type guard that returns true if a provider is a music provider."""
100 return provider.type == ProviderType.MUSIC
101
102
103def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
104 """Type guard that returns true if a provider is a player provider."""
105 return provider.type == ProviderType.PLAYER
106
107
108class CoreState(Enum):
109 """Enum representing the core state of the Music Assistant server."""
110
111 STARTING = "starting"
112 RUNNING = "running"
113 STOPPING = "stopping"
114 STOPPED = "stopped"
115
116
117class MusicAssistant:
118 """Main MusicAssistant (Server) object."""
119
120 loop: asyncio.AbstractEventLoop
121 aiozc: AsyncZeroconf
122 config: ConfigController
123 webserver: WebserverController
124 cache: CacheController
125 metadata: MetaDataController
126 music: MusicController
127 players: PlayerController
128 player_queues: PlayerQueuesController
129 streams: StreamsController
130 _aiobrowser: AsyncServiceBrowser
131
132 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
133 """Initialize the MusicAssistant Server."""
134 self._state = CoreState.STARTING
135 self.storage_path = storage_path
136 self.cache_path = cache_path
137 self.safe_mode = safe_mode
138 # we dynamically register command handlers which can be consumed by the apis
139 self.command_handlers: dict[str, APICommandHandler] = {}
140 self._subscribers: set[EventSubscriptionType] = set()
141 self._provider_manifests: dict[str, ProviderManifest] = {}
142 self._providers: dict[str, ProviderInstanceType] = {}
143 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
144 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
145 self.running_as_hass_addon: bool = False
146 self.version: str = "0.0.0"
147 self.logger = LOGGER
148 self.dev_mode = (
149 os.environ.get("PYTHONDEVMODE") == "1"
150 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
151 )
152 self._http_session: ClientSession | None = None
153 self._http_session_no_ssl: ClientSession | None = None
154 self._mdns_locks: dict[str, asyncio.Lock] = {}
155
156 async def start(self) -> None:
157 """Start running the Music Assistant server."""
158 self.loop = asyncio.get_running_loop()
159 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
160 self.running_as_hass_addon = await is_hass_supervisor()
161 self.version = await get_package_version("music_assistant") or "0.0.0"
162 # setup config controller first and fetch important config values
163 self.config = ConfigController(self)
164 await self.config.setup()
165 # create shared zeroconf instance
166 zeroconf_interfaces = self.config.get_raw_core_config_value(
167 "players", CONF_ZEROCONF_INTERFACES, "default"
168 )
169 # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
170 use_all_interfaces = zeroconf_interfaces == "all"
171 self.aiozc = AsyncZeroconf(
172 ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
173 interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
174 )
175 # load all available providers from manifest files
176 await self.__load_provider_manifests()
177 # setup/migrate storage
178 await self._setup_storage()
179 LOGGER.info(
180 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
181 self.server_id,
182 self.version,
183 self.running_as_hass_addon,
184 self.safe_mode,
185 )
186 # setup other core controllers
187 self.cache = CacheController(self)
188 self.webserver = WebserverController(self)
189 self.metadata = MetaDataController(self)
190 self.music = MusicController(self)
191 self.players = PlayerController(self)
192 self.player_queues = PlayerQueuesController(self)
193 self.streams = StreamsController(self)
194 # add manifests for core controllers
195 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
196 controller: CoreController = getattr(self, controller_name)
197 self._provider_manifests[controller.domain] = controller.manifest
198 # load webserver/api first so the api/frontend is available as soon as possible,
199 # other controllers are not yet available while we're starting (or performing migrations)
200 self._register_api_commands()
201 await self.webserver.setup(await self.config.get_core_config("webserver"))
202 await self.cache.setup(await self.config.get_core_config("cache"))
203 await self.streams.setup(await self.config.get_core_config("streams"))
204 await self.music.setup(await self.config.get_core_config("music"))
205 await self.metadata.setup(await self.config.get_core_config("metadata"))
206 await self.players.setup(await self.config.get_core_config("players"))
207 await self.player_queues.setup(await self.config.get_core_config("player_queues"))
208 # load builtin providers (always needed, also in safe mode)
209 await self._load_builtin_providers()
210 # setup discovery
211 await self._setup_discovery()
212 # at this point we are fully up and running,
213 # set state to running to signal we're ready
214 self._state = CoreState.RUNNING
215 # load regular providers (skip when in safe mode)
216 # providers are loaded in background tasks so they won't block
217 # the startup if they fail or take a long time to load
218 if not self.safe_mode:
219 await self._load_providers()
220
221 async def stop(self) -> None:
222 """Stop running the music assistant server."""
223 LOGGER.info("Stop called, cleaning up...")
224 self.signal_event(EventType.SHUTDOWN)
225 self._state = CoreState.STOPPING
226 # cancel all running tasks
227 for task in list(self._tracked_tasks.values()):
228 task.cancel()
229 # cleanup all providers
230 await asyncio.gather(
231 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
232 return_exceptions=True,
233 )
234 # stop core controllers
235 await self.streams.close()
236 await self.webserver.close()
237 await self.metadata.close()
238 await self.music.close()
239 await self.player_queues.close()
240 await self.players.close()
241 # cleanup cache and config
242 await self.config.close()
243 await self.cache.close()
244 # close/cleanup shared http session
245 if self._http_session:
246 self._http_session.detach()
247 if self._http_session.connector:
248 await self._http_session.connector.close()
249 if self._http_session_no_ssl:
250 self._http_session_no_ssl.detach()
251 if self._http_session_no_ssl.connector:
252 await self._http_session_no_ssl.connector.close()
253 self._state = CoreState.STOPPED
254
255 @property
256 def state(self) -> CoreState:
257 """Return current state of the core."""
258 return self._state
259
260 @property
261 def closing(self) -> bool:
262 """Return true if the server is (in the process of) closing."""
263 return self._state in (CoreState.STOPPING, CoreState.STOPPED)
264
265 @property
266 def server_id(self) -> str:
267 """Return unique ID of this server."""
268 if not self.config.initialized:
269 return ""
270 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
271
272 @property
273 def http_session(self) -> ClientSession:
274 """
275 Return the shared HTTP Client session (with SSL).
276
277 NOTE: May only be called from the event loop.
278 """
279 if self._http_session is None:
280 self._http_session = create_clientsession(self, verify_ssl=True)
281 return self._http_session
282
283 @property
284 def http_session_no_ssl(self) -> ClientSession:
285 """
286 Return the shared HTTP Client session (without SSL).
287
288 NOTE: May only be called from the event loop thread.
289 """
290 if self._http_session_no_ssl is None:
291 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
292 return self._http_session_no_ssl
293
294 @api_command("info")
295 def get_server_info(self) -> ServerInfoMessage:
296 """Return Info of this server."""
297 return ServerInfoMessage(
298 server_id=self.server_id,
299 server_version=self.version,
300 schema_version=API_SCHEMA_VERSION,
301 min_supported_schema_version=MIN_SCHEMA_VERSION,
302 base_url=self.webserver.base_url,
303 homeassistant_addon=self.running_as_hass_addon,
304 onboard_done=self.config.onboard_done,
305 )
306
307 @api_command("providers/manifests")
308 def get_provider_manifests(self) -> list[ProviderManifest]:
309 """Return all Provider manifests."""
310 return list(self._provider_manifests.values())
311
312 @api_command("providers/manifests/get")
313 def get_provider_manifest(self, instance_id_or_domain: str) -> ProviderManifest:
314 """Return Provider manifests of single provider(domain)."""
315 if instance_id_or_domain in self._provider_manifests:
316 return self._provider_manifests[instance_id_or_domain]
317 if provider := self.get_provider(instance_id_or_domain, return_unavailable=True):
318 return provider.manifest
319 raise KeyError(f"Provider manifest not found for {instance_id_or_domain}")
320
321 @api_command("providers")
322 def get_providers(
323 self, provider_type: ProviderType | None = None
324 ) -> list[ProviderInstanceType]:
325 """
326 Return all loaded/running Providers (instances).
327
328 Optionally filtered by ProviderType.
329 Note that this applies user filters for music providers (for non admin users).
330 """
331 user = get_current_user()
332 user_provider_filter = (
333 user.provider_filter if user and user.role != UserRole.ADMIN else None
334 )
335 return [
336 x
337 for x in list(self._providers.values())
338 if (provider_type is None or provider_type == x.type)
339 # apply user provider filter
340 and (
341 not user_provider_filter
342 or x.instance_id in user_provider_filter
343 or x.type != ProviderType.MUSIC
344 )
345 ]
346
347 @api_command("logging/get", required_role=UserRole.ADMIN)
348 async def get_application_log(self) -> str:
349 """Return the application log from file."""
350 logfile = os.path.join(self.storage_path, "musicassistant.log")
351 async with aiofiles.open(logfile) as _file:
352 return str(await _file.read())
353
354 @property
355 def providers(self) -> list[ProviderInstanceType]:
356 """
357 Return all loaded/running Providers (instances).
358
359 Note that this skips user filters so may only be called from internal code.
360 """
361 return list(self._providers.values())
362
363 @overload
364 def get_provider(
365 self,
366 provider_instance_or_domain: str,
367 return_unavailable: bool = False,
368 provider_type: None = None,
369 ) -> ProviderInstanceType | None: ...
370
371 @overload
372 def get_provider(
373 self,
374 provider_instance_or_domain: str,
375 return_unavailable: bool = False,
376 *,
377 provider_type: type[_ProviderT],
378 ) -> _ProviderT | None: ...
379
380 def get_provider(
381 self,
382 provider_instance_or_domain: str,
383 return_unavailable: bool = False,
384 provider_type: type[_ProviderT] | None = None,
385 ) -> ProviderInstanceType | _ProviderT | None:
386 """Return provider by instance id or domain.
387
388 :param provider_instance_or_domain: Instance ID or domain of the provider.
389 :param return_unavailable: Also return unavailable providers.
390 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
391 """
392 # lookup by instance_id first
393 if prov := self._providers.get(provider_instance_or_domain):
394 if return_unavailable or prov.available:
395 return prov
396 if not getattr(prov, "is_streaming_provider", None):
397 # no need to lookup other instances because this provider has unique data
398 return None
399 provider_instance_or_domain = prov.domain
400 # fallback to match on domain
401 for prov in list(self._providers.values()):
402 if prov.domain != provider_instance_or_domain:
403 continue
404 if return_unavailable or prov.available:
405 return prov
406 return None
407
408 def get_provider_instances(
409 self,
410 domain: str,
411 return_unavailable: bool = False,
412 provider_type: ProviderType | None = None,
413 ) -> list[ProviderInstanceType]:
414 """
415 Return all provider instances for a given domain.
416
417 Note that this skips user filters so may only be called from internal code.
418 """
419 return [
420 prov
421 for prov in list(self._providers.values())
422 if (provider_type is None or provider_type == prov.type)
423 and prov.domain == domain
424 and (return_unavailable or prov.available)
425 ]
426
427 def signal_event(
428 self,
429 event: EventType,
430 object_id: str | None = None,
431 data: Any = None,
432 ) -> None:
433 """Signal event to subscribers."""
434 if self.closing:
435 return
436
437 self.verify_event_loop_thread("signal_event")
438
439 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
440 # do not log queue time updated events because that is too chatty
441 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
442
443 event_obj = MassEvent(event=event, object_id=object_id, data=data)
444 for cb_func, event_filter, id_filter in list(self._subscribers):
445 if not (event_filter is None or event in event_filter):
446 continue
447 if not (id_filter is None or object_id in id_filter):
448 continue
449 if inspect.iscoroutinefunction(cb_func):
450 if TYPE_CHECKING:
451 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
452 self.create_task(cb_func, event_obj)
453 else:
454 if TYPE_CHECKING:
455 cb_func = cast("Callable[[MassEvent], None]", cb_func)
456 self.loop.call_soon_threadsafe(cb_func, event_obj)
457
458 def subscribe(
459 self,
460 cb_func: EventCallBackType,
461 event_filter: EventType | tuple[EventType, ...] | None = None,
462 id_filter: str | tuple[str, ...] | None = None,
463 ) -> Callable[[], None]:
464 """Add callback to event listeners.
465
466 Returns function to remove the listener.
467 :param cb_func: callback function or coroutine
468 :param event_filter: Optionally only listen for these events
469 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
470 """
471 if isinstance(event_filter, EventType):
472 event_filter = (event_filter,)
473 if isinstance(id_filter, str):
474 id_filter = (id_filter,)
475 listener = (cb_func, event_filter, id_filter)
476 self._subscribers.add(listener)
477
478 def remove_listener() -> None:
479 self._subscribers.remove(listener)
480
481 return remove_listener
482
483 def create_task(
484 self,
485 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
486 *args: Any,
487 task_id: str | None = None,
488 abort_existing: bool = False,
489 eager_start: bool = True,
490 **kwargs: Any,
491 ) -> asyncio.Task[_R]:
492 """Create Task on (main) event loop from Coroutine(function).
493
494 Tasks created by this helper will be properly cancelled on stop.
495
496 :param target: Coroutine function or awaitable to run as a task.
497 :param args: Arguments to pass to the coroutine function.
498 :param task_id: Optional ID to track and deduplicate tasks.
499 :param abort_existing: If True, cancel existing task with same task_id.
500 :param eager_start: If True (default), start task immediately without waiting
501 for next event loop iteration. This ensures proper ordering
502 when creating multiple tasks in sequence.
503 :param kwargs: Keyword arguments to pass to the coroutine function.
504 """
505 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
506 # prevent duplicate tasks if task_id is given and already present
507 if abort_existing:
508 existing.cancel()
509 else:
510 return existing
511 self.verify_event_loop_thread("create_task")
512
513 if inspect.iscoroutinefunction(target):
514 # coroutine function
515 coro = target(*args, **kwargs)
516 elif inspect.iscoroutine(target):
517 # coroutine
518 coro = target
519 elif callable(target):
520 raise RuntimeError("Function is not a coroutine or coroutine function")
521 else:
522 raise RuntimeError("Target is missing")
523
524 # Use asyncio.Task directly with eager_start for immediate execution
525 task: asyncio.Task[_R] = asyncio.Task(coro, loop=self.loop, eager_start=eager_start)
526
527 if task_id is None:
528 task_id = uuid4().hex
529
530 def task_done_callback(_task: asyncio.Task[Any]) -> None:
531 self._tracked_tasks.pop(task_id, None)
532 # log unhandled exceptions
533 if (
534 LOGGER.isEnabledFor(logging.DEBUG)
535 and not _task.cancelled()
536 and (err := _task.exception())
537 ):
538 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
539 LOGGER.warning(
540 "Exception in task %s - target: %s: %s",
541 task_name,
542 str(target),
543 str(err),
544 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
545 )
546
547 self._tracked_tasks[task_id] = task
548 task.add_done_callback(task_done_callback)
549 return task
550
551 def call_later(
552 self,
553 delay: float,
554 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
555 *args: Any,
556 task_id: str | None = None,
557 **kwargs: Any,
558 ) -> asyncio.TimerHandle:
559 """
560 Run callable/awaitable after given delay.
561
562 Use task_id for debouncing.
563 """
564 self.verify_event_loop_thread("call_later")
565
566 if not task_id:
567 task_id = uuid4().hex
568
569 if existing := self._tracked_timers.get(task_id):
570 existing.cancel()
571
572 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
573 self._tracked_timers.pop(task_id)
574 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
575
576 def _call_sync(_target: Callable[..., _R]) -> None:
577 self._tracked_timers.pop(task_id)
578 _target(*args, **kwargs)
579
580 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
581 # coroutine function
582 if TYPE_CHECKING:
583 target = cast("Coroutine[Any, Any, _R]", target)
584 handle = self.loop.call_later(delay, _create_task, target)
585 else:
586 # regular sync callable
587 if TYPE_CHECKING:
588 target = cast("Callable[..., _R]", target)
589 handle = self.loop.call_later(delay, _call_sync, target)
590 self._tracked_timers[task_id] = handle
591 return handle
592
593 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
594 """Get existing scheduled task."""
595 if existing := self._tracked_tasks.get(task_id):
596 # prevent duplicate tasks if task_id is given and already present
597 return existing
598 return None
599
600 def cancel_task(self, task_id: str) -> None:
601 """Cancel existing scheduled task."""
602 if existing := self._tracked_tasks.pop(task_id, None):
603 existing.cancel()
604
605 def cancel_timer(self, task_id: str) -> None:
606 """Cancel existing scheduled timer."""
607 if existing := self._tracked_timers.pop(task_id, None):
608 existing.cancel()
609
610 def register_api_command(
611 self,
612 command: str,
613 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
614 authenticated: bool = True,
615 required_role: str | None = None,
616 alias: bool = False,
617 ) -> Callable[[], None]:
618 """Dynamically register a command on the API.
619
620 :param command: The command name/path.
621 :param handler: The function to handle the command.
622 :param authenticated: Whether authentication is required (default: True).
623 :param required_role: Required user role ("admin" or "user")
624 None means any authenticated user.
625 :param alias: Whether this is an alias for backward compatibility (default: False).
626 Aliases are not shown in API documentation but remain functional.
627
628 Returns handle to unregister.
629 """
630 if command in self.command_handlers:
631 msg = f"Command {command} is already registered"
632 raise RuntimeError(msg)
633 self.command_handlers[command] = APICommandHandler.parse(
634 command, handler, authenticated, required_role, alias
635 )
636
637 def unregister() -> None:
638 self.command_handlers.pop(command, None)
639
640 return unregister
641
642 async def load_provider_config(
643 self,
644 prov_conf: ProviderConfig,
645 ) -> None:
646 """Try to load a provider and catch errors."""
647 # cancel existing (re)load timer if needed
648 task_id = f"load_provider_{prov_conf.instance_id}"
649 if existing := self._tracked_timers.pop(task_id, None):
650 existing.cancel()
651
652 await self._load_provider(prov_conf)
653
654 # (re)load any dependants
655 prov_configs = await self.config.get_provider_configs(include_values=True)
656 for dep_prov_conf in prov_configs:
657 if not dep_prov_conf.enabled:
658 continue
659 manifest = self.get_provider_manifest(dep_prov_conf.domain)
660 if not manifest.depends_on:
661 continue
662 if manifest.depends_on == prov_conf.domain:
663 await self._load_provider(dep_prov_conf)
664
665 async def load_provider(
666 self,
667 instance_id: str,
668 allow_retry: bool = False,
669 ) -> None:
670 """Try to load a provider and catch errors."""
671 try:
672 prov_conf = await self.config.get_provider_config(instance_id)
673 except KeyError:
674 # Was deleted before we could run
675 return
676
677 if not prov_conf.enabled:
678 # Was disabled before we could run
679 return
680
681 # cancel existing (re)load timer if needed
682 task_id = f"load_provider_{instance_id}"
683 if existing := self._tracked_timers.pop(task_id, None):
684 existing.cancel()
685
686 try:
687 await self.load_provider_config(prov_conf)
688 except Exception as exc:
689 # if loading failed, we store the error in the config object
690 # so we can show something useful to the user
691 prov_conf.last_error = str(exc)
692 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
693
694 # auto schedule a retry if the (re)load failed with a handled exception
695 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
696 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
697 if will_retry:
698 self.call_later(
699 120,
700 self.load_provider,
701 instance_id,
702 allow_retry,
703 task_id=task_id,
704 )
705 LOGGER.warning(
706 "Error loading provider(instance) %s: %s%s",
707 prov_conf.name or prov_conf.instance_id,
708 str(exc) or exc.__class__.__name__,
709 " (will be retried later)" if will_retry else "",
710 # log full stack trace if verbose logging is enabled
711 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
712 )
713 return
714
715 # (re)load any dependents if needed
716 for dep_prov in self.providers:
717 if dep_prov.available:
718 continue
719 if dep_prov.manifest.depends_on == prov_conf.domain:
720 await self.unload_provider(dep_prov.instance_id)
721
722 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
723 """Unload a provider."""
724 self.music.unschedule_provider_sync(instance_id)
725 if provider := self._providers.get(instance_id):
726 # remove mdns discovery if needed
727 if provider.manifest.mdns_discovery:
728 for mdns_type in provider.manifest.mdns_discovery:
729 self._aiobrowser.types.discard(mdns_type)
730 if isinstance(provider, PlayerProvider):
731 await self.players.on_provider_unload(provider)
732 if isinstance(provider, MusicProvider):
733 await self.music.on_provider_unload(provider)
734 # check if there are no other providers dependent of this provider
735 for dep_prov in self.providers:
736 if dep_prov.manifest.depends_on == provider.domain:
737 await self.unload_provider(dep_prov.instance_id)
738 if is_player_provider(provider):
739 # unregister all players of this provider
740 for player in provider.players:
741 await self.players.unregister(player.player_id, permanent=is_removed)
742 try:
743 await provider.unload(is_removed)
744 except Exception as err:
745 LOGGER.warning(
746 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
747 )
748 finally:
749 self._providers.pop(instance_id, None)
750 await self._update_available_providers_cache()
751 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
752
753 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
754 """Unload a provider when it got into trouble which needs user interaction."""
755 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
756 await self.unload_provider(instance_id)
757
758 async def run_provider_discovery(self, instance_id: str) -> None:
759 """
760 Run mDNS discovery for a given provider.
761
762 In case of a PlayerProvider, will also call its own discovery method.
763 """
764 provider = self.get_provider(instance_id, return_unavailable=False)
765 if not provider:
766 raise KeyError(f"Provider with instance ID {instance_id} not found")
767 if provider.manifest.mdns_discovery:
768 if provider.instance_id not in self._mdns_locks:
769 self._mdns_locks[provider.instance_id] = asyncio.Lock()
770 async with self._mdns_locks[provider.instance_id]:
771 for mdns_type in provider.manifest.mdns_discovery or []:
772 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
773 if mdns_type not in mdns_name or mdns_type == mdns_name:
774 continue
775 info = AsyncServiceInfo(mdns_type, mdns_name)
776 if await info.async_request(self.aiozc.zeroconf, 3000):
777 await provider.on_mdns_service_state_change(
778 mdns_name, ServiceStateChange.Added, info
779 )
780 if isinstance(provider, PlayerProvider):
781 await provider.discover_players()
782
783 def verify_event_loop_thread(self, what: str) -> None:
784 """Report and raise if we are not running in the event loop thread."""
785 if self.loop_thread_id != threading.get_ident():
786 raise RuntimeError(
787 f"Non-Async operation detected: {what} may only be called from the eventloop."
788 )
789
790 def _register_api_commands(self) -> None:
791 """Register all methods decorated as api_command within a class(instance)."""
792 for cls in (
793 self,
794 self.config,
795 self.metadata,
796 self.music,
797 self.players,
798 self.player_queues,
799 self.webserver,
800 self.webserver.auth,
801 ):
802 for attr_name in dir(cls):
803 if attr_name.startswith("__"):
804 continue
805 try:
806 obj = getattr(cls, attr_name)
807 except (AttributeError, RuntimeError):
808 # Skip properties that fail during initialization
809 continue
810 if hasattr(obj, "api_cmd"):
811 # method is decorated with our api decorator
812 authenticated = getattr(obj, "api_authenticated", True)
813 required_role = getattr(obj, "api_required_role", None)
814 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
815
816 async def _load_builtin_providers(self) -> None:
817 """
818 Load all builtin providers.
819
820 Builtin providers are always needed (also in safe mode) and are fully awaited.
821 On error, setup will fail.
822 """
823 # create default config for any 'builtin' providers
824 for prov_manifest in self._provider_manifests.values():
825 if prov_manifest.type == ProviderType.CORE:
826 # core controllers are not real providers
827 continue
828 if not prov_manifest.builtin:
829 continue
830 await self.config.create_builtin_provider_config(prov_manifest.domain)
831
832 # load all configured (and enabled) builtin providers
833 prov_configs = await self.config.get_provider_configs(include_values=True)
834 builtin_configs: list[ProviderConfig] = [
835 prov_conf
836 for prov_conf in prov_configs
837 if (manifest := self._provider_manifests.get(prov_conf.domain))
838 and manifest.builtin
839 and (prov_conf.enabled or manifest.allow_disable is False)
840 ]
841
842 # load builtin providers and wait for them to complete
843 await asyncio.gather(
844 *[self.load_provider(conf.instance_id, allow_retry=True) for conf in builtin_configs]
845 )
846
847 async def _load_providers(self) -> None:
848 """
849 Load regular (non-builtin) providers from config.
850
851 Regular providers are loaded in background tasks
852 and can fail without affecting core setup.
853 """
854 # handle default providers setup
855 self.config.set_default(CONF_DEFAULT_PROVIDERS_SETUP, set())
856 default_providers_setup = cast("set[str]", self.config.get(CONF_DEFAULT_PROVIDERS_SETUP))
857 changes_made = False
858 for default_provider, require_mdns in DEFAULT_PROVIDERS:
859 if default_provider in default_providers_setup:
860 # already processed/setup before, skip
861 continue
862 if not (manifest := self._provider_manifests.get(default_provider)):
863 continue
864 if require_mdns:
865 # if mdns discovery is required, check if we have seen any mdns entries
866 # for this provider before setting it up
867 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
868 if manifest.mdns_discovery and any(
869 mdns_type in mdns_name for mdns_type in manifest.mdns_discovery
870 ):
871 break
872 else:
873 continue
874 await self.config.create_builtin_provider_config(manifest.domain)
875 changes_made = True
876 # TEMP: migration - to be removed after 2.8 release
877 # enable all existing players of the default providers if they are not already enabled
878 # due to the linked protocol feature we introduced
879 for player_config in await self.config.get_player_configs(
880 provider=default_provider, include_disabled=True
881 ):
882 if player_config.enabled:
883 continue
884 await self.config.save_player_config(player_config.player_id, {"enabled": True})
885 default_providers_setup.add(default_provider)
886 if changes_made:
887 self.config.set(CONF_DEFAULT_PROVIDERS_SETUP, default_providers_setup)
888 self.config.save(True)
889 # load all configured (and enabled) regular (non-builtin) providers
890 prov_configs = await self.config.get_provider_configs(include_values=True)
891 other_configs: list[ProviderConfig] = [
892 prov_conf
893 for prov_conf in prov_configs
894 if prov_conf.enabled
895 and (
896 not (manifest := self._provider_manifests.get(prov_conf.domain))
897 or not manifest.builtin
898 )
899 ]
900 # load providers concurrently via tasks
901 for prov_conf in other_configs:
902 # Use a task so we can load multiple providers at once.
903 # If a provider fails, that will not block the loading of other providers.
904 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
905
906 async def _load_provider(self, conf: ProviderConfig) -> None:
907 """Load (or reload) a provider."""
908 # if provider is already loaded, stop and unload it first
909 await self.unload_provider(conf.instance_id)
910 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
911 if not conf.enabled:
912 msg = "Provider is disabled"
913 raise SetupFailedError(msg)
914
915 # validate config
916 try:
917 conf.validate()
918 except (KeyError, ValueError, AttributeError, TypeError) as err:
919 msg = "Configuration is invalid"
920 raise SetupFailedError(msg) from err
921
922 domain = conf.domain
923 prov_manifest = self._provider_manifests.get(domain)
924 # check for other instances of this provider
925 existing = next((x for x in self.providers if x.domain == domain), None)
926 if existing and prov_manifest and not prov_manifest.multi_instance:
927 msg = f"Provider {domain} already loaded and only one instance allowed."
928 raise SetupFailedError(msg)
929 # check valid manifest (just in case)
930 if not prov_manifest:
931 msg = f"Provider {domain} manifest not found"
932 raise SetupFailedError(msg)
933
934 # handle dependency on other provider
935 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
936 # we can safely ignore this completely as the setup will be retried later
937 # automatically when the dependency is loaded
938 return
939
940 # try to setup the module
941 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
942 try:
943 async with asyncio.timeout(30):
944 provider = await prov_mod.setup(self, prov_manifest, conf)
945 except TimeoutError as err:
946 msg = f"Provider {domain} did not load within 30 seconds"
947 raise SetupFailedError(msg) from err
948
949 # run async setup
950 await provider.handle_async_init()
951
952 # if we reach this point, the provider loaded successfully
953 self._providers[provider.instance_id] = provider
954 LOGGER.info(
955 "Loaded %s provider %s",
956 provider.type.value,
957 provider.name,
958 )
959 provider.available = True
960
961 # adapt logging name if needed
962 provider._set_log_level_from_config(provider.config)
963
964 # execute post load actions
965 async def _on_provider_loaded() -> None:
966 await provider.loaded_in_mass()
967 await self.run_provider_discovery(provider.instance_id)
968 # push instance name to config (to persist it if it was autogenerated)
969 if provider.default_name != conf.default_name:
970 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
971
972 self.create_task(_on_provider_loaded())
973
974 # clear any previous error in config and signal update
975 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
976 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
977 await self._update_available_providers_cache()
978 if isinstance(provider, MusicProvider):
979 await self.music.on_provider_loaded(provider)
980 if isinstance(provider, PlayerProvider):
981 await self.players.on_provider_loaded(provider)
982
983 async def __load_provider_manifests(self) -> None:
984 """Preload all available provider manifest files."""
985
986 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
987 """Preload all available provider manifest files."""
988 # get files in subdirectory
989 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
990 file_path = os.path.join(provider_path, file_str)
991 if not await isfile(file_path):
992 continue
993 if file_str != "manifest.json":
994 continue
995 try:
996 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
997 # check for icon.svg file
998 if not provider_manifest.icon_svg:
999 icon_path = os.path.join(provider_path, "icon.svg")
1000 if await isfile(icon_path):
1001 provider_manifest.icon_svg = await get_icon_string(icon_path)
1002 # check for dark_icon file
1003 if not provider_manifest.icon_svg_dark:
1004 icon_path = os.path.join(provider_path, "icon_dark.svg")
1005 if await isfile(icon_path):
1006 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
1007 # check for icon_monochrome file
1008 if not provider_manifest.icon_svg_monochrome:
1009 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
1010 if await isfile(icon_path):
1011 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
1012 # override Home Assistant provider if we're running as add-on
1013 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
1014 provider_manifest.builtin = True
1015 provider_manifest.allow_disable = False
1016
1017 self._provider_manifests[provider_manifest.domain] = provider_manifest
1018 LOGGER.log(
1019 VERBOSE_LOG_LEVEL, "Loaded manifest for provider %s", provider_manifest.name
1020 )
1021 except Exception as exc:
1022 LOGGER.exception(
1023 "Error while loading manifest for provider %s",
1024 provider_domain,
1025 exc_info=exc,
1026 )
1027
1028 async with TaskManager(self) as tg:
1029 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
1030 if dir_str.startswith("."):
1031 # skip hidden directories
1032 continue
1033 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
1034 if dir_str.startswith("_") and not self.dev_mode:
1035 # only load demo/test providers if debug mode is enabled (e.g. for development)
1036 continue
1037 if not await isdir(dir_path):
1038 continue
1039 tg.create_task(load_provider_manifest(dir_str, dir_path))
1040 self.logger.debug("Loaded %s provider manifests", len(self._provider_manifests))
1041
1042 async def _setup_discovery(self) -> None:
1043 """Handle setup of MDNS discovery."""
1044 # create a global mdns browser
1045 all_types: set[str] = set()
1046 for prov_manifest in self._provider_manifests.values():
1047 if prov_manifest.mdns_discovery:
1048 all_types.update(prov_manifest.mdns_discovery)
1049 self._aiobrowser = AsyncServiceBrowser(
1050 self.aiozc.zeroconf,
1051 list(all_types),
1052 handlers=[self._on_mdns_service_state_change],
1053 )
1054 # register MA itself on mdns to be discovered
1055 zeroconf_type = "_mass._tcp.local."
1056 server_id = self.server_id
1057 LOGGER.debug("Starting Zeroconf broadcast...")
1058 info = AsyncServiceInfo(
1059 zeroconf_type,
1060 name=f"{server_id}.{zeroconf_type}",
1061 addresses=[await get_ip_pton(self.webserver.publish_ip)],
1062 port=self.webserver.publish_port,
1063 properties=self.get_server_info().to_dict(),
1064 server="mass.local.",
1065 )
1066 try:
1067 existing = getattr(self, "mass_zc_service_set", None)
1068 if existing:
1069 await self.aiozc.async_update_service(info)
1070 else:
1071 await self.aiozc.async_register_service(info)
1072 self.mass_zc_service_set = True
1073 except NonUniqueNameException:
1074 LOGGER.error(
1075 "Music Assistant instance with identical name present in the local network!"
1076 )
1077
1078 def _on_mdns_service_state_change(
1079 self,
1080 zeroconf: Zeroconf,
1081 service_type: str,
1082 name: str,
1083 state_change: ServiceStateChange,
1084 ) -> None:
1085 """Handle MDNS service state callback."""
1086
1087 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
1088 if prov.instance_id not in self._mdns_locks:
1089 self._mdns_locks[prov.instance_id] = asyncio.Lock()
1090 if state_change == ServiceStateChange.Removed:
1091 info = None
1092 else:
1093 info = AsyncServiceInfo(service_type, name)
1094 await info.async_request(zeroconf, 3000)
1095 # use a lock per provider instance to avoid
1096 # race conditions in processing mdns events
1097 async with self._mdns_locks[prov.instance_id]:
1098 await prov.on_mdns_service_state_change(name, state_change, info)
1099
1100 LOGGER.log(
1101 VERBOSE_LOG_LEVEL,
1102 "Service %s of type %s state changed: %s",
1103 name,
1104 service_type,
1105 state_change,
1106 )
1107 for prov in list(self._providers.values()):
1108 if not prov.manifest.mdns_discovery:
1109 continue
1110 if not prov.available:
1111 continue
1112 if service_type in prov.manifest.mdns_discovery:
1113 self.create_task(process_mdns_state_change(prov))
1114
1115 async def __aenter__(self) -> Self:
1116 """Return Context manager."""
1117 await self.start()
1118 return self
1119
1120 async def __aexit__(
1121 self,
1122 exc_type: type[BaseException] | None,
1123 exc_val: BaseException | None,
1124 exc_tb: TracebackType | None,
1125 ) -> bool | None:
1126 """Exit context manager."""
1127 await self.stop()
1128 return None
1129
1130 async def _update_available_providers_cache(self) -> None:
1131 """Update the global cache variable of loaded/available providers."""
1132 await set_global_cache_values(
1133 {
1134 "provider_domains": {x.domain for x in self.providers},
1135 "provider_instance_ids": {x.instance_id for x in self.providers},
1136 "available_providers": {
1137 *{x.domain for x in self.providers},
1138 *{x.instance_id for x in self.providers},
1139 },
1140 "unique_providers": self.music.get_unique_providers(),
1141 "streaming_providers": {
1142 x.domain
1143 for x in self.providers
1144 if is_music_provider(x) and x.is_streaming_provider
1145 },
1146 "non_streaming_providers": {
1147 x.instance_id
1148 for x in self.providers
1149 if not (is_music_provider(x) and x.is_streaming_provider)
1150 },
1151 }
1152 )
1153
1154 async def _setup_storage(self) -> None:
1155 """Handle Setup of storage/cache folder(s)."""
1156 if not await isdir(self.storage_path):
1157 await mkdirs(self.storage_path)
1158 if not await isdir(self.cache_path):
1159 await mkdirs(self.cache_path)
1160