music-assistant-server

26.8 KBPY
albums.py
26.8 KB616 lines • python
1"""Manage MediaItems of type Album."""
2
3from __future__ import annotations
4
5import contextlib
6from collections.abc import Iterable
7from typing import TYPE_CHECKING, Any, cast
8
9from music_assistant_models.enums import AlbumType, MediaType, ProviderFeature
10from music_assistant_models.errors import InvalidDataError, MediaNotFoundError, MusicAssistantError
11from music_assistant_models.media_items import (
12    Album,
13    Artist,
14    ItemMapping,
15    MediaItemImage,
16    ProviderMapping,
17    Track,
18    UniqueList,
19)
20
21from music_assistant.constants import DB_TABLE_ALBUM_ARTISTS, DB_TABLE_ALBUM_TRACKS, DB_TABLE_ALBUMS
22from music_assistant.controllers.media.base import MediaControllerBase
23from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user
24from music_assistant.helpers.compare import (
25    compare_album,
26    compare_artists,
27    compare_media_item,
28    create_safe_string,
29    loose_compare_strings,
30)
31from music_assistant.helpers.database import UNSET
32from music_assistant.helpers.json import serialize_to_json
33from music_assistant.models.music_provider import MusicProvider
34
35if TYPE_CHECKING:
36    from music_assistant import MusicAssistant
37
38
39class AlbumsController(MediaControllerBase[Album]):
40    """Controller managing MediaItems of type Album."""
41
42    db_table = DB_TABLE_ALBUMS
43    media_type = MediaType.ALBUM
44    item_cls = Album
45
46    def __init__(self, mass: MusicAssistant) -> None:
47        """Initialize class."""
48        super().__init__(mass)
49        self.base_query = """
50        SELECT
51            albums.*,
52            (SELECT JSON_GROUP_ARRAY(
53                json_object(
54                'item_id', album_pm.provider_item_id,
55                    'provider_domain', album_pm.provider_domain,
56                        'provider_instance', album_pm.provider_instance,
57                        'available', album_pm.available,
58                        'audio_format', json(album_pm.audio_format),
59                        'url', album_pm.url,
60                        'details', album_pm.details,
61                        'in_library', album_pm.in_library,
62                        'is_unique', album_pm.is_unique
63                )) FROM provider_mappings album_pm WHERE album_pm.item_id = albums.item_id AND album_pm.media_type = 'album') AS provider_mappings,
64            (SELECT JSON_GROUP_ARRAY(
65                json_object(
66                'item_id', artists.item_id,
67                'provider', 'library',
68                    'name', artists.name,
69                    'sort_name', artists.sort_name,
70                    'media_type', 'artist'
71                )) FROM artists JOIN album_artists on album_artists.album_id = albums.item_id  WHERE artists.item_id = album_artists.artist_id) AS artists
72            FROM albums"""  # noqa: E501
73        # register (extra) api handlers
74        api_base = self.api_base
75        self.mass.register_api_command(f"music/{api_base}/album_tracks", self.tracks)
76        self.mass.register_api_command(f"music/{api_base}/album_versions", self.versions)
77
78    async def get(
79        self,
80        item_id: str,
81        provider_instance_id_or_domain: str,
82        recursive: bool = True,
83    ) -> Album:
84        """Return (full) details for a single media item."""
85        album = await super().get(
86            item_id,
87            provider_instance_id_or_domain,
88        )
89        if not recursive:
90            return album
91
92        # append artist details to full album item (resolve ItemMappings)
93        album_artists: UniqueList[Artist | ItemMapping] = UniqueList()
94        for artist in album.artists:
95            if not isinstance(artist, ItemMapping):
96                album_artists.append(artist)
97                continue
98            with contextlib.suppress(MediaNotFoundError):
99                album_artists.append(
100                    await self.mass.music.artists.get(
101                        artist.item_id,
102                        artist.provider,
103                    )
104                )
105        album.artists = album_artists
106        return album
107
108    async def library_items(
109        self,
110        favorite: bool | None = None,
111        search: str | None = None,
112        limit: int = 500,
113        offset: int = 0,
114        order_by: str = "sort_name",
115        provider: str | list[str] | None = None,
116        album_types: list[AlbumType] | None = None,
117    ) -> list[Album]:
118        """Get in-database albums.
119
120        :param favorite: Filter by favorite status.
121        :param search: Filter by search query.
122        :param limit: Maximum number of items to return.
123        :param offset: Number of items to skip.
124        :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
125        :param provider: Filter by provider instance ID (single string or list).
126        :param album_types: Filter by album types.
127        """
128        extra_query_params: dict[str, Any] = {}
129        extra_query_parts: list[str] = []
130        extra_join_parts: list[str] = []
131        artist_table_joined = False
132        # optional album type filter
133        if album_types:
134            extra_query_parts.append("albums.album_type IN :album_types")
135            extra_query_params["album_types"] = [x.value for x in album_types]
136        if order_by and "artist_name" in order_by:
137            # join artist table to allow sorting on artist name
138            extra_join_parts.append(
139                "JOIN album_artists ON album_artists.album_id = albums.item_id "
140                "JOIN artists ON artists.item_id = album_artists.artist_id "
141            )
142            artist_table_joined = True
143        if search and " - " in search:
144            # handle combined artist + title search
145            artist_str, title_str = search.split(" - ", 1)
146            search = None
147            title_str = create_safe_string(title_str, True, True)
148            artist_str = create_safe_string(artist_str, True, True)
149            extra_query_parts.append("albums.search_name LIKE :search_title")
150            extra_query_params["search_title"] = f"%{title_str}%"
151            # use join with artists table to filter on artist name
152            extra_join_parts.append(
153                "JOIN album_artists ON album_artists.album_id = albums.item_id "
154                "JOIN artists ON artists.item_id = album_artists.artist_id "
155                "AND artists.search_name LIKE :search_artist"
156                if not artist_table_joined
157                else "AND artists.search_name LIKE :search_artist"
158            )
159            artist_table_joined = True
160            extra_query_params["search_artist"] = f"%{artist_str}%"
161        result = await self.get_library_items_by_query(
162            favorite=favorite,
163            search=search,
164            limit=limit,
165            offset=offset,
166            order_by=order_by,
167            provider_filter=self._ensure_provider_filter(provider),
168            extra_query_parts=extra_query_parts,
169            extra_query_params=extra_query_params,
170            extra_join_parts=extra_join_parts,
171        )
172
173        # Calculate how many more items we need to reach the original limit
174        remaining_limit = limit - len(result)
175
176        if search and len(result) < 25 and not offset and remaining_limit > 0:
177            # append artist items to result
178            search = create_safe_string(search, True, True)
179            extra_join_parts.append(
180                "JOIN album_artists ON album_artists.album_id = albums.item_id "
181                "JOIN artists ON artists.item_id = album_artists.artist_id "
182                "AND artists.search_name LIKE :search_artist"
183                if not artist_table_joined
184                else "AND artists.search_name LIKE :search_artist"
185            )
186            extra_query_params["search_artist"] = f"%{search}%"
187            existing_uris = {item.uri for item in result}
188
189            for album in await self.get_library_items_by_query(
190                favorite=favorite,
191                search=None,
192                limit=remaining_limit,
193                order_by=order_by,
194                provider_filter=self._ensure_provider_filter(provider),
195                extra_query_parts=extra_query_parts,
196                extra_query_params=extra_query_params,
197                extra_join_parts=extra_join_parts,
198            ):
199                # prevent duplicates (when artist is also in the title)
200                if album.uri not in existing_uris:
201                    result.append(album)
202                    # Stop if we've reached the original limit
203                    if len(result) >= limit:
204                        break
205        return result
206
207    async def library_count(
208        self, favorite_only: bool = False, album_types: list[AlbumType] | None = None
209    ) -> int:
210        """Return the total number of items in the library."""
211        sql_query = f"SELECT item_id FROM {self.db_table}"
212        query_parts: list[str] = []
213        query_params: dict[str, Any] = {}
214        if favorite_only:
215            query_parts.append("favorite = 1")
216        if album_types:
217            query_parts.append("albums.album_type IN :album_types")
218            query_params["album_types"] = [x.value for x in album_types]
219        if query_parts:
220            sql_query += f" WHERE {' AND '.join(query_parts)}"
221        return await self.mass.music.database.get_count_from_query(sql_query, query_params)
222
223    async def remove_item_from_library(self, item_id: str | int, recursive: bool = True) -> None:
224        """Delete item from the library(database)."""
225        db_id = int(item_id)  # ensure integer
226        # recursively also remove album tracks
227        for db_track in await self.get_library_album_tracks(db_id):
228            if not recursive:
229                raise MusicAssistantError("Album still has tracks linked")
230            with contextlib.suppress(MediaNotFoundError):
231                await self.mass.music.tracks.remove_item_from_library(db_track.item_id)
232        # delete entry(s) from albumtracks table
233        await self.mass.music.database.delete(DB_TABLE_ALBUM_TRACKS, {"album_id": db_id})
234        # delete entry(s) from album artists table
235        await self.mass.music.database.delete(DB_TABLE_ALBUM_ARTISTS, {"album_id": db_id})
236        # delete the album itself from db
237        # this will raise if the item still has references and recursive is false
238        await super().remove_item_from_library(item_id)
239
240    async def tracks(
241        self,
242        item_id: str,
243        provider_instance_id_or_domain: str,
244        in_library_only: bool = False,
245    ) -> list[Track]:
246        """Return album tracks for the given provider album id."""
247        # always check if we have a library item for this album
248        library_album = await self.get_library_item_by_prov_id(
249            item_id, provider_instance_id_or_domain
250        )
251        if not library_album:
252            album_tracks = await self._get_provider_album_tracks(
253                item_id, provider_instance_id_or_domain
254            )
255            if album_tracks and not album_tracks[0].image:
256                # set album image from provider album if not present on tracks
257                prov_album = await self.get_provider_item(item_id, provider_instance_id_or_domain)
258                if prov_album.image:
259                    for track in album_tracks:
260                        if not track.image:
261                            track.metadata.add_image(prov_album.image)
262            return album_tracks
263
264        db_items = await self.get_library_album_tracks(library_album.item_id)
265        result: list[Track] = list(db_items)
266        if in_library_only:
267            # return in-library items only
268            return sorted(db_items, key=lambda x: (x.disc_number, x.track_number))
269
270        # return all (unique) items from all providers
271        # because we are returning the items from all providers combined,
272        # we need to make sure that we don't return duplicates
273        unique_ids: set[str] = {f"{x.disc_number}.{x.track_number}" for x in db_items}
274        unique_ids.update({f"{x.name.lower()}.{x.version.lower()}" for x in db_items})
275        for db_item in db_items:
276            unique_ids.update(x.item_id for x in db_item.provider_mappings)
277        user = get_current_user()
278        user_provider_filter = user.provider_filter if user and user.provider_filter else None
279        for provider_mapping in library_album.provider_mappings:
280            if (
281                user_provider_filter
282                and provider_mapping.provider_instance not in user_provider_filter
283            ):
284                continue
285            provider_tracks = await self._get_provider_album_tracks(
286                provider_mapping.item_id, provider_mapping.provider_instance
287            )
288            for provider_track in provider_tracks:
289                # In some cases (looking at you YTM) the disc/track number is not obtained from
290                # library_tracks. Ensure to update the disc/track number when interacting with
291                # album tracks
292                db_track = next(
293                    (
294                        x
295                        for x in db_items
296                        if x.sort_name == provider_track.sort_name
297                        and x.version == provider_track.version
298                    ),
299                    None,
300                )
301                if (
302                    db_track
303                    and db_track.track_number == 0
304                    and db_track.track_number != provider_track.track_number
305                ):
306                    await self._set_album_track(
307                        db_id=int(library_album.item_id),
308                        db_track_id=int(db_track.item_id),
309                        track=provider_track,
310                    )
311                if provider_track.item_id in unique_ids:
312                    continue
313                unique_id = f"{provider_track.disc_number}.{provider_track.track_number}"
314                if unique_id in unique_ids:
315                    continue
316                unique_id = f"{provider_track.name.lower()}.{provider_track.version.lower()}"
317                if unique_id in unique_ids:
318                    continue
319                unique_ids.add(unique_id)
320                provider_track.album = library_album
321                # always prefer album image
322                album_images = [library_album.image] if library_album.image else []
323                track_images: list[MediaItemImage] = provider_track.metadata.images or []
324                provider_track.metadata.images = UniqueList(album_images + track_images)
325                result.append(provider_track)
326        # NOTE: we need to return the results sorted on disc/track here
327        # to ensure the correct order at playback
328        return sorted(result, key=lambda x: (x.disc_number, x.track_number))
329
330    async def versions(
331        self,
332        item_id: str,
333        provider_instance_id_or_domain: str,
334    ) -> UniqueList[Album]:
335        """Return all versions of an album we can find on all providers."""
336        album = await self.get_provider_item(item_id, provider_instance_id_or_domain)
337        search_query = f"{album.artists[0].name} - {album.name}" if album.artists else album.name
338        result: UniqueList[Album] = UniqueList()
339        for provider_id in self.mass.music.get_unique_providers():
340            provider = self.mass.get_provider(provider_id)
341            if not provider or not isinstance(provider, MusicProvider):
342                continue
343            if not provider.library_supported(MediaType.ALBUM):
344                continue
345            result.extend(
346                prov_item
347                for prov_item in await self.search(search_query, provider_id)
348                if loose_compare_strings(album.name, prov_item.name)
349                and compare_artists(prov_item.artists, album.artists, any_match=True)
350                # make sure that the 'base' version is NOT included
351                and not album.provider_mappings.intersection(prov_item.provider_mappings)
352            )
353        return result
354
355    async def get_library_album_tracks(
356        self,
357        item_id: str | int,
358    ) -> list[Track]:
359        """Return in-database album tracks for the given database album."""
360        db_id = int(item_id)  # ensure integer
361        return await self.mass.music.tracks.get_library_items_by_query(
362            extra_query_parts=["WHERE album_tracks.album_id = :album_id"],
363            extra_query_params={"album_id": db_id},
364        )
365
366    async def add_item_mapping_as_album_to_library(self, item: ItemMapping) -> Album:
367        """
368        Add an ItemMapping as an Album to the library.
369
370        This is only used in special occasions as is basically adds an album
371        to the db without a lot of mandatory data, such as artists.
372        """
373        album = self.album_from_item_mapping(item)
374        return await self.add_item_to_library(album)
375
376    async def _add_library_item(self, item: Album, overwrite_existing: bool = False) -> int:
377        """Add a new record to the database."""
378        if not isinstance(item, Album):  # TODO: Remove this once the codebase is fully typed
379            msg = "Not a valid Album object (ItemMapping can not be added to db)"  # type: ignore[unreachable]
380            raise InvalidDataError(msg)
381        db_id = await self.mass.music.database.insert(
382            self.db_table,
383            {
384                "name": item.name,
385                "sort_name": item.sort_name,
386                "version": item.version,
387                "favorite": item.favorite,
388                "album_type": item.album_type,
389                "year": item.year,
390                "metadata": serialize_to_json(item.metadata),
391                "external_ids": serialize_to_json(item.external_ids),
392                "search_name": create_safe_string(item.name, True, True),
393                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
394                "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
395            },
396        )
397        # update/set provider_mappings table
398        await self.set_provider_mappings(db_id, item.provider_mappings)
399        # set track artist(s)
400        await self._set_album_artists(db_id, item.artists)
401        self.logger.debug("added %s to database (id: %s)", item.name, db_id)
402        return db_id
403
404    async def _update_library_item(
405        self, item_id: str | int, update: Album, overwrite: bool = False
406    ) -> None:
407        """Update existing record in the database."""
408        db_id = int(item_id)  # ensure integer
409        cur_item = await self.get_library_item(db_id)
410        metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
411        if getattr(update, "album_type", AlbumType.UNKNOWN) != AlbumType.UNKNOWN:
412            album_type = update.album_type
413        else:
414            album_type = cur_item.album_type
415        cur_item.external_ids.update(update.external_ids)
416        name = update.name if overwrite else cur_item.name
417        sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
418        await self.mass.music.database.update(
419            self.db_table,
420            {"item_id": db_id},
421            {
422                "name": name,
423                "sort_name": sort_name,
424                "version": update.version if overwrite else cur_item.version or update.version,
425                "year": update.year if overwrite else cur_item.year or update.year,
426                "album_type": album_type.value,
427                "metadata": serialize_to_json(metadata),
428                "external_ids": serialize_to_json(
429                    update.external_ids if overwrite else cur_item.external_ids
430                ),
431                "search_name": create_safe_string(name, True, True),
432                "search_sort_name": create_safe_string(sort_name or "", True, True),
433                "timestamp_added": int(update.date_added.timestamp())
434                if update.date_added
435                else UNSET,
436            },
437        )
438        # update/set provider_mappings table
439        provider_mappings = (
440            update.provider_mappings
441            if overwrite
442            else {*update.provider_mappings, *cur_item.provider_mappings}
443        )
444        await self.set_provider_mappings(db_id, provider_mappings, overwrite)
445        # set album artist(s)
446        artists = update.artists if overwrite else cur_item.artists + update.artists
447        await self._set_album_artists(db_id, artists, overwrite=overwrite)
448        self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
449
450    async def _get_provider_album_tracks(
451        self, item_id: str, provider_instance_id_or_domain: str
452    ) -> list[Track]:
453        """Return album tracks for the given provider album id."""
454        if prov := self.mass.get_provider(provider_instance_id_or_domain):
455            prov = cast("MusicProvider", prov)
456            return await prov.get_album_tracks(item_id)
457        return []
458
459    async def radio_mode_base_tracks(
460        self,
461        item: Album,
462        preferred_provider_instances: list[str] | None = None,
463    ) -> list[Track]:
464        """
465        Get the list of base tracks from the controller used to calculate the dynamic radio.
466
467        :param item: The Album to get base tracks for.
468        :param preferred_provider_instances: List of preferred provider instance IDs to use.
469        """
470        return await self.tracks(item.item_id, item.provider, in_library_only=False)
471
472    async def _set_album_artists(
473        self,
474        db_id: int,
475        artists: Iterable[Artist | ItemMapping],
476        overwrite: bool = False,
477    ) -> None:
478        """Store Album Artists."""
479        if overwrite:
480            # on overwrite, clear the album_artists table first
481            await self.mass.music.database.delete(
482                DB_TABLE_ALBUM_ARTISTS,
483                {
484                    "album_id": db_id,
485                },
486            )
487        for artist in artists:
488            await self._set_album_artist(db_id, artist=artist, overwrite=overwrite)
489
490    async def _set_album_artist(
491        self, db_id: int, artist: Artist | ItemMapping, overwrite: bool = False
492    ) -> ItemMapping:
493        """Store Album Artist info."""
494        db_artist: Artist | ItemMapping | None = None
495        if artist.provider == "library":
496            db_artist = artist
497        elif existing := await self.mass.music.artists.get_library_item_by_prov_id(
498            artist.item_id, artist.provider
499        ):
500            db_artist = existing
501
502        if not db_artist or overwrite:
503            # Convert ItemMapping to Artist if needed
504            artist_to_add = (
505                self.mass.music.artists.artist_from_item_mapping(artist)
506                if isinstance(artist, ItemMapping)
507                else artist
508            )
509            db_artist = await self.mass.music.artists.add_item_to_library(
510                artist_to_add, overwrite_existing=overwrite
511            )
512        # write (or update) record in album_artists table
513        await self.mass.music.database.insert_or_replace(
514            DB_TABLE_ALBUM_ARTISTS,
515            {
516                "album_id": db_id,
517                "artist_id": int(db_artist.item_id),
518            },
519        )
520        return ItemMapping.from_item(db_artist)
521
522    async def _set_album_track(self, db_id: int, db_track_id: int, track: Track) -> None:
523        """Store Album Track info."""
524        # write (or update) record in album_tracks table
525        await self.mass.music.database.insert_or_replace(
526            DB_TABLE_ALBUM_TRACKS,
527            {
528                "album_id": db_id,
529                "track_id": db_track_id,
530                "track_number": track.track_number,
531                "disc_number": track.disc_number,
532            },
533        )
534
535    async def match_provider(
536        self, db_album: Album, provider: MusicProvider, strict: bool = True
537    ) -> list[ProviderMapping]:
538        """
539        Try to find match on (streaming) provider for the provided (database) album.
540
541        This is used to link objects of different providers/qualities together.
542        """
543        self.logger.debug("Trying to match album %s on provider %s", db_album.name, provider.name)
544        matches: list[ProviderMapping] = []
545        artist_name = db_album.artists[0].name
546        search_str = f"{artist_name} - {db_album.name}"
547        search_result = await self.search(search_str, provider.instance_id)
548        for search_result_item in search_result:
549            if not search_result_item.available:
550                continue
551            if not compare_media_item(db_album, search_result_item, strict=strict):
552                continue
553            # we must fetch the full album version, search results can be simplified objects
554            prov_album = await self.get_provider_item(
555                search_result_item.item_id,
556                search_result_item.provider,
557                fallback=search_result_item,
558            )
559            if compare_album(db_album, prov_album, strict=strict):
560                # 100% match
561                matches.extend(prov_album.provider_mappings)
562        if not matches:
563            self.logger.debug(
564                "Could not find match for Album %s on provider %s",
565                db_album.name,
566                provider.name,
567            )
568        return matches
569
570    async def match_providers(self, db_album: Album) -> None:
571        """Try to find match on all (streaming) providers for the provided (database) album.
572
573        This is used to link objects of different providers/qualities together.
574        """
575        if db_album.provider != "library":
576            return  # Matching only supported for database items
577        if not db_album.artists:
578            return  # guard
579
580        # try to find match on all providers
581        processed_domains = set()
582        for provider in self.mass.music.providers:
583            if provider.domain in processed_domains:
584                continue
585            if ProviderFeature.SEARCH not in provider.supported_features:
586                continue
587            if not provider.library_supported(MediaType.ALBUM):
588                continue
589            if not provider.is_streaming_provider:
590                # matching on unique providers is pointless as they push (all) their content to MA
591                continue
592            if match := await self.match_provider(db_album, provider):
593                # 100% match, we update the db with the additional provider mapping(s)
594                await self.add_provider_mappings(db_album.item_id, match)
595                processed_domains.add(provider.domain)
596
597    def album_from_item_mapping(self, item: ItemMapping) -> Album:
598        """Create an Album object from an ItemMapping object."""
599        domain, instance_id = None, None
600        if prov := self.mass.get_provider(item.provider):
601            domain = prov.domain
602            instance_id = prov.instance_id
603        return Album.from_dict(
604            {
605                **item.to_dict(),
606                "provider_mappings": [
607                    {
608                        "item_id": item.item_id,
609                        "provider_domain": domain,
610                        "provider_instance": instance_id,
611                        "available": item.available,
612                    }
613                ],
614            }
615        )
616