/
/
/
1"""API client wrapper for KION Music."""
2
3from __future__ import annotations
4
5import logging
6from datetime import UTC, datetime
7from typing import TYPE_CHECKING, Any, cast
8
9from music_assistant_models.errors import (
10 LoginFailed,
11 ProviderUnavailableError,
12 ResourceTemporarilyUnavailable,
13)
14from yandex_music import Album as YandexAlbum
15from yandex_music import Artist as YandexArtist
16from yandex_music import ClientAsync, Search, TrackShort
17from yandex_music import Playlist as YandexPlaylist
18from yandex_music import Track as YandexTrack
19from yandex_music.exceptions import BadRequestError, NetworkError, UnauthorizedError
20from yandex_music.utils.sign_request import get_sign_request
21
22if TYPE_CHECKING:
23 from yandex_music import DownloadInfo
24
25from .constants import DEFAULT_BASE_URL, DEFAULT_LIMIT, ROTOR_FEEDBACK_FROM, ROTOR_STATION_MY_MIX
26
27# get-file-info with quality=lossless returns FLAC; default /tracks/.../download-info often does not
28# Prefer flac-mp4/aac-mp4 (KION API moved to these formats around 2025)
29GET_FILE_INFO_CODECS = "flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4"
30
31LOGGER = logging.getLogger(__name__)
32
33
34class KionMusicClient:
35 """Wrapper around yandex-music-api ClientAsync."""
36
37 def __init__(self, token: str, base_url: str | None = None) -> None:
38 """Initialize the KION Music client.
39
40 :param token: KION Music OAuth token.
41 :param base_url: Optional API base URL (defaults to KION Music API).
42 """
43 self._token = token
44 self._base_url = base_url or DEFAULT_BASE_URL
45 self._client: ClientAsync | None = None
46 self._user_id: int | None = None
47
48 @property
49 def user_id(self) -> int:
50 """Return the user ID."""
51 if self._user_id is None:
52 raise ProviderUnavailableError("Client not initialized, call connect() first")
53 return self._user_id
54
55 async def connect(self) -> bool:
56 """Initialize the client and verify token validity.
57
58 :return: True if connection was successful.
59 :raises LoginFailed: If the token is invalid.
60 """
61 try:
62 self._client = await ClientAsync(self._token, base_url=self._base_url).init()
63 if self._client.me is None or self._client.me.account is None:
64 raise LoginFailed("Failed to get account info")
65 self._user_id = self._client.me.account.uid
66 LOGGER.debug("Connected to KION Music as user %s", self._user_id)
67 return True
68 except UnauthorizedError as err:
69 raise LoginFailed("Invalid KION Music token") from err
70 except NetworkError as err:
71 msg = "Network error connecting to KION Music"
72 raise ResourceTemporarilyUnavailable(msg) from err
73
74 async def disconnect(self) -> None:
75 """Disconnect the client."""
76 self._client = None
77 self._user_id = None
78
79 def _ensure_connected(self) -> ClientAsync:
80 """Ensure the client is connected and return it."""
81 if self._client is None:
82 raise ProviderUnavailableError("Client not connected, call connect() first")
83 return self._client
84
85 def _is_connection_error(self, err: Exception) -> bool:
86 """Return True if the exception indicates a connection or server drop."""
87 if isinstance(err, NetworkError):
88 return True
89 msg = str(err).lower()
90 return "disconnect" in msg or "connection" in msg or "timeout" in msg
91
92 async def _reconnect(self) -> None:
93 """Disconnect and connect again to recover from Server disconnected / connection errors."""
94 await self.disconnect()
95 await self.connect()
96
97 # Rotor (radio station) methods
98
99 async def get_rotor_station_tracks(
100 self,
101 station_id: str,
102 queue: str | int | None = None,
103 ) -> tuple[list[YandexTrack], str | None]:
104 """Get tracks from a rotor station (e.g. user:onyourwave or track:1234).
105
106 :param station_id: Station ID (e.g. ROTOR_STATION_MY_MIX or "track:1234" for similar).
107 :param queue: Optional track ID for pagination (first track of previous batch).
108 :return: Tuple of (list of track objects, batch_id for feedback or None).
109 """
110 for attempt in range(2):
111 client = self._ensure_connected()
112 try:
113 result = await client.rotor_station_tracks(station_id, settings2=True, queue=queue)
114 if not result or not result.sequence:
115 return ([], result.batch_id if result else None)
116 track_ids = []
117 for seq in result.sequence:
118 if seq.track is None:
119 continue
120 tid = getattr(seq.track, "id", None) or getattr(seq.track, "track_id", None)
121 if tid is not None:
122 track_ids.append(str(tid))
123 if not track_ids:
124 return ([], result.batch_id if result else None)
125 full_tracks = await self.get_tracks(track_ids)
126 order_map = {str(t.id): t for t in full_tracks if hasattr(t, "id") and t.id}
127 ordered = [order_map[tid] for tid in track_ids if tid in order_map]
128 return (ordered, result.batch_id if result else None)
129 except BadRequestError as err:
130 LOGGER.warning("Error fetching rotor station %s tracks: %s", station_id, err)
131 return ([], None)
132 except (NetworkError, Exception) as err:
133 if attempt == 0 and self._is_connection_error(err):
134 LOGGER.warning(
135 "Connection error fetching rotor tracks, reconnecting: %s",
136 err,
137 )
138 try:
139 await self._reconnect()
140 except Exception as recon_err:
141 LOGGER.warning("Reconnect failed: %s", recon_err)
142 return ([], None)
143 else:
144 LOGGER.warning("Error fetching rotor station tracks: %s", err)
145 return ([], None)
146 return ([], None)
147
148 async def get_my_mix_tracks(
149 self, queue: str | int | None = None
150 ) -> tuple[list[YandexTrack], str | None]:
151 """Get tracks from the My Mix (Ðой ÐикÑ) radio station.
152
153 :param queue: Optional track ID of the last track from the previous batch (API uses it for
154 pagination; do not pass batch_id).
155 :return: Tuple of (list of track objects, batch_id for feedback).
156 """
157 return await self.get_rotor_station_tracks(ROTOR_STATION_MY_MIX, queue=queue)
158
159 async def send_rotor_station_feedback(
160 self,
161 station_id: str,
162 feedback_type: str,
163 *,
164 batch_id: str | None = None,
165 track_id: str | None = None,
166 total_played_seconds: int | None = None,
167 ) -> bool:
168 """Send rotor station feedback for My Mix recommendations.
169
170 Used to report radioStarted, trackStarted, trackFinished, skip so that
171 the service can improve subsequent recommendations.
172
173 :param station_id: Station ID (e.g. ROTOR_STATION_MY_MIX).
174 :param feedback_type: One of 'radioStarted', 'trackStarted', 'trackFinished', 'skip'.
175 :param batch_id: Optional batch ID from the last get_my_mix_tracks response.
176 :param track_id: Track ID (required for trackStarted, trackFinished, skip).
177 :param total_played_seconds: Seconds played (for trackFinished, skip).
178 :return: True if the request succeeded.
179 """
180 client = self._ensure_connected()
181 payload: dict[str, Any] = {
182 "type": feedback_type,
183 "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
184 }
185 if feedback_type == "radioStarted":
186 payload["from"] = ROTOR_FEEDBACK_FROM
187 if track_id is not None:
188 payload["trackId"] = track_id
189 if total_played_seconds is not None:
190 payload["totalPlayedSeconds"] = total_played_seconds
191 if batch_id is not None:
192 payload["batchId"] = batch_id
193
194 url = f"{client.base_url}/rotor/station/{station_id}/feedback"
195 for attempt in range(2):
196 client = self._ensure_connected()
197 try:
198 await client.request.post(url, payload)
199 return True
200 except BadRequestError as err:
201 LOGGER.debug("Rotor feedback %s failed: %s", feedback_type, err)
202 return False
203 except (NetworkError, Exception) as err:
204 if attempt == 0 and self._is_connection_error(err):
205 LOGGER.warning(
206 "Connection error on rotor feedback %s, reconnecting: %s",
207 feedback_type,
208 err,
209 )
210 try:
211 await self._reconnect()
212 except Exception as recon_err:
213 LOGGER.debug("Reconnect failed: %s", recon_err)
214 return False
215 else:
216 LOGGER.debug("Rotor feedback %s failed: %s", feedback_type, err)
217 return False
218 return False
219
220 # Library methods
221
222 async def get_liked_tracks(self) -> list[TrackShort]:
223 """Get user's liked tracks.
224
225 :return: List of liked track objects.
226 """
227 client = self._ensure_connected()
228 try:
229 result = await client.users_likes_tracks()
230 if result is None:
231 return []
232 return result.tracks or []
233 except (BadRequestError, NetworkError) as err:
234 LOGGER.error("Error fetching liked tracks: %s", err)
235 raise ResourceTemporarilyUnavailable("Failed to fetch liked tracks") from err
236
237 async def get_liked_albums(self, batch_size: int = 50) -> list[YandexAlbum]:
238 """Get user's liked albums with full details (including cover art).
239
240 The users_likes_albums endpoint returns minimal album data without
241 cover_uri, so we fetch full album details in batches afterwards.
242
243 :return: List of liked album objects with full details.
244 """
245 client = self._ensure_connected()
246 try:
247 result = await client.users_likes_albums()
248 if result is None:
249 return []
250 album_ids = [
251 str(like.album.id) for like in result if like.album is not None and like.album.id
252 ]
253 if not album_ids:
254 return []
255 # Fetch full album details in batches to get cover_uri and other metadata
256 # batch_size is now a parameter with default 50
257 full_albums: list[YandexAlbum] = []
258 for i in range(0, len(album_ids), batch_size):
259 batch = album_ids[i : i + batch_size]
260 try:
261 batch_result = await client.albums(batch)
262 if batch_result:
263 full_albums.extend(batch_result)
264 except (BadRequestError, NetworkError) as batch_err:
265 LOGGER.warning("Error fetching album details batch: %s", batch_err)
266 # Fall back to minimal data for this batch
267 batch_set = set(batch)
268 for like in result:
269 if (
270 like.album is not None
271 and like.album.id
272 and str(like.album.id) in batch_set
273 ):
274 full_albums.append(like.album)
275 return full_albums
276 except (BadRequestError, NetworkError) as err:
277 LOGGER.error("Error fetching liked albums: %s", err)
278 raise ResourceTemporarilyUnavailable("Failed to fetch liked albums") from err
279
280 async def get_liked_artists(self) -> list[YandexArtist]:
281 """Get user's liked artists.
282
283 :return: List of liked artist objects.
284 """
285 client = self._ensure_connected()
286 try:
287 result = await client.users_likes_artists()
288 if result is None:
289 return []
290 return [like.artist for like in result if like.artist is not None]
291 except (BadRequestError, NetworkError) as err:
292 LOGGER.error("Error fetching liked artists: %s", err)
293 raise ResourceTemporarilyUnavailable("Failed to fetch liked artists") from err
294
295 async def get_user_playlists(self) -> list[YandexPlaylist]:
296 """Get user's playlists.
297
298 :return: List of playlist objects.
299 """
300 client = self._ensure_connected()
301 try:
302 result = await client.users_playlists_list()
303 if result is None:
304 return []
305 return list(result)
306 except (BadRequestError, NetworkError) as err:
307 LOGGER.error("Error fetching playlists: %s", err)
308 raise ResourceTemporarilyUnavailable("Failed to fetch playlists") from err
309
310 # Search
311
312 async def search(
313 self,
314 query: str,
315 search_type: str = "all",
316 limit: int = DEFAULT_LIMIT,
317 ) -> Search | None:
318 """Search for tracks, albums, artists, or playlists.
319
320 :param query: Search query string.
321 :param search_type: Type of search ('all', 'track', 'album', 'artist', 'playlist').
322 :param limit: Maximum number of results per type.
323 :return: Search results object.
324 """
325 client = self._ensure_connected()
326 try:
327 return await client.search(query, type_=search_type, page=0, nocorrect=False)
328 except (BadRequestError, NetworkError) as err:
329 LOGGER.error("Search error: %s", err)
330 raise ResourceTemporarilyUnavailable("Search failed") from err
331
332 # Get single items
333
334 async def get_track(self, track_id: str) -> YandexTrack | None:
335 """Get a single track by ID.
336
337 :param track_id: Track ID.
338 :return: Track object or None if not found.
339 """
340 client = self._ensure_connected()
341 try:
342 tracks = await client.tracks([track_id])
343 return tracks[0] if tracks else None
344 except (BadRequestError, NetworkError) as err:
345 LOGGER.error("Error fetching track %s: %s", track_id, err)
346 return None
347
348 async def get_tracks(self, track_ids: list[str]) -> list[YandexTrack]:
349 """Get multiple tracks by IDs.
350
351 :param track_ids: List of track IDs.
352 :return: List of track objects.
353 :raises ResourceTemporarilyUnavailable: On network errors after retry.
354 """
355 client = self._ensure_connected()
356 try:
357 result = await client.tracks(track_ids)
358 return result or []
359 except NetworkError as err:
360 # Retry once on network errors (timeout, disconnect, etc.)
361 LOGGER.warning("Network error fetching tracks, retrying once: %s", err)
362 try:
363 result = await client.tracks(track_ids)
364 return result or []
365 except NetworkError as retry_err:
366 LOGGER.error("Error fetching tracks (retry failed): %s", retry_err)
367 raise ResourceTemporarilyUnavailable("Failed to fetch tracks") from retry_err
368 except BadRequestError as err:
369 LOGGER.error("Error fetching tracks: %s", err)
370 return []
371
372 async def get_album(self, album_id: str) -> YandexAlbum | None:
373 """Get a single album by ID.
374
375 :param album_id: Album ID.
376 :return: Album object or None if not found.
377 """
378 client = self._ensure_connected()
379 try:
380 albums = await client.albums([album_id])
381 return albums[0] if albums else None
382 except (BadRequestError, NetworkError) as err:
383 LOGGER.error("Error fetching album %s: %s", album_id, err)
384 return None
385
386 async def get_album_with_tracks(self, album_id: str) -> YandexAlbum | None:
387 """Get an album with its tracks.
388
389 Uses the same semantics as the web client: albums/{id}/with-tracks
390 with resumeStream, richTracks, withListeningFinished when the library
391 passes them through.
392
393 :param album_id: Album ID.
394 :return: Album object with tracks or None if not found.
395 """
396 client = self._ensure_connected()
397 try:
398 return await client.albums_with_tracks(
399 album_id,
400 resumeStream=True,
401 richTracks=True,
402 withListeningFinished=True,
403 )
404 except TypeError:
405 # Older yandex-music may not accept these kwargs
406 return await client.albums_with_tracks(album_id)
407 except (BadRequestError, NetworkError) as err:
408 LOGGER.error("Error fetching album with tracks %s: %s", album_id, err)
409 return None
410
411 async def get_artist(self, artist_id: str) -> YandexArtist | None:
412 """Get a single artist by ID.
413
414 :param artist_id: Artist ID.
415 :return: Artist object or None if not found.
416 """
417 client = self._ensure_connected()
418 try:
419 artists = await client.artists([artist_id])
420 return artists[0] if artists else None
421 except (BadRequestError, NetworkError) as err:
422 LOGGER.error("Error fetching artist %s: %s", artist_id, err)
423 return None
424
425 async def get_artist_albums(
426 self, artist_id: str, limit: int = DEFAULT_LIMIT
427 ) -> list[YandexAlbum]:
428 """Get artist's albums.
429
430 :param artist_id: Artist ID.
431 :param limit: Maximum number of albums.
432 :return: List of album objects.
433 """
434 client = self._ensure_connected()
435 try:
436 result = await client.artists_direct_albums(artist_id, page=0, page_size=limit)
437 if result is None:
438 return []
439 return result.albums or []
440 except (BadRequestError, NetworkError) as err:
441 LOGGER.error("Error fetching artist albums %s: %s", artist_id, err)
442 return []
443
444 async def get_artist_tracks(
445 self, artist_id: str, limit: int = DEFAULT_LIMIT
446 ) -> list[YandexTrack]:
447 """Get artist's top tracks.
448
449 :param artist_id: Artist ID.
450 :param limit: Maximum number of tracks.
451 :return: List of track objects.
452 """
453 client = self._ensure_connected()
454 try:
455 result = await client.artists_tracks(artist_id, page=0, page_size=limit)
456 if result is None:
457 return []
458 return result.tracks or []
459 except (BadRequestError, NetworkError) as err:
460 LOGGER.error("Error fetching artist tracks %s: %s", artist_id, err)
461 return []
462
463 async def get_playlist(self, user_id: str, playlist_id: str) -> YandexPlaylist | None:
464 """Get a playlist by ID.
465
466 :param user_id: User ID (owner of the playlist).
467 :param playlist_id: Playlist ID (kind).
468 :return: Playlist object or None if not found.
469 :raises ResourceTemporarilyUnavailable: On network errors.
470 """
471 client = self._ensure_connected()
472 try:
473 result = await client.users_playlists(kind=int(playlist_id), user_id=user_id)
474 if isinstance(result, list):
475 return result[0] if result else None
476 return result
477 except NetworkError as err:
478 LOGGER.warning("Network error fetching playlist %s/%s: %s", user_id, playlist_id, err)
479 raise ResourceTemporarilyUnavailable("Failed to fetch playlist") from err
480 except BadRequestError as err:
481 LOGGER.error("Error fetching playlist %s/%s: %s", user_id, playlist_id, err)
482 return None
483
484 # Streaming
485
486 async def get_track_download_info(
487 self, track_id: str, get_direct_links: bool = True
488 ) -> list[DownloadInfo]:
489 """Get download info for a track.
490
491 :param track_id: Track ID.
492 :param get_direct_links: Whether to get direct download links.
493 :return: List of download info objects.
494 """
495 client = self._ensure_connected()
496 try:
497 result = await client.tracks_download_info(track_id, get_direct_links=get_direct_links)
498 return result or []
499 except (BadRequestError, NetworkError) as err:
500 LOGGER.error("Error fetching download info for track %s: %s", track_id, err)
501 return []
502
503 async def get_track_file_info_lossless(self, track_id: str) -> dict[str, Any] | None:
504 """Request lossless stream via get-file-info (quality=lossless).
505
506 The /tracks/{id}/download-info endpoint often returns only MP3; get-file-info
507 with quality=lossless and codecs=flac,... returns FLAC when available.
508
509 Includes retry with reconnect on transient connection errors so that a
510 momentary disconnect does not silently fall back to lossy quality.
511
512 :param track_id: Track ID.
513 :return: Parsed downloadInfo dict (url, codec, urls, ...) or None on error.
514 """
515
516 def _parse_file_info_result(raw: dict[str, Any] | None) -> dict[str, Any] | None:
517 if not raw or not isinstance(raw, dict):
518 return None
519 download_info = raw.get("download_info")
520 if not download_info or not download_info.get("url"):
521 return None
522 return cast("dict[str, Any]", download_info)
523
524 for attempt in range(2):
525 client = self._ensure_connected()
526 sign = get_sign_request(track_id)
527 base_params = {
528 "ts": sign.timestamp,
529 "trackId": track_id,
530 "quality": "lossless",
531 "codecs": GET_FILE_INFO_CODECS,
532 "sign": sign.value,
533 }
534
535 url = f"{client.base_url}/get-file-info"
536 params_encraw = {**base_params, "transports": "encraw"}
537 try:
538 result = await client.request.get(url, params=params_encraw)
539 return _parse_file_info_result(result)
540 except UnauthorizedError as err:
541 LOGGER.debug(
542 "get-file-info lossless for track %s (transports=encraw): %s %s",
543 track_id,
544 type(err).__name__,
545 getattr(err, "message", str(err)) or repr(err),
546 )
547 LOGGER.debug(
548 "If you have KION Music Plus and this track has lossless, "
549 "try a token from the web client (music.mts.ru)."
550 )
551 params_raw = {**base_params, "transports": "raw"}
552 try:
553 result = await client.request.get(url, params=params_raw)
554 return _parse_file_info_result(result)
555 except (BadRequestError, NetworkError, UnauthorizedError) as retry_err:
556 LOGGER.debug(
557 "get-file-info lossless for track %s (transports=raw): %s %s",
558 track_id,
559 type(retry_err).__name__,
560 getattr(retry_err, "message", str(retry_err)) or repr(retry_err),
561 )
562 return None
563 except BadRequestError as err:
564 LOGGER.debug(
565 "get-file-info lossless for track %s: %s %s",
566 track_id,
567 type(err).__name__,
568 getattr(err, "message", str(err)) or repr(err),
569 )
570 return None
571 except (NetworkError, Exception) as err:
572 if attempt == 0 and self._is_connection_error(err):
573 LOGGER.warning(
574 "Connection error on get-file-info lossless for track %s, reconnecting: %s",
575 track_id,
576 err,
577 )
578 try:
579 await self._reconnect()
580 except Exception as recon_err:
581 LOGGER.debug("Reconnect failed: %s", recon_err)
582 return None
583 else:
584 LOGGER.debug(
585 "get-file-info lossless for track %s: %s %s",
586 track_id,
587 type(err).__name__,
588 getattr(err, "message", str(err)) or repr(err),
589 )
590 return None
591 return None
592
593 # Library modifications
594
595 async def like_track(self, track_id: str) -> bool:
596 """Add a track to liked tracks.
597
598 :param track_id: Track ID to like.
599 :return: True if successful.
600 """
601 client = self._ensure_connected()
602 try:
603 result = await client.users_likes_tracks_add(track_id)
604 return result is not None
605 except (BadRequestError, NetworkError) as err:
606 LOGGER.error("Error liking track %s: %s", track_id, err)
607 return False
608
609 async def unlike_track(self, track_id: str) -> bool:
610 """Remove a track from liked tracks.
611
612 :param track_id: Track ID to unlike.
613 :return: True if successful.
614 """
615 client = self._ensure_connected()
616 try:
617 result = await client.users_likes_tracks_remove(track_id)
618 return result is not None
619 except (BadRequestError, NetworkError) as err:
620 LOGGER.error("Error unliking track %s: %s", track_id, err)
621 return False
622
623 async def like_album(self, album_id: str) -> bool:
624 """Add an album to liked albums.
625
626 :param album_id: Album ID to like.
627 :return: True if successful.
628 """
629 client = self._ensure_connected()
630 try:
631 result = await client.users_likes_albums_add(album_id)
632 return result is not None
633 except (BadRequestError, NetworkError) as err:
634 LOGGER.error("Error liking album %s: %s", album_id, err)
635 return False
636
637 async def unlike_album(self, album_id: str) -> bool:
638 """Remove an album from liked albums.
639
640 :param album_id: Album ID to unlike.
641 :return: True if successful.
642 """
643 client = self._ensure_connected()
644 try:
645 result = await client.users_likes_albums_remove(album_id)
646 return result is not None
647 except (BadRequestError, NetworkError) as err:
648 LOGGER.error("Error unliking album %s: %s", album_id, err)
649 return False
650
651 async def like_artist(self, artist_id: str) -> bool:
652 """Add an artist to liked artists.
653
654 :param artist_id: Artist ID to like.
655 :return: True if successful.
656 """
657 client = self._ensure_connected()
658 try:
659 result = await client.users_likes_artists_add(artist_id)
660 return result is not None
661 except (BadRequestError, NetworkError) as err:
662 LOGGER.error("Error liking artist %s: %s", artist_id, err)
663 return False
664
665 async def unlike_artist(self, artist_id: str) -> bool:
666 """Remove an artist from liked artists.
667
668 :param artist_id: Artist ID to unlike.
669 :return: True if successful.
670 """
671 client = self._ensure_connected()
672 try:
673 result = await client.users_likes_artists_remove(artist_id)
674 return result is not None
675 except (BadRequestError, NetworkError) as err:
676 LOGGER.error("Error unliking artist %s: %s", artist_id, err)
677 return False
678