music-assistant-server

17.3 KBPY
api.py
17.3 KB445 lines • python
1"""Helpers for dealing with API's to interact with Music Assistant."""
2
3from __future__ import annotations
4
5import importlib
6import inspect
7import logging
8from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable, Sequence
9from dataclasses import MISSING, dataclass
10from datetime import datetime
11from enum import Enum
12from types import NoneType, UnionType
13from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints
14
15from mashumaro.exceptions import MissingField
16from music_assistant_models.media_items.media_item import MediaItem
17
18from music_assistant.helpers.util import try_parse_bool
19
20LOGGER = logging.getLogger(__name__)
21
22_F = TypeVar("_F", bound=Callable[..., Any])
23
24# Cache for resolved type alias strings to avoid repeated imports
25_TYPE_ALIAS_CACHE: dict[str, Any] = {}
26
27
28def _resolve_string_type(type_str: str) -> Any:
29    """
30    Resolve a string type reference back to the actual type.
31
32    This is needed when type aliases like ConfigValueType are converted to strings
33    during type hint resolution to avoid isinstance() errors with complex unions.
34
35    Uses a module-level cache to avoid repeated imports.
36
37    :param type_str: String name of the type (e.g., "ConfigValueType").
38    :return: The actual type object, or the string if resolution fails.
39    """
40    # Check cache first
41    if type_str in _TYPE_ALIAS_CACHE:
42        return _TYPE_ALIAS_CACHE[type_str]
43
44    type_alias_map = {
45        "ConfigValueType": ("music_assistant_models.config_entries", "ConfigValueType"),
46        "MediaItemType": ("music_assistant_models.media_items", "MediaItemType"),
47    }
48
49    if type_str not in type_alias_map:
50        # Cache the string itself for unknown types
51        _TYPE_ALIAS_CACHE[type_str] = type_str
52        return type_str
53
54    module_name, type_name = type_alias_map[type_str]
55    try:
56        module = importlib.import_module(module_name)
57        resolved_type = getattr(module, type_name)
58        # Cache the successfully resolved type
59        _TYPE_ALIAS_CACHE[type_str] = resolved_type
60        return resolved_type
61    except (ImportError, AttributeError) as err:
62        LOGGER.warning("Failed to resolve type alias %s: %s", type_str, err)
63        # Cache the string to avoid repeated failed attempts
64        _TYPE_ALIAS_CACHE[type_str] = type_str
65        return type_str
66
67
68def _resolve_generic_type_args(
69    args: tuple[Any, ...],
70    func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
71    config_value_type: Any,
72    media_item_type: Any,
73) -> tuple[list[Any], bool]:
74    """Resolve TypeVars and type aliases in generic type arguments.
75
76    :param args: Type arguments from a generic type (e.g., from list[T] or dict[K, V])
77    :param func: The function being analyzed
78    :param config_value_type: The ConfigValueType type alias to compare against
79    :param media_item_type: The MediaItemType type alias to compare against
80    :return: Tuple of (resolved_args, changed) where changed is True if any args were modified
81    """
82    new_args: list[Any] = []
83    changed = False
84
85    for arg in args:
86        # Check if arg matches ConfigValueType union (type alias that was expanded)
87        if arg == config_value_type:
88            # Replace with string reference to preserve type alias
89            new_args.append("ConfigValueType")
90            changed = True
91        # Check if arg matches MediaItemType union (type alias that was expanded)
92        elif arg == media_item_type:
93            # Replace with string reference to preserve type alias
94            new_args.append("MediaItemType")
95            changed = True
96        elif isinstance(arg, TypeVar):
97            # For ItemCls, resolve to concrete type
98            if arg.__name__ == "ItemCls" and hasattr(func, "__self__"):
99                if hasattr(func.__self__, "item_cls"):
100                    new_args.append(func.__self__.item_cls)
101                    changed = True
102                else:
103                    new_args.append(arg)
104            # For ConfigValue TypeVars, resolve to string name
105            elif "ConfigValue" in arg.__name__:
106                new_args.append("ConfigValueType")
107                changed = True
108            else:
109                new_args.append(arg)
110        # Check if arg is a Union containing a TypeVar
111        elif get_origin(arg) in (Union, UnionType):
112            union_args = get_args(arg)
113            for union_arg in union_args:
114                if isinstance(union_arg, TypeVar) and union_arg.__bound__ is not None:
115                    # Resolve the TypeVar in the union
116                    union_arg_index = union_args.index(union_arg)
117                    resolved = _resolve_typevar_in_union(
118                        union_arg, func, union_args, union_arg_index
119                    )
120                    new_args.append(resolved)
121                    changed = True
122                    break
123            else:
124                # No TypeVar found in union, keep as-is
125                new_args.append(arg)
126        else:
127            new_args.append(arg)
128
129    return new_args, changed
130
131
132def _resolve_typevar_in_union(
133    arg: TypeVar,
134    func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
135    args: tuple[Any, ...],
136    i: int,
137) -> Any:
138    """Resolve a TypeVar found in a Union to its concrete type.
139
140    :param arg: The TypeVar to resolve.
141    :param func: The function being analyzed.
142    :param args: All args from the Union.
143    :param i: Index of the TypeVar in the args.
144    """
145    bound_type = arg.__bound__
146    if not bound_type or not hasattr(arg, "__name__"):
147        return bound_type
148
149    type_var_name = arg.__name__
150
151    # Map TypeVar names to their type alias names
152    if "ConfigValue" in type_var_name:
153        return "ConfigValueType"
154
155    if type_var_name == "ItemCls":
156        # Resolve ItemCls to the actual media item class (e.g., Artist, Album, Track)
157        if hasattr(func, "__self__") and hasattr(func.__self__, "item_cls"):
158            resolved_type = func.__self__.item_cls
159            # Preserve other types in the union (like None for Optional)
160            other_args = [a for j, a in enumerate(args) if j != i]
161            if other_args:
162                # Reconstruct union with resolved type
163                return Union[resolved_type, *other_args]
164            return resolved_type
165        # Fallback to bound if we can't get item_cls
166        return bound_type
167
168    # Check if the bound is MediaItemType by comparing the union
169    from music_assistant_models.media_items import (  # noqa: PLC0415
170        MediaItemType as media_item_type,  # noqa: N813
171    )
172
173    if bound_type == media_item_type:
174        return "MediaItemType"
175
176    # Fallback to the bound type
177    return bound_type
178
179
180@dataclass
181class APICommandHandler:
182    """Model for an API command handler."""
183
184    command: str
185    signature: inspect.Signature
186    type_hints: dict[str, Any]
187    target: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]]
188    authenticated: bool = True
189    required_role: str | None = None  # "admin" or "user" or None
190    alias: bool = False  # If True, this is an alias for backward compatibility
191
192    @classmethod
193    def parse(
194        cls,
195        command: str,
196        func: Callable[..., Coroutine[Any, Any, Any] | AsyncGenerator[Any, Any]],
197        authenticated: bool = True,
198        required_role: str | None = None,
199        alias: bool = False,
200    ) -> APICommandHandler:
201        """Parse APICommandHandler by providing a function.
202
203        :param command: The command name/path.
204        :param func: The function to handle the command.
205        :param authenticated: Whether authentication is required (default: True).
206        :param required_role: Required user role ("admin" or "user")
207            None for any authenticated user.
208        :param alias: Whether this is an alias for backward compatibility (default: False).
209        """
210        type_hints = get_type_hints(func)
211        # workaround for generic typevar ItemCls that needs to be resolved
212        # to the real media item type. TODO: find a better way to do this
213        # without this hack
214        # Import type aliases to compare against
215        from music_assistant_models.config_entries import (  # noqa: PLC0415
216            ConfigValueType as config_value_type,  # noqa: N813
217        )
218        from music_assistant_models.media_items import (  # noqa: PLC0415
219            MediaItemType as media_item_type,  # noqa: N813
220        )
221
222        for key, value in type_hints.items():
223            # Handle generic types (list, tuple, dict, etc.) that may contain TypeVars
224            # For example: list[ItemCls] should become list[Artist]
225            # For example: dict[str, ConfigValueType] should preserve ConfigValueType
226            origin = get_origin(value)
227            if origin in (list, tuple, set, frozenset, dict):
228                args = get_args(value)
229                if args:
230                    new_args, changed = _resolve_generic_type_args(
231                        args, func, config_value_type, media_item_type
232                    )
233                    if changed:
234                        # Reconstruct the generic type with resolved TypeVars
235                        type_hints[key] = origin[tuple(new_args)]
236                continue
237
238            # Handle Union types that may contain TypeVars
239            # For example: _ConfigValueT | ConfigValueType should become just "ConfigValueType"
240            # when _ConfigValueT is bound to ConfigValueType
241            if origin is Union or origin is UnionType:
242                args = get_args(value)
243                # Check if union contains a TypeVar
244                # If the TypeVar's bound is a union that was flattened into the current union,
245                # we can just use the bound type for documentation purposes
246                typevar_found = False
247                for i, arg in enumerate(args):
248                    if isinstance(arg, TypeVar) and arg.__bound__ is not None:
249                        typevar_found = True
250                        type_hints[key] = _resolve_typevar_in_union(arg, func, args, i)
251                        break
252                if typevar_found:
253                    continue
254            if not hasattr(value, "__name__"):
255                continue
256            if value.__name__ == "ItemCls":
257                type_hints[key] = func.__self__.item_cls  # type: ignore[attr-defined]
258            # Resolve TypeVars to their bound type for API documentation
259            # This handles cases like _ConfigValueT which should show as ConfigValueType
260            elif isinstance(value, TypeVar):
261                if value.__bound__ is not None:
262                    type_hints[key] = value.__bound__
263        return APICommandHandler(
264            command=command,
265            signature=inspect.signature(func),
266            type_hints=type_hints,
267            target=func,
268            authenticated=authenticated,
269            required_role=required_role,
270            alias=alias,
271        )
272
273
274def api_command(
275    command: str, authenticated: bool = True, required_role: str | None = None
276) -> Callable[[_F], _F]:
277    """Decorate a function as API route/command.
278
279    :param command: The command name/path.
280    :param authenticated: Whether authentication is required (default: True).
281    :param required_role: Required user role ("admin" or "user"), None means any authenticated user.
282    """
283
284    def decorate(func: _F) -> _F:
285        func.api_cmd = command  # type: ignore[attr-defined]
286        func.api_authenticated = authenticated  # type: ignore[attr-defined]
287        func.api_required_role = required_role  # type: ignore[attr-defined]
288        return func
289
290    return decorate
291
292
293def parse_arguments(
294    func_sig: inspect.Signature,
295    func_types: dict[str, Any],
296    args: dict[str, Any] | None,
297    strict: bool = False,
298) -> dict[str, Any]:
299    """Parse (and convert) incoming arguments to correct types."""
300    if args is None:
301        args = {}
302    final_args = {}
303    # ignore extra args if not strict
304    if strict:
305        for key, value in args.items():
306            if key not in func_sig.parameters:
307                raise KeyError(f"Invalid parameter: '{key}'")
308    # parse arguments to correct type
309    for name, param in func_sig.parameters.items():
310        value = args.get(name)
311        default = MISSING if param.default is inspect.Parameter.empty else param.default
312        try:
313            final_args[name] = parse_value(name, value, func_types[name], default)
314        except TypeError:
315            # retry one more time with allow_value_convert=True
316            final_args[name] = parse_value(
317                name, value, func_types[name], default, allow_value_convert=True
318            )
319    return final_args
320
321
322def parse_utc_timestamp(datetime_string: str) -> datetime:
323    """Parse datetime from string."""
324    return datetime.fromisoformat(datetime_string)
325
326
327def parse_value(  # noqa: PLR0911
328    name: str,
329    value: Any,
330    value_type: Any,
331    default: Any = MISSING,
332    allow_value_convert: bool = False,
333) -> Any:
334    """Try to parse a value from raw (json) data and type annotations."""
335    # Resolve string type hints early for proper handling
336    if isinstance(value_type, str):
337        value_type = _resolve_string_type(value_type)
338        # If still a string after resolution, return value as-is
339        if isinstance(value_type, str):
340            LOGGER.debug("Unknown string type hint: %s, returning value as-is", value_type)
341            return value
342
343    if isinstance(value, dict) and hasattr(value_type, "from_dict"):
344        # Only validate media_type for actual MediaItem subclasses, not for other classes
345        # like StreamDetails that have a media_type field for a different purpose
346        if (
347            "media_type" in value
348            and value_type.__name__ != "ItemMapping"
349            and issubclass(value_type, MediaItem)
350            and value["media_type"] != value_type.media_type
351        ):
352            msg = "Invalid MediaType"
353            raise ValueError(msg)
354        return value_type.from_dict(value)
355
356    if value is None and not isinstance(default, type(MISSING)):
357        return default
358    if value is None and value_type is NoneType:
359        return None
360    origin = get_origin(value_type)
361    if origin in (tuple, list, Sequence, Iterable):
362        # For abstract types like Sequence and Iterable, use list as the concrete type
363        concrete_type = list if origin in (Sequence, Iterable) else origin
364        return concrete_type(
365            parse_value(
366                name, subvalue, get_args(value_type)[0], allow_value_convert=allow_value_convert
367            )
368            for subvalue in value
369            if subvalue is not None
370        )
371    if origin is dict:
372        subkey_type = get_args(value_type)[0]
373        subvalue_type = get_args(value_type)[1]
374        return {
375            parse_value(subkey, subkey, subkey_type): parse_value(
376                f"{subkey}.value", subvalue, subvalue_type, allow_value_convert=allow_value_convert
377            )
378            for subkey, subvalue in value.items()
379        }
380    if origin is Union or origin is UnionType:
381        # try all possible types
382        sub_value_types = get_args(value_type)
383        for sub_arg_type in sub_value_types:
384            if value is NoneType and sub_arg_type is NoneType:
385                return value
386            # try them all until one succeeds
387            try:
388                return parse_value(
389                    name, value, sub_arg_type, allow_value_convert=allow_value_convert
390                )
391            except (KeyError, TypeError, ValueError, MissingField):
392                pass
393        # if we get to this point, all possibilities failed
394        # find out if we should raise or log this
395        err = (
396            f"Value {value} of type {type(value)} is invalid for {name}, "
397            f"expected value of type {value_type}"
398        )
399        if NoneType not in sub_value_types:
400            # raise exception, we have no idea how to handle this value
401            raise TypeError(err)
402        # failed to parse the (sub) value but None allowed, log only
403        logging.getLogger(__name__).warning(err)
404        return None
405    if origin is type:
406        assert isinstance(value, str)  # for type checking
407        return eval(value)
408    if value_type is Any:
409        return value
410    if value is None and value_type is not NoneType:
411        msg = f"`{name}` of type `{value_type}` is required."
412        raise KeyError(msg)
413
414    try:
415        if issubclass(value_type, Enum):
416            return value_type(value)
417        if issubclass(value_type, datetime):
418            assert isinstance(value, str)  # for type checking
419            return parse_utc_timestamp(value)
420    except TypeError:
421        # happens if value_type is not a class
422        pass
423
424    if allow_value_convert:
425        # allow conversion of common types/mistakes
426        if value_type is float and isinstance(value, int):
427            return float(value)
428        if value_type is int and isinstance(value, float):
429            return int(value)
430        if value_type is int and isinstance(value, str) and value.isnumeric():
431            return int(value)
432        if value_type is float and isinstance(value, str) and value.isnumeric():
433            return float(value)
434        if value_type is bool and isinstance(value, str | int):
435            return try_parse_bool(value)
436
437    if not isinstance(value, value_type):
438        # all options failed, raise exception
439        msg = (
440            f"Value {value} of type {type(value)} is invalid for {name}, "
441            f"expected value of type {value_type}"
442        )
443        raise TypeError(msg)
444    return value
445