/
/
/
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
13from uuid import uuid4
14
15import aiofiles
16from aiofiles.os import wrap
17from music_assistant_models.api import ServerInfoMessage
18from music_assistant_models.auth import UserRole
19from music_assistant_models.enums import EventType, ProviderType
20from music_assistant_models.errors import MusicAssistantError, SetupFailedError
21from music_assistant_models.event import MassEvent
22from music_assistant_models.helpers import set_global_cache_values
23from music_assistant_models.provider import ProviderManifest
24from zeroconf import (
25 InterfaceChoice,
26 IPVersion,
27 NonUniqueNameException,
28 ServiceStateChange,
29 Zeroconf,
30)
31from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
32
33from music_assistant.constants import (
34 API_SCHEMA_VERSION,
35 CONF_DEFAULT_PROVIDERS_SETUP,
36 CONF_PROVIDERS,
37 CONF_SERVER_ID,
38 CONF_ZEROCONF_INTERFACES,
39 CONFIGURABLE_CORE_CONTROLLERS,
40 DEFAULT_PROVIDERS,
41 MASS_LOGGER_NAME,
42 MIN_SCHEMA_VERSION,
43 VERBOSE_LOG_LEVEL,
44)
45from music_assistant.controllers.cache import CacheController
46from music_assistant.controllers.config import ConfigController
47from music_assistant.controllers.metadata import MetaDataController
48from music_assistant.controllers.music import MusicController
49from music_assistant.controllers.player_queues import PlayerQueuesController
50from music_assistant.controllers.players import PlayerController
51from music_assistant.controllers.streams import StreamsController
52from music_assistant.controllers.webserver import WebserverController
53from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
54from music_assistant.helpers.aiohttp_client import create_clientsession
55from music_assistant.helpers.api import APICommandHandler, api_command
56from music_assistant.helpers.images import get_icon_string
57from music_assistant.helpers.util import (
58 TaskManager,
59 get_ip_pton,
60 get_package_version,
61 is_hass_supervisor,
62 load_provider_module,
63)
64from music_assistant.models import ProviderInstanceType
65from music_assistant.models.music_provider import MusicProvider
66from music_assistant.models.player_provider import PlayerProvider
67
68if TYPE_CHECKING:
69 from types import TracebackType
70
71 from aiohttp import ClientSession
72 from music_assistant_models.config_entries import ProviderConfig
73
74 from music_assistant.models.core_controller import CoreController
75
76isdir = wrap(os.path.isdir)
77isfile = wrap(os.path.isfile)
78mkdirs = wrap(os.makedirs)
79rmfile = wrap(os.remove)
80listdir = wrap(os.listdir)
81rename = wrap(os.rename)
82
83EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
84EventSubscriptionType = tuple[
85 EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
86]
87
88LOGGER = logging.getLogger(MASS_LOGGER_NAME)
89
90BASE_DIR = os.path.dirname(os.path.abspath(__file__))
91PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
92
93_R = TypeVar("_R")
94_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
95
96
97def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
98 """Type guard that returns true if a provider is a music provider."""
99 return provider.type == ProviderType.MUSIC
100
101
102def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
103 """Type guard that returns true if a provider is a player provider."""
104 return provider.type == ProviderType.PLAYER
105
106
107class MusicAssistant:
108 """Main MusicAssistant (Server) object."""
109
110 loop: asyncio.AbstractEventLoop
111 aiozc: AsyncZeroconf
112 config: ConfigController
113 webserver: WebserverController
114 cache: CacheController
115 metadata: MetaDataController
116 music: MusicController
117 players: PlayerController
118 player_queues: PlayerQueuesController
119 streams: StreamsController
120 _aiobrowser: AsyncServiceBrowser
121
122 def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
123 """Initialize the MusicAssistant Server."""
124 self.storage_path = storage_path
125 self.cache_path = cache_path
126 self.safe_mode = safe_mode
127 # we dynamically register command handlers which can be consumed by the apis
128 self.command_handlers: dict[str, APICommandHandler] = {}
129 self._subscribers: set[EventSubscriptionType] = set()
130 self._provider_manifests: dict[str, ProviderManifest] = {}
131 self._providers: dict[str, ProviderInstanceType] = {}
132 self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
133 self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
134 self.closing = False
135 self.running_as_hass_addon: bool = False
136 self.version: str = "0.0.0"
137 self.dev_mode = (
138 os.environ.get("PYTHONDEVMODE") == "1"
139 or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
140 )
141 self._http_session: ClientSession | None = None
142 self._http_session_no_ssl: ClientSession | None = None
143 self._mdns_locks: dict[str, asyncio.Lock] = {}
144
145 async def start(self) -> None:
146 """Start running the Music Assistant server."""
147 self.loop = asyncio.get_running_loop()
148 self.loop_thread_id = getattr(self.loop, "_thread_id") # noqa: B009
149 self.running_as_hass_addon = await is_hass_supervisor()
150 self.version = await get_package_version("music_assistant") or "0.0.0"
151 # setup config controller first and fetch important config values
152 self.config = ConfigController(self)
153 await self.config.setup()
154 # create shared zeroconf instance
155 zeroconf_interfaces = self.config.get_raw_core_config_value(
156 "players", CONF_ZEROCONF_INTERFACES, "default"
157 )
158 # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
159 use_all_interfaces = zeroconf_interfaces == "all"
160 self.aiozc = AsyncZeroconf(
161 ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
162 interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
163 )
164 # load all available providers from manifest files
165 await self.__load_provider_manifests()
166 # setup/migrate storage
167 await self._setup_storage()
168 LOGGER.info(
169 "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
170 self.server_id,
171 self.version,
172 self.running_as_hass_addon,
173 self.safe_mode,
174 )
175 # setup other core controllers
176 self.cache = CacheController(self)
177 self.webserver = WebserverController(self)
178 self.metadata = MetaDataController(self)
179 self.music = MusicController(self)
180 self.players = PlayerController(self)
181 self.player_queues = PlayerQueuesController(self)
182 self.streams = StreamsController(self)
183 # add manifests for core controllers
184 for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
185 controller: CoreController = getattr(self, controller_name)
186 self._provider_manifests[controller.domain] = controller.manifest
187 await self.cache.setup(await self.config.get_core_config("cache"))
188 # load streams controller early so we can abort if we can't load it
189 await self.streams.setup(await self.config.get_core_config("streams"))
190 await self.music.setup(await self.config.get_core_config("music"))
191 await self.metadata.setup(await self.config.get_core_config("metadata"))
192 await self.players.setup(await self.config.get_core_config("players"))
193 await self.player_queues.setup(await self.config.get_core_config("player_queues"))
194 # load webserver/api last so the api/frontend is
195 # not yet available while we're starting (or performing migrations)
196 self._register_api_commands()
197 await self.webserver.setup(await self.config.get_core_config("webserver"))
198 # setup discovery
199 await self._setup_discovery()
200 # load providers
201 if not self.safe_mode:
202 await self._load_providers()
203
204 async def stop(self) -> None:
205 """Stop running the music assistant server."""
206 LOGGER.info("Stop called, cleaning up...")
207 self.signal_event(EventType.SHUTDOWN)
208 self.closing = True
209 # cancel all running tasks
210 for task in list(self._tracked_tasks.values()):
211 task.cancel()
212 # cleanup all providers
213 await asyncio.gather(
214 *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
215 return_exceptions=True,
216 )
217 # stop core controllers
218 await self.streams.close()
219 await self.webserver.close()
220 await self.metadata.close()
221 await self.music.close()
222 await self.player_queues.close()
223 await self.players.close()
224 # cleanup cache and config
225 await self.config.close()
226 await self.cache.close()
227 # close/cleanup shared http session
228 if self._http_session:
229 self._http_session.detach()
230 if self._http_session.connector:
231 await self._http_session.connector.close()
232 if self._http_session_no_ssl:
233 self._http_session_no_ssl.detach()
234 if self._http_session_no_ssl.connector:
235 await self._http_session_no_ssl.connector.close()
236
237 @property
238 def server_id(self) -> str:
239 """Return unique ID of this server."""
240 if not self.config.initialized:
241 return ""
242 return self.config.get(CONF_SERVER_ID) # type: ignore[no-any-return]
243
244 @property
245 def http_session(self) -> ClientSession:
246 """
247 Return the shared HTTP Client session (with SSL).
248
249 NOTE: May only be called from the event loop.
250 """
251 if self._http_session is None:
252 self._http_session = create_clientsession(self, verify_ssl=True)
253 return self._http_session
254
255 @property
256 def http_session_no_ssl(self) -> ClientSession:
257 """
258 Return the shared HTTP Client session (without SSL).
259
260 NOTE: May only be called from the event loop thread.
261 """
262 if self._http_session_no_ssl is None:
263 self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
264 return self._http_session_no_ssl
265
266 @api_command("info")
267 def get_server_info(self) -> ServerInfoMessage:
268 """Return Info of this server."""
269 return ServerInfoMessage(
270 server_id=self.server_id,
271 server_version=self.version,
272 schema_version=API_SCHEMA_VERSION,
273 min_supported_schema_version=MIN_SCHEMA_VERSION,
274 base_url=self.webserver.base_url,
275 homeassistant_addon=self.running_as_hass_addon,
276 onboard_done=self.config.onboard_done,
277 )
278
279 @api_command("providers/manifests")
280 def get_provider_manifests(self) -> list[ProviderManifest]:
281 """Return all Provider manifests."""
282 return list(self._provider_manifests.values())
283
284 @api_command("providers/manifests/get")
285 def get_provider_manifest(self, instance_id_or_domain: str) -> ProviderManifest:
286 """Return Provider manifests of single provider(domain)."""
287 if instance_id_or_domain in self._provider_manifests:
288 return self._provider_manifests[instance_id_or_domain]
289 if provider := self.get_provider(instance_id_or_domain, return_unavailable=True):
290 return provider.manifest
291 raise KeyError(f"Provider manifest not found for {instance_id_or_domain}")
292
293 @api_command("providers")
294 def get_providers(
295 self, provider_type: ProviderType | None = None
296 ) -> list[ProviderInstanceType]:
297 """
298 Return all loaded/running Providers (instances).
299
300 Optionally filtered by ProviderType.
301 Note that this applies user filters for music providers (for non admin users).
302 """
303 user = get_current_user()
304 user_provider_filter = (
305 user.provider_filter if user and user.role != UserRole.ADMIN else None
306 )
307 return [
308 x
309 for x in list(self._providers.values())
310 if (provider_type is None or provider_type == x.type)
311 # apply user provider filter
312 and (
313 not user_provider_filter
314 or x.instance_id in user_provider_filter
315 or x.type != ProviderType.MUSIC
316 )
317 ]
318
319 @api_command("logging/get", required_role=UserRole.ADMIN)
320 async def get_application_log(self) -> str:
321 """Return the application log from file."""
322 logfile = os.path.join(self.storage_path, "musicassistant.log")
323 async with aiofiles.open(logfile) as _file:
324 return str(await _file.read())
325
326 @property
327 def providers(self) -> list[ProviderInstanceType]:
328 """
329 Return all loaded/running Providers (instances).
330
331 Note that this skips user filters so may only be called from internal code.
332 """
333 return list(self._providers.values())
334
335 @overload
336 def get_provider(
337 self,
338 provider_instance_or_domain: str,
339 return_unavailable: bool = False,
340 provider_type: None = None,
341 ) -> ProviderInstanceType | None: ...
342
343 @overload
344 def get_provider(
345 self,
346 provider_instance_or_domain: str,
347 return_unavailable: bool = False,
348 *,
349 provider_type: type[_ProviderT],
350 ) -> _ProviderT | None: ...
351
352 def get_provider(
353 self,
354 provider_instance_or_domain: str,
355 return_unavailable: bool = False,
356 provider_type: type[_ProviderT] | None = None,
357 ) -> ProviderInstanceType | _ProviderT | None:
358 """Return provider by instance id or domain.
359
360 :param provider_instance_or_domain: Instance ID or domain of the provider.
361 :param return_unavailable: Also return unavailable providers.
362 :param provider_type: Optional type hint for the expected provider type (unused at runtime).
363 """
364 # lookup by instance_id first
365 if prov := self._providers.get(provider_instance_or_domain):
366 if return_unavailable or prov.available:
367 return prov
368 if not getattr(prov, "is_streaming_provider", None):
369 # no need to lookup other instances because this provider has unique data
370 return None
371 provider_instance_or_domain = prov.domain
372 # fallback to match on domain
373 for prov in list(self._providers.values()):
374 if prov.domain != provider_instance_or_domain:
375 continue
376 if return_unavailable or prov.available:
377 return prov
378 return None
379
380 def get_provider_instances(
381 self,
382 domain: str,
383 return_unavailable: bool = False,
384 provider_type: ProviderType | None = None,
385 ) -> list[ProviderInstanceType]:
386 """
387 Return all provider instances for a given domain.
388
389 Note that this skips user filters so may only be called from internal code.
390 """
391 return [
392 prov
393 for prov in list(self._providers.values())
394 if (provider_type is None or provider_type == prov.type)
395 and prov.domain == domain
396 and (return_unavailable or prov.available)
397 ]
398
399 def signal_event(
400 self,
401 event: EventType,
402 object_id: str | None = None,
403 data: Any = None,
404 ) -> None:
405 """Signal event to subscribers."""
406 if self.closing:
407 return
408
409 self.verify_event_loop_thread("signal_event")
410
411 if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
412 # do not log queue time updated events because that is too chatty
413 LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
414
415 event_obj = MassEvent(event=event, object_id=object_id, data=data)
416 for cb_func, event_filter, id_filter in list(self._subscribers):
417 if not (event_filter is None or event in event_filter):
418 continue
419 if not (id_filter is None or object_id in id_filter):
420 continue
421 if inspect.iscoroutinefunction(cb_func):
422 if TYPE_CHECKING:
423 cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
424 self.create_task(cb_func, event_obj)
425 else:
426 if TYPE_CHECKING:
427 cb_func = cast("Callable[[MassEvent], None]", cb_func)
428 self.loop.call_soon_threadsafe(cb_func, event_obj)
429
430 def subscribe(
431 self,
432 cb_func: EventCallBackType,
433 event_filter: EventType | tuple[EventType, ...] | None = None,
434 id_filter: str | tuple[str, ...] | None = None,
435 ) -> Callable[[], None]:
436 """Add callback to event listeners.
437
438 Returns function to remove the listener.
439 :param cb_func: callback function or coroutine
440 :param event_filter: Optionally only listen for these events
441 :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
442 """
443 if isinstance(event_filter, EventType):
444 event_filter = (event_filter,)
445 if isinstance(id_filter, str):
446 id_filter = (id_filter,)
447 listener = (cb_func, event_filter, id_filter)
448 self._subscribers.add(listener)
449
450 def remove_listener() -> None:
451 self._subscribers.remove(listener)
452
453 return remove_listener
454
455 def create_task(
456 self,
457 target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
458 *args: Any,
459 task_id: str | None = None,
460 abort_existing: bool = False,
461 eager_start: bool = True,
462 **kwargs: Any,
463 ) -> asyncio.Task[_R]:
464 """Create Task on (main) event loop from Coroutine(function).
465
466 Tasks created by this helper will be properly cancelled on stop.
467
468 :param target: Coroutine function or awaitable to run as a task.
469 :param args: Arguments to pass to the coroutine function.
470 :param task_id: Optional ID to track and deduplicate tasks.
471 :param abort_existing: If True, cancel existing task with same task_id.
472 :param eager_start: If True (default), start task immediately without waiting
473 for next event loop iteration. This ensures proper ordering
474 when creating multiple tasks in sequence.
475 :param kwargs: Keyword arguments to pass to the coroutine function.
476 """
477 if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
478 # prevent duplicate tasks if task_id is given and already present
479 if abort_existing:
480 existing.cancel()
481 else:
482 return existing
483 self.verify_event_loop_thread("create_task")
484
485 if inspect.iscoroutinefunction(target):
486 # coroutine function
487 coro = target(*args, **kwargs)
488 elif inspect.iscoroutine(target):
489 # coroutine
490 coro = target
491 elif callable(target):
492 raise RuntimeError("Function is not a coroutine or coroutine function")
493 else:
494 raise RuntimeError("Target is missing")
495
496 # Use asyncio.Task directly with eager_start for immediate execution
497 task: asyncio.Task[_R] = asyncio.Task(coro, loop=self.loop, eager_start=eager_start)
498
499 if task_id is None:
500 task_id = uuid4().hex
501
502 def task_done_callback(_task: asyncio.Task[Any]) -> None:
503 self._tracked_tasks.pop(task_id, None)
504 # log unhandled exceptions
505 if (
506 LOGGER.isEnabledFor(logging.DEBUG)
507 and not _task.cancelled()
508 and (err := _task.exception())
509 ):
510 task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
511 LOGGER.warning(
512 "Exception in task %s - target: %s: %s",
513 task_name,
514 str(target),
515 str(err),
516 exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
517 )
518
519 self._tracked_tasks[task_id] = task
520 task.add_done_callback(task_done_callback)
521 return task
522
523 def call_later(
524 self,
525 delay: float,
526 target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
527 *args: Any,
528 task_id: str | None = None,
529 **kwargs: Any,
530 ) -> asyncio.TimerHandle:
531 """
532 Run callable/awaitable after given delay.
533
534 Use task_id for debouncing.
535 """
536 self.verify_event_loop_thread("call_later")
537
538 if not task_id:
539 task_id = uuid4().hex
540
541 if existing := self._tracked_timers.get(task_id):
542 existing.cancel()
543
544 def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
545 self._tracked_timers.pop(task_id)
546 self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
547
548 def _call_sync(_target: Callable[..., _R]) -> None:
549 self._tracked_timers.pop(task_id)
550 _target(*args, **kwargs)
551
552 if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
553 # coroutine function
554 if TYPE_CHECKING:
555 target = cast("Coroutine[Any, Any, _R]", target)
556 handle = self.loop.call_later(delay, _create_task, target)
557 else:
558 # regular sync callable
559 if TYPE_CHECKING:
560 target = cast("Callable[..., _R]", target)
561 handle = self.loop.call_later(delay, _call_sync, target)
562 self._tracked_timers[task_id] = handle
563 return handle
564
565 def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
566 """Get existing scheduled task."""
567 if existing := self._tracked_tasks.get(task_id):
568 # prevent duplicate tasks if task_id is given and already present
569 return existing
570 return None
571
572 def cancel_task(self, task_id: str) -> None:
573 """Cancel existing scheduled task."""
574 if existing := self._tracked_tasks.pop(task_id, None):
575 existing.cancel()
576
577 def cancel_timer(self, task_id: str) -> None:
578 """Cancel existing scheduled timer."""
579 if existing := self._tracked_timers.pop(task_id, None):
580 existing.cancel()
581
582 def register_api_command(
583 self,
584 command: str,
585 handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
586 authenticated: bool = True,
587 required_role: str | None = None,
588 alias: bool = False,
589 ) -> Callable[[], None]:
590 """Dynamically register a command on the API.
591
592 :param command: The command name/path.
593 :param handler: The function to handle the command.
594 :param authenticated: Whether authentication is required (default: True).
595 :param required_role: Required user role ("admin" or "user")
596 None means any authenticated user.
597 :param alias: Whether this is an alias for backward compatibility (default: False).
598 Aliases are not shown in API documentation but remain functional.
599
600 Returns handle to unregister.
601 """
602 if command in self.command_handlers:
603 msg = f"Command {command} is already registered"
604 raise RuntimeError(msg)
605 self.command_handlers[command] = APICommandHandler.parse(
606 command, handler, authenticated, required_role, alias
607 )
608
609 def unregister() -> None:
610 self.command_handlers.pop(command, None)
611
612 return unregister
613
614 async def load_provider_config(
615 self,
616 prov_conf: ProviderConfig,
617 ) -> None:
618 """Try to load a provider and catch errors."""
619 # cancel existing (re)load timer if needed
620 task_id = f"load_provider_{prov_conf.instance_id}"
621 if existing := self._tracked_timers.pop(task_id, None):
622 existing.cancel()
623
624 await self._load_provider(prov_conf)
625
626 # (re)load any dependants
627 prov_configs = await self.config.get_provider_configs(include_values=True)
628 for dep_prov_conf in prov_configs:
629 if not dep_prov_conf.enabled:
630 continue
631 manifest = self.get_provider_manifest(dep_prov_conf.domain)
632 if not manifest.depends_on:
633 continue
634 if manifest.depends_on == prov_conf.domain:
635 await self._load_provider(dep_prov_conf)
636
637 async def load_provider(
638 self,
639 instance_id: str,
640 allow_retry: bool = False,
641 ) -> None:
642 """Try to load a provider and catch errors."""
643 try:
644 prov_conf = await self.config.get_provider_config(instance_id)
645 except KeyError:
646 # Was deleted before we could run
647 return
648
649 if not prov_conf.enabled:
650 # Was disabled before we could run
651 return
652
653 # cancel existing (re)load timer if needed
654 task_id = f"load_provider_{instance_id}"
655 if existing := self._tracked_timers.pop(task_id, None):
656 existing.cancel()
657
658 try:
659 await self.load_provider_config(prov_conf)
660 except Exception as exc:
661 # if loading failed, we store the error in the config object
662 # so we can show something useful to the user
663 prov_conf.last_error = str(exc)
664 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
665
666 # auto schedule a retry if the (re)load failed with a handled exception
667 # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
668 will_retry = allow_retry and isinstance(exc, MusicAssistantError)
669 if will_retry:
670 self.call_later(
671 120,
672 self.load_provider,
673 instance_id,
674 allow_retry,
675 task_id=task_id,
676 )
677 LOGGER.warning(
678 "Error loading provider(instance) %s: %s%s",
679 prov_conf.name or prov_conf.instance_id,
680 str(exc) or exc.__class__.__name__,
681 " (will be retried later)" if will_retry else "",
682 # log full stack trace if verbose logging is enabled
683 exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
684 )
685 return
686
687 # (re)load any dependents if needed
688 for dep_prov in self.providers:
689 if dep_prov.available:
690 continue
691 if dep_prov.manifest.depends_on == prov_conf.domain:
692 await self.unload_provider(dep_prov.instance_id)
693
694 async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
695 """Unload a provider."""
696 self.music.unschedule_provider_sync(instance_id)
697 if provider := self._providers.get(instance_id):
698 # remove mdns discovery if needed
699 if provider.manifest.mdns_discovery:
700 for mdns_type in provider.manifest.mdns_discovery:
701 self._aiobrowser.types.discard(mdns_type)
702 if isinstance(provider, PlayerProvider):
703 await self.players.on_provider_unload(provider)
704 if isinstance(provider, MusicProvider):
705 await self.music.on_provider_unload(provider)
706 # check if there are no other providers dependent of this provider
707 for dep_prov in self.providers:
708 if dep_prov.manifest.depends_on == provider.domain:
709 await self.unload_provider(dep_prov.instance_id)
710 if is_player_provider(provider):
711 # unregister all players of this provider
712 for player in provider.players:
713 await self.players.unregister(player.player_id, permanent=is_removed)
714 try:
715 await provider.unload(is_removed)
716 except Exception as err:
717 LOGGER.warning(
718 "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
719 )
720 finally:
721 self._providers.pop(instance_id, None)
722 await self._update_available_providers_cache()
723 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
724
725 async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
726 """Unload a provider when it got into trouble which needs user interaction."""
727 self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
728 await self.unload_provider(instance_id)
729
730 async def run_provider_discovery(self, instance_id: str) -> None:
731 """
732 Run mDNS discovery for a given provider.
733
734 In case of a PlayerProvider, will also call its own discovery method.
735 """
736 provider = self.get_provider(instance_id, return_unavailable=False)
737 if not provider:
738 raise KeyError(f"Provider with instance ID {instance_id} not found")
739 if provider.manifest.mdns_discovery:
740 if provider.instance_id not in self._mdns_locks:
741 self._mdns_locks[provider.instance_id] = asyncio.Lock()
742 async with self._mdns_locks[provider.instance_id]:
743 for mdns_type in provider.manifest.mdns_discovery or []:
744 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
745 if mdns_type not in mdns_name or mdns_type == mdns_name:
746 continue
747 info = AsyncServiceInfo(mdns_type, mdns_name)
748 if await info.async_request(self.aiozc.zeroconf, 3000):
749 await provider.on_mdns_service_state_change(
750 mdns_name, ServiceStateChange.Added, info
751 )
752 if isinstance(provider, PlayerProvider):
753 await provider.discover_players()
754
755 def verify_event_loop_thread(self, what: str) -> None:
756 """Report and raise if we are not running in the event loop thread."""
757 if self.loop_thread_id != threading.get_ident():
758 raise RuntimeError(
759 f"Non-Async operation detected: {what} may only be called from the eventloop."
760 )
761
762 def _register_api_commands(self) -> None:
763 """Register all methods decorated as api_command within a class(instance)."""
764 for cls in (
765 self,
766 self.config,
767 self.metadata,
768 self.music,
769 self.players,
770 self.player_queues,
771 self.webserver,
772 self.webserver.auth,
773 ):
774 for attr_name in dir(cls):
775 if attr_name.startswith("__"):
776 continue
777 try:
778 obj = getattr(cls, attr_name)
779 except (AttributeError, RuntimeError):
780 # Skip properties that fail during initialization
781 continue
782 if hasattr(obj, "api_cmd"):
783 # method is decorated with our api decorator
784 authenticated = getattr(obj, "api_authenticated", True)
785 required_role = getattr(obj, "api_required_role", None)
786 self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
787
788 async def _load_providers(self) -> None:
789 """Load providers from config."""
790 # create default config for any 'builtin' providers (e.g. URL provider)
791 for prov_manifest in self._provider_manifests.values():
792 if prov_manifest.type == ProviderType.CORE:
793 # core controllers are not real providers
794 continue
795 if not prov_manifest.builtin:
796 continue
797 await self.config.create_builtin_provider_config(prov_manifest.domain)
798 # handle default providers setup
799 self.config.set_default(CONF_DEFAULT_PROVIDERS_SETUP, set())
800 default_providers_setup = cast("set[str]", self.config.get(CONF_DEFAULT_PROVIDERS_SETUP))
801 changes_made = False
802 for default_provider, require_mdns in DEFAULT_PROVIDERS:
803 if default_provider in default_providers_setup:
804 # already processed/setup before, skip
805 continue
806 if not (manifest := self._provider_manifests.get(default_provider)):
807 continue
808 if require_mdns:
809 # if mdns discovery is required, check if we have seen any mdns entries
810 # for this provider before setting it up
811 for mdns_name in set(self.aiozc.zeroconf.cache.cache):
812 if manifest.mdns_discovery and any(
813 mdns_type in mdns_name for mdns_type in manifest.mdns_discovery
814 ):
815 break
816 else:
817 continue
818 await self.config.create_builtin_provider_config(manifest.domain)
819 changes_made = True
820 # TEMP: migration - to be removed after 2.8 release
821 # enable all existing players of the default providers if they are not already enabled
822 # due to the linked protocol feature we introduced
823 for player_config in await self.config.get_player_configs(
824 provider=default_provider, include_disabled=True
825 ):
826 if player_config.enabled:
827 continue
828 await self.config.save_player_config(player_config.player_id, {"enabled": True})
829 default_providers_setup.add(default_provider)
830 if changes_made:
831 self.config.set(CONF_DEFAULT_PROVIDERS_SETUP, default_providers_setup)
832 self.config.save(True)
833
834 # load all configured (and enabled) providers
835 # builtin providers are loaded first (and awaited) before loading the rest
836 prov_configs = await self.config.get_provider_configs(include_values=True)
837 builtin_configs: list[ProviderConfig] = []
838 other_configs: list[ProviderConfig] = []
839 for prov_conf in prov_configs:
840 if not prov_conf.enabled:
841 continue
842 manifest = self._provider_manifests.get(prov_conf.domain)
843 if manifest and manifest.builtin:
844 builtin_configs.append(prov_conf)
845 else:
846 other_configs.append(prov_conf)
847
848 # load builtin providers first and wait for them to complete
849 await asyncio.gather(
850 *[self.load_provider(conf.instance_id, allow_retry=True) for conf in builtin_configs]
851 )
852
853 # load remaining providers concurrently via tasks
854 for prov_conf in other_configs:
855 # Use a task so we can load multiple providers at once.
856 # If a provider fails, that will not block the loading of other providers.
857 self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
858
859 async def _load_provider(self, conf: ProviderConfig) -> None:
860 """Load (or reload) a provider."""
861 # if provider is already loaded, stop and unload it first
862 await self.unload_provider(conf.instance_id)
863 LOGGER.debug("Loading provider %s", conf.name or conf.domain)
864 if not conf.enabled:
865 msg = "Provider is disabled"
866 raise SetupFailedError(msg)
867
868 # validate config
869 try:
870 conf.validate()
871 except (KeyError, ValueError, AttributeError, TypeError) as err:
872 msg = "Configuration is invalid"
873 raise SetupFailedError(msg) from err
874
875 domain = conf.domain
876 prov_manifest = self._provider_manifests.get(domain)
877 # check for other instances of this provider
878 existing = next((x for x in self.providers if x.domain == domain), None)
879 if existing and prov_manifest and not prov_manifest.multi_instance:
880 msg = f"Provider {domain} already loaded and only one instance allowed."
881 raise SetupFailedError(msg)
882 # check valid manifest (just in case)
883 if not prov_manifest:
884 msg = f"Provider {domain} manifest not found"
885 raise SetupFailedError(msg)
886
887 # handle dependency on other provider
888 if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
889 # we can safely ignore this completely as the setup will be retried later
890 # automatically when the dependency is loaded
891 return
892
893 # try to setup the module
894 prov_mod = await load_provider_module(domain, prov_manifest.requirements)
895 try:
896 async with asyncio.timeout(30):
897 provider = await prov_mod.setup(self, prov_manifest, conf)
898 except TimeoutError as err:
899 msg = f"Provider {domain} did not load within 30 seconds"
900 raise SetupFailedError(msg) from err
901
902 # run async setup
903 await provider.handle_async_init()
904
905 # if we reach this point, the provider loaded successfully
906 self._providers[provider.instance_id] = provider
907 LOGGER.info(
908 "Loaded %s provider %s",
909 provider.type.value,
910 provider.name,
911 )
912 provider.available = True
913
914 # adapt logging name if needed
915 provider._set_log_level_from_config(provider.config)
916
917 # execute post load actions
918 async def _on_provider_loaded() -> None:
919 await provider.loaded_in_mass()
920 await self.run_provider_discovery(provider.instance_id)
921 # push instance name to config (to persist it if it was autogenerated)
922 if provider.default_name != conf.default_name:
923 self.config.set_provider_default_name(provider.instance_id, provider.default_name)
924
925 self.create_task(_on_provider_loaded())
926
927 # clear any previous error in config and signal update
928 self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
929 self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
930 await self._update_available_providers_cache()
931 if isinstance(provider, MusicProvider):
932 await self.music.on_provider_loaded(provider)
933 if isinstance(provider, PlayerProvider):
934 await self.players.on_provider_loaded(provider)
935
936 async def __load_provider_manifests(self) -> None:
937 """Preload all available provider manifest files."""
938
939 async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
940 """Preload all available provider manifest files."""
941 # get files in subdirectory
942 for file_str in await asyncio.to_thread(os.listdir, provider_path): # noqa: PTH208, RUF100
943 file_path = os.path.join(provider_path, file_str)
944 if not await isfile(file_path):
945 continue
946 if file_str != "manifest.json":
947 continue
948 try:
949 provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
950 # check for icon.svg file
951 if not provider_manifest.icon_svg:
952 icon_path = os.path.join(provider_path, "icon.svg")
953 if await isfile(icon_path):
954 provider_manifest.icon_svg = await get_icon_string(icon_path)
955 # check for dark_icon file
956 if not provider_manifest.icon_svg_dark:
957 icon_path = os.path.join(provider_path, "icon_dark.svg")
958 if await isfile(icon_path):
959 provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
960 # check for icon_monochrome file
961 if not provider_manifest.icon_svg_monochrome:
962 icon_path = os.path.join(provider_path, "icon_monochrome.svg")
963 if await isfile(icon_path):
964 provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
965 # override Home Assistant provider if we're running as add-on
966 if provider_manifest.domain == "hass" and self.running_as_hass_addon:
967 provider_manifest.builtin = True
968 provider_manifest.allow_disable = False
969
970 self._provider_manifests[provider_manifest.domain] = provider_manifest
971 LOGGER.debug("Loaded manifest for provider %s", provider_manifest.name)
972 except Exception as exc:
973 LOGGER.exception(
974 "Error while loading manifest for provider %s",
975 provider_domain,
976 exc_info=exc,
977 )
978
979 async with TaskManager(self) as tg:
980 for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH): # noqa: PTH208, RUF100
981 if dir_str.startswith("."):
982 # skip hidden directories
983 continue
984 dir_path = os.path.join(PROVIDERS_PATH, dir_str)
985 if dir_str.startswith("_") and not self.dev_mode:
986 # only load demo/test providers if debug mode is enabled (e.g. for development)
987 continue
988 if not await isdir(dir_path):
989 continue
990 tg.create_task(load_provider_manifest(dir_str, dir_path))
991
992 async def _setup_discovery(self) -> None:
993 """Handle setup of MDNS discovery."""
994 # create a global mdns browser
995 all_types: set[str] = set()
996 for prov_manifest in self._provider_manifests.values():
997 if prov_manifest.mdns_discovery:
998 all_types.update(prov_manifest.mdns_discovery)
999 self._aiobrowser = AsyncServiceBrowser(
1000 self.aiozc.zeroconf,
1001 list(all_types),
1002 handlers=[self._on_mdns_service_state_change],
1003 )
1004 # register MA itself on mdns to be discovered
1005 zeroconf_type = "_mass._tcp.local."
1006 server_id = self.server_id
1007 LOGGER.debug("Starting Zeroconf broadcast...")
1008 info = AsyncServiceInfo(
1009 zeroconf_type,
1010 name=f"{server_id}.{zeroconf_type}",
1011 addresses=[await get_ip_pton(self.webserver.publish_ip)],
1012 port=self.webserver.publish_port,
1013 properties=self.get_server_info().to_dict(),
1014 server="mass.local.",
1015 )
1016 try:
1017 existing = getattr(self, "mass_zc_service_set", None)
1018 if existing:
1019 await self.aiozc.async_update_service(info)
1020 else:
1021 await self.aiozc.async_register_service(info)
1022 self.mass_zc_service_set = True
1023 except NonUniqueNameException:
1024 LOGGER.error(
1025 "Music Assistant instance with identical name present in the local network!"
1026 )
1027
1028 def _on_mdns_service_state_change(
1029 self,
1030 zeroconf: Zeroconf,
1031 service_type: str,
1032 name: str,
1033 state_change: ServiceStateChange,
1034 ) -> None:
1035 """Handle MDNS service state callback."""
1036
1037 async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
1038 if prov.instance_id not in self._mdns_locks:
1039 self._mdns_locks[prov.instance_id] = asyncio.Lock()
1040 if state_change == ServiceStateChange.Removed:
1041 info = None
1042 else:
1043 info = AsyncServiceInfo(service_type, name)
1044 await info.async_request(zeroconf, 3000)
1045 # use a lock per provider instance to avoid
1046 # race conditions in processing mdns events
1047 async with self._mdns_locks[prov.instance_id]:
1048 await prov.on_mdns_service_state_change(name, state_change, info)
1049
1050 LOGGER.log(
1051 VERBOSE_LOG_LEVEL,
1052 "Service %s of type %s state changed: %s",
1053 name,
1054 service_type,
1055 state_change,
1056 )
1057 for prov in list(self._providers.values()):
1058 if not prov.manifest.mdns_discovery:
1059 continue
1060 if not prov.available:
1061 continue
1062 if service_type in prov.manifest.mdns_discovery:
1063 self.create_task(process_mdns_state_change(prov))
1064
1065 async def __aenter__(self) -> Self:
1066 """Return Context manager."""
1067 await self.start()
1068 return self
1069
1070 async def __aexit__(
1071 self,
1072 exc_type: type[BaseException] | None,
1073 exc_val: BaseException | None,
1074 exc_tb: TracebackType | None,
1075 ) -> bool | None:
1076 """Exit context manager."""
1077 await self.stop()
1078 return None
1079
1080 async def _update_available_providers_cache(self) -> None:
1081 """Update the global cache variable of loaded/available providers."""
1082 await set_global_cache_values(
1083 {
1084 "provider_domains": {x.domain for x in self.providers},
1085 "provider_instance_ids": {x.instance_id for x in self.providers},
1086 "available_providers": {
1087 *{x.domain for x in self.providers},
1088 *{x.instance_id for x in self.providers},
1089 },
1090 "unique_providers": self.music.get_unique_providers(),
1091 "streaming_providers": {
1092 x.domain
1093 for x in self.providers
1094 if is_music_provider(x) and x.is_streaming_provider
1095 },
1096 "non_streaming_providers": {
1097 x.instance_id
1098 for x in self.providers
1099 if not (is_music_provider(x) and x.is_streaming_provider)
1100 },
1101 }
1102 )
1103
1104 async def _setup_storage(self) -> None:
1105 """Handle Setup of storage/cache folder(s)."""
1106 if not await isdir(self.storage_path):
1107 await mkdirs(self.storage_path)
1108 if not await isdir(self.cache_path):
1109 await mkdirs(self.cache_path)
1110