music-assistant-server

27.1 KBPY
api_client.py
27.1 KB678 lines • python
1"""API client wrapper for KION Music."""
2
3from __future__ import annotations
4
5import logging
6from datetime import UTC, datetime
7from typing import TYPE_CHECKING, Any, cast
8
9from music_assistant_models.errors import (
10    LoginFailed,
11    ProviderUnavailableError,
12    ResourceTemporarilyUnavailable,
13)
14from yandex_music import Album as YandexAlbum
15from yandex_music import Artist as YandexArtist
16from yandex_music import ClientAsync, Search, TrackShort
17from yandex_music import Playlist as YandexPlaylist
18from yandex_music import Track as YandexTrack
19from yandex_music.exceptions import BadRequestError, NetworkError, UnauthorizedError
20from yandex_music.utils.sign_request import get_sign_request
21
22if TYPE_CHECKING:
23    from yandex_music import DownloadInfo
24
25from .constants import DEFAULT_BASE_URL, DEFAULT_LIMIT, ROTOR_FEEDBACK_FROM, ROTOR_STATION_MY_MIX
26
27# get-file-info with quality=lossless returns FLAC; default /tracks/.../download-info often does not
28# Prefer flac-mp4/aac-mp4 (KION API moved to these formats around 2025)
29GET_FILE_INFO_CODECS = "flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4"
30
31LOGGER = logging.getLogger(__name__)
32
33
34class KionMusicClient:
35    """Wrapper around yandex-music-api ClientAsync."""
36
37    def __init__(self, token: str, base_url: str | None = None) -> None:
38        """Initialize the KION Music client.
39
40        :param token: KION Music OAuth token.
41        :param base_url: Optional API base URL (defaults to KION Music API).
42        """
43        self._token = token
44        self._base_url = base_url or DEFAULT_BASE_URL
45        self._client: ClientAsync | None = None
46        self._user_id: int | None = None
47
48    @property
49    def user_id(self) -> int:
50        """Return the user ID."""
51        if self._user_id is None:
52            raise ProviderUnavailableError("Client not initialized, call connect() first")
53        return self._user_id
54
55    async def connect(self) -> bool:
56        """Initialize the client and verify token validity.
57
58        :return: True if connection was successful.
59        :raises LoginFailed: If the token is invalid.
60        """
61        try:
62            self._client = await ClientAsync(self._token, base_url=self._base_url).init()
63            if self._client.me is None or self._client.me.account is None:
64                raise LoginFailed("Failed to get account info")
65            self._user_id = self._client.me.account.uid
66            LOGGER.debug("Connected to KION Music as user %s", self._user_id)
67            return True
68        except UnauthorizedError as err:
69            raise LoginFailed("Invalid KION Music token") from err
70        except NetworkError as err:
71            msg = "Network error connecting to KION Music"
72            raise ResourceTemporarilyUnavailable(msg) from err
73
74    async def disconnect(self) -> None:
75        """Disconnect the client."""
76        self._client = None
77        self._user_id = None
78
79    def _ensure_connected(self) -> ClientAsync:
80        """Ensure the client is connected and return it."""
81        if self._client is None:
82            raise ProviderUnavailableError("Client not connected, call connect() first")
83        return self._client
84
85    def _is_connection_error(self, err: Exception) -> bool:
86        """Return True if the exception indicates a connection or server drop."""
87        if isinstance(err, NetworkError):
88            return True
89        msg = str(err).lower()
90        return "disconnect" in msg or "connection" in msg or "timeout" in msg
91
92    async def _reconnect(self) -> None:
93        """Disconnect and connect again to recover from Server disconnected / connection errors."""
94        await self.disconnect()
95        await self.connect()
96
97    # Rotor (radio station) methods
98
99    async def get_rotor_station_tracks(
100        self,
101        station_id: str,
102        queue: str | int | None = None,
103    ) -> tuple[list[YandexTrack], str | None]:
104        """Get tracks from a rotor station (e.g. user:onyourwave or track:1234).
105
106        :param station_id: Station ID (e.g. ROTOR_STATION_MY_MIX or "track:1234" for similar).
107        :param queue: Optional track ID for pagination (first track of previous batch).
108        :return: Tuple of (list of track objects, batch_id for feedback or None).
109        """
110        for attempt in range(2):
111            client = self._ensure_connected()
112            try:
113                result = await client.rotor_station_tracks(station_id, settings2=True, queue=queue)
114                if not result or not result.sequence:
115                    return ([], result.batch_id if result else None)
116                track_ids = []
117                for seq in result.sequence:
118                    if seq.track is None:
119                        continue
120                    tid = getattr(seq.track, "id", None) or getattr(seq.track, "track_id", None)
121                    if tid is not None:
122                        track_ids.append(str(tid))
123                if not track_ids:
124                    return ([], result.batch_id if result else None)
125                full_tracks = await self.get_tracks(track_ids)
126                order_map = {str(t.id): t for t in full_tracks if hasattr(t, "id") and t.id}
127                ordered = [order_map[tid] for tid in track_ids if tid in order_map]
128                return (ordered, result.batch_id if result else None)
129            except BadRequestError as err:
130                LOGGER.warning("Error fetching rotor station %s tracks: %s", station_id, err)
131                return ([], None)
132            except (NetworkError, Exception) as err:
133                if attempt == 0 and self._is_connection_error(err):
134                    LOGGER.warning(
135                        "Connection error fetching rotor tracks, reconnecting: %s",
136                        err,
137                    )
138                    try:
139                        await self._reconnect()
140                    except Exception as recon_err:
141                        LOGGER.warning("Reconnect failed: %s", recon_err)
142                        return ([], None)
143                else:
144                    LOGGER.warning("Error fetching rotor station tracks: %s", err)
145                    return ([], None)
146        return ([], None)
147
148    async def get_my_mix_tracks(
149        self, queue: str | int | None = None
150    ) -> tuple[list[YandexTrack], str | None]:
151        """Get tracks from the My Mix (Мой Микс) radio station.
152
153        :param queue: Optional track ID of the last track from the previous batch (API uses it for
154            pagination; do not pass batch_id).
155        :return: Tuple of (list of track objects, batch_id for feedback).
156        """
157        return await self.get_rotor_station_tracks(ROTOR_STATION_MY_MIX, queue=queue)
158
159    async def send_rotor_station_feedback(
160        self,
161        station_id: str,
162        feedback_type: str,
163        *,
164        batch_id: str | None = None,
165        track_id: str | None = None,
166        total_played_seconds: int | None = None,
167    ) -> bool:
168        """Send rotor station feedback for My Mix recommendations.
169
170        Used to report radioStarted, trackStarted, trackFinished, skip so that
171        the service can improve subsequent recommendations.
172
173        :param station_id: Station ID (e.g. ROTOR_STATION_MY_MIX).
174        :param feedback_type: One of 'radioStarted', 'trackStarted', 'trackFinished', 'skip'.
175        :param batch_id: Optional batch ID from the last get_my_mix_tracks response.
176        :param track_id: Track ID (required for trackStarted, trackFinished, skip).
177        :param total_played_seconds: Seconds played (for trackFinished, skip).
178        :return: True if the request succeeded.
179        """
180        client = self._ensure_connected()
181        payload: dict[str, Any] = {
182            "type": feedback_type,
183            "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
184        }
185        if feedback_type == "radioStarted":
186            payload["from"] = ROTOR_FEEDBACK_FROM
187        if track_id is not None:
188            payload["trackId"] = track_id
189        if total_played_seconds is not None:
190            payload["totalPlayedSeconds"] = total_played_seconds
191        if batch_id is not None:
192            payload["batchId"] = batch_id
193
194        url = f"{client.base_url}/rotor/station/{station_id}/feedback"
195        for attempt in range(2):
196            client = self._ensure_connected()
197            try:
198                await client.request.post(url, payload)
199                return True
200            except BadRequestError as err:
201                LOGGER.debug("Rotor feedback %s failed: %s", feedback_type, err)
202                return False
203            except (NetworkError, Exception) as err:
204                if attempt == 0 and self._is_connection_error(err):
205                    LOGGER.warning(
206                        "Connection error on rotor feedback %s, reconnecting: %s",
207                        feedback_type,
208                        err,
209                    )
210                    try:
211                        await self._reconnect()
212                    except Exception as recon_err:
213                        LOGGER.debug("Reconnect failed: %s", recon_err)
214                        return False
215                else:
216                    LOGGER.debug("Rotor feedback %s failed: %s", feedback_type, err)
217                    return False
218        return False
219
220    # Library methods
221
222    async def get_liked_tracks(self) -> list[TrackShort]:
223        """Get user's liked tracks.
224
225        :return: List of liked track objects.
226        """
227        client = self._ensure_connected()
228        try:
229            result = await client.users_likes_tracks()
230            if result is None:
231                return []
232            return result.tracks or []
233        except (BadRequestError, NetworkError) as err:
234            LOGGER.error("Error fetching liked tracks: %s", err)
235            raise ResourceTemporarilyUnavailable("Failed to fetch liked tracks") from err
236
237    async def get_liked_albums(self, batch_size: int = 50) -> list[YandexAlbum]:
238        """Get user's liked albums with full details (including cover art).
239
240        The users_likes_albums endpoint returns minimal album data without
241        cover_uri, so we fetch full album details in batches afterwards.
242
243        :return: List of liked album objects with full details.
244        """
245        client = self._ensure_connected()
246        try:
247            result = await client.users_likes_albums()
248            if result is None:
249                return []
250            album_ids = [
251                str(like.album.id) for like in result if like.album is not None and like.album.id
252            ]
253            if not album_ids:
254                return []
255            # Fetch full album details in batches to get cover_uri and other metadata
256            # batch_size is now a parameter with default 50
257            full_albums: list[YandexAlbum] = []
258            for i in range(0, len(album_ids), batch_size):
259                batch = album_ids[i : i + batch_size]
260                try:
261                    batch_result = await client.albums(batch)
262                    if batch_result:
263                        full_albums.extend(batch_result)
264                except (BadRequestError, NetworkError) as batch_err:
265                    LOGGER.warning("Error fetching album details batch: %s", batch_err)
266                    # Fall back to minimal data for this batch
267                    batch_set = set(batch)
268                    for like in result:
269                        if (
270                            like.album is not None
271                            and like.album.id
272                            and str(like.album.id) in batch_set
273                        ):
274                            full_albums.append(like.album)
275            return full_albums
276        except (BadRequestError, NetworkError) as err:
277            LOGGER.error("Error fetching liked albums: %s", err)
278            raise ResourceTemporarilyUnavailable("Failed to fetch liked albums") from err
279
280    async def get_liked_artists(self) -> list[YandexArtist]:
281        """Get user's liked artists.
282
283        :return: List of liked artist objects.
284        """
285        client = self._ensure_connected()
286        try:
287            result = await client.users_likes_artists()
288            if result is None:
289                return []
290            return [like.artist for like in result if like.artist is not None]
291        except (BadRequestError, NetworkError) as err:
292            LOGGER.error("Error fetching liked artists: %s", err)
293            raise ResourceTemporarilyUnavailable("Failed to fetch liked artists") from err
294
295    async def get_user_playlists(self) -> list[YandexPlaylist]:
296        """Get user's playlists.
297
298        :return: List of playlist objects.
299        """
300        client = self._ensure_connected()
301        try:
302            result = await client.users_playlists_list()
303            if result is None:
304                return []
305            return list(result)
306        except (BadRequestError, NetworkError) as err:
307            LOGGER.error("Error fetching playlists: %s", err)
308            raise ResourceTemporarilyUnavailable("Failed to fetch playlists") from err
309
310    # Search
311
312    async def search(
313        self,
314        query: str,
315        search_type: str = "all",
316        limit: int = DEFAULT_LIMIT,
317    ) -> Search | None:
318        """Search for tracks, albums, artists, or playlists.
319
320        :param query: Search query string.
321        :param search_type: Type of search ('all', 'track', 'album', 'artist', 'playlist').
322        :param limit: Maximum number of results per type.
323        :return: Search results object.
324        """
325        client = self._ensure_connected()
326        try:
327            return await client.search(query, type_=search_type, page=0, nocorrect=False)
328        except (BadRequestError, NetworkError) as err:
329            LOGGER.error("Search error: %s", err)
330            raise ResourceTemporarilyUnavailable("Search failed") from err
331
332    # Get single items
333
334    async def get_track(self, track_id: str) -> YandexTrack | None:
335        """Get a single track by ID.
336
337        :param track_id: Track ID.
338        :return: Track object or None if not found.
339        """
340        client = self._ensure_connected()
341        try:
342            tracks = await client.tracks([track_id])
343            return tracks[0] if tracks else None
344        except (BadRequestError, NetworkError) as err:
345            LOGGER.error("Error fetching track %s: %s", track_id, err)
346            return None
347
348    async def get_tracks(self, track_ids: list[str]) -> list[YandexTrack]:
349        """Get multiple tracks by IDs.
350
351        :param track_ids: List of track IDs.
352        :return: List of track objects.
353        :raises ResourceTemporarilyUnavailable: On network errors after retry.
354        """
355        client = self._ensure_connected()
356        try:
357            result = await client.tracks(track_ids)
358            return result or []
359        except NetworkError as err:
360            # Retry once on network errors (timeout, disconnect, etc.)
361            LOGGER.warning("Network error fetching tracks, retrying once: %s", err)
362            try:
363                result = await client.tracks(track_ids)
364                return result or []
365            except NetworkError as retry_err:
366                LOGGER.error("Error fetching tracks (retry failed): %s", retry_err)
367                raise ResourceTemporarilyUnavailable("Failed to fetch tracks") from retry_err
368        except BadRequestError as err:
369            LOGGER.error("Error fetching tracks: %s", err)
370            return []
371
372    async def get_album(self, album_id: str) -> YandexAlbum | None:
373        """Get a single album by ID.
374
375        :param album_id: Album ID.
376        :return: Album object or None if not found.
377        """
378        client = self._ensure_connected()
379        try:
380            albums = await client.albums([album_id])
381            return albums[0] if albums else None
382        except (BadRequestError, NetworkError) as err:
383            LOGGER.error("Error fetching album %s: %s", album_id, err)
384            return None
385
386    async def get_album_with_tracks(self, album_id: str) -> YandexAlbum | None:
387        """Get an album with its tracks.
388
389        Uses the same semantics as the web client: albums/{id}/with-tracks
390        with resumeStream, richTracks, withListeningFinished when the library
391        passes them through.
392
393        :param album_id: Album ID.
394        :return: Album object with tracks or None if not found.
395        """
396        client = self._ensure_connected()
397        try:
398            return await client.albums_with_tracks(
399                album_id,
400                resumeStream=True,
401                richTracks=True,
402                withListeningFinished=True,
403            )
404        except TypeError:
405            # Older yandex-music may not accept these kwargs
406            return await client.albums_with_tracks(album_id)
407        except (BadRequestError, NetworkError) as err:
408            LOGGER.error("Error fetching album with tracks %s: %s", album_id, err)
409            return None
410
411    async def get_artist(self, artist_id: str) -> YandexArtist | None:
412        """Get a single artist by ID.
413
414        :param artist_id: Artist ID.
415        :return: Artist object or None if not found.
416        """
417        client = self._ensure_connected()
418        try:
419            artists = await client.artists([artist_id])
420            return artists[0] if artists else None
421        except (BadRequestError, NetworkError) as err:
422            LOGGER.error("Error fetching artist %s: %s", artist_id, err)
423            return None
424
425    async def get_artist_albums(
426        self, artist_id: str, limit: int = DEFAULT_LIMIT
427    ) -> list[YandexAlbum]:
428        """Get artist's albums.
429
430        :param artist_id: Artist ID.
431        :param limit: Maximum number of albums.
432        :return: List of album objects.
433        """
434        client = self._ensure_connected()
435        try:
436            result = await client.artists_direct_albums(artist_id, page=0, page_size=limit)
437            if result is None:
438                return []
439            return result.albums or []
440        except (BadRequestError, NetworkError) as err:
441            LOGGER.error("Error fetching artist albums %s: %s", artist_id, err)
442            return []
443
444    async def get_artist_tracks(
445        self, artist_id: str, limit: int = DEFAULT_LIMIT
446    ) -> list[YandexTrack]:
447        """Get artist's top tracks.
448
449        :param artist_id: Artist ID.
450        :param limit: Maximum number of tracks.
451        :return: List of track objects.
452        """
453        client = self._ensure_connected()
454        try:
455            result = await client.artists_tracks(artist_id, page=0, page_size=limit)
456            if result is None:
457                return []
458            return result.tracks or []
459        except (BadRequestError, NetworkError) as err:
460            LOGGER.error("Error fetching artist tracks %s: %s", artist_id, err)
461            return []
462
463    async def get_playlist(self, user_id: str, playlist_id: str) -> YandexPlaylist | None:
464        """Get a playlist by ID.
465
466        :param user_id: User ID (owner of the playlist).
467        :param playlist_id: Playlist ID (kind).
468        :return: Playlist object or None if not found.
469        :raises ResourceTemporarilyUnavailable: On network errors.
470        """
471        client = self._ensure_connected()
472        try:
473            result = await client.users_playlists(kind=int(playlist_id), user_id=user_id)
474            if isinstance(result, list):
475                return result[0] if result else None
476            return result
477        except NetworkError as err:
478            LOGGER.warning("Network error fetching playlist %s/%s: %s", user_id, playlist_id, err)
479            raise ResourceTemporarilyUnavailable("Failed to fetch playlist") from err
480        except BadRequestError as err:
481            LOGGER.error("Error fetching playlist %s/%s: %s", user_id, playlist_id, err)
482            return None
483
484    # Streaming
485
486    async def get_track_download_info(
487        self, track_id: str, get_direct_links: bool = True
488    ) -> list[DownloadInfo]:
489        """Get download info for a track.
490
491        :param track_id: Track ID.
492        :param get_direct_links: Whether to get direct download links.
493        :return: List of download info objects.
494        """
495        client = self._ensure_connected()
496        try:
497            result = await client.tracks_download_info(track_id, get_direct_links=get_direct_links)
498            return result or []
499        except (BadRequestError, NetworkError) as err:
500            LOGGER.error("Error fetching download info for track %s: %s", track_id, err)
501            return []
502
503    async def get_track_file_info_lossless(self, track_id: str) -> dict[str, Any] | None:
504        """Request lossless stream via get-file-info (quality=lossless).
505
506        The /tracks/{id}/download-info endpoint often returns only MP3; get-file-info
507        with quality=lossless and codecs=flac,... returns FLAC when available.
508
509        Includes retry with reconnect on transient connection errors so that a
510        momentary disconnect does not silently fall back to lossy quality.
511
512        :param track_id: Track ID.
513        :return: Parsed downloadInfo dict (url, codec, urls, ...) or None on error.
514        """
515
516        def _parse_file_info_result(raw: dict[str, Any] | None) -> dict[str, Any] | None:
517            if not raw or not isinstance(raw, dict):
518                return None
519            download_info = raw.get("download_info")
520            if not download_info or not download_info.get("url"):
521                return None
522            return cast("dict[str, Any]", download_info)
523
524        for attempt in range(2):
525            client = self._ensure_connected()
526            sign = get_sign_request(track_id)
527            base_params = {
528                "ts": sign.timestamp,
529                "trackId": track_id,
530                "quality": "lossless",
531                "codecs": GET_FILE_INFO_CODECS,
532                "sign": sign.value,
533            }
534
535            url = f"{client.base_url}/get-file-info"
536            params_encraw = {**base_params, "transports": "encraw"}
537            try:
538                result = await client.request.get(url, params=params_encraw)
539                return _parse_file_info_result(result)
540            except UnauthorizedError as err:
541                LOGGER.debug(
542                    "get-file-info lossless for track %s (transports=encraw): %s %s",
543                    track_id,
544                    type(err).__name__,
545                    getattr(err, "message", str(err)) or repr(err),
546                )
547                LOGGER.debug(
548                    "If you have KION Music Plus and this track has lossless, "
549                    "try a token from the web client (music.mts.ru)."
550                )
551                params_raw = {**base_params, "transports": "raw"}
552                try:
553                    result = await client.request.get(url, params=params_raw)
554                    return _parse_file_info_result(result)
555                except (BadRequestError, NetworkError, UnauthorizedError) as retry_err:
556                    LOGGER.debug(
557                        "get-file-info lossless for track %s (transports=raw): %s %s",
558                        track_id,
559                        type(retry_err).__name__,
560                        getattr(retry_err, "message", str(retry_err)) or repr(retry_err),
561                    )
562                    return None
563            except BadRequestError as err:
564                LOGGER.debug(
565                    "get-file-info lossless for track %s: %s %s",
566                    track_id,
567                    type(err).__name__,
568                    getattr(err, "message", str(err)) or repr(err),
569                )
570                return None
571            except (NetworkError, Exception) as err:
572                if attempt == 0 and self._is_connection_error(err):
573                    LOGGER.warning(
574                        "Connection error on get-file-info lossless for track %s, reconnecting: %s",
575                        track_id,
576                        err,
577                    )
578                    try:
579                        await self._reconnect()
580                    except Exception as recon_err:
581                        LOGGER.debug("Reconnect failed: %s", recon_err)
582                        return None
583                else:
584                    LOGGER.debug(
585                        "get-file-info lossless for track %s: %s %s",
586                        track_id,
587                        type(err).__name__,
588                        getattr(err, "message", str(err)) or repr(err),
589                    )
590                    return None
591        return None
592
593    # Library modifications
594
595    async def like_track(self, track_id: str) -> bool:
596        """Add a track to liked tracks.
597
598        :param track_id: Track ID to like.
599        :return: True if successful.
600        """
601        client = self._ensure_connected()
602        try:
603            result = await client.users_likes_tracks_add(track_id)
604            return result is not None
605        except (BadRequestError, NetworkError) as err:
606            LOGGER.error("Error liking track %s: %s", track_id, err)
607            return False
608
609    async def unlike_track(self, track_id: str) -> bool:
610        """Remove a track from liked tracks.
611
612        :param track_id: Track ID to unlike.
613        :return: True if successful.
614        """
615        client = self._ensure_connected()
616        try:
617            result = await client.users_likes_tracks_remove(track_id)
618            return result is not None
619        except (BadRequestError, NetworkError) as err:
620            LOGGER.error("Error unliking track %s: %s", track_id, err)
621            return False
622
623    async def like_album(self, album_id: str) -> bool:
624        """Add an album to liked albums.
625
626        :param album_id: Album ID to like.
627        :return: True if successful.
628        """
629        client = self._ensure_connected()
630        try:
631            result = await client.users_likes_albums_add(album_id)
632            return result is not None
633        except (BadRequestError, NetworkError) as err:
634            LOGGER.error("Error liking album %s: %s", album_id, err)
635            return False
636
637    async def unlike_album(self, album_id: str) -> bool:
638        """Remove an album from liked albums.
639
640        :param album_id: Album ID to unlike.
641        :return: True if successful.
642        """
643        client = self._ensure_connected()
644        try:
645            result = await client.users_likes_albums_remove(album_id)
646            return result is not None
647        except (BadRequestError, NetworkError) as err:
648            LOGGER.error("Error unliking album %s: %s", album_id, err)
649            return False
650
651    async def like_artist(self, artist_id: str) -> bool:
652        """Add an artist to liked artists.
653
654        :param artist_id: Artist ID to like.
655        :return: True if successful.
656        """
657        client = self._ensure_connected()
658        try:
659            result = await client.users_likes_artists_add(artist_id)
660            return result is not None
661        except (BadRequestError, NetworkError) as err:
662            LOGGER.error("Error liking artist %s: %s", artist_id, err)
663            return False
664
665    async def unlike_artist(self, artist_id: str) -> bool:
666        """Remove an artist from liked artists.
667
668        :param artist_id: Artist ID to unlike.
669        :return: True if successful.
670        """
671        client = self._ensure_connected()
672        try:
673            result = await client.users_likes_artists_remove(artist_id)
674            return result is not None
675        except (BadRequestError, NetworkError) as err:
676            LOGGER.error("Error unliking artist %s: %s", artist_id, err)
677            return False
678