/
/
/
1"""Manage MediaItems of type Playlist."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, cast
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import (
10 InvalidDataError,
11 InvalidProviderURI,
12 MediaNotFoundError,
13 ProviderUnavailableError,
14)
15from music_assistant_models.media_items import (
16 Playlist,
17 Track,
18)
19
20from music_assistant.constants import (
21 DB_TABLE_PLAYLISTS,
22 PLAYLIST_MEDIA_TYPES,
23 PlaylistPlayableItem,
24)
25from music_assistant.helpers.compare import create_safe_string
26from music_assistant.helpers.database import UNSET
27from music_assistant.helpers.json import serialize_to_json
28from music_assistant.helpers.security import is_safe_name
29from music_assistant.helpers.uri import create_uri, parse_uri
30from music_assistant.helpers.util import guard_single_request
31from music_assistant.models.music_provider import MusicProvider
32
33from .base import MediaControllerBase
34
35if TYPE_CHECKING:
36 from music_assistant import MusicAssistant
37
38
39class PlaylistController(MediaControllerBase[Playlist]):
40 """Controller managing MediaItems of type Playlist."""
41
42 db_table = DB_TABLE_PLAYLISTS
43 media_type = MediaType.PLAYLIST
44 item_cls = Playlist
45
46 def __init__(self, mass: MusicAssistant) -> None:
47 """Initialize class."""
48 super().__init__(mass)
49 # register (extra) api handlers
50 api_base = self.api_base
51 self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist)
52 self.mass.register_api_command("music/playlists/playlist_tracks", self.tracks)
53 self.mass.register_api_command(
54 "music/playlists/add_playlist_tracks", self.add_playlist_tracks
55 )
56 self.mass.register_api_command(
57 "music/playlists/remove_playlist_tracks", self.remove_playlist_tracks
58 )
59
60 def _verify_update_allowed(self, current_item: Playlist, update: Playlist) -> None:
61 """Verify that the update is allowed from a security perspective.
62
63 Prevents updating item_id for non-streaming providers to prevent path traversal attacks.
64 """
65 # Build lookup dict of current mappings: provider_instance -> item_id
66 current_mappings = {
67 mapping.provider_instance: mapping.item_id for mapping in current_item.provider_mappings
68 }
69
70 # Check if any existing mapping's item_id has been modified for non-streaming providers
71 for update_mapping in update.provider_mappings:
72 # Only check if this is an existing mapping being modified
73 if update_mapping.provider_instance in current_mappings:
74 current_item_id = current_mappings[update_mapping.provider_instance]
75
76 # Disallow item_id changes for filesystem-based providers (filesystem, builtin)
77 if (
78 current_item_id != update_mapping.item_id
79 and update_mapping.provider_instance.startswith(("filesystem", "builtin"))
80 ):
81 msg = (
82 f"Updating item_id is not allowed for filesystem-based providers: "
83 f"attempted to change '{current_item_id}' to '{update_mapping.item_id}'"
84 )
85 raise InvalidDataError(msg)
86
87 async def tracks(
88 self,
89 item_id: str,
90 provider_instance_id_or_domain: str,
91 force_refresh: bool = False,
92 ) -> AsyncGenerator[PlaylistPlayableItem, None]:
93 """Return playlist tracks for the given provider playlist id."""
94 if provider_instance_id_or_domain == "library":
95 library_item = await self.get_library_item(item_id)
96 provider_instance_id_or_domain, item_id = self._select_provider_id(library_item)
97 # playlist tracks are not stored in the db,
98 # we always fetched them (cached) from the provider
99 page = 0
100 while True:
101 tracks = await self._get_provider_playlist_tracks(
102 item_id,
103 provider_instance_id_or_domain,
104 page=page,
105 force_refresh=force_refresh,
106 )
107 if not tracks:
108 break
109 for track in tracks:
110 yield track
111 page += 1
112
113 async def create_playlist(
114 self, name: str, provider_instance_or_domain: str | None = None
115 ) -> Playlist:
116 """Create new playlist."""
117 # if provider is omitted, just pick builtin provider
118 if provider_instance_or_domain:
119 provider = self.mass.get_provider(provider_instance_or_domain)
120 if provider is None:
121 raise ProviderUnavailableError
122 else:
123 provider = self.mass.get_provider("builtin")
124 # grab all existing track ids in the playlist so we can check for duplicates
125 provider = cast("MusicProvider", provider)
126
127 if not is_safe_name(name):
128 msg = f"{name} is not a valid Playlist name"
129 raise InvalidDataError(msg)
130 # create playlist on the provider
131 playlist = await provider.create_playlist(name)
132 for prov_mapping in playlist.provider_mappings:
133 # when manually creating a playlist, it's always in the library
134 prov_mapping.in_library = True
135 # add the new playlist to the library
136 return await self.add_item_to_library(playlist, False)
137
138 async def add_playlist_tracks(self, db_playlist_id: str | int, uris: list[str]) -> None:
139 """Add tracks to playlist."""
140 # ruff: noqa: PLR0915
141 db_id = int(db_playlist_id) # ensure integer
142 playlist = await self.get_library_item(db_id)
143 if not playlist:
144 msg = f"Playlist with id {db_id} not found"
145 raise MediaNotFoundError(msg)
146 if not playlist.is_editable:
147 msg = f"Playlist {playlist.name} is not editable"
148 raise InvalidDataError(msg)
149 # Validate uris to prevent code injection
150 for uri in uris:
151 # Prevent code injection via newlines in URIs
152 if "\n" in uri or "\r" in uri:
153 msg = "Invalid URI: newlines not allowed"
154 raise InvalidProviderURI(msg)
155 await parse_uri(uri)
156 # grab all existing track ids in the playlist so we can check for duplicates
157 # use _select_provider_id to respect user's provider filter
158 playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
159 playlist_prov = self.mass.get_provider(playlist_prov_instance)
160 if not playlist_prov or not playlist_prov.available:
161 raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
162 playlist_prov = cast("MusicProvider", playlist_prov)
163
164 # sets to track existing tracks
165 cur_playlist_track_ids: set[str] = set()
166 cur_playlist_track_uris: set[str] = set()
167
168 # collect current track IDs and URIs
169 async for item in self.tracks(playlist.item_id, playlist.provider):
170 if item.item_id:
171 cur_playlist_track_ids.add(item.item_id)
172 if item.uri:
173 cur_playlist_track_uris.add(item.uri)
174
175 # unwrap URIs to individual track URIs
176 unwrapped_uris: list[str] = []
177 for uri in uris:
178 # URI could be a playlist or album uri, unwrap it
179 if not ("://" in uri and len(uri.split("/")) >= 4):
180 # NOT a music assistant-style uri (provider://media_type/item_id)
181 self.logger.warning(
182 "Not adding %s to playlist %s - not a valid uri", uri, playlist.name
183 )
184 continue
185 # music assistant-style uri
186 # provider://media_type/item_id
187 provider_instance_id_or_domain, rest = uri.split("://", 1)
188 media_type_str, item_id = rest.split("/", 1)
189 media_type = MediaType(media_type_str)
190 if media_type == MediaType.ALBUM:
191 album_tracks = await self.mass.music.albums.tracks(
192 item_id, provider_instance_id_or_domain
193 )
194 for track in album_tracks:
195 if track.uri is not None:
196 unwrapped_uris.append(track.uri)
197 elif media_type == MediaType.PLAYLIST:
198 async for item in self.tracks(item_id, provider_instance_id_or_domain):
199 if item.uri is not None:
200 unwrapped_uris.append(item.uri)
201 elif media_type in PLAYLIST_MEDIA_TYPES:
202 unwrapped_uris.append(uri)
203 else:
204 self.logger.warning(
205 "Not adding %s to playlist %s - media type not supported in playlists",
206 uri,
207 playlist.name,
208 )
209 continue
210
211 # work out the track id's that need to be added
212 # filter out duplicates and items that not exist on the provider.
213 ids_to_add: list[str] = []
214 for uri in unwrapped_uris:
215 # skip if item already in the playlist
216 if uri in cur_playlist_track_uris:
217 self.logger.info(
218 "Not adding %s to playlist %s - it already exists",
219 uri,
220 playlist.name,
221 )
222 continue
223
224 # parse uri for further processing
225 media_type, provider_instance_id_or_domain, item_id = await parse_uri(uri)
226
227 # non-track items can only be added to builtin playlists
228 if media_type != MediaType.TRACK and playlist_prov.domain != "builtin":
229 self.logger.warning(
230 "Not adding %s to playlist %s - only supported in builtin playlists",
231 uri,
232 playlist.name,
233 )
234 continue
235
236 # skip if item already in the playlist
237 if item_id in cur_playlist_track_ids:
238 self.logger.warning(
239 "Not adding %s to playlist %s - it already exists",
240 uri,
241 playlist.name,
242 )
243 continue
244
245 # special: the builtin provider can handle uri's from all providers (with uri as id)
246 if playlist_prov.domain == "builtin":
247 # For non-library URIs, add directly (they're already portable provider URIs)
248 if provider_instance_id_or_domain != "library":
249 if uri not in ids_to_add:
250 ids_to_add.append(uri)
251 self.logger.info(
252 "Adding %s to playlist %s",
253 uri,
254 playlist.name,
255 )
256 continue
257 # For library URIs, convert to provider URIs to survive DB rebuilds
258 # Get the full item from library to access all provider mappings
259 full_item = await self.mass.music.get_item_by_uri(uri)
260 if not hasattr(full_item, "provider_mappings"):
261 self.logger.warning(
262 "Can't add %s to playlist %s - unsupported media type",
263 uri,
264 playlist.name,
265 )
266 continue
267
268 # For tracks, try to match to playlist provider
269 # For non-track items, just use first available mapping
270 provider_mappings = full_item.provider_mappings
271 if media_type == MediaType.TRACK:
272 # Cast to Track for mypy - we know it's a track from media_type check
273 full_track = cast("Track", full_item)
274 # Try to match the track to additional providers
275 track_prov_domains = {x.provider_domain for x in provider_mappings}
276 if (
277 playlist_prov.is_streaming_provider
278 and playlist_prov.domain not in track_prov_domains
279 ):
280 provider_mappings.update(
281 await self.mass.music.tracks.match_provider(
282 full_track, playlist_prov, strict=False
283 )
284 )
285
286 # Sort by quality (highest first) for deterministic selection
287 provider_mappings = sorted(provider_mappings, key=lambda x: x.quality, reverse=True)
288
289 # Add first available provider mapping
290 for prov_mapping in provider_mappings:
291 if not prov_mapping.available:
292 continue
293 item_prov = self.mass.get_provider(prov_mapping.provider_instance)
294 if not item_prov:
295 continue
296 # Create provider URI from the mapping
297 provider_uri = create_uri(
298 media_type,
299 item_prov.instance_id,
300 prov_mapping.item_id,
301 )
302 if (
303 provider_uri not in ids_to_add
304 and provider_uri not in cur_playlist_track_uris
305 ):
306 ids_to_add.append(provider_uri)
307 self.logger.info(
308 "Adding %s to playlist %s",
309 provider_uri,
310 playlist.name,
311 )
312 break
313 else:
314 self.logger.warning(
315 "Can't add %s to playlist %s - no available provider mapping",
316 uri,
317 playlist.name,
318 )
319 continue
320
321 # if target playlist is an exact provider match, we can add it
322 if provider_instance_id_or_domain != "library":
323 item_prov = self.mass.get_provider(provider_instance_id_or_domain)
324 if not item_prov or not item_prov.available:
325 self.logger.warning(
326 "Skip adding %s to playlist: Provider %s is not available",
327 uri,
328 provider_instance_id_or_domain,
329 )
330 continue
331 if item_prov.instance_id == playlist_prov.instance_id:
332 if item_id not in ids_to_add:
333 ids_to_add.append(item_id)
334 continue
335
336 # For provider-specific playlists: match tracks with quality sorting
337 # (Non-track items can only be added to builtin playlists, validated earlier)
338 full_track = await self.mass.music.tracks.get(
339 item_id,
340 provider_instance_id_or_domain,
341 recursive=provider_instance_id_or_domain != "library",
342 )
343 track_prov_domains = {x.provider_domain for x in full_track.provider_mappings}
344 if (
345 playlist_prov.domain != "builtin"
346 and playlist_prov.is_streaming_provider
347 and playlist_prov.domain not in track_prov_domains
348 ):
349 # try to match the track to the playlist provider
350 full_track.provider_mappings.update(
351 await self.mass.music.tracks.match_provider(
352 full_track, playlist_prov, strict=False
353 )
354 )
355
356 # a track can contain multiple versions on the same provider
357 # simply sort by quality and just add the first available version
358 for track_version in sorted(
359 full_track.provider_mappings, key=lambda x: x.quality, reverse=True
360 ):
361 if not track_version.available:
362 continue
363 if track_version.item_id in cur_playlist_track_ids:
364 break # already existing in the playlist
365 item_prov = self.mass.get_provider(track_version.provider_instance)
366 if not item_prov:
367 continue
368 track_version_uri = create_uri(
369 MediaType.TRACK,
370 item_prov.instance_id,
371 track_version.item_id,
372 )
373 if track_version_uri in cur_playlist_track_uris:
374 self.logger.warning(
375 "Not adding %s to playlist %s - it already exists",
376 full_track.name,
377 playlist.name,
378 )
379 break # already existing in the playlist
380 # Add track to provider-specific playlist
381 if item_prov.instance_id == playlist_prov.instance_id:
382 if track_version.item_id not in ids_to_add:
383 ids_to_add.append(track_version.item_id)
384 self.logger.info(
385 "Adding %s to playlist %s",
386 full_track.name,
387 playlist.name,
388 )
389 break
390 else:
391 self.logger.warning(
392 "Can't add %s to playlist %s - it is not available on provider %s",
393 full_track.name,
394 playlist.name,
395 playlist_prov.name,
396 )
397
398 if not ids_to_add:
399 return
400
401 # actually add the tracks to the playlist on the provider
402 await playlist_prov.add_playlist_tracks(playlist_prov_item_id, ids_to_add)
403 # invalidate cache so tracks get refreshed
404 self._refresh_playlist_tracks(playlist)
405 await self.update_item_in_library(db_playlist_id, playlist)
406
407 async def add_playlist_track(self, db_playlist_id: str | int, track_uri: str) -> None:
408 """Add (single) track to playlist."""
409 await self.add_playlist_tracks(db_playlist_id, [track_uri])
410
411 async def remove_playlist_tracks(
412 self, db_playlist_id: str | int, positions_to_remove: tuple[int, ...]
413 ) -> None:
414 """Remove multiple tracks from playlist."""
415 db_id = int(db_playlist_id) # ensure integer
416 playlist = await self.get_library_item(db_id)
417 if not playlist:
418 msg = f"Playlist with id {db_id} not found"
419 raise MediaNotFoundError(msg)
420 if not playlist.is_editable:
421 msg = f"Playlist {playlist.name} is not editable"
422 raise InvalidDataError(msg)
423 # use _select_provider_id to respect user's provider filter
424 playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
425 provider = self.mass.get_provider(playlist_prov_instance)
426 if not provider or not isinstance(provider, MusicProvider):
427 raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
428 if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
429 msg = f"Provider {provider.name} does not support editing playlists"
430 raise InvalidDataError(msg)
431 await provider.remove_playlist_tracks(playlist_prov_item_id, positions_to_remove)
432
433 await self.update_item_in_library(db_playlist_id, playlist)
434
435 async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int:
436 """Add a new record to the database."""
437 db_id = await self.mass.music.database.insert(
438 self.db_table,
439 {
440 "name": item.name,
441 "sort_name": item.sort_name,
442 "owner": item.owner,
443 "is_editable": item.is_editable,
444 "favorite": item.favorite,
445 "metadata": serialize_to_json(item.metadata),
446 "external_ids": serialize_to_json(item.external_ids),
447 "search_name": create_safe_string(item.name, True, True),
448 "search_sort_name": create_safe_string(item.sort_name or "", True, True),
449 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
450 },
451 )
452 # update/set provider_mappings table
453 await self.set_provider_mappings(db_id, item.provider_mappings)
454 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
455 return db_id
456
457 async def _update_library_item(
458 self, item_id: str | int, update: Playlist, overwrite: bool = False
459 ) -> None:
460 """Update existing record in the database."""
461 db_id = int(item_id) # ensure integer
462 cur_item = await self.get_library_item(db_id)
463 self._verify_update_allowed(cur_item, update)
464 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
465 cur_item.external_ids.update(update.external_ids)
466 name = update.name if overwrite else cur_item.name
467 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
468 await self.mass.music.database.update(
469 self.db_table,
470 {"item_id": db_id},
471 {
472 # always prefer name/owner from updated item here
473 "name": name,
474 "sort_name": sort_name,
475 "owner": update.owner or cur_item.owner,
476 "is_editable": update.is_editable,
477 "metadata": serialize_to_json(metadata),
478 "external_ids": serialize_to_json(
479 update.external_ids if overwrite else cur_item.external_ids
480 ),
481 "search_name": create_safe_string(name, True, True),
482 "search_sort_name": create_safe_string(sort_name or "", True, True),
483 "timestamp_added": int(update.date_added.timestamp())
484 if update.date_added
485 else UNSET,
486 },
487 )
488 # update/set provider_mappings table
489 provider_mappings = (
490 update.provider_mappings
491 if overwrite
492 else {*update.provider_mappings, *cur_item.provider_mappings}
493 )
494 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
495 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
496
497 @guard_single_request # type: ignore[type-var] # TODO: fix typing in util.py
498 async def _get_provider_playlist_tracks(
499 self,
500 item_id: str,
501 provider_instance_id_or_domain: str,
502 page: int = 0,
503 force_refresh: bool = False,
504 ) -> list[PlaylistPlayableItem]:
505 """Return playlist tracks for the given provider playlist id."""
506 assert provider_instance_id_or_domain != "library"
507 if not (provider := self.mass.get_provider(provider_instance_id_or_domain)):
508 return []
509 provider = cast("MusicProvider", provider)
510 async with self.mass.cache.handle_refresh(force_refresh):
511 # Builtin provider overrides to return list[PlaylistPlayableItem],
512 # others return list[Track]. Since Track is part of PlaylistPlayableItem union,
513 # this is safe at runtime. Type ignore needed because list is invariant.
514 return await provider.get_playlist_tracks(item_id, page=page) # type: ignore[return-value]
515
516 async def radio_mode_base_tracks(
517 self,
518 item: Playlist,
519 preferred_provider_instances: list[str] | None = None,
520 ) -> list[Track]:
521 """
522 Get the list of base tracks from the controller used to calculate the dynamic radio.
523
524 :param item: The Playlist to get base tracks for.
525 :param preferred_provider_instances: List of preferred provider instance IDs to use.
526 """
527 return [
528 x
529 async for x in self.tracks(item.item_id, item.provider)
530 # Radio mode only works with Tracks (filter out all other types)
531 if isinstance(x, Track) and x.available
532 ]
533
534 async def match_providers(self, db_item: Playlist) -> None:
535 """Try to find match on all (streaming) providers for the provided (database) item.
536
537 This is used to link objects of different providers/qualities together.
538 """
539 # playlists can only be matched on the same provider (if not unique)
540 if self.mass.music.match_provider_instances(db_item):
541 await self.add_provider_mappings(db_item.item_id, db_item.provider_mappings)
542
543 def _refresh_playlist_tracks(self, playlist: Playlist) -> None:
544 """Refresh playlist tracks by forcing a cache refresh."""
545
546 async def _refresh(playlist: Playlist) -> None:
547 # simply iterate all tracks with force_refresh=True to refresh the cache
548 async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True):
549 pass
550
551 task_id = f"refresh_playlist_tracks_{playlist.item_id}"
552 self.mass.call_later(5, _refresh, playlist, task_id=task_id) # debounce multiple calls
553