music-assistant-server

8.1 KBPY
provider.py
8.1 KB208 lines • python
1"""Model/base for a Provider implementation within Music Assistant."""
2
3from __future__ import annotations
4
5import logging
6from typing import TYPE_CHECKING, Any, final
7
8from music_assistant_models.errors import UnsupportedFeaturedException
9
10from music_assistant.constants import CONF_LOG_LEVEL, MASS_LOGGER_NAME
11
12if TYPE_CHECKING:
13    from music_assistant_models.config_entries import ProviderConfig
14    from music_assistant_models.enums import ProviderFeature, ProviderStage, ProviderType
15    from music_assistant_models.provider import ProviderManifest
16    from zeroconf import ServiceStateChange
17    from zeroconf.asyncio import AsyncServiceInfo
18
19    from music_assistant.mass import MusicAssistant
20
21
22class Provider:
23    """Base representation of a Provider implementation within Music Assistant."""
24
25    def __init__(
26        self,
27        mass: MusicAssistant,
28        manifest: ProviderManifest,
29        config: ProviderConfig,
30        supported_features: set[ProviderFeature] | None = None,
31    ) -> None:
32        """Initialize MusicProvider."""
33        self.mass = mass
34        self.manifest = manifest
35        self.config = config
36        self._supported_features = supported_features or set()
37        self._set_log_level_from_config(config)
38        self.cache = mass.cache
39        self.available = False
40
41    @property
42    def supported_features(self) -> set[ProviderFeature]:
43        """Return the features supported by this Provider."""
44        # should not be overridden in normal circumstances
45        return self._supported_features
46
47    async def handle_async_init(self) -> None:
48        """Handle async initialization of the provider."""
49
50    async def loaded_in_mass(self) -> None:
51        """Call after the provider has been loaded."""
52
53    async def unload(self, is_removed: bool = False) -> None:
54        """
55        Handle unload/close of the provider.
56
57        Called when provider is deregistered (e.g. MA exiting or config reloading).
58        is_removed will be set to True when the provider is removed from the configuration.
59        """
60
61    async def update_config(self, config: ProviderConfig, changed_keys: set[str]) -> None:
62        """
63        Handle logic when the config is updated.
64
65        Override this method in your provider implementation if you need
66        to perform any additional setup logic after the provider is registered and
67        the self.config was loaded, and whenever the config changes.
68
69        The default implementation reloads the provider on any config change
70        (except log-level-only changes), since provider reloads are lightweight
71        and most providers cache config values at setup time.
72        """
73        # always update the stored config so dynamic reads pick up new values
74        self.config = config
75
76        # update log level if changed
77        if f"values/{CONF_LOG_LEVEL}" in changed_keys or "name" in changed_keys:
78            self._set_log_level_from_config(config)
79
80        # reload if any non-log-level value keys changed
81        value_keys_changed = {
82            k for k in changed_keys if k.startswith("values/") and k != f"values/{CONF_LOG_LEVEL}"
83        }
84        if value_keys_changed:
85            self.logger.info(
86                "Config updated, reloading provider %s (instance_id=%s)",
87                self.domain,
88                self.instance_id,
89            )
90            task_id = f"provider_reload_{self.instance_id}"
91            self.mass.call_later(1, self.mass.load_provider_config, config, task_id=task_id)
92
93    async def on_mdns_service_state_change(
94        self, name: str, state_change: ServiceStateChange, info: AsyncServiceInfo | None
95    ) -> None:
96        """Handle MDNS service state callback."""
97
98    @property
99    @final
100    def type(self) -> ProviderType:
101        """Return type of this provider."""
102        return self.manifest.type
103
104    @property
105    @final
106    def domain(self) -> str:
107        """Return domain for this provider."""
108        return self.manifest.domain
109
110    @property
111    @final
112    def instance_id(self) -> str:
113        """Return instance_id for this provider(instance)."""
114        return self.config.instance_id
115
116    @property
117    @final
118    def name(self) -> str:
119        """Return (custom) friendly name for this provider instance."""
120        if self.config.name:
121            # always prefer user-set name from config
122            return self.config.name
123        return self.default_name
124
125    @property
126    @final
127    def default_name(self) -> str:
128        """Return a default friendly name for this provider instance."""
129        # create default name based on instance count
130        prov_confs = self.mass.config.get("providers", {}).values()
131        instances = [x["instance_id"] for x in prov_confs if x["domain"] == self.domain]
132        if len(instances) <= 1:
133            # only one instance (or no instances yet at all) - return provider name
134            return self.manifest.name
135        instance_name_postfix = self.instance_name_postfix
136        if not instance_name_postfix:
137            # default implementation - simply use the instance number/index
138            instance_name_postfix = str(instances.index(self.instance_id) + 1)
139        # append instance name to provider name
140        return f"{self.manifest.name} [{self.instance_name_postfix}]"
141
142    @property
143    def instance_name_postfix(self) -> str | None:
144        """Return a (default) instance name postfix for this provider instance."""
145        return None
146
147    @property
148    @final
149    def stage(self) -> ProviderStage:
150        """Return the stage of this provider."""
151        return self.manifest.stage
152
153    def unload_with_error(self, error: str) -> None:
154        """Unload provider with error message."""
155        self.mass.call_later(1, self.mass.unload_provider, self.instance_id, error)
156
157    def to_dict(self) -> dict[str, Any]:
158        """Return Provider(instance) as serializable dict."""
159        return {
160            "type": self.type.value,
161            "domain": self.domain,
162            "name": self.name,
163            "default_name": self.default_name,
164            "instance_name_postfix": self.instance_name_postfix,
165            "instance_id": self.instance_id,
166            "lookup_key": self.instance_id,  # include for backwards compatibility
167            "supported_features": [x.value for x in self.supported_features],
168            "available": self.available,
169            "is_streaming_provider": getattr(self, "is_streaming_provider", None),
170        }
171
172    def supports_feature(self, feature: ProviderFeature) -> bool:
173        """Return True if this provider supports the given feature."""
174        return feature in self.supported_features
175
176    def check_feature(self, feature: ProviderFeature) -> None:
177        """Check if this provider supports the given feature."""
178        if not self.supports_feature(feature):
179            raise UnsupportedFeaturedException(
180                f"Provider {self.name} does not support feature {feature.name}"
181            )
182
183    def _update_config_value(self, key: str, value: Any, encrypted: bool = False) -> None:
184        """Update a config value."""
185        self.mass.config.set_raw_provider_config_value(self.instance_id, key, value, encrypted)
186        # also update the cached copy within the provider instance
187        self.config.values[key].value = value
188
189    def _set_log_level_from_config(self, config: ProviderConfig) -> None:
190        """Set log level from config."""
191        mass_logger = logging.getLogger(MASS_LOGGER_NAME)
192        # self.name is only available after async_init. Otherwise we run into a race condition.
193        # see https://github.com/music-assistant/support/issues/4801
194        logging_name = self.domain
195        if getattr(self, "available", False):
196            # async_init completed
197            logging_name = self.name
198        self.logger = mass_logger.getChild(logging_name)
199        log_level = str(config.get_value(CONF_LOG_LEVEL))
200        if log_level == "GLOBAL":
201            self.logger.setLevel(mass_logger.level)
202        else:
203            self.logger.setLevel(log_level)
204        if logging.getLogger().level > self.logger.level:
205            # if the root logger's level is higher, we need to adjust that too
206            logging.getLogger().setLevel(self.logger.level)
207        self.logger.debug("Log level configured to %s", log_level)
208