/
/
/
1"""Yandex Music provider implementation."""
2
3from __future__ import annotations
4
5from typing import TYPE_CHECKING
6
7from music_assistant_models.enums import MediaType
8from music_assistant_models.errors import (
9 InvalidDataError,
10 LoginFailed,
11 MediaNotFoundError,
12 ProviderUnavailableError,
13)
14from music_assistant_models.media_items import (
15 Album,
16 Artist,
17 ItemMapping,
18 MediaItemType,
19 Playlist,
20 SearchResults,
21 Track,
22)
23
24from music_assistant.controllers.cache import use_cache
25from music_assistant.models.music_provider import MusicProvider
26
27from .api_client import YandexMusicClient
28from .constants import CONF_TOKEN, PLAYLIST_ID_SPLITTER
29from .parsers import parse_album, parse_artist, parse_playlist, parse_track
30from .streaming import YandexMusicStreamingManager
31
32if TYPE_CHECKING:
33 from collections.abc import AsyncGenerator
34
35 from music_assistant_models.streamdetails import StreamDetails
36
37
38class YandexMusicProvider(MusicProvider):
39 """Implementation of a Yandex Music MusicProvider."""
40
41 _client: YandexMusicClient | None = None
42 _streaming: YandexMusicStreamingManager | None = None
43
44 @property
45 def client(self) -> YandexMusicClient:
46 """Return the Yandex Music client."""
47 if self._client is None:
48 raise ProviderUnavailableError("Provider not initialized")
49 return self._client
50
51 @property
52 def streaming(self) -> YandexMusicStreamingManager:
53 """Return the streaming manager."""
54 if self._streaming is None:
55 raise ProviderUnavailableError("Provider not initialized")
56 return self._streaming
57
58 async def handle_async_init(self) -> None:
59 """Handle async initialization of the provider."""
60 token = self.config.get_value(CONF_TOKEN)
61 if not token:
62 raise LoginFailed("No Yandex Music token provided")
63
64 self._client = YandexMusicClient(str(token))
65 await self._client.connect()
66 self._streaming = YandexMusicStreamingManager(self)
67 self.logger.info("Successfully connected to Yandex Music")
68
69 async def unload(self, is_removed: bool = False) -> None:
70 """Handle unload/close of the provider.
71
72 :param is_removed: Whether the provider is being removed.
73 """
74 if self._client:
75 await self._client.disconnect()
76 self._client = None
77 self._streaming = None
78 await super().unload(is_removed)
79
80 def get_item_mapping(self, media_type: MediaType | str, key: str, name: str) -> ItemMapping:
81 """Create a generic item mapping.
82
83 :param media_type: The media type.
84 :param key: The item ID.
85 :param name: The item name.
86 :return: An ItemMapping instance.
87 """
88 if isinstance(media_type, str):
89 media_type = MediaType(media_type)
90 return ItemMapping(
91 media_type=media_type,
92 item_id=key,
93 provider=self.instance_id,
94 name=name,
95 )
96
97 # Search
98
99 @use_cache(3600 * 24 * 14)
100 async def search(
101 self, search_query: str, media_types: list[MediaType], limit: int = 5
102 ) -> SearchResults:
103 """Perform search on Yandex Music.
104
105 :param search_query: The search query.
106 :param media_types: List of media types to search for.
107 :param limit: Maximum number of results per type.
108 :return: SearchResults with found items.
109 """
110 result = SearchResults()
111
112 # Determine search type based on requested media types
113 # Map MediaType to Yandex API search type
114 type_mapping = {
115 MediaType.TRACK: "track",
116 MediaType.ALBUM: "album",
117 MediaType.ARTIST: "artist",
118 MediaType.PLAYLIST: "playlist",
119 }
120 requested_types = [type_mapping[mt] for mt in media_types if mt in type_mapping]
121
122 # Use specific type if only one requested, otherwise search all
123 search_type = requested_types[0] if len(requested_types) == 1 else "all"
124
125 search_result = await self.client.search(search_query, search_type=search_type, limit=limit)
126 if not search_result:
127 return result
128
129 # Parse tracks
130 if MediaType.TRACK in media_types and search_result.tracks:
131 for track in search_result.tracks.results[:limit]:
132 try:
133 result.tracks = [*result.tracks, parse_track(self, track)]
134 except InvalidDataError as err:
135 self.logger.debug("Error parsing track: %s", err)
136
137 # Parse albums
138 if MediaType.ALBUM in media_types and search_result.albums:
139 for album in search_result.albums.results[:limit]:
140 try:
141 result.albums = [*result.albums, parse_album(self, album)]
142 except InvalidDataError as err:
143 self.logger.debug("Error parsing album: %s", err)
144
145 # Parse artists
146 if MediaType.ARTIST in media_types and search_result.artists:
147 for artist in search_result.artists.results[:limit]:
148 try:
149 result.artists = [*result.artists, parse_artist(self, artist)]
150 except InvalidDataError as err:
151 self.logger.debug("Error parsing artist: %s", err)
152
153 # Parse playlists
154 if MediaType.PLAYLIST in media_types and search_result.playlists:
155 for playlist in search_result.playlists.results[:limit]:
156 try:
157 result.playlists = [*result.playlists, parse_playlist(self, playlist)]
158 except InvalidDataError as err:
159 self.logger.debug("Error parsing playlist: %s", err)
160
161 return result
162
163 # Get single items
164
165 @use_cache(3600 * 24 * 30)
166 async def get_artist(self, prov_artist_id: str) -> Artist:
167 """Get artist details by ID.
168
169 :param prov_artist_id: The provider artist ID.
170 :return: Artist object.
171 :raises MediaNotFoundError: If artist not found.
172 """
173 artist = await self.client.get_artist(prov_artist_id)
174 if not artist:
175 raise MediaNotFoundError(f"Artist {prov_artist_id} not found")
176 return parse_artist(self, artist)
177
178 @use_cache(3600 * 24 * 30)
179 async def get_album(self, prov_album_id: str) -> Album:
180 """Get album details by ID.
181
182 :param prov_album_id: The provider album ID.
183 :return: Album object.
184 :raises MediaNotFoundError: If album not found.
185 """
186 album = await self.client.get_album(prov_album_id)
187 if not album:
188 raise MediaNotFoundError(f"Album {prov_album_id} not found")
189 return parse_album(self, album)
190
191 @use_cache(3600 * 24 * 30)
192 async def get_track(self, prov_track_id: str) -> Track:
193 """Get track details by ID.
194
195 :param prov_track_id: The provider track ID.
196 :return: Track object.
197 :raises MediaNotFoundError: If track not found.
198 """
199 yandex_track = await self.client.get_track(prov_track_id)
200 if not yandex_track:
201 raise MediaNotFoundError(f"Track {prov_track_id} not found")
202 return parse_track(self, yandex_track)
203
204 @use_cache(3600 * 24 * 30)
205 async def get_playlist(self, prov_playlist_id: str) -> Playlist:
206 """Get playlist details by ID.
207
208 :param prov_playlist_id: The provider playlist ID (format: "owner_id:kind").
209 :return: Playlist object.
210 :raises MediaNotFoundError: If playlist not found.
211 """
212 # Parse the playlist ID (format: owner_id:kind)
213 if PLAYLIST_ID_SPLITTER in prov_playlist_id:
214 owner_id, kind = prov_playlist_id.split(PLAYLIST_ID_SPLITTER, 1)
215 else:
216 owner_id = str(self.client.user_id)
217 kind = prov_playlist_id
218
219 playlist = await self.client.get_playlist(owner_id, kind)
220 if not playlist:
221 raise MediaNotFoundError(f"Playlist {prov_playlist_id} not found")
222 return parse_playlist(self, playlist)
223
224 # Get related items
225
226 @use_cache(3600 * 24 * 30)
227 async def get_album_tracks(self, prov_album_id: str) -> list[Track]:
228 """Get album tracks.
229
230 :param prov_album_id: The provider album ID.
231 :return: List of Track objects.
232 """
233 album = await self.client.get_album_with_tracks(prov_album_id)
234 if not album or not album.volumes:
235 return []
236
237 tracks = []
238 for volume_index, volume in enumerate(album.volumes):
239 for track_index, track in enumerate(volume):
240 try:
241 parsed_track = parse_track(self, track)
242 parsed_track.disc_number = volume_index + 1
243 parsed_track.track_number = track_index + 1
244 tracks.append(parsed_track)
245 except InvalidDataError as err:
246 self.logger.debug("Error parsing album track: %s", err)
247 return tracks
248
249 @use_cache(3600 * 3)
250 async def get_playlist_tracks(self, prov_playlist_id: str, page: int = 0) -> list[Track]:
251 """Get playlist tracks.
252
253 :param prov_playlist_id: The provider playlist ID (format: "owner_id:kind").
254 :param page: Page number for pagination.
255 :return: List of Track objects.
256 """
257 # Parse the playlist ID (format: owner_id:kind)
258 if PLAYLIST_ID_SPLITTER in prov_playlist_id:
259 owner_id, kind = prov_playlist_id.split(PLAYLIST_ID_SPLITTER, 1)
260 else:
261 owner_id = str(self.client.user_id)
262 kind = prov_playlist_id
263
264 playlist = await self.client.get_playlist(owner_id, kind)
265 if not playlist or not playlist.tracks:
266 return []
267
268 # Yandex returns TrackShort objects, we need to fetch full track info
269 track_ids = [
270 str(track.track_id) if hasattr(track, "track_id") else str(track.id)
271 for track in playlist.tracks
272 if track
273 ]
274
275 if not track_ids:
276 return []
277
278 # Fetch full track details
279 full_tracks = await self.client.get_tracks(track_ids)
280 tracks = []
281 for track in full_tracks:
282 try:
283 tracks.append(parse_track(self, track))
284 except InvalidDataError as err:
285 self.logger.debug("Error parsing playlist track: %s", err)
286 return tracks
287
288 @use_cache(3600 * 24 * 7)
289 async def get_artist_albums(self, prov_artist_id: str) -> list[Album]:
290 """Get artist's albums.
291
292 :param prov_artist_id: The provider artist ID.
293 :return: List of Album objects.
294 """
295 albums = await self.client.get_artist_albums(prov_artist_id)
296 result = []
297 for album in albums:
298 try:
299 result.append(parse_album(self, album))
300 except InvalidDataError as err:
301 self.logger.debug("Error parsing artist album: %s", err)
302 return result
303
304 @use_cache(3600 * 24 * 7)
305 async def get_artist_toptracks(self, prov_artist_id: str) -> list[Track]:
306 """Get artist's top tracks.
307
308 :param prov_artist_id: The provider artist ID.
309 :return: List of Track objects.
310 """
311 tracks = await self.client.get_artist_tracks(prov_artist_id)
312 result = []
313 for track in tracks:
314 try:
315 result.append(parse_track(self, track))
316 except InvalidDataError as err:
317 self.logger.debug("Error parsing artist track: %s", err)
318 return result
319
320 # Library methods
321
322 async def get_library_artists(self) -> AsyncGenerator[Artist, None]:
323 """Retrieve library artists from Yandex Music."""
324 artists = await self.client.get_liked_artists()
325 for artist in artists:
326 try:
327 yield parse_artist(self, artist)
328 except InvalidDataError as err:
329 self.logger.debug("Error parsing library artist: %s", err)
330
331 async def get_library_albums(self) -> AsyncGenerator[Album, None]:
332 """Retrieve library albums from Yandex Music."""
333 albums = await self.client.get_liked_albums()
334 for album in albums:
335 try:
336 yield parse_album(self, album)
337 except InvalidDataError as err:
338 self.logger.debug("Error parsing library album: %s", err)
339
340 async def get_library_tracks(self) -> AsyncGenerator[Track, None]:
341 """Retrieve library tracks from Yandex Music."""
342 track_shorts = await self.client.get_liked_tracks()
343 if not track_shorts:
344 return
345
346 # Fetch full track details in batches
347 track_ids = [str(ts.track_id) for ts in track_shorts if ts.track_id]
348 batch_size = 50
349 for i in range(0, len(track_ids), batch_size):
350 batch_ids = track_ids[i : i + batch_size]
351 full_tracks = await self.client.get_tracks(batch_ids)
352 for track in full_tracks:
353 try:
354 yield parse_track(self, track)
355 except InvalidDataError as err:
356 self.logger.debug("Error parsing library track: %s", err)
357
358 async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]:
359 """Retrieve library playlists from Yandex Music."""
360 playlists = await self.client.get_user_playlists()
361 for playlist in playlists:
362 try:
363 yield parse_playlist(self, playlist)
364 except InvalidDataError as err:
365 self.logger.debug("Error parsing library playlist: %s", err)
366
367 # Library edit methods
368
369 async def library_add(self, item: MediaItemType) -> bool:
370 """Add item to library.
371
372 :param item: The media item to add.
373 :return: True if successful.
374 """
375 prov_item_id = self._get_provider_item_id(item)
376 if not prov_item_id:
377 return False
378
379 if item.media_type == MediaType.TRACK:
380 return await self.client.like_track(prov_item_id)
381 if item.media_type == MediaType.ALBUM:
382 return await self.client.like_album(prov_item_id)
383 if item.media_type == MediaType.ARTIST:
384 return await self.client.like_artist(prov_item_id)
385 return False
386
387 async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool:
388 """Remove item from library.
389
390 :param prov_item_id: The provider item ID.
391 :param media_type: The media type.
392 :return: True if successful.
393 """
394 if media_type == MediaType.TRACK:
395 return await self.client.unlike_track(prov_item_id)
396 if media_type == MediaType.ALBUM:
397 return await self.client.unlike_album(prov_item_id)
398 if media_type == MediaType.ARTIST:
399 return await self.client.unlike_artist(prov_item_id)
400 return False
401
402 def _get_provider_item_id(self, item: MediaItemType) -> str | None:
403 """Get provider item ID from media item."""
404 for mapping in item.provider_mappings:
405 if mapping.provider_instance == self.instance_id:
406 return mapping.item_id
407 return item.item_id if item.provider == self.instance_id else None
408
409 # Streaming
410
411 async def get_stream_details(
412 self, item_id: str, media_type: MediaType = MediaType.TRACK
413 ) -> StreamDetails:
414 """Get stream details for a track.
415
416 :param item_id: The track ID.
417 :param media_type: The media type (should be TRACK).
418 :return: StreamDetails for the track.
419 """
420 return await self.streaming.get_stream_details(item_id)
421