music-assistant-server

47.7 KBPY
mass.py
47.7 KB1,167 lines • python
1"""Main Music Assistant class."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8import os
9import pathlib
10import threading
11from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
12from enum import Enum
13from typing import TYPE_CHECKING, Any, Self, TypeGuard, TypeVar, cast, overload
14from uuid import uuid4
15
16import aiofiles
17from aiofiles.os import wrap
18from music_assistant_models.api import ServerInfoMessage
19from music_assistant_models.auth import UserRole
20from music_assistant_models.enums import EventType, ProviderType
21from music_assistant_models.errors import MusicAssistantError, SetupFailedError
22from music_assistant_models.event import MassEvent
23from music_assistant_models.helpers import set_global_cache_values
24from music_assistant_models.provider import ProviderManifest
25from zeroconf import (
26    InterfaceChoice,
27    IPVersion,
28    NonUniqueNameException,
29    ServiceStateChange,
30    Zeroconf,
31)
32from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
33
34from music_assistant.constants import (
35    API_SCHEMA_VERSION,
36    CONF_DEFAULT_PROVIDERS_SETUP,
37    CONF_PROVIDERS,
38    CONF_SERVER_ID,
39    CONF_ZEROCONF_INTERFACES,
40    CONFIGURABLE_CORE_CONTROLLERS,
41    DEFAULT_PROVIDERS,
42    MASS_LOGGER_NAME,
43    MIN_SCHEMA_VERSION,
44    VERBOSE_LOG_LEVEL,
45)
46from music_assistant.controllers.cache import CacheController
47from music_assistant.controllers.config import ConfigController
48from music_assistant.controllers.metadata import MetaDataController
49from music_assistant.controllers.music import MusicController
50from music_assistant.controllers.player_queues import PlayerQueuesController
51from music_assistant.controllers.players import PlayerController
52from music_assistant.controllers.streams import StreamsController
53from music_assistant.controllers.webserver import WebserverController
54from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
55from music_assistant.helpers.aiohttp_client import create_clientsession
56from music_assistant.helpers.api import APICommandHandler, api_command
57from music_assistant.helpers.images import get_icon_string
58from music_assistant.helpers.util import (
59    TaskManager,
60    get_ip_pton,
61    get_package_version,
62    is_hass_supervisor,
63    load_provider_module,
64)
65from music_assistant.models import ProviderInstanceType
66from music_assistant.models.music_provider import MusicProvider
67from music_assistant.models.player_provider import PlayerProvider
68
69if TYPE_CHECKING:
70    from types import TracebackType
71
72    from aiohttp import ClientSession
73    from music_assistant_models.config_entries import ProviderConfig
74
75    from music_assistant.models.core_controller import CoreController
76
77isdir = wrap(os.path.isdir)
78isfile = wrap(os.path.isfile)
79mkdirs = wrap(os.makedirs)
80rmfile = wrap(os.remove)
81listdir = wrap(os.listdir)
82rename = wrap(os.rename)
83
84EventCallBackType = Callable[[MassEvent], None] | Callable[[MassEvent], Coroutine[Any, Any, None]]
85EventSubscriptionType = tuple[
86    EventCallBackType, tuple[EventType, ...] | None, tuple[str, ...] | None
87]
88
89LOGGER = logging.getLogger(MASS_LOGGER_NAME)
90
91BASE_DIR = os.path.dirname(os.path.abspath(__file__))
92PROVIDERS_PATH = os.path.join(BASE_DIR, "providers")
93
94_R = TypeVar("_R")
95_ProviderT = TypeVar("_ProviderT", bound=ProviderInstanceType)
96
97
98def is_music_provider(provider: ProviderInstanceType) -> TypeGuard[MusicProvider]:
99    """Type guard that returns true if a provider is a music provider."""
100    return provider.type == ProviderType.MUSIC
101
102
103def is_player_provider(provider: ProviderInstanceType) -> TypeGuard[PlayerProvider]:
104    """Type guard that returns true if a provider is a player provider."""
105    return provider.type == ProviderType.PLAYER
106
107
108class CoreState(Enum):
109    """Enum representing the core state of the Music Assistant server."""
110
111    STARTING = "starting"
112    RUNNING = "running"
113    STOPPING = "stopping"
114    STOPPED = "stopped"
115
116
117class MusicAssistant:
118    """Main MusicAssistant (Server) object."""
119
120    loop: asyncio.AbstractEventLoop
121    aiozc: AsyncZeroconf
122    config: ConfigController
123    webserver: WebserverController
124    cache: CacheController
125    metadata: MetaDataController
126    music: MusicController
127    players: PlayerController
128    player_queues: PlayerQueuesController
129    streams: StreamsController
130    _aiobrowser: AsyncServiceBrowser
131
132    def __init__(self, storage_path: str, cache_path: str, safe_mode: bool = False) -> None:
133        """Initialize the MusicAssistant Server."""
134        self._state = CoreState.STARTING
135        self.storage_path = storage_path
136        self.cache_path = cache_path
137        self.safe_mode = safe_mode
138        # we dynamically register command handlers which can be consumed by the apis
139        self.command_handlers: dict[str, APICommandHandler] = {}
140        self._subscribers: set[EventSubscriptionType] = set()
141        self._provider_manifests: dict[str, ProviderManifest] = {}
142        self._providers: dict[str, ProviderInstanceType] = {}
143        self._tracked_tasks: dict[str, asyncio.Task[Any]] = {}
144        self._tracked_timers: dict[str, asyncio.TimerHandle] = {}
145        self.running_as_hass_addon: bool = False
146        self.version: str = "0.0.0"
147        self.logger = LOGGER
148        self.dev_mode = (
149            os.environ.get("PYTHONDEVMODE") == "1"
150            or pathlib.Path(__file__).parent.resolve().parent.resolve().joinpath(".venv").exists()
151        )
152        self._http_session: ClientSession | None = None
153        self._http_session_no_ssl: ClientSession | None = None
154        self._mdns_locks: dict[str, asyncio.Lock] = {}
155
156    async def start(self) -> None:
157        """Start running the Music Assistant server."""
158        self.loop = asyncio.get_running_loop()
159        self.loop_thread_id = getattr(self.loop, "_thread_id")  # noqa: B009
160        self.running_as_hass_addon = await is_hass_supervisor()
161        self.version = await get_package_version("music_assistant") or "0.0.0"
162        # setup config controller first and fetch important config values
163        self.config = ConfigController(self)
164        await self.config.setup()
165        # create shared zeroconf instance
166        zeroconf_interfaces = self.config.get_raw_core_config_value(
167            "players", CONF_ZEROCONF_INTERFACES, "default"
168        )
169        # IPv6 requires InterfaceChoice.All, so only enable when zeroconf_interfaces is "all"
170        use_all_interfaces = zeroconf_interfaces == "all"
171        self.aiozc = AsyncZeroconf(
172            ip_version=IPVersion.All if use_all_interfaces else IPVersion.V4Only,
173            interfaces=InterfaceChoice.All if use_all_interfaces else InterfaceChoice.Default,
174        )
175        # load all available providers from manifest files
176        await self.__load_provider_manifests()
177        # setup/migrate storage
178        await self._setup_storage()
179        LOGGER.info(
180            "Starting Music Assistant Server (%s) version %s - HA add-on: %s - Safe mode: %s",
181            self.server_id,
182            self.version,
183            self.running_as_hass_addon,
184            self.safe_mode,
185        )
186        # setup other core controllers
187        self.cache = CacheController(self)
188        self.webserver = WebserverController(self)
189        self.metadata = MetaDataController(self)
190        self.music = MusicController(self)
191        self.players = PlayerController(self)
192        self.player_queues = PlayerQueuesController(self)
193        self.streams = StreamsController(self)
194        # add manifests for core controllers
195        for controller_name in CONFIGURABLE_CORE_CONTROLLERS:
196            controller: CoreController = getattr(self, controller_name)
197            self._provider_manifests[controller.domain] = controller.manifest
198
199        # setup all core controllers in parallel
200        async def setup_controller(controller: CoreController) -> None:
201            await controller.setup(await self.config.get_core_config(controller.domain))
202            controller.initialized.set()
203
204        async with asyncio.TaskGroup() as tg:
205            tg.create_task(setup_controller(self.cache))
206            tg.create_task(setup_controller(self.streams))
207            tg.create_task(setup_controller(self.music))
208            tg.create_task(setup_controller(self.metadata))
209            tg.create_task(setup_controller(self.players))
210            tg.create_task(setup_controller(self.player_queues))
211
212        # load webserver/api now that the core controllers are setup and ready to be used
213        self._register_api_commands()
214        await self.webserver.setup(await self.config.get_core_config("webserver"))
215        # load builtin providers (always needed, also in safe mode)
216        await self._load_builtin_providers()
217        # setup discovery
218        await self._setup_discovery()
219        # load regular providers (skip when in safe mode)
220        # providers are loaded in background tasks so they won't block
221        # the startup if they fail or take a long time to load
222        if not self.safe_mode:
223            await self._load_providers()
224        # at this point we are fully up and running,
225        # set state to running to signal we're ready
226        self._state = CoreState.RUNNING
227
228    async def stop(self) -> None:
229        """Stop running the music assistant server."""
230        LOGGER.info("Stop called, cleaning up...")
231        self.signal_event(EventType.SHUTDOWN)
232        self._state = CoreState.STOPPING
233        # cancel all running tasks
234        for task in list(self._tracked_tasks.values()):
235            task.cancel()
236        # cleanup all providers
237        await asyncio.gather(
238            *[self.unload_provider(prov_id) for prov_id in list(self._providers.keys())],
239            return_exceptions=True,
240        )
241        # stop core controllers
242        await self.streams.close()
243        await self.webserver.close()
244        await self.metadata.close()
245        await self.music.close()
246        await self.player_queues.close()
247        await self.players.close()
248        # cleanup cache and config
249        await self.config.close()
250        await self.cache.close()
251        # close/cleanup shared http session
252        if self._http_session:
253            self._http_session.detach()
254            if self._http_session.connector:
255                await self._http_session.connector.close()
256        if self._http_session_no_ssl:
257            self._http_session_no_ssl.detach()
258            if self._http_session_no_ssl.connector:
259                await self._http_session_no_ssl.connector.close()
260        self._state = CoreState.STOPPED
261
262    @property
263    def state(self) -> CoreState:
264        """Return current state of the core."""
265        return self._state
266
267    @property
268    def closing(self) -> bool:
269        """Return true if the server is (in the process of) closing."""
270        return self._state in (CoreState.STOPPING, CoreState.STOPPED)
271
272    @property
273    def server_id(self) -> str:
274        """Return unique ID of this server."""
275        if not self.config.initialized:
276            return ""
277        return self.config.get(CONF_SERVER_ID)  # type: ignore[no-any-return]
278
279    @property
280    def http_session(self) -> ClientSession:
281        """
282        Return the shared HTTP Client session (with SSL).
283
284        NOTE: May only be called from the event loop.
285        """
286        if self._http_session is None:
287            self._http_session = create_clientsession(self, verify_ssl=True)
288        return self._http_session
289
290    @property
291    def http_session_no_ssl(self) -> ClientSession:
292        """
293        Return the shared HTTP Client session (without SSL).
294
295        NOTE: May only be called from the event loop thread.
296        """
297        if self._http_session_no_ssl is None:
298            self._http_session_no_ssl = create_clientsession(self, verify_ssl=False)
299        return self._http_session_no_ssl
300
301    @api_command("info")
302    def get_server_info(self) -> ServerInfoMessage:
303        """Return Info of this server."""
304        return ServerInfoMessage(
305            server_id=self.server_id,
306            server_version=self.version,
307            schema_version=API_SCHEMA_VERSION,
308            min_supported_schema_version=MIN_SCHEMA_VERSION,
309            base_url=self.webserver.base_url,
310            homeassistant_addon=self.running_as_hass_addon,
311            onboard_done=self.config.onboard_done,
312        )
313
314    @api_command("providers/manifests")
315    def get_provider_manifests(self) -> list[ProviderManifest]:
316        """Return all Provider manifests."""
317        return list(self._provider_manifests.values())
318
319    @api_command("providers/manifests/get")
320    def get_provider_manifest(self, instance_id_or_domain: str) -> ProviderManifest:
321        """Return Provider manifests of single provider(domain)."""
322        if instance_id_or_domain in self._provider_manifests:
323            return self._provider_manifests[instance_id_or_domain]
324        if provider := self.get_provider(instance_id_or_domain, return_unavailable=True):
325            return provider.manifest
326        raise KeyError(f"Provider manifest not found for {instance_id_or_domain}")
327
328    @api_command("providers")
329    def get_providers(
330        self, provider_type: ProviderType | None = None
331    ) -> list[ProviderInstanceType]:
332        """
333        Return all loaded/running Providers (instances).
334
335        Optionally filtered by ProviderType.
336        Note that this applies user filters for music providers (for non admin users).
337        """
338        user = get_current_user()
339        user_provider_filter = (
340            user.provider_filter if user and user.role != UserRole.ADMIN else None
341        )
342        return [
343            x
344            for x in list(self._providers.values())
345            if (provider_type is None or provider_type == x.type)
346            # apply user provider filter
347            and (
348                not user_provider_filter
349                or x.instance_id in user_provider_filter
350                or x.type != ProviderType.MUSIC
351            )
352        ]
353
354    @api_command("logging/get", required_role=UserRole.ADMIN)
355    async def get_application_log(self) -> str:
356        """Return the application log from file."""
357        logfile = os.path.join(self.storage_path, "musicassistant.log")
358        async with aiofiles.open(logfile) as _file:
359            return str(await _file.read())
360
361    @property
362    def providers(self) -> list[ProviderInstanceType]:
363        """
364        Return all loaded/running Providers (instances).
365
366        Note that this skips user filters so may only be called from internal code.
367        """
368        return list(self._providers.values())
369
370    @overload
371    def get_provider(
372        self,
373        provider_instance_or_domain: str,
374        return_unavailable: bool = False,
375        provider_type: None = None,
376    ) -> ProviderInstanceType | None: ...
377
378    @overload
379    def get_provider(
380        self,
381        provider_instance_or_domain: str,
382        return_unavailable: bool = False,
383        *,
384        provider_type: type[_ProviderT],
385    ) -> _ProviderT | None: ...
386
387    def get_provider(
388        self,
389        provider_instance_or_domain: str,
390        return_unavailable: bool = False,
391        provider_type: type[_ProviderT] | None = None,
392    ) -> ProviderInstanceType | _ProviderT | None:
393        """Return provider by instance id or domain.
394
395        :param provider_instance_or_domain: Instance ID or domain of the provider.
396        :param return_unavailable: Also return unavailable providers.
397        :param provider_type: Optional type hint for the expected provider type (unused at runtime).
398        """
399        # lookup by instance_id first
400        if prov := self._providers.get(provider_instance_or_domain):
401            if return_unavailable or prov.available:
402                return prov
403            if not getattr(prov, "is_streaming_provider", None):
404                # no need to lookup other instances because this provider has unique data
405                return None
406            provider_instance_or_domain = prov.domain
407        # fallback to match on domain
408        for prov in list(self._providers.values()):
409            if prov.domain != provider_instance_or_domain:
410                continue
411            if return_unavailable or prov.available:
412                return prov
413        return None
414
415    def get_provider_instances(
416        self,
417        domain: str,
418        return_unavailable: bool = False,
419        provider_type: ProviderType | None = None,
420    ) -> list[ProviderInstanceType]:
421        """
422        Return all provider instances for a given domain.
423
424        Note that this skips user filters so may only be called from internal code.
425        """
426        return [
427            prov
428            for prov in list(self._providers.values())
429            if (provider_type is None or provider_type == prov.type)
430            and prov.domain == domain
431            and (return_unavailable or prov.available)
432        ]
433
434    def signal_event(
435        self,
436        event: EventType,
437        object_id: str | None = None,
438        data: Any = None,
439    ) -> None:
440        """Signal event to subscribers."""
441        if self.closing:
442            return
443
444        self.verify_event_loop_thread("signal_event")
445
446        if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL):
447            # do not log queue time updated events because that is too chatty
448            LOGGER.getChild("event").log(VERBOSE_LOG_LEVEL, "%s %s", event.value, object_id or "")
449
450        event_obj = MassEvent(event=event, object_id=object_id, data=data)
451        for cb_func, event_filter, id_filter in list(self._subscribers):
452            if not (event_filter is None or event in event_filter):
453                continue
454            if not (id_filter is None or object_id in id_filter):
455                continue
456            if inspect.iscoroutinefunction(cb_func):
457                if TYPE_CHECKING:
458                    cb_func = cast("Callable[[MassEvent], Coroutine[Any, Any, None]]", cb_func)
459                self.create_task(cb_func, event_obj)
460            else:
461                if TYPE_CHECKING:
462                    cb_func = cast("Callable[[MassEvent], None]", cb_func)
463                self.loop.call_soon_threadsafe(cb_func, event_obj)
464
465    def subscribe(
466        self,
467        cb_func: EventCallBackType,
468        event_filter: EventType | tuple[EventType, ...] | None = None,
469        id_filter: str | tuple[str, ...] | None = None,
470    ) -> Callable[[], None]:
471        """Add callback to event listeners.
472
473        Returns function to remove the listener.
474            :param cb_func: callback function or coroutine
475            :param event_filter: Optionally only listen for these events
476            :param id_filter: Optionally only listen for these id's (player_id, queue_id, uri)
477        """
478        if isinstance(event_filter, EventType):
479            event_filter = (event_filter,)
480        if isinstance(id_filter, str):
481            id_filter = (id_filter,)
482        listener = (cb_func, event_filter, id_filter)
483        self._subscribers.add(listener)
484
485        def remove_listener() -> None:
486            self._subscribers.remove(listener)
487
488        return remove_listener
489
490    def create_task(
491        self,
492        target: Callable[..., Coroutine[Any, Any, _R]] | Awaitable[_R],
493        *args: Any,
494        task_id: str | None = None,
495        abort_existing: bool = False,
496        eager_start: bool = True,
497        **kwargs: Any,
498    ) -> asyncio.Task[_R]:
499        """Create Task on (main) event loop from Coroutine(function).
500
501        Tasks created by this helper will be properly cancelled on stop.
502
503        :param target: Coroutine function or awaitable to run as a task.
504        :param args: Arguments to pass to the coroutine function.
505        :param task_id: Optional ID to track and deduplicate tasks.
506        :param abort_existing: If True, cancel existing task with same task_id.
507        :param eager_start: If True (default), start task immediately without waiting
508                           for next event loop iteration. This ensures proper ordering
509                           when creating multiple tasks in sequence.
510        :param kwargs: Keyword arguments to pass to the coroutine function.
511        """
512        if task_id and (existing := self._tracked_tasks.get(task_id)) and not existing.done():
513            # prevent duplicate tasks if task_id is given and already present
514            if abort_existing:
515                existing.cancel()
516            else:
517                return existing
518        self.verify_event_loop_thread("create_task")
519
520        if inspect.iscoroutinefunction(target):
521            # coroutine function
522            coro = target(*args, **kwargs)
523        elif inspect.iscoroutine(target):
524            # coroutine
525            coro = target
526        elif callable(target):
527            raise RuntimeError("Function is not a coroutine or coroutine function")
528        else:
529            raise RuntimeError("Target is missing")
530
531        # Use asyncio.Task directly with eager_start for immediate execution
532        task: asyncio.Task[_R] = asyncio.Task(coro, loop=self.loop, eager_start=eager_start)
533
534        if task_id is None:
535            task_id = uuid4().hex
536
537        def task_done_callback(_task: asyncio.Task[Any]) -> None:
538            self._tracked_tasks.pop(task_id, None)
539            # log unhandled exceptions
540            if (
541                LOGGER.isEnabledFor(logging.DEBUG)
542                and not _task.cancelled()
543                and (err := _task.exception())
544            ):
545                task_name = _task.get_name() if hasattr(_task, "get_name") else str(_task)
546                LOGGER.warning(
547                    "Exception in task %s - target: %s: %s",
548                    task_name,
549                    str(target),
550                    str(err),
551                    exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None,
552                )
553
554        self._tracked_tasks[task_id] = task
555        task.add_done_callback(task_done_callback)
556        return task
557
558    def call_later(
559        self,
560        delay: float,
561        target: Coroutine[Any, Any, _R] | Awaitable[_R] | Callable[..., _R],
562        *args: Any,
563        task_id: str | None = None,
564        **kwargs: Any,
565    ) -> asyncio.TimerHandle:
566        """
567        Run callable/awaitable after given delay.
568
569        Use task_id for debouncing.
570        """
571        self.verify_event_loop_thread("call_later")
572
573        if not task_id:
574            task_id = uuid4().hex
575
576        if existing := self._tracked_timers.get(task_id):
577            existing.cancel()
578
579        def _create_task(_target: Coroutine[Any, Any, _R]) -> None:
580            self._tracked_timers.pop(task_id)
581            self.create_task(_target, *args, task_id=task_id, abort_existing=True, **kwargs)
582
583        def _call_sync(_target: Callable[..., _R]) -> None:
584            self._tracked_timers.pop(task_id)
585            _target(*args, **kwargs)
586
587        if inspect.iscoroutinefunction(target) or inspect.iscoroutine(target):
588            # coroutine function
589            if TYPE_CHECKING:
590                target = cast("Coroutine[Any, Any, _R]", target)
591            handle = self.loop.call_later(delay, _create_task, target)
592        else:
593            # regular sync callable
594            if TYPE_CHECKING:
595                target = cast("Callable[..., _R]", target)
596            handle = self.loop.call_later(delay, _call_sync, target)
597        self._tracked_timers[task_id] = handle
598        return handle
599
600    def get_task(self, task_id: str) -> asyncio.Task[Any] | None:
601        """Get existing scheduled task."""
602        if existing := self._tracked_tasks.get(task_id):
603            # prevent duplicate tasks if task_id is given and already present
604            return existing
605        return None
606
607    def cancel_task(self, task_id: str) -> None:
608        """Cancel existing scheduled task."""
609        if existing := self._tracked_tasks.pop(task_id, None):
610            existing.cancel()
611
612    def cancel_timer(self, task_id: str) -> None:
613        """Cancel existing scheduled timer."""
614        if existing := self._tracked_timers.pop(task_id, None):
615            existing.cancel()
616
617    def register_api_command(
618        self,
619        command: str,
620        handler: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
621        authenticated: bool = True,
622        required_role: str | None = None,
623        alias: bool = False,
624    ) -> Callable[[], None]:
625        """Dynamically register a command on the API.
626
627        :param command: The command name/path.
628        :param handler: The function to handle the command.
629        :param authenticated: Whether authentication is required (default: True).
630        :param required_role: Required user role ("admin" or "user")
631            None means any authenticated user.
632        :param alias: Whether this is an alias for backward compatibility (default: False).
633            Aliases are not shown in API documentation but remain functional.
634
635        Returns handle to unregister.
636        """
637        if command in self.command_handlers:
638            msg = f"Command {command} is already registered"
639            raise RuntimeError(msg)
640        self.command_handlers[command] = APICommandHandler.parse(
641            command, handler, authenticated, required_role, alias
642        )
643
644        def unregister() -> None:
645            self.command_handlers.pop(command, None)
646
647        return unregister
648
649    async def load_provider_config(
650        self,
651        prov_conf: ProviderConfig,
652    ) -> None:
653        """Try to load a provider and catch errors."""
654        # cancel existing (re)load timer if needed
655        task_id = f"load_provider_{prov_conf.instance_id}"
656        if existing := self._tracked_timers.pop(task_id, None):
657            existing.cancel()
658
659        await self._load_provider(prov_conf)
660
661        # (re)load any dependants
662        prov_configs = await self.config.get_provider_configs(include_values=True)
663        for dep_prov_conf in prov_configs:
664            if not dep_prov_conf.enabled:
665                continue
666            manifest = self.get_provider_manifest(dep_prov_conf.domain)
667            if not manifest.depends_on:
668                continue
669            if manifest.depends_on == prov_conf.domain:
670                await self._load_provider(dep_prov_conf)
671
672    async def load_provider(
673        self,
674        instance_id: str,
675        allow_retry: bool = False,
676    ) -> None:
677        """Try to load a provider and catch errors."""
678        try:
679            prov_conf = await self.config.get_provider_config(instance_id)
680        except KeyError:
681            # Was deleted before we could run
682            return
683
684        if not prov_conf.enabled:
685            # Was disabled before we could run
686            return
687
688        # cancel existing (re)load timer if needed
689        task_id = f"load_provider_{instance_id}"
690        if existing := self._tracked_timers.pop(task_id, None):
691            existing.cancel()
692
693        try:
694            await self.load_provider_config(prov_conf)
695        except Exception as exc:
696            # if loading failed, we store the error in the config object
697            # so we can show something useful to the user
698            prov_conf.last_error = str(exc)
699            self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", str(exc))
700
701            # auto schedule a retry if the (re)load failed with a handled exception
702            # unhandled exceptions (e.g. ValueError) are likely bugs that won't resolve themselves
703            will_retry = allow_retry and isinstance(exc, MusicAssistantError)
704            if will_retry:
705                self.call_later(
706                    120,
707                    self.load_provider,
708                    instance_id,
709                    allow_retry,
710                    task_id=task_id,
711                )
712            LOGGER.warning(
713                "Error loading provider(instance) %s: %s%s",
714                prov_conf.name or prov_conf.instance_id,
715                str(exc) or exc.__class__.__name__,
716                " (will be retried later)" if will_retry else "",
717                # log full stack trace if verbose logging is enabled
718                exc_info=exc if LOGGER.isEnabledFor(VERBOSE_LOG_LEVEL) else None,
719            )
720            return
721
722        # (re)load any dependents if needed
723        for dep_prov in self.providers:
724            if dep_prov.available:
725                continue
726            if dep_prov.manifest.depends_on == prov_conf.domain:
727                await self.unload_provider(dep_prov.instance_id)
728
729    async def unload_provider(self, instance_id: str, is_removed: bool = False) -> None:
730        """Unload a provider."""
731        self.music.unschedule_provider_sync(instance_id)
732        if provider := self._providers.get(instance_id):
733            # remove mdns discovery if needed
734            if provider.manifest.mdns_discovery:
735                for mdns_type in provider.manifest.mdns_discovery:
736                    self._aiobrowser.types.discard(mdns_type)
737            if isinstance(provider, PlayerProvider):
738                await self.players.on_provider_unload(provider)
739            if isinstance(provider, MusicProvider):
740                await self.music.on_provider_unload(provider)
741            # check if there are no other providers dependent of this provider
742            for dep_prov in self.providers:
743                if dep_prov.manifest.depends_on == provider.domain:
744                    await self.unload_provider(dep_prov.instance_id)
745            if is_player_provider(provider):
746                # unregister all players of this provider
747                for player in provider.players:
748                    await self.players.unregister(player.player_id, permanent=is_removed)
749            try:
750                await provider.unload(is_removed)
751            except Exception as err:
752                LOGGER.warning(
753                    "Error while unloading provider %s: %s", provider.name, str(err), exc_info=err
754                )
755            finally:
756                self._providers.pop(instance_id, None)
757                await self._update_available_providers_cache()
758                self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
759
760    async def unload_provider_with_error(self, instance_id: str, error: str) -> None:
761        """Unload a provider when it got into trouble which needs user interaction."""
762        self.config.set(f"{CONF_PROVIDERS}/{instance_id}/last_error", error)
763        await self.unload_provider(instance_id)
764
765    async def run_provider_discovery(self, instance_id: str) -> None:
766        """
767        Run mDNS discovery for a given provider.
768
769        In case of a PlayerProvider, will also call its own discovery method.
770        """
771        provider = self.get_provider(instance_id, return_unavailable=False)
772        if not provider:
773            raise KeyError(f"Provider with instance ID {instance_id} not found")
774        if provider.manifest.mdns_discovery:
775            if provider.instance_id not in self._mdns_locks:
776                self._mdns_locks[provider.instance_id] = asyncio.Lock()
777            async with self._mdns_locks[provider.instance_id]:
778                for mdns_type in provider.manifest.mdns_discovery or []:
779                    for mdns_name in set(self.aiozc.zeroconf.cache.cache):
780                        if mdns_type not in mdns_name or mdns_type == mdns_name:
781                            continue
782                        info = AsyncServiceInfo(mdns_type, mdns_name)
783                        if await info.async_request(self.aiozc.zeroconf, 3000):
784                            await provider.on_mdns_service_state_change(
785                                mdns_name, ServiceStateChange.Added, info
786                            )
787        if isinstance(provider, PlayerProvider):
788            await provider.discover_players()
789
790    def verify_event_loop_thread(self, what: str) -> None:
791        """Report and raise if we are not running in the event loop thread."""
792        if self.loop_thread_id != threading.get_ident():
793            raise RuntimeError(
794                f"Non-Async operation detected: {what} may only be called from the eventloop."
795            )
796
797    def _register_api_commands(self) -> None:
798        """Register all methods decorated as api_command within a class(instance)."""
799        for cls in (
800            self,
801            self.config,
802            self.metadata,
803            self.music,
804            self.players,
805            self.player_queues,
806            self.webserver,
807            self.webserver.auth,
808        ):
809            for attr_name in dir(cls):
810                if attr_name.startswith("__"):
811                    continue
812                try:
813                    obj = getattr(cls, attr_name)
814                except (AttributeError, RuntimeError):
815                    # Skip properties that fail during initialization
816                    continue
817                if hasattr(obj, "api_cmd"):
818                    # method is decorated with our api decorator
819                    authenticated = getattr(obj, "api_authenticated", True)
820                    required_role = getattr(obj, "api_required_role", None)
821                    self.register_api_command(obj.api_cmd, obj, authenticated, required_role)
822
823    async def _load_builtin_providers(self) -> None:
824        """
825        Load all builtin providers.
826
827        Builtin providers are always needed (also in safe mode) and are fully awaited.
828        On error, setup will fail.
829        """
830        # create default config for any 'builtin' providers
831        for prov_manifest in self._provider_manifests.values():
832            if prov_manifest.type == ProviderType.CORE:
833                # core controllers are not real providers
834                continue
835            if not prov_manifest.builtin:
836                continue
837            await self.config.create_builtin_provider_config(prov_manifest.domain)
838
839        # load all configured (and enabled) builtin providers
840        prov_configs = await self.config.get_provider_configs(include_values=True)
841        builtin_configs: list[ProviderConfig] = [
842            prov_conf
843            for prov_conf in prov_configs
844            if (manifest := self._provider_manifests.get(prov_conf.domain))
845            and manifest.builtin
846            and (prov_conf.enabled or manifest.allow_disable is False)
847        ]
848
849        # load builtin providers and wait for them to complete
850        async with asyncio.TaskGroup() as tg:
851            for conf in builtin_configs:
852                tg.create_task(self.load_provider(conf.instance_id, allow_retry=True))
853
854    async def _load_providers(self) -> None:
855        """
856        Load regular (non-builtin) providers from config.
857
858        Regular providers are loaded in background tasks
859        and can fail without affecting core setup.
860        """
861        # handle default providers setup
862        self.config.set_default(CONF_DEFAULT_PROVIDERS_SETUP, set())
863        default_providers_setup = cast("set[str]", self.config.get(CONF_DEFAULT_PROVIDERS_SETUP))
864        changes_made = False
865        for default_provider, require_mdns in DEFAULT_PROVIDERS:
866            if default_provider in default_providers_setup:
867                # already processed/setup before, skip
868                continue
869            if not (manifest := self._provider_manifests.get(default_provider)):
870                continue
871            if require_mdns:
872                # if mdns discovery is required, check if we have seen any mdns entries
873                # for this provider before setting it up
874                for mdns_name in set(self.aiozc.zeroconf.cache.cache):
875                    if manifest.mdns_discovery and any(
876                        mdns_type in mdns_name for mdns_type in manifest.mdns_discovery
877                    ):
878                        break
879                else:
880                    continue
881            await self.config.create_builtin_provider_config(manifest.domain)
882            changes_made = True
883            # TEMP: migration - to be removed after 2.8 release
884            # enable all existing players of the default providers if they are not already enabled
885            # due to the linked protocol feature we introduced
886            for player_config in await self.config.get_player_configs(
887                provider=default_provider, include_disabled=True
888            ):
889                if player_config.enabled:
890                    continue
891                await self.config.save_player_config(player_config.player_id, {"enabled": True})
892            default_providers_setup.add(default_provider)
893        if changes_made:
894            self.config.set(CONF_DEFAULT_PROVIDERS_SETUP, default_providers_setup)
895            self.config.save(True)
896        # load all configured (and enabled) regular (non-builtin) providers
897        prov_configs = await self.config.get_provider_configs(include_values=True)
898        other_configs: list[ProviderConfig] = [
899            prov_conf
900            for prov_conf in prov_configs
901            if prov_conf.enabled
902            and (
903                not (manifest := self._provider_manifests.get(prov_conf.domain))
904                or not manifest.builtin
905            )
906        ]
907        # load providers concurrently via tasks
908        for prov_conf in other_configs:
909            # Use a task so we can load multiple providers at once.
910            # If a provider fails, that will not block the loading of other providers.
911            self.create_task(self.load_provider(prov_conf.instance_id, allow_retry=True))
912
913    async def _load_provider(self, conf: ProviderConfig) -> None:
914        """Load (or reload) a provider."""
915        # if provider is already loaded, stop and unload it first
916        await self.unload_provider(conf.instance_id)
917        LOGGER.debug("Loading provider %s", conf.name or conf.domain)
918        if not conf.enabled:
919            msg = "Provider is disabled"
920            raise SetupFailedError(msg)
921
922        # validate config
923        try:
924            conf.validate()
925        except (KeyError, ValueError, AttributeError, TypeError) as err:
926            msg = "Configuration is invalid"
927            raise SetupFailedError(msg) from err
928
929        domain = conf.domain
930        prov_manifest = self._provider_manifests.get(domain)
931        # check for other instances of this provider
932        existing = next((x for x in self.providers if x.domain == domain), None)
933        if existing and prov_manifest and not prov_manifest.multi_instance:
934            msg = f"Provider {domain} already loaded and only one instance allowed."
935            raise SetupFailedError(msg)
936        # check valid manifest (just in case)
937        if not prov_manifest:
938            msg = f"Provider {domain} manifest not found"
939            raise SetupFailedError(msg)
940
941        # handle dependency on other provider
942        if prov_manifest.depends_on and not self.get_provider(prov_manifest.depends_on):
943            # we can safely ignore this completely as the setup will be retried later
944            # automatically when the dependency is loaded
945            return
946
947        # try to setup the module
948        prov_mod = await load_provider_module(domain, prov_manifest.requirements)
949        try:
950            async with asyncio.timeout(30):
951                provider = await prov_mod.setup(self, prov_manifest, conf)
952        except TimeoutError as err:
953            msg = f"Provider {domain} did not load within 30 seconds"
954            raise SetupFailedError(msg) from err
955
956        # run async setup
957        await provider.handle_async_init()
958
959        # if we reach this point, the provider loaded successfully
960        self._providers[provider.instance_id] = provider
961        LOGGER.info(
962            "Loaded %s provider %s",
963            provider.type.value,
964            provider.name,
965        )
966        provider.available = True
967
968        # adapt logging name if needed
969        provider._set_log_level_from_config(provider.config)
970
971        # execute post load actions
972        async def _on_provider_loaded() -> None:
973            await provider.loaded_in_mass()
974            await self.run_provider_discovery(provider.instance_id)
975            # push instance name to config (to persist it if it was autogenerated)
976            if provider.default_name != conf.default_name:
977                self.config.set_provider_default_name(provider.instance_id, provider.default_name)
978
979        self.create_task(_on_provider_loaded())
980
981        # clear any previous error in config and signal update
982        self.config.set(f"{CONF_PROVIDERS}/{conf.instance_id}/last_error", None)
983        self.signal_event(EventType.PROVIDERS_UPDATED, data=self.get_providers())
984        await self._update_available_providers_cache()
985        if isinstance(provider, MusicProvider):
986            await self.music.on_provider_loaded(provider)
987        if isinstance(provider, PlayerProvider):
988            await self.players.on_provider_loaded(provider)
989
990    async def __load_provider_manifests(self) -> None:
991        """Preload all available provider manifest files."""
992
993        async def load_provider_manifest(provider_domain: str, provider_path: str) -> None:
994            """Preload all available provider manifest files."""
995            # get files in subdirectory
996            for file_str in await asyncio.to_thread(os.listdir, provider_path):  # noqa: PTH208, RUF100
997                file_path = os.path.join(provider_path, file_str)
998                if not await isfile(file_path):
999                    continue
1000                if file_str != "manifest.json":
1001                    continue
1002                try:
1003                    provider_manifest: ProviderManifest = await ProviderManifest.parse(file_path)
1004                    # check for icon.svg file
1005                    if not provider_manifest.icon_svg:
1006                        icon_path = os.path.join(provider_path, "icon.svg")
1007                        if await isfile(icon_path):
1008                            provider_manifest.icon_svg = await get_icon_string(icon_path)
1009                    # check for dark_icon file
1010                    if not provider_manifest.icon_svg_dark:
1011                        icon_path = os.path.join(provider_path, "icon_dark.svg")
1012                        if await isfile(icon_path):
1013                            provider_manifest.icon_svg_dark = await get_icon_string(icon_path)
1014                    # check for icon_monochrome file
1015                    if not provider_manifest.icon_svg_monochrome:
1016                        icon_path = os.path.join(provider_path, "icon_monochrome.svg")
1017                        if await isfile(icon_path):
1018                            provider_manifest.icon_svg_monochrome = await get_icon_string(icon_path)
1019                    # override Home Assistant provider if we're running as add-on
1020                    if provider_manifest.domain == "hass" and self.running_as_hass_addon:
1021                        provider_manifest.builtin = True
1022                        provider_manifest.allow_disable = False
1023
1024                    self._provider_manifests[provider_manifest.domain] = provider_manifest
1025                    LOGGER.log(
1026                        VERBOSE_LOG_LEVEL, "Loaded manifest for provider %s", provider_manifest.name
1027                    )
1028                except Exception as exc:
1029                    LOGGER.exception(
1030                        "Error while loading manifest for provider %s",
1031                        provider_domain,
1032                        exc_info=exc,
1033                    )
1034
1035        async with TaskManager(self) as tg:
1036            for dir_str in await asyncio.to_thread(os.listdir, PROVIDERS_PATH):  # noqa: PTH208, RUF100
1037                if dir_str.startswith("."):
1038                    # skip hidden directories
1039                    continue
1040                dir_path = os.path.join(PROVIDERS_PATH, dir_str)
1041                if dir_str.startswith("_") and not self.dev_mode:
1042                    # only load demo/test providers if debug mode is enabled (e.g. for development)
1043                    continue
1044                if not await isdir(dir_path):
1045                    continue
1046                tg.create_task(load_provider_manifest(dir_str, dir_path))
1047        self.logger.debug("Loaded %s provider manifests", len(self._provider_manifests))
1048
1049    async def _setup_discovery(self) -> None:
1050        """Handle setup of MDNS discovery."""
1051        # create a global mdns browser
1052        all_types: set[str] = set()
1053        for prov_manifest in self._provider_manifests.values():
1054            if prov_manifest.mdns_discovery:
1055                all_types.update(prov_manifest.mdns_discovery)
1056        self._aiobrowser = AsyncServiceBrowser(
1057            self.aiozc.zeroconf,
1058            list(all_types),
1059            handlers=[self._on_mdns_service_state_change],
1060        )
1061        # register MA itself on mdns to be discovered
1062        zeroconf_type = "_mass._tcp.local."
1063        server_id = self.server_id
1064        LOGGER.debug("Starting Zeroconf broadcast...")
1065        info = AsyncServiceInfo(
1066            zeroconf_type,
1067            name=f"{server_id}.{zeroconf_type}",
1068            addresses=[await get_ip_pton(self.webserver.publish_ip)],
1069            port=self.webserver.publish_port,
1070            properties=self.get_server_info().to_dict(),
1071            server="mass.local.",
1072        )
1073        try:
1074            existing = getattr(self, "mass_zc_service_set", None)
1075            if existing:
1076                await self.aiozc.async_update_service(info)
1077            else:
1078                await self.aiozc.async_register_service(info)
1079            self.mass_zc_service_set = True
1080        except NonUniqueNameException:
1081            LOGGER.error(
1082                "Music Assistant instance with identical name present in the local network!"
1083            )
1084
1085    def _on_mdns_service_state_change(
1086        self,
1087        zeroconf: Zeroconf,
1088        service_type: str,
1089        name: str,
1090        state_change: ServiceStateChange,
1091    ) -> None:
1092        """Handle MDNS service state callback."""
1093
1094        async def process_mdns_state_change(prov: ProviderInstanceType) -> None:
1095            if prov.instance_id not in self._mdns_locks:
1096                self._mdns_locks[prov.instance_id] = asyncio.Lock()
1097            if state_change == ServiceStateChange.Removed:
1098                info = None
1099            else:
1100                info = AsyncServiceInfo(service_type, name)
1101                await info.async_request(zeroconf, 3000)
1102            # use a lock per provider instance to avoid
1103            # race conditions in processing mdns events
1104            async with self._mdns_locks[prov.instance_id]:
1105                await prov.on_mdns_service_state_change(name, state_change, info)
1106
1107        LOGGER.log(
1108            VERBOSE_LOG_LEVEL,
1109            "Service %s of type %s state changed: %s",
1110            name,
1111            service_type,
1112            state_change,
1113        )
1114        for prov in list(self._providers.values()):
1115            if not prov.manifest.mdns_discovery:
1116                continue
1117            if not prov.available:
1118                continue
1119            if service_type in prov.manifest.mdns_discovery:
1120                self.create_task(process_mdns_state_change(prov))
1121
1122    async def __aenter__(self) -> Self:
1123        """Return Context manager."""
1124        await self.start()
1125        return self
1126
1127    async def __aexit__(
1128        self,
1129        exc_type: type[BaseException] | None,
1130        exc_val: BaseException | None,
1131        exc_tb: TracebackType | None,
1132    ) -> bool | None:
1133        """Exit context manager."""
1134        await self.stop()
1135        return None
1136
1137    async def _update_available_providers_cache(self) -> None:
1138        """Update the global cache variable of loaded/available providers."""
1139        await set_global_cache_values(
1140            {
1141                "provider_domains": {x.domain for x in self.providers},
1142                "provider_instance_ids": {x.instance_id for x in self.providers},
1143                "available_providers": {
1144                    *{x.domain for x in self.providers},
1145                    *{x.instance_id for x in self.providers},
1146                },
1147                "unique_providers": self.music.get_unique_providers(),
1148                "streaming_providers": {
1149                    x.domain
1150                    for x in self.providers
1151                    if is_music_provider(x) and x.is_streaming_provider
1152                },
1153                "non_streaming_providers": {
1154                    x.instance_id
1155                    for x in self.providers
1156                    if not (is_music_provider(x) and x.is_streaming_provider)
1157                },
1158            }
1159        )
1160
1161    async def _setup_storage(self) -> None:
1162        """Handle Setup of storage/cache folder(s)."""
1163        if not await isdir(self.storage_path):
1164            await mkdirs(self.storage_path)
1165        if not await isdir(self.cache_path):
1166            await mkdirs(self.cache_path)
1167