music-assistant-server

8.2 KBPY
provider.py
8.2 KB212 lines • python
1"""Model/base for a Provider implementation within Music Assistant."""
2
3from __future__ import annotations
4
5import logging
6from typing import TYPE_CHECKING, Any, final
7
8from music_assistant_models.errors import UnsupportedFeaturedException
9
10from music_assistant.constants import CONF_LOG_LEVEL, MASS_LOGGER_NAME
11
12if TYPE_CHECKING:
13    from music_assistant_models.config_entries import ProviderConfig
14    from music_assistant_models.enums import ProviderFeature, ProviderStage, ProviderType
15    from music_assistant_models.provider import ProviderManifest
16    from zeroconf import ServiceStateChange
17    from zeroconf.asyncio import AsyncServiceInfo
18
19    from music_assistant.mass import MusicAssistant
20
21
22class Provider:
23    """Base representation of a Provider implementation within Music Assistant."""
24
25    mass: MusicAssistant
26    manifest: ProviderManifest
27    config: ProviderConfig
28
29    def __init__(
30        self,
31        mass: MusicAssistant,
32        manifest: ProviderManifest,
33        config: ProviderConfig,
34        supported_features: set[ProviderFeature] | None = None,
35    ) -> None:
36        """Initialize MusicProvider."""
37        self.mass = mass
38        self.manifest = manifest
39        self.config = config
40        self._supported_features = supported_features or set()
41        self._set_log_level_from_config(config)
42        self.cache = mass.cache
43        self.available = False
44
45    @property
46    def supported_features(self) -> set[ProviderFeature]:
47        """Return the features supported by this Provider."""
48        # should not be overridden in normal circumstances
49        return self._supported_features
50
51    async def handle_async_init(self) -> None:
52        """Handle async initialization of the provider."""
53
54    async def loaded_in_mass(self) -> None:
55        """Call after the provider has been loaded."""
56
57    async def unload(self, is_removed: bool = False) -> None:
58        """
59        Handle unload/close of the provider.
60
61        Called when provider is deregistered (e.g. MA exiting or config reloading).
62        is_removed will be set to True when the provider is removed from the configuration.
63        """
64
65    async def update_config(self, config: ProviderConfig, changed_keys: set[str]) -> None:
66        """
67        Handle logic when the config is updated.
68
69        Override this method in your provider implementation if you need
70        to perform any additional setup logic after the provider is registered and
71        the self.config was loaded, and whenever the config changes.
72
73        The default implementation reloads the provider on any config change
74        (except log-level-only changes), since provider reloads are lightweight
75        and most providers cache config values at setup time.
76        """
77        # always update the stored config so dynamic reads pick up new values
78        self.config = config
79
80        # update log level if changed
81        if f"values/{CONF_LOG_LEVEL}" in changed_keys or "name" in changed_keys:
82            self._set_log_level_from_config(config)
83
84        # reload if any non-log-level value keys changed
85        value_keys_changed = {
86            k for k in changed_keys if k.startswith("values/") and k != f"values/{CONF_LOG_LEVEL}"
87        }
88        if value_keys_changed:
89            self.logger.info(
90                "Config updated, reloading provider %s (instance_id=%s)",
91                self.domain,
92                self.instance_id,
93            )
94            task_id = f"provider_reload_{self.instance_id}"
95            self.mass.call_later(1, self.mass.load_provider_config, config, task_id=task_id)
96
97    async def on_mdns_service_state_change(
98        self, name: str, state_change: ServiceStateChange, info: AsyncServiceInfo | None
99    ) -> None:
100        """Handle MDNS service state callback."""
101
102    @property
103    @final
104    def type(self) -> ProviderType:
105        """Return type of this provider."""
106        return self.manifest.type
107
108    @property
109    @final
110    def domain(self) -> str:
111        """Return domain for this provider."""
112        return self.manifest.domain
113
114    @property
115    @final
116    def instance_id(self) -> str:
117        """Return instance_id for this provider(instance)."""
118        return self.config.instance_id
119
120    @property
121    @final
122    def name(self) -> str:
123        """Return (custom) friendly name for this provider instance."""
124        if self.config.name:
125            # always prefer user-set name from config
126            return self.config.name
127        return self.default_name
128
129    @property
130    @final
131    def default_name(self) -> str:
132        """Return a default friendly name for this provider instance."""
133        # create default name based on instance count
134        prov_confs = self.mass.config.get("providers", {}).values()
135        instances = [x["instance_id"] for x in prov_confs if x["domain"] == self.domain]
136        if len(instances) <= 1:
137            # only one instance (or no instances yet at all) - return provider name
138            return self.manifest.name
139        instance_name_postfix = self.instance_name_postfix
140        if not instance_name_postfix:
141            # default implementation - simply use the instance number/index
142            instance_name_postfix = str(instances.index(self.instance_id) + 1)
143        # append instance name to provider name
144        return f"{self.manifest.name} [{self.instance_name_postfix}]"
145
146    @property
147    def instance_name_postfix(self) -> str | None:
148        """Return a (default) instance name postfix for this provider instance."""
149        return None
150
151    @property
152    @final
153    def stage(self) -> ProviderStage:
154        """Return the stage of this provider."""
155        return self.manifest.stage
156
157    def unload_with_error(self, error: str) -> None:
158        """Unload provider with error message."""
159        self.mass.call_later(1, self.mass.unload_provider, self.instance_id, error)
160
161    def to_dict(self) -> dict[str, Any]:
162        """Return Provider(instance) as serializable dict."""
163        return {
164            "type": self.type.value,
165            "domain": self.domain,
166            "name": self.name,
167            "default_name": self.default_name,
168            "instance_name_postfix": self.instance_name_postfix,
169            "instance_id": self.instance_id,
170            "lookup_key": self.instance_id,  # include for backwards compatibility
171            "supported_features": [x.value for x in self.supported_features],
172            "available": self.available,
173            "is_streaming_provider": getattr(self, "is_streaming_provider", None),
174        }
175
176    def supports_feature(self, feature: ProviderFeature) -> bool:
177        """Return True if this provider supports the given feature."""
178        return feature in self.supported_features
179
180    def check_feature(self, feature: ProviderFeature) -> None:
181        """Check if this provider supports the given feature."""
182        if not self.supports_feature(feature):
183            raise UnsupportedFeaturedException(
184                f"Provider {self.name} does not support feature {feature.name}"
185            )
186
187    def _update_config_value(self, key: str, value: Any, encrypted: bool = False) -> None:
188        """Update a config value."""
189        self.mass.config.set_raw_provider_config_value(self.instance_id, key, value, encrypted)
190        # also update the cached copy within the provider instance
191        self.config.values[key].value = value
192
193    def _set_log_level_from_config(self, config: ProviderConfig) -> None:
194        """Set log level from config."""
195        mass_logger = logging.getLogger(MASS_LOGGER_NAME)
196        # self.name is only available after async_init. Otherwise we run into a race condition.
197        # see https://github.com/music-assistant/support/issues/4801
198        logging_name = self.domain
199        if getattr(self, "available", False):
200            # async_init completed
201            logging_name = self.name
202        self.logger = mass_logger.getChild(logging_name)
203        log_level = str(config.get_value(CONF_LOG_LEVEL))
204        if log_level == "GLOBAL":
205            self.logger.setLevel(mass_logger.level)
206        else:
207            self.logger.setLevel(log_level)
208        if logging.getLogger().level > self.logger.level:
209            # if the root logger's level is higher, we need to adjust that too
210            logging.getLogger().setLevel(self.logger.level)
211        self.logger.debug("Log level configured to %s", log_level)
212