music-assistant-server

27 KBPY
albums.py
27 KB622 lines • python
1"""Manage MediaItems of type Album."""
2
3from __future__ import annotations
4
5import contextlib
6from collections.abc import Iterable
7from typing import TYPE_CHECKING, Any, cast
8
9from music_assistant_models.enums import AlbumType, MediaType, ProviderFeature
10from music_assistant_models.errors import InvalidDataError, MediaNotFoundError, MusicAssistantError
11from music_assistant_models.media_items import (
12    Album,
13    Artist,
14    ItemMapping,
15    MediaItemImage,
16    ProviderMapping,
17    Track,
18    UniqueList,
19)
20
21from music_assistant.constants import DB_TABLE_ALBUM_ARTISTS, DB_TABLE_ALBUM_TRACKS, DB_TABLE_ALBUMS
22from music_assistant.controllers.media.base import MediaControllerBase
23from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
24from music_assistant.helpers.compare import (
25    compare_album,
26    compare_artists,
27    compare_media_item,
28    create_safe_string,
29    loose_compare_strings,
30)
31from music_assistant.helpers.database import UNSET
32from music_assistant.helpers.json import serialize_to_json
33from music_assistant.models.music_provider import MusicProvider
34
35if TYPE_CHECKING:
36    from music_assistant import MusicAssistant
37
38
39class AlbumsController(MediaControllerBase[Album]):
40    """Controller managing MediaItems of type Album."""
41
42    db_table = DB_TABLE_ALBUMS
43    media_type = MediaType.ALBUM
44    item_cls = Album
45
46    def __init__(self, mass: MusicAssistant) -> None:
47        """Initialize class."""
48        super().__init__(mass)
49        self.base_query = """
50        SELECT
51            albums.*,
52            (SELECT JSON_GROUP_ARRAY(
53                json_object(
54                'item_id', album_pm.provider_item_id,
55                    'provider_domain', album_pm.provider_domain,
56                        'provider_instance', album_pm.provider_instance,
57                        'available', album_pm.available,
58                        'audio_format', json(album_pm.audio_format),
59                        'url', album_pm.url,
60                        'details', album_pm.details,
61                        'in_library', album_pm.in_library,
62                        'is_unique', album_pm.is_unique
63                )) FROM provider_mappings album_pm WHERE album_pm.item_id = albums.item_id AND album_pm.media_type = 'album') AS provider_mappings,
64            (SELECT JSON_GROUP_ARRAY(
65                json_object(
66                'item_id', artists.item_id,
67                'provider', 'library',
68                    'name', artists.name,
69                    'sort_name', artists.sort_name,
70                    'media_type', 'artist'
71                )) FROM artists JOIN album_artists on album_artists.album_id = albums.item_id  WHERE artists.item_id = album_artists.artist_id) AS artists
72            FROM albums"""  # noqa: E501
73        # register (extra) api handlers
74        api_base = self.api_base
75        self.mass.register_api_command(f"music/{api_base}/album_tracks", self.tracks)
76        self.mass.register_api_command(f"music/{api_base}/album_versions", self.versions)
77
78    async def get(
79        self,
80        item_id: str,
81        provider_instance_id_or_domain: str,
82        recursive: bool = True,
83    ) -> Album:
84        """Return (full) details for a single media item."""
85        album = await super().get(
86            item_id,
87            provider_instance_id_or_domain,
88        )
89        if not recursive:
90            return album
91
92        # append artist details to full album item (resolve ItemMappings)
93        album_artists: UniqueList[Artist | ItemMapping] = UniqueList()
94        for artist in album.artists:
95            if not isinstance(artist, ItemMapping):
96                album_artists.append(artist)
97                continue
98            with contextlib.suppress(MediaNotFoundError):
99                album_artists.append(
100                    await self.mass.music.artists.get(
101                        artist.item_id,
102                        artist.provider,
103                    )
104                )
105        album.artists = album_artists
106        return album
107
108    async def library_items(
109        self,
110        favorite: bool | None = None,
111        search: str | None = None,
112        limit: int = 500,
113        offset: int = 0,
114        order_by: str = "sort_name",
115        provider: str | list[str] | None = None,
116        genre: int | list[int] | None = None,
117        album_types: list[AlbumType] | None = None,
118        **kwargs: Any,
119    ) -> list[Album]:
120        """Get in-database albums.
121
122        :param favorite: Filter by favorite status.
123        :param search: Filter by search query.
124        :param limit: Maximum number of items to return.
125        :param offset: Number of items to skip.
126        :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
127        :param provider: Filter by provider instance ID (single string or list).
128        :param album_types: Filter by album types.
129        :param genre: Filter by genre id(s).
130        """
131        extra_query_params: dict[str, Any] = {}
132        extra_query_parts: list[str] = []
133        extra_join_parts: list[str] = []
134        artist_table_joined = False
135        # optional album type filter
136        if album_types:
137            extra_query_parts.append("albums.album_type IN :album_types")
138            extra_query_params["album_types"] = [x.value for x in album_types]
139        if order_by and "artist_name" in order_by:
140            # join artist table to allow sorting on artist name
141            extra_join_parts.append(
142                "JOIN album_artists ON album_artists.album_id = albums.item_id "
143                "JOIN artists ON artists.item_id = album_artists.artist_id "
144            )
145            artist_table_joined = True
146        if search and " - " in search:
147            # handle combined artist + title search
148            artist_str, title_str = search.split(" - ", 1)
149            search = None
150            title_str = create_safe_string(title_str, True, True)
151            artist_str = create_safe_string(artist_str, True, True)
152            extra_query_parts.append("albums.search_name LIKE :search_title")
153            extra_query_params["search_title"] = f"%{title_str}%"
154            # use join with artists table to filter on artist name
155            extra_join_parts.append(
156                "JOIN album_artists ON album_artists.album_id = albums.item_id "
157                "JOIN artists ON artists.item_id = album_artists.artist_id "
158                "AND artists.search_name LIKE :search_artist"
159                if not artist_table_joined
160                else "AND artists.search_name LIKE :search_artist"
161            )
162            artist_table_joined = True
163            extra_query_params["search_artist"] = f"%{artist_str}%"
164        result = await self.get_library_items_by_query(
165            favorite=favorite,
166            search=search,
167            genre_ids=genre,
168            limit=limit,
169            offset=offset,
170            order_by=order_by,
171            provider_filter=self._ensure_provider_filter(provider),
172            extra_query_parts=extra_query_parts,
173            extra_query_params=extra_query_params,
174            extra_join_parts=extra_join_parts,
175            in_library_only=True,
176        )
177
178        # Calculate how many more items we need to reach the original limit
179        remaining_limit = limit - len(result)
180
181        if search and len(result) < 25 and not offset and remaining_limit > 0:
182            # append artist items to result
183            search = create_safe_string(search, True, True)
184            extra_join_parts.append(
185                "JOIN album_artists ON album_artists.album_id = albums.item_id "
186                "JOIN artists ON artists.item_id = album_artists.artist_id "
187                "AND artists.search_name LIKE :search_artist"
188                if not artist_table_joined
189                else "AND artists.search_name LIKE :search_artist"
190            )
191            extra_query_params["search_artist"] = f"%{search}%"
192            existing_uris = {item.uri for item in result}
193
194            for album in await self.get_library_items_by_query(
195                favorite=favorite,
196                search=None,
197                limit=remaining_limit,
198                order_by=order_by,
199                provider_filter=self._ensure_provider_filter(provider),
200                extra_query_parts=extra_query_parts,
201                extra_query_params=extra_query_params,
202                extra_join_parts=extra_join_parts,
203                in_library_only=True,
204            ):
205                # prevent duplicates (when artist is also in the title)
206                if album.uri not in existing_uris:
207                    result.append(album)
208                    # Stop if we've reached the original limit
209                    if len(result) >= limit:
210                        break
211        return result
212
213    async def library_count(
214        self, favorite_only: bool = False, album_types: list[AlbumType] | None = None
215    ) -> int:
216        """Return the total number of items in the library."""
217        sql_query = f"SELECT item_id FROM {self.db_table}"
218        query_parts: list[str] = []
219        query_params: dict[str, Any] = {}
220        if favorite_only:
221            query_parts.append("favorite = 1")
222        if album_types:
223            query_parts.append("albums.album_type IN :album_types")
224            query_params["album_types"] = [x.value for x in album_types]
225        if query_parts:
226            sql_query += f" WHERE {' AND '.join(query_parts)}"
227        return await self.mass.music.database.get_count_from_query(sql_query, query_params)
228
229    async def remove_item_from_library(self, item_id: str | int, recursive: bool = True) -> None:
230        """Delete item from the library(database)."""
231        db_id = int(item_id)  # ensure integer
232        # recursively also remove album tracks
233        for db_track in await self.get_library_album_tracks(db_id):
234            if not recursive:
235                raise MusicAssistantError("Album still has tracks linked")
236            with contextlib.suppress(MediaNotFoundError):
237                await self.mass.music.tracks.remove_item_from_library(db_track.item_id)
238        # delete entry(s) from albumtracks table
239        await self.mass.music.database.delete(DB_TABLE_ALBUM_TRACKS, {"album_id": db_id})
240        # delete entry(s) from album artists table
241        await self.mass.music.database.delete(DB_TABLE_ALBUM_ARTISTS, {"album_id": db_id})
242        # delete the album itself from db
243        # this will raise if the item still has references and recursive is false
244        await super().remove_item_from_library(item_id)
245
246    async def tracks(
247        self,
248        item_id: str,
249        provider_instance_id_or_domain: str,
250        in_library_only: bool = False,
251    ) -> list[Track]:
252        """Return album tracks for the given provider album id."""
253        # always check if we have a library item for this album
254        library_album = await self.get_library_item_by_prov_id(
255            item_id, provider_instance_id_or_domain
256        )
257        if not library_album:
258            album_tracks = await self._get_provider_album_tracks(
259                item_id, provider_instance_id_or_domain
260            )
261            if album_tracks and not album_tracks[0].image:
262                # set album image from provider album if not present on tracks
263                prov_album = await self.get_provider_item(item_id, provider_instance_id_or_domain)
264                if prov_album.image:
265                    for track in album_tracks:
266                        if not track.image:
267                            track.metadata.add_image(prov_album.image)
268            return album_tracks
269
270        db_items = await self.get_library_album_tracks(library_album.item_id)
271        result: list[Track] = list(db_items)
272        if in_library_only:
273            # return in-library items only
274            return sorted(db_items, key=lambda x: (x.disc_number, x.track_number))
275
276        # return all (unique) items from all providers
277        # because we are returning the items from all providers combined,
278        # we need to make sure that we don't return duplicates
279        unique_ids: set[str] = {f"{x.disc_number}.{x.track_number}" for x in db_items}
280        unique_ids.update({f"{x.name.lower()}.{x.version.lower()}" for x in db_items})
281        for db_item in db_items:
282            unique_ids.update(x.item_id for x in db_item.provider_mappings)
283        user = get_current_user()
284        user_provider_filter = user.provider_filter if user and user.provider_filter else None
285        for provider_mapping in library_album.provider_mappings:
286            if (
287                user_provider_filter
288                and provider_mapping.provider_instance not in user_provider_filter
289            ):
290                continue
291            provider_tracks = await self._get_provider_album_tracks(
292                provider_mapping.item_id, provider_mapping.provider_instance
293            )
294            for provider_track in provider_tracks:
295                # In some cases (looking at you YTM) the disc/track number is not obtained from
296                # library_tracks. Ensure to update the disc/track number when interacting with
297                # album tracks
298                db_track = next(
299                    (
300                        x
301                        for x in db_items
302                        if x.sort_name == provider_track.sort_name
303                        and x.version == provider_track.version
304                    ),
305                    None,
306                )
307                if (
308                    db_track
309                    and db_track.track_number == 0
310                    and db_track.track_number != provider_track.track_number
311                ):
312                    await self._set_album_track(
313                        db_id=int(library_album.item_id),
314                        db_track_id=int(db_track.item_id),
315                        track=provider_track,
316                    )
317                if provider_track.item_id in unique_ids:
318                    continue
319                unique_id = f"{provider_track.disc_number}.{provider_track.track_number}"
320                if unique_id in unique_ids:
321                    continue
322                unique_id = f"{provider_track.name.lower()}.{provider_track.version.lower()}"
323                if unique_id in unique_ids:
324                    continue
325                unique_ids.add(unique_id)
326                provider_track.album = library_album
327                # always prefer album image
328                album_images = [library_album.image] if library_album.image else []
329                track_images: list[MediaItemImage] = provider_track.metadata.images or []
330                provider_track.metadata.images = UniqueList(album_images + track_images)
331                result.append(provider_track)
332        # NOTE: we need to return the results sorted on disc/track here
333        # to ensure the correct order at playback
334        return sorted(result, key=lambda x: (x.disc_number, x.track_number))
335
336    async def versions(
337        self,
338        item_id: str,
339        provider_instance_id_or_domain: str,
340    ) -> UniqueList[Album]:
341        """Return all versions of an album we can find on all providers."""
342        album = await self.get_provider_item(item_id, provider_instance_id_or_domain)
343        search_query = f"{album.artists[0].name} - {album.name}" if album.artists else album.name
344        result: UniqueList[Album] = UniqueList()
345        for provider_id in self.mass.music.get_unique_providers():
346            provider = self.mass.get_provider(provider_id)
347            if not provider or not isinstance(provider, MusicProvider):
348                continue
349            if not provider.library_supported(MediaType.ALBUM):
350                continue
351            result.extend(
352                prov_item
353                for prov_item in await self.search(search_query, provider_id)
354                if loose_compare_strings(album.name, prov_item.name)
355                and compare_artists(prov_item.artists, album.artists, any_match=True)
356                # make sure that the 'base' version is NOT included
357                and not album.provider_mappings.intersection(prov_item.provider_mappings)
358            )
359        return result
360
361    async def get_library_album_tracks(
362        self,
363        item_id: str | int,
364    ) -> list[Track]:
365        """Return in-database album tracks for the given database album."""
366        db_id = int(item_id)  # ensure integer
367        return await self.mass.music.tracks.get_library_items_by_query(
368            extra_query_parts=["WHERE album_tracks.album_id = :album_id"],
369            extra_query_params={"album_id": db_id},
370        )
371
372    async def add_item_mapping_as_album_to_library(self, item: ItemMapping) -> Album:
373        """
374        Add an ItemMapping as an Album to the library.
375
376        This is only used in special occasions as is basically adds an album
377        to the db without a lot of mandatory data, such as artists.
378        """
379        album = self.album_from_item_mapping(item)
380        return await self.add_item_to_library(album)
381
382    async def _add_library_item(self, item: Album, overwrite_existing: bool = False) -> int:
383        """Add a new record to the database."""
384        if not isinstance(item, Album):  # TODO: Remove this once the codebase is fully typed
385            msg = "Not a valid Album object (ItemMapping can not be added to db)"  # type: ignore[unreachable]
386            raise InvalidDataError(msg)
387        db_id = await self.mass.music.database.insert(
388            self.db_table,
389            {
390                "name": item.name,
391                "sort_name": item.sort_name,
392                "version": item.version,
393                "favorite": item.favorite,
394                "album_type": item.album_type,
395                "year": item.year,
396                "metadata": serialize_to_json(item.metadata),
397                "external_ids": serialize_to_json(item.external_ids),
398                "search_name": create_safe_string(item.name, True, True),
399                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
400                "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
401            },
402        )
403        # update/set provider_mappings table
404        await self.set_provider_mappings(db_id, item.provider_mappings)
405        # set track artist(s)
406        await self._set_album_artists(db_id, item.artists)
407        self.logger.debug("added %s to database (id: %s)", item.name, db_id)
408        return db_id
409
410    async def _update_library_item(
411        self, item_id: str | int, update: Album, overwrite: bool = False
412    ) -> None:
413        """Update existing record in the database."""
414        db_id = int(item_id)  # ensure integer
415        cur_item = await self.get_library_item(db_id)
416        metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
417        if getattr(update, "album_type", AlbumType.UNKNOWN) != AlbumType.UNKNOWN:
418            album_type = update.album_type
419        else:
420            album_type = cur_item.album_type
421        cur_item.external_ids.update(update.external_ids)
422        name = update.name if overwrite else cur_item.name
423        sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
424        await self.mass.music.database.update(
425            self.db_table,
426            {"item_id": db_id},
427            {
428                "name": name,
429                "sort_name": sort_name,
430                "version": update.version if overwrite else cur_item.version or update.version,
431                "year": update.year if overwrite else cur_item.year or update.year,
432                "album_type": album_type.value,
433                "metadata": serialize_to_json(metadata),
434                "external_ids": serialize_to_json(
435                    update.external_ids if overwrite else cur_item.external_ids
436                ),
437                "search_name": create_safe_string(name, True, True),
438                "search_sort_name": create_safe_string(sort_name or "", True, True),
439                "timestamp_added": int(update.date_added.timestamp())
440                if update.date_added
441                else UNSET,
442            },
443        )
444        # update/set provider_mappings table
445        provider_mappings = (
446            update.provider_mappings
447            if overwrite
448            else {*update.provider_mappings, *cur_item.provider_mappings}
449        )
450        await self.set_provider_mappings(db_id, provider_mappings, overwrite)
451        # set album artist(s)
452        artists = update.artists if overwrite else cur_item.artists + update.artists
453        await self._set_album_artists(db_id, artists, overwrite=overwrite)
454        self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
455
456    async def _get_provider_album_tracks(
457        self, item_id: str, provider_instance_id_or_domain: str
458    ) -> list[Track]:
459        """Return album tracks for the given provider album id."""
460        if prov := self.mass.get_provider(provider_instance_id_or_domain):
461            prov = cast("MusicProvider", prov)
462            return await prov.get_album_tracks(item_id)
463        return []
464
465    async def radio_mode_base_tracks(
466        self,
467        item: Album,
468        preferred_provider_instances: list[str] | None = None,
469    ) -> list[Track]:
470        """
471        Get the list of base tracks from the controller used to calculate the dynamic radio.
472
473        :param item: The Album to get base tracks for.
474        :param preferred_provider_instances: List of preferred provider instance IDs to use.
475        """
476        return await self.tracks(item.item_id, item.provider, in_library_only=False)
477
478    async def _set_album_artists(
479        self,
480        db_id: int,
481        artists: Iterable[Artist | ItemMapping],
482        overwrite: bool = False,
483    ) -> None:
484        """Store Album Artists."""
485        if overwrite:
486            # on overwrite, clear the album_artists table first
487            await self.mass.music.database.delete(
488                DB_TABLE_ALBUM_ARTISTS,
489                {
490                    "album_id": db_id,
491                },
492            )
493        for artist in artists:
494            await self._set_album_artist(db_id, artist=artist, overwrite=overwrite)
495
496    async def _set_album_artist(
497        self, db_id: int, artist: Artist | ItemMapping, overwrite: bool = False
498    ) -> ItemMapping:
499        """Store Album Artist info."""
500        db_artist: Artist | ItemMapping | None = None
501        if artist.provider == "library":
502            db_artist = artist
503        elif existing := await self.mass.music.artists.get_library_item_by_prov_id(
504            artist.item_id, artist.provider
505        ):
506            db_artist = existing
507
508        if not db_artist or overwrite:
509            # Convert ItemMapping to Artist if needed
510            artist_to_add = (
511                self.mass.music.artists.artist_from_item_mapping(artist)
512                if isinstance(artist, ItemMapping)
513                else artist
514            )
515            db_artist = await self.mass.music.artists.add_item_to_library(
516                artist_to_add, overwrite_existing=overwrite
517            )
518        # write (or update) record in album_artists table
519        await self.mass.music.database.insert_or_replace(
520            DB_TABLE_ALBUM_ARTISTS,
521            {
522                "album_id": db_id,
523                "artist_id": int(db_artist.item_id),
524            },
525        )
526        return ItemMapping.from_item(db_artist)
527
528    async def _set_album_track(self, db_id: int, db_track_id: int, track: Track) -> None:
529        """Store Album Track info."""
530        # write (or update) record in album_tracks table
531        await self.mass.music.database.insert_or_replace(
532            DB_TABLE_ALBUM_TRACKS,
533            {
534                "album_id": db_id,
535                "track_id": db_track_id,
536                "track_number": track.track_number,
537                "disc_number": track.disc_number,
538            },
539        )
540
541    async def match_provider(
542        self, db_album: Album, provider: MusicProvider, strict: bool = True
543    ) -> list[ProviderMapping]:
544        """
545        Try to find match on (streaming) provider for the provided (database) album.
546
547        This is used to link objects of different providers/qualities together.
548        """
549        self.logger.debug("Trying to match album %s on provider %s", db_album.name, provider.name)
550        matches: list[ProviderMapping] = []
551        artist_name = db_album.artists[0].name
552        search_str = f"{artist_name} - {db_album.name}"
553        search_result = await self.search(search_str, provider.instance_id)
554        for search_result_item in search_result:
555            if not search_result_item.available:
556                continue
557            if not compare_media_item(db_album, search_result_item, strict=strict):
558                continue
559            # we must fetch the full album version, search results can be simplified objects
560            prov_album = await self.get_provider_item(
561                search_result_item.item_id,
562                search_result_item.provider,
563                fallback=search_result_item,
564            )
565            if compare_album(db_album, prov_album, strict=strict):
566                # 100% match
567                matches.extend(prov_album.provider_mappings)
568        if not matches:
569            self.logger.debug(
570                "Could not find match for Album %s on provider %s",
571                db_album.name,
572                provider.name,
573            )
574        return matches
575
576    async def match_providers(self, db_album: Album) -> None:
577        """Try to find match on all (streaming) providers for the provided (database) album.
578
579        This is used to link objects of different providers/qualities together.
580        """
581        if db_album.provider != "library":
582            return  # Matching only supported for database items
583        if not db_album.artists:
584            return  # guard
585
586        # try to find match on all providers
587        processed_domains = set()
588        for provider in self.mass.music.providers:
589            if provider.domain in processed_domains:
590                continue
591            if ProviderFeature.SEARCH not in provider.supported_features:
592                continue
593            if not provider.library_supported(MediaType.ALBUM):
594                continue
595            if not provider.is_streaming_provider:
596                # matching on unique providers is pointless as they push (all) their content to MA
597                continue
598            if match := await self.match_provider(db_album, provider):
599                # 100% match, we update the db with the additional provider mapping(s)
600                await self.add_provider_mappings(db_album.item_id, match)
601                processed_domains.add(provider.domain)
602
603    def album_from_item_mapping(self, item: ItemMapping) -> Album:
604        """Create an Album object from an ItemMapping object."""
605        domain, instance_id = None, None
606        if prov := self.mass.get_provider(item.provider):
607            domain = prov.domain
608            instance_id = prov.instance_id
609        return Album.from_dict(
610            {
611                **item.to_dict(),
612                "provider_mappings": [
613                    {
614                        "item_id": item.item_id,
615                        "provider_domain": domain,
616                        "provider_instance": instance_id,
617                        "available": item.available,
618                    }
619                ],
620            }
621        )
622