/
/
/
1"""Manage MediaItems of type Playlist."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, cast
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import (
10 InvalidDataError,
11 InvalidProviderURI,
12 MediaNotFoundError,
13 ProviderUnavailableError,
14)
15from music_assistant_models.media_items import Playlist, Track
16
17from music_assistant.constants import DB_TABLE_PLAYLISTS
18from music_assistant.helpers.compare import create_safe_string
19from music_assistant.helpers.database import UNSET
20from music_assistant.helpers.json import serialize_to_json
21from music_assistant.helpers.security import is_safe_name
22from music_assistant.helpers.uri import create_uri, parse_uri
23from music_assistant.helpers.util import guard_single_request
24from music_assistant.models.music_provider import MusicProvider
25
26from .base import MediaControllerBase
27
28if TYPE_CHECKING:
29 from music_assistant import MusicAssistant
30
31
32class PlaylistController(MediaControllerBase[Playlist]):
33 """Controller managing MediaItems of type Playlist."""
34
35 db_table = DB_TABLE_PLAYLISTS
36 media_type = MediaType.PLAYLIST
37 item_cls = Playlist
38
39 def __init__(self, mass: MusicAssistant) -> None:
40 """Initialize class."""
41 super().__init__(mass)
42 # register (extra) api handlers
43 api_base = self.api_base
44 self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist)
45 self.mass.register_api_command("music/playlists/playlist_tracks", self.tracks)
46 self.mass.register_api_command(
47 "music/playlists/add_playlist_tracks", self.add_playlist_tracks
48 )
49 self.mass.register_api_command(
50 "music/playlists/remove_playlist_tracks", self.remove_playlist_tracks
51 )
52
53 def _verify_update_allowed(self, current_item: Playlist, update: Playlist) -> None:
54 """Verify that the update is allowed from a security perspective.
55
56 Prevents updating item_id for non-streaming providers to prevent path traversal attacks.
57 """
58 # Build lookup dict of current mappings: provider_instance -> item_id
59 current_mappings = {
60 mapping.provider_instance: mapping.item_id for mapping in current_item.provider_mappings
61 }
62
63 # Check if any existing mapping's item_id has been modified for non-streaming providers
64 for update_mapping in update.provider_mappings:
65 # Only check if this is an existing mapping being modified
66 if update_mapping.provider_instance in current_mappings:
67 current_item_id = current_mappings[update_mapping.provider_instance]
68
69 # Disallow item_id changes for filesystem-based providers (filesystem, builtin)
70 if (
71 current_item_id != update_mapping.item_id
72 and update_mapping.provider_instance.startswith(("filesystem", "builtin"))
73 ):
74 msg = (
75 f"Updating item_id is not allowed for filesystem-based providers: "
76 f"attempted to change '{current_item_id}' to '{update_mapping.item_id}'"
77 )
78 raise InvalidDataError(msg)
79
80 async def tracks(
81 self,
82 item_id: str,
83 provider_instance_id_or_domain: str,
84 force_refresh: bool = False,
85 ) -> AsyncGenerator[Track, None]:
86 """Return playlist tracks for the given provider playlist id."""
87 if provider_instance_id_or_domain == "library":
88 library_item = await self.get_library_item(item_id)
89 provider_instance_id_or_domain, item_id = self._select_provider_id(library_item)
90 # playlist tracks are not stored in the db,
91 # we always fetched them (cached) from the provider
92 page = 0
93 while True:
94 tracks = await self._get_provider_playlist_tracks(
95 item_id,
96 provider_instance_id_or_domain,
97 page=page,
98 force_refresh=force_refresh,
99 )
100 if not tracks:
101 break
102 for track in tracks:
103 yield track
104 page += 1
105
106 async def create_playlist(
107 self, name: str, provider_instance_or_domain: str | None = None
108 ) -> Playlist:
109 """Create new playlist."""
110 # if provider is omitted, just pick builtin provider
111 if provider_instance_or_domain:
112 provider = self.mass.get_provider(provider_instance_or_domain)
113 if provider is None:
114 raise ProviderUnavailableError
115 else:
116 provider = self.mass.get_provider("builtin")
117 # grab all existing track ids in the playlist so we can check for duplicates
118 provider = cast("MusicProvider", provider)
119
120 if not is_safe_name(name):
121 msg = f"{name} is not a valid Playlist name"
122 raise InvalidDataError(msg)
123 # create playlist on the provider
124 playlist = await provider.create_playlist(name)
125 for prov_mapping in playlist.provider_mappings:
126 # when manually creating a playlist, it's always in the library
127 prov_mapping.in_library = True
128 # add the new playlist to the library
129 return await self.add_item_to_library(playlist, False)
130
131 async def add_playlist_tracks(self, db_playlist_id: str | int, uris: list[str]) -> None:
132 """Add tracks to playlist."""
133 # ruff: noqa: PLR0915
134 db_id = int(db_playlist_id) # ensure integer
135 playlist = await self.get_library_item(db_id)
136 if not playlist:
137 msg = f"Playlist with id {db_id} not found"
138 raise MediaNotFoundError(msg)
139 if not playlist.is_editable:
140 msg = f"Playlist {playlist.name} is not editable"
141 raise InvalidDataError(msg)
142 # Validate uris to prevent code injection
143 for uri in uris:
144 # Prevent code injection via newlines in URIs
145 if "\n" in uri or "\r" in uri:
146 msg = "Invalid URI: newlines not allowed"
147 raise InvalidProviderURI(msg)
148 await parse_uri(uri)
149 # grab all existing track ids in the playlist so we can check for duplicates
150 # use _select_provider_id to respect user's provider filter
151 playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
152 playlist_prov = self.mass.get_provider(playlist_prov_instance)
153 if not playlist_prov or not playlist_prov.available:
154 raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
155 playlist_prov = cast("MusicProvider", playlist_prov)
156
157 # sets to track existing tracks
158 cur_playlist_track_ids: set[str] = set()
159 cur_playlist_track_uris: set[str] = set()
160
161 # collect current track IDs and URIs
162 async for item in self.tracks(playlist.item_id, playlist.provider):
163 if item.item_id:
164 cur_playlist_track_ids.add(item.item_id)
165 if item.uri:
166 cur_playlist_track_uris.add(item.uri)
167
168 # unwrap URIs to individual track URIs
169 unwrapped_uris: list[str] = []
170 for uri in uris:
171 # URI could be a playlist or album uri, unwrap it
172 if not ("://" in uri and len(uri.split("/")) >= 4):
173 # NOT a music assistant-style uri (provider://media_type/item_id)
174 self.logger.warning(
175 "Not adding %s to playlist %s - not a valid uri", uri, playlist.name
176 )
177 continue
178 # music assistant-style uri
179 # provider://media_type/item_id
180 provider_instance_id_or_domain, rest = uri.split("://", 1)
181 media_type_str, item_id = rest.split("/", 1)
182 media_type = MediaType(media_type_str)
183 if media_type == MediaType.ALBUM:
184 album_tracks = await self.mass.music.albums.tracks(
185 item_id, provider_instance_id_or_domain
186 )
187 for track in album_tracks:
188 if track.uri is not None:
189 unwrapped_uris.append(track.uri)
190 elif media_type == MediaType.PLAYLIST:
191 async for track in self.tracks(item_id, provider_instance_id_or_domain):
192 if track.uri is not None:
193 unwrapped_uris.append(track.uri)
194 elif media_type == MediaType.TRACK:
195 unwrapped_uris.append(uri)
196 else:
197 self.logger.warning(
198 "Not adding %s to playlist %s - not a track", uri, playlist.name
199 )
200 continue
201
202 # work out the track id's that need to be added
203 # filter out duplicates and items that not exist on the provider.
204 ids_to_add: list[str] = []
205 for uri in unwrapped_uris:
206 # skip if item already in the playlist
207 if uri in cur_playlist_track_uris:
208 self.logger.info(
209 "Not adding %s to playlist %s - it already exists",
210 uri,
211 playlist.name,
212 )
213 continue
214
215 # parse uri for further processing
216 media_type, provider_instance_id_or_domain, item_id = await parse_uri(uri)
217
218 # skip if item already in the playlist
219 if item_id in cur_playlist_track_ids:
220 self.logger.warning(
221 "Not adding %s to playlist %s - it already exists",
222 uri,
223 playlist.name,
224 )
225 continue
226
227 # special: the builtin provider can handle uri's from all providers (with uri as id)
228 if provider_instance_id_or_domain != "library" and playlist_prov.domain == "builtin":
229 # note: we try not to add library uri's to the builtin playlists
230 # so we can survive db rebuilds
231 if uri not in ids_to_add:
232 ids_to_add.append(uri)
233 self.logger.info(
234 "Adding %s to playlist %s",
235 uri,
236 playlist.name,
237 )
238 continue
239
240 # if target playlist is an exact provider match, we can add it
241 if provider_instance_id_or_domain != "library":
242 item_prov = self.mass.get_provider(provider_instance_id_or_domain)
243 if not item_prov or not item_prov.available:
244 self.logger.warning(
245 "Skip adding %s to playlist: Provider %s is not available",
246 uri,
247 provider_instance_id_or_domain,
248 )
249 continue
250 if item_prov.instance_id == playlist_prov.instance_id:
251 if item_id not in ids_to_add:
252 ids_to_add.append(item_id)
253 continue
254
255 # ensure we have a full (library) track (including all provider mappings)
256 full_track = await self.mass.music.tracks.get(
257 item_id,
258 provider_instance_id_or_domain,
259 recursive=provider_instance_id_or_domain != "library",
260 )
261 track_prov_domains = {x.provider_domain for x in full_track.provider_mappings}
262 if (
263 playlist_prov.domain != "builtin"
264 and playlist_prov.is_streaming_provider
265 and playlist_prov.domain not in track_prov_domains
266 ):
267 # try to match the track to the playlist provider
268 full_track.provider_mappings.update(
269 await self.mass.music.tracks.match_provider(
270 full_track, playlist_prov, strict=False
271 )
272 )
273
274 # a track can contain multiple versions on the same provider
275 # simply sort by quality and just add the first available version
276 for track_version in sorted(
277 full_track.provider_mappings, key=lambda x: x.quality, reverse=True
278 ):
279 if not track_version.available:
280 continue
281 if track_version.item_id in cur_playlist_track_ids:
282 break # already existing in the playlist
283 item_prov = self.mass.get_provider(track_version.provider_instance)
284 if not item_prov:
285 continue
286 track_version_uri = create_uri(
287 MediaType.TRACK,
288 item_prov.instance_id,
289 track_version.item_id,
290 )
291 if track_version_uri in cur_playlist_track_uris:
292 self.logger.warning(
293 "Not adding %s to playlist %s - it already exists",
294 full_track.name,
295 playlist.name,
296 )
297 break # already existing in the playlist
298 if playlist_prov.domain == "builtin":
299 # the builtin provider can handle uri's from all providers (with uri as id)
300 if track_version_uri not in ids_to_add:
301 ids_to_add.append(track_version_uri)
302 self.logger.info(
303 "Adding %s to playlist %s",
304 full_track.name,
305 playlist.name,
306 )
307 break
308 if item_prov.instance_id == playlist_prov.instance_id:
309 if track_version.item_id not in ids_to_add:
310 ids_to_add.append(track_version.item_id)
311 self.logger.info(
312 "Adding %s to playlist %s",
313 full_track.name,
314 playlist.name,
315 )
316 break
317 else:
318 self.logger.warning(
319 "Can't add %s to playlist %s - it is not available on provider %s",
320 full_track.name,
321 playlist.name,
322 playlist_prov.name,
323 )
324
325 if not ids_to_add:
326 return
327
328 # actually add the tracks to the playlist on the provider
329 await playlist_prov.add_playlist_tracks(playlist_prov_item_id, ids_to_add)
330 # invalidate cache so tracks get refreshed
331 self._refresh_playlist_tracks(playlist)
332 await self.update_item_in_library(db_playlist_id, playlist)
333
334 async def add_playlist_track(self, db_playlist_id: str | int, track_uri: str) -> None:
335 """Add (single) track to playlist."""
336 await self.add_playlist_tracks(db_playlist_id, [track_uri])
337
338 async def remove_playlist_tracks(
339 self, db_playlist_id: str | int, positions_to_remove: tuple[int, ...]
340 ) -> None:
341 """Remove multiple tracks from playlist."""
342 db_id = int(db_playlist_id) # ensure integer
343 playlist = await self.get_library_item(db_id)
344 if not playlist:
345 msg = f"Playlist with id {db_id} not found"
346 raise MediaNotFoundError(msg)
347 if not playlist.is_editable:
348 msg = f"Playlist {playlist.name} is not editable"
349 raise InvalidDataError(msg)
350 # use _select_provider_id to respect user's provider filter
351 playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
352 provider = self.mass.get_provider(playlist_prov_instance)
353 if not provider or not isinstance(provider, MusicProvider):
354 raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
355 if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
356 msg = f"Provider {provider.name} does not support editing playlists"
357 raise InvalidDataError(msg)
358 await provider.remove_playlist_tracks(playlist_prov_item_id, positions_to_remove)
359
360 await self.update_item_in_library(db_playlist_id, playlist)
361
362 async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int:
363 """Add a new record to the database."""
364 db_id = await self.mass.music.database.insert(
365 self.db_table,
366 {
367 "name": item.name,
368 "sort_name": item.sort_name,
369 "owner": item.owner,
370 "is_editable": item.is_editable,
371 "favorite": item.favorite,
372 "metadata": serialize_to_json(item.metadata),
373 "external_ids": serialize_to_json(item.external_ids),
374 "search_name": create_safe_string(item.name, True, True),
375 "search_sort_name": create_safe_string(item.sort_name or "", True, True),
376 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
377 },
378 )
379 # update/set provider_mappings table
380 await self.set_provider_mappings(db_id, item.provider_mappings)
381 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
382 return db_id
383
384 async def _update_library_item(
385 self, item_id: str | int, update: Playlist, overwrite: bool = False
386 ) -> None:
387 """Update existing record in the database."""
388 db_id = int(item_id) # ensure integer
389 cur_item = await self.get_library_item(db_id)
390 self._verify_update_allowed(cur_item, update)
391 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
392 cur_item.external_ids.update(update.external_ids)
393 name = update.name if overwrite else cur_item.name
394 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
395 await self.mass.music.database.update(
396 self.db_table,
397 {"item_id": db_id},
398 {
399 # always prefer name/owner from updated item here
400 "name": name,
401 "sort_name": sort_name,
402 "owner": update.owner or cur_item.owner,
403 "is_editable": update.is_editable,
404 "metadata": serialize_to_json(metadata),
405 "external_ids": serialize_to_json(
406 update.external_ids if overwrite else cur_item.external_ids
407 ),
408 "search_name": create_safe_string(name, True, True),
409 "search_sort_name": create_safe_string(sort_name or "", True, True),
410 "timestamp_added": int(update.date_added.timestamp())
411 if update.date_added
412 else UNSET,
413 },
414 )
415 # update/set provider_mappings table
416 provider_mappings = (
417 update.provider_mappings
418 if overwrite
419 else {*update.provider_mappings, *cur_item.provider_mappings}
420 )
421 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
422 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
423
424 @guard_single_request # type: ignore[type-var] # TODO: fix typing in util.py
425 async def _get_provider_playlist_tracks(
426 self,
427 item_id: str,
428 provider_instance_id_or_domain: str,
429 page: int = 0,
430 force_refresh: bool = False,
431 ) -> list[Track]:
432 """Return playlist tracks for the given provider playlist id."""
433 assert provider_instance_id_or_domain != "library"
434 if not (provider := self.mass.get_provider(provider_instance_id_or_domain)):
435 return []
436 provider = cast("MusicProvider", provider)
437 async with self.mass.cache.handle_refresh(force_refresh):
438 return await provider.get_playlist_tracks(item_id, page=page)
439
440 async def radio_mode_base_tracks(
441 self,
442 item: Playlist,
443 preferred_provider_instances: list[str] | None = None,
444 ) -> list[Track]:
445 """
446 Get the list of base tracks from the controller used to calculate the dynamic radio.
447
448 :param item: The Playlist to get base tracks for.
449 :param preferred_provider_instances: List of preferred provider instance IDs to use.
450 """
451 return [
452 x
453 async for x in self.tracks(item.item_id, item.provider)
454 # filter out unavailable tracks
455 if x.available
456 ]
457
458 async def match_providers(self, db_item: Playlist) -> None:
459 """Try to find match on all (streaming) providers for the provided (database) item.
460
461 This is used to link objects of different providers/qualities together.
462 """
463 # playlists can only be matched on the same provider (if not unique)
464 if self.mass.music.match_provider_instances(db_item):
465 await self.add_provider_mappings(db_item.item_id, db_item.provider_mappings)
466
467 def _refresh_playlist_tracks(self, playlist: Playlist) -> None:
468 """Refresh playlist tracks by forcing a cache refresh."""
469
470 async def _refresh(playlist: Playlist) -> None:
471 # simply iterate all tracks with force_refresh=True to refresh the cache
472 async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True):
473 pass
474
475 task_id = f"refresh_playlist_tracks_{playlist.item_id}"
476 self.mass.call_later(5, _refresh, playlist, task_id=task_id) # debounce multiple calls
477