music-assistant-server

24.6 KBPY
playlists.py
24.6 KB553 lines • python
1"""Manage MediaItems of type Playlist."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, cast
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import (
10    InvalidDataError,
11    InvalidProviderURI,
12    MediaNotFoundError,
13    ProviderUnavailableError,
14)
15from music_assistant_models.media_items import (
16    Playlist,
17    Track,
18)
19
20from music_assistant.constants import (
21    DB_TABLE_PLAYLISTS,
22    PLAYLIST_MEDIA_TYPES,
23    PlaylistPlayableItem,
24)
25from music_assistant.helpers.compare import create_safe_string
26from music_assistant.helpers.database import UNSET
27from music_assistant.helpers.json import serialize_to_json
28from music_assistant.helpers.security import is_safe_name
29from music_assistant.helpers.uri import create_uri, parse_uri
30from music_assistant.helpers.util import guard_single_request
31from music_assistant.models.music_provider import MusicProvider
32
33from .base import MediaControllerBase
34
35if TYPE_CHECKING:
36    from music_assistant import MusicAssistant
37
38
39class PlaylistController(MediaControllerBase[Playlist]):
40    """Controller managing MediaItems of type Playlist."""
41
42    db_table = DB_TABLE_PLAYLISTS
43    media_type = MediaType.PLAYLIST
44    item_cls = Playlist
45
46    def __init__(self, mass: MusicAssistant) -> None:
47        """Initialize class."""
48        super().__init__(mass)
49        # register (extra) api handlers
50        api_base = self.api_base
51        self.mass.register_api_command(f"music/{api_base}/create_playlist", self.create_playlist)
52        self.mass.register_api_command("music/playlists/playlist_tracks", self.tracks)
53        self.mass.register_api_command(
54            "music/playlists/add_playlist_tracks", self.add_playlist_tracks
55        )
56        self.mass.register_api_command(
57            "music/playlists/remove_playlist_tracks", self.remove_playlist_tracks
58        )
59
60    def _verify_update_allowed(self, current_item: Playlist, update: Playlist) -> None:
61        """Verify that the update is allowed from a security perspective.
62
63        Prevents updating item_id for non-streaming providers to prevent path traversal attacks.
64        """
65        # Build lookup dict of current mappings: provider_instance -> item_id
66        current_mappings = {
67            mapping.provider_instance: mapping.item_id for mapping in current_item.provider_mappings
68        }
69
70        # Check if any existing mapping's item_id has been modified for non-streaming providers
71        for update_mapping in update.provider_mappings:
72            # Only check if this is an existing mapping being modified
73            if update_mapping.provider_instance in current_mappings:
74                current_item_id = current_mappings[update_mapping.provider_instance]
75
76                # Disallow item_id changes for filesystem-based providers (filesystem, builtin)
77                if (
78                    current_item_id != update_mapping.item_id
79                    and update_mapping.provider_instance.startswith(("filesystem", "builtin"))
80                ):
81                    msg = (
82                        f"Updating item_id is not allowed for filesystem-based providers: "
83                        f"attempted to change '{current_item_id}' to '{update_mapping.item_id}'"
84                    )
85                    raise InvalidDataError(msg)
86
87    async def tracks(
88        self,
89        item_id: str,
90        provider_instance_id_or_domain: str,
91        force_refresh: bool = False,
92    ) -> AsyncGenerator[PlaylistPlayableItem, None]:
93        """Return playlist tracks for the given provider playlist id."""
94        if provider_instance_id_or_domain == "library":
95            library_item = await self.get_library_item(item_id)
96            provider_instance_id_or_domain, item_id = self._select_provider_id(library_item)
97        # playlist tracks are not stored in the db,
98        # we always fetched them (cached) from the provider
99        page = 0
100        while True:
101            tracks = await self._get_provider_playlist_tracks(
102                item_id,
103                provider_instance_id_or_domain,
104                page=page,
105                force_refresh=force_refresh,
106            )
107            if not tracks:
108                break
109            for track in tracks:
110                yield track
111            page += 1
112
113    async def create_playlist(
114        self, name: str, provider_instance_or_domain: str | None = None
115    ) -> Playlist:
116        """Create new playlist."""
117        # if provider is omitted, just pick builtin provider
118        if provider_instance_or_domain:
119            provider = self.mass.get_provider(provider_instance_or_domain)
120            if provider is None:
121                raise ProviderUnavailableError
122        else:
123            provider = self.mass.get_provider("builtin")
124        # grab all existing track ids in the playlist so we can check for duplicates
125        provider = cast("MusicProvider", provider)
126
127        if not is_safe_name(name):
128            msg = f"{name} is not a valid Playlist name"
129            raise InvalidDataError(msg)
130        # create playlist on the provider
131        playlist = await provider.create_playlist(name)
132        for prov_mapping in playlist.provider_mappings:
133            # when manually creating a playlist, it's always in the library
134            prov_mapping.in_library = True
135        # add the new playlist to the library
136        return await self.add_item_to_library(playlist, False)
137
138    async def add_playlist_tracks(self, db_playlist_id: str | int, uris: list[str]) -> None:
139        """Add tracks to playlist."""
140        # ruff: noqa: PLR0915
141        db_id = int(db_playlist_id)  # ensure integer
142        playlist = await self.get_library_item(db_id)
143        if not playlist:
144            msg = f"Playlist with id {db_id} not found"
145            raise MediaNotFoundError(msg)
146        if not playlist.is_editable:
147            msg = f"Playlist {playlist.name} is not editable"
148            raise InvalidDataError(msg)
149        # Validate uris to prevent code injection
150        for uri in uris:
151            # Prevent code injection via newlines in URIs
152            if "\n" in uri or "\r" in uri:
153                msg = "Invalid URI: newlines not allowed"
154                raise InvalidProviderURI(msg)
155            await parse_uri(uri)
156        # grab all existing track ids in the playlist so we can check for duplicates
157        # use _select_provider_id to respect user's provider filter
158        playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
159        playlist_prov = self.mass.get_provider(playlist_prov_instance)
160        if not playlist_prov or not playlist_prov.available:
161            raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
162        playlist_prov = cast("MusicProvider", playlist_prov)
163
164        # sets to track existing tracks
165        cur_playlist_track_ids: set[str] = set()
166        cur_playlist_track_uris: set[str] = set()
167
168        # collect current track IDs and URIs
169        async for item in self.tracks(playlist.item_id, playlist.provider):
170            if item.item_id:
171                cur_playlist_track_ids.add(item.item_id)
172            if item.uri:
173                cur_playlist_track_uris.add(item.uri)
174
175        # unwrap URIs to individual track URIs
176        unwrapped_uris: list[str] = []
177        for uri in uris:
178            # URI could be a playlist or album uri, unwrap it
179            if not ("://" in uri and len(uri.split("/")) >= 4):
180                # NOT a music assistant-style uri (provider://media_type/item_id)
181                self.logger.warning(
182                    "Not adding %s to playlist %s - not a valid uri", uri, playlist.name
183                )
184                continue
185            # music assistant-style uri
186            # provider://media_type/item_id
187            provider_instance_id_or_domain, rest = uri.split("://", 1)
188            media_type_str, item_id = rest.split("/", 1)
189            media_type = MediaType(media_type_str)
190            if media_type == MediaType.ALBUM:
191                album_tracks = await self.mass.music.albums.tracks(
192                    item_id, provider_instance_id_or_domain
193                )
194                for track in album_tracks:
195                    if track.uri is not None:
196                        unwrapped_uris.append(track.uri)
197            elif media_type == MediaType.PLAYLIST:
198                async for item in self.tracks(item_id, provider_instance_id_or_domain):
199                    if item.uri is not None:
200                        unwrapped_uris.append(item.uri)
201            elif media_type in PLAYLIST_MEDIA_TYPES:
202                unwrapped_uris.append(uri)
203            else:
204                self.logger.warning(
205                    "Not adding %s to playlist %s - media type not supported in playlists",
206                    uri,
207                    playlist.name,
208                )
209                continue
210
211        # work out the track id's that need to be added
212        # filter out duplicates and items that not exist on the provider.
213        ids_to_add: list[str] = []
214        for uri in unwrapped_uris:
215            # skip if item already in the playlist
216            if uri in cur_playlist_track_uris:
217                self.logger.info(
218                    "Not adding %s to playlist %s - it already exists",
219                    uri,
220                    playlist.name,
221                )
222                continue
223
224            # parse uri for further processing
225            media_type, provider_instance_id_or_domain, item_id = await parse_uri(uri)
226
227            # non-track items can only be added to builtin playlists
228            if media_type != MediaType.TRACK and playlist_prov.domain != "builtin":
229                self.logger.warning(
230                    "Not adding %s to playlist %s - only supported in builtin playlists",
231                    uri,
232                    playlist.name,
233                )
234                continue
235
236            # skip if item already in the playlist
237            if item_id in cur_playlist_track_ids:
238                self.logger.warning(
239                    "Not adding %s to playlist %s - it already exists",
240                    uri,
241                    playlist.name,
242                )
243                continue
244
245            # special: the builtin provider can handle uri's from all providers (with uri as id)
246            if playlist_prov.domain == "builtin":
247                # For non-library URIs, add directly (they're already portable provider URIs)
248                if provider_instance_id_or_domain != "library":
249                    if uri not in ids_to_add:
250                        ids_to_add.append(uri)
251                    self.logger.info(
252                        "Adding %s to playlist %s",
253                        uri,
254                        playlist.name,
255                    )
256                    continue
257                # For library URIs, convert to provider URIs to survive DB rebuilds
258                # Get the full item from library to access all provider mappings
259                full_item = await self.mass.music.get_item_by_uri(uri)
260                if not hasattr(full_item, "provider_mappings"):
261                    self.logger.warning(
262                        "Can't add %s to playlist %s - unsupported media type",
263                        uri,
264                        playlist.name,
265                    )
266                    continue
267
268                # For tracks, try to match to playlist provider
269                # For non-track items, just use first available mapping
270                provider_mappings = full_item.provider_mappings
271                if media_type == MediaType.TRACK:
272                    # Cast to Track for mypy - we know it's a track from media_type check
273                    full_track = cast("Track", full_item)
274                    # Try to match the track to additional providers
275                    track_prov_domains = {x.provider_domain for x in provider_mappings}
276                    if (
277                        playlist_prov.is_streaming_provider
278                        and playlist_prov.domain not in track_prov_domains
279                    ):
280                        provider_mappings.update(
281                            await self.mass.music.tracks.match_provider(
282                                full_track, playlist_prov, strict=False
283                            )
284                        )
285
286                # Sort by quality (highest first) for deterministic selection
287                provider_mappings = sorted(provider_mappings, key=lambda x: x.quality, reverse=True)
288
289                # Add first available provider mapping
290                for prov_mapping in provider_mappings:
291                    if not prov_mapping.available:
292                        continue
293                    item_prov = self.mass.get_provider(prov_mapping.provider_instance)
294                    if not item_prov:
295                        continue
296                    # Create provider URI from the mapping
297                    provider_uri = create_uri(
298                        media_type,
299                        item_prov.instance_id,
300                        prov_mapping.item_id,
301                    )
302                    if (
303                        provider_uri not in ids_to_add
304                        and provider_uri not in cur_playlist_track_uris
305                    ):
306                        ids_to_add.append(provider_uri)
307                        self.logger.info(
308                            "Adding %s to playlist %s",
309                            provider_uri,
310                            playlist.name,
311                        )
312                    break
313                else:
314                    self.logger.warning(
315                        "Can't add %s to playlist %s - no available provider mapping",
316                        uri,
317                        playlist.name,
318                    )
319                continue
320
321            # if target playlist is an exact provider match, we can add it
322            if provider_instance_id_or_domain != "library":
323                item_prov = self.mass.get_provider(provider_instance_id_or_domain)
324                if not item_prov or not item_prov.available:
325                    self.logger.warning(
326                        "Skip adding %s to playlist: Provider %s is not available",
327                        uri,
328                        provider_instance_id_or_domain,
329                    )
330                    continue
331                if item_prov.instance_id == playlist_prov.instance_id:
332                    if item_id not in ids_to_add:
333                        ids_to_add.append(item_id)
334                    continue
335
336            # For provider-specific playlists: match tracks with quality sorting
337            # (Non-track items can only be added to builtin playlists, validated earlier)
338            full_track = await self.mass.music.tracks.get(
339                item_id,
340                provider_instance_id_or_domain,
341                recursive=provider_instance_id_or_domain != "library",
342            )
343            track_prov_domains = {x.provider_domain for x in full_track.provider_mappings}
344            if (
345                playlist_prov.domain != "builtin"
346                and playlist_prov.is_streaming_provider
347                and playlist_prov.domain not in track_prov_domains
348            ):
349                # try to match the track to the playlist provider
350                full_track.provider_mappings.update(
351                    await self.mass.music.tracks.match_provider(
352                        full_track, playlist_prov, strict=False
353                    )
354                )
355
356            # a track can contain multiple versions on the same provider
357            # simply sort by quality and just add the first available version
358            for track_version in sorted(
359                full_track.provider_mappings, key=lambda x: x.quality, reverse=True
360            ):
361                if not track_version.available:
362                    continue
363                if track_version.item_id in cur_playlist_track_ids:
364                    break  # already existing in the playlist
365                item_prov = self.mass.get_provider(track_version.provider_instance)
366                if not item_prov:
367                    continue
368                track_version_uri = create_uri(
369                    MediaType.TRACK,
370                    item_prov.instance_id,
371                    track_version.item_id,
372                )
373                if track_version_uri in cur_playlist_track_uris:
374                    self.logger.warning(
375                        "Not adding %s to playlist %s - it already exists",
376                        full_track.name,
377                        playlist.name,
378                    )
379                    break  # already existing in the playlist
380                # Add track to provider-specific playlist
381                if item_prov.instance_id == playlist_prov.instance_id:
382                    if track_version.item_id not in ids_to_add:
383                        ids_to_add.append(track_version.item_id)
384                    self.logger.info(
385                        "Adding %s to playlist %s",
386                        full_track.name,
387                        playlist.name,
388                    )
389                    break
390            else:
391                self.logger.warning(
392                    "Can't add %s to playlist %s - it is not available on provider %s",
393                    full_track.name,
394                    playlist.name,
395                    playlist_prov.name,
396                )
397
398        if not ids_to_add:
399            return
400
401        # actually add the tracks to the playlist on the provider
402        await playlist_prov.add_playlist_tracks(playlist_prov_item_id, ids_to_add)
403        # invalidate cache so tracks get refreshed
404        self._refresh_playlist_tracks(playlist)
405        await self.update_item_in_library(db_playlist_id, playlist)
406
407    async def add_playlist_track(self, db_playlist_id: str | int, track_uri: str) -> None:
408        """Add (single) track to playlist."""
409        await self.add_playlist_tracks(db_playlist_id, [track_uri])
410
411    async def remove_playlist_tracks(
412        self, db_playlist_id: str | int, positions_to_remove: tuple[int, ...]
413    ) -> None:
414        """Remove multiple tracks from playlist."""
415        db_id = int(db_playlist_id)  # ensure integer
416        playlist = await self.get_library_item(db_id)
417        if not playlist:
418            msg = f"Playlist with id {db_id} not found"
419            raise MediaNotFoundError(msg)
420        if not playlist.is_editable:
421            msg = f"Playlist {playlist.name} is not editable"
422            raise InvalidDataError(msg)
423        # use _select_provider_id to respect user's provider filter
424        playlist_prov_instance, playlist_prov_item_id = self._select_provider_id(playlist)
425        provider = self.mass.get_provider(playlist_prov_instance)
426        if not provider or not isinstance(provider, MusicProvider):
427            raise ProviderUnavailableError(f"Provider {playlist_prov_instance} is not available")
428        if ProviderFeature.PLAYLIST_TRACKS_EDIT not in provider.supported_features:
429            msg = f"Provider {provider.name} does not support editing playlists"
430            raise InvalidDataError(msg)
431        await provider.remove_playlist_tracks(playlist_prov_item_id, positions_to_remove)
432
433        await self.update_item_in_library(db_playlist_id, playlist)
434
435    async def _add_library_item(self, item: Playlist, overwrite_existing: bool = False) -> int:
436        """Add a new record to the database."""
437        db_id = await self.mass.music.database.insert(
438            self.db_table,
439            {
440                "name": item.name,
441                "sort_name": item.sort_name,
442                "owner": item.owner,
443                "is_editable": item.is_editable,
444                "favorite": item.favorite,
445                "metadata": serialize_to_json(item.metadata),
446                "external_ids": serialize_to_json(item.external_ids),
447                "search_name": create_safe_string(item.name, True, True),
448                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
449                "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
450            },
451        )
452        # update/set provider_mappings table
453        await self.set_provider_mappings(db_id, item.provider_mappings)
454        self.logger.debug("added %s to database (id: %s)", item.name, db_id)
455        return db_id
456
457    async def _update_library_item(
458        self, item_id: str | int, update: Playlist, overwrite: bool = False
459    ) -> None:
460        """Update existing record in the database."""
461        db_id = int(item_id)  # ensure integer
462        cur_item = await self.get_library_item(db_id)
463        self._verify_update_allowed(cur_item, update)
464        metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
465        cur_item.external_ids.update(update.external_ids)
466        name = update.name if overwrite else cur_item.name
467        sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
468        await self.mass.music.database.update(
469            self.db_table,
470            {"item_id": db_id},
471            {
472                # always prefer name/owner from updated item here
473                "name": name,
474                "sort_name": sort_name,
475                "owner": update.owner or cur_item.owner,
476                "is_editable": update.is_editable,
477                "metadata": serialize_to_json(metadata),
478                "external_ids": serialize_to_json(
479                    update.external_ids if overwrite else cur_item.external_ids
480                ),
481                "search_name": create_safe_string(name, True, True),
482                "search_sort_name": create_safe_string(sort_name or "", True, True),
483                "timestamp_added": int(update.date_added.timestamp())
484                if update.date_added
485                else UNSET,
486            },
487        )
488        # update/set provider_mappings table
489        provider_mappings = (
490            update.provider_mappings
491            if overwrite
492            else {*update.provider_mappings, *cur_item.provider_mappings}
493        )
494        await self.set_provider_mappings(db_id, provider_mappings, overwrite)
495        self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
496
497    @guard_single_request  # type: ignore[type-var]  # TODO: fix typing in util.py
498    async def _get_provider_playlist_tracks(
499        self,
500        item_id: str,
501        provider_instance_id_or_domain: str,
502        page: int = 0,
503        force_refresh: bool = False,
504    ) -> list[PlaylistPlayableItem]:
505        """Return playlist tracks for the given provider playlist id."""
506        assert provider_instance_id_or_domain != "library"
507        if not (provider := self.mass.get_provider(provider_instance_id_or_domain)):
508            return []
509        provider = cast("MusicProvider", provider)
510        async with self.mass.cache.handle_refresh(force_refresh):
511            # Builtin provider overrides to return list[PlaylistPlayableItem],
512            # others return list[Track]. Since Track is part of PlaylistPlayableItem union,
513            # this is safe at runtime. Type ignore needed because list is invariant.
514            return await provider.get_playlist_tracks(item_id, page=page)  # type: ignore[return-value]
515
516    async def radio_mode_base_tracks(
517        self,
518        item: Playlist,
519        preferred_provider_instances: list[str] | None = None,
520    ) -> list[Track]:
521        """
522        Get the list of base tracks from the controller used to calculate the dynamic radio.
523
524        :param item: The Playlist to get base tracks for.
525        :param preferred_provider_instances: List of preferred provider instance IDs to use.
526        """
527        return [
528            x
529            async for x in self.tracks(item.item_id, item.provider)
530            # Radio mode only works with Tracks (filter out all other types)
531            if isinstance(x, Track) and x.available
532        ]
533
534    async def match_providers(self, db_item: Playlist) -> None:
535        """Try to find match on all (streaming) providers for the provided (database) item.
536
537        This is used to link objects of different providers/qualities together.
538        """
539        # playlists can only be matched on the same provider (if not unique)
540        if self.mass.music.match_provider_instances(db_item):
541            await self.add_provider_mappings(db_item.item_id, db_item.provider_mappings)
542
543    def _refresh_playlist_tracks(self, playlist: Playlist) -> None:
544        """Refresh playlist tracks by forcing a cache refresh."""
545
546        async def _refresh(playlist: Playlist) -> None:
547            # simply iterate all tracks with force_refresh=True to refresh the cache
548            async for _ in self.tracks(playlist.item_id, playlist.provider, force_refresh=True):
549                pass
550
551        task_id = f"refresh_playlist_tracks_{playlist.item_id}"
552        self.mass.call_later(5, _refresh, playlist, task_id=task_id)  # debounce multiple calls
553