music-assistant-server

13.9 KBPY
podcasts.py
13.9 KB326 lines • python
1"""Manage MediaItems of type Podcast."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, Any
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
10from music_assistant_models.media_items import Podcast, PodcastEpisode, ProviderMapping, UniqueList
11
12from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
13from music_assistant.controllers.media.base import MediaControllerBase
14from music_assistant.helpers.compare import (
15    compare_media_item,
16    compare_podcast,
17    create_safe_string,
18    loose_compare_strings,
19)
20from music_assistant.helpers.database import UNSET
21from music_assistant.helpers.json import serialize_to_json
22from music_assistant.models.music_provider import MusicProvider
23
24if TYPE_CHECKING:
25    from music_assistant_models.media_items import Track
26
27    from music_assistant import MusicAssistant
28
29
30class PodcastsController(MediaControllerBase[Podcast]):
31    """Controller managing MediaItems of type Podcast."""
32
33    db_table = DB_TABLE_PODCASTS
34    media_type = MediaType.PODCAST
35    item_cls = Podcast
36
37    def __init__(self, mass: MusicAssistant) -> None:
38        """Initialize class."""
39        super().__init__(mass)
40        # register (extra) api handlers
41        api_base = self.api_base
42        self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
43        self.mass.register_api_command(f"music/{api_base}/podcast_episode", self.episode)
44        self.mass.register_api_command(f"music/{api_base}/podcast_versions", self.versions)
45
46    async def library_items(
47        self,
48        favorite: bool | None = None,
49        search: str | None = None,
50        limit: int = 500,
51        offset: int = 0,
52        order_by: str = "sort_name",
53        provider: str | list[str] | None = None,
54        genre: int | list[int] | None = None,
55        **kwargs: Any,
56    ) -> list[Podcast]:
57        """Get in-database podcasts.
58
59        :param favorite: Filter by favorite status.
60        :param search: Filter by search query.
61        :param limit: Maximum number of items to return.
62        :param offset: Number of items to skip.
63        :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
64        :param provider: Filter by provider instance ID (single string or list).
65        :param genre: Filter by genre id(s).
66        """
67        result = await self.get_library_items_by_query(
68            favorite=favorite,
69            search=search,
70            genre_ids=genre,
71            limit=limit,
72            offset=offset,
73            order_by=order_by,
74            provider_filter=self._ensure_provider_filter(provider),
75            in_library_only=True,
76        )
77        if search and len(result) < 25 and not offset:
78            # append publisher items to result
79            extra_query_parts: list[str] = [
80                "WHERE podcasts.publisher LIKE :search",
81            ]
82            extra_query_params: dict[str, Any] = {
83                "search": f"%{search}%",
84            }
85            return result + await self.get_library_items_by_query(
86                favorite=favorite,
87                search=None,
88                genre_ids=genre,
89                limit=limit,
90                order_by=order_by,
91                provider_filter=self._ensure_provider_filter(provider),
92                extra_query_parts=extra_query_parts,
93                extra_query_params=extra_query_params,
94                in_library_only=True,
95            )
96        return result
97
98    async def episodes(
99        self,
100        item_id: str,
101        provider_instance_id_or_domain: str,
102    ) -> AsyncGenerator[PodcastEpisode, None]:
103        """Return podcast episodes for the given provider podcast id."""
104        # always check if we have a library item for this podcast
105        if provider_instance_id_or_domain == "library":
106            library_podcast = await self.get_library_item(item_id)
107            if not library_podcast:
108                raise MediaNotFoundError(f"Podcast {item_id} not found in library")
109            provider_instance_id_or_domain, item_id = self._select_provider_id(library_podcast)
110        # podcast episodes are not stored in the db/library
111        # so we always need to fetch them from the provider
112        async for episode in self._get_provider_podcast_episodes(
113            item_id, provider_instance_id_or_domain
114        ):
115            yield episode
116
117    async def episode(
118        self,
119        item_id: str,
120        provider_instance_id_or_domain: str,
121    ) -> PodcastEpisode:
122        """Return single podcast episode by the given provider podcast id."""
123        prov = self.mass.get_provider(provider_instance_id_or_domain)
124        if not isinstance(prov, MusicProvider):
125            raise ProviderUnavailableError("Provider not found")
126        return await prov.get_podcast_episode(item_id)
127
128    async def versions(
129        self,
130        item_id: str,
131        provider_instance_id_or_domain: str,
132    ) -> UniqueList[Podcast]:
133        """Return all versions of an podcast we can find on all providers."""
134        podcast = await self.get_provider_item(item_id, provider_instance_id_or_domain)
135        search_query = podcast.name
136        result: UniqueList[Podcast] = UniqueList()
137        for provider_id in self.mass.music.get_unique_providers():
138            provider = self.mass.get_provider(provider_id)
139            if not isinstance(provider, MusicProvider):
140                continue
141            if not provider.library_supported(MediaType.PODCAST):
142                continue
143            result.extend(
144                prov_item
145                for prov_item in await self.search(search_query, provider_id)
146                if loose_compare_strings(podcast.name, prov_item.name)
147                # make sure that the 'base' version is NOT included
148                and not podcast.provider_mappings.intersection(prov_item.provider_mappings)
149            )
150        return result
151
152    async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
153        """Add a new record to the database."""
154        db_id = await self.mass.music.database.insert(
155            self.db_table,
156            {
157                "name": item.name,
158                "sort_name": item.sort_name,
159                "version": item.version,
160                "favorite": item.favorite,
161                "metadata": serialize_to_json(item.metadata),
162                "external_ids": serialize_to_json(item.external_ids),
163                "publisher": item.publisher,
164                "total_episodes": item.total_episodes or 0,
165                "search_name": create_safe_string(item.name, True, True),
166                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
167                "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
168            },
169        )
170        # update/set provider_mappings table
171        await self.set_provider_mappings(db_id, item.provider_mappings)
172        self.logger.debug("added %s to database (id: %s)", item.name, db_id)
173        return db_id
174
175    async def _update_library_item(
176        self, item_id: str | int, update: Podcast, overwrite: bool = False
177    ) -> None:
178        """Update existing record in the database."""
179        db_id = int(item_id)  # ensure integer
180        cur_item = await self.get_library_item(db_id)
181        metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
182        cur_item.external_ids.update(update.external_ids)
183        name = update.name if overwrite else cur_item.name
184        sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
185        await self.mass.music.database.update(
186            self.db_table,
187            {"item_id": db_id},
188            {
189                "name": name,
190                "sort_name": sort_name,
191                "version": update.version if overwrite else cur_item.version or update.version,
192                "metadata": serialize_to_json(metadata),
193                "external_ids": serialize_to_json(
194                    update.external_ids if overwrite else cur_item.external_ids
195                ),
196                "publisher": cur_item.publisher or update.publisher,
197                "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
198                "search_name": create_safe_string(name, True, True),
199                "search_sort_name": create_safe_string(sort_name or "", True, True),
200                "timestamp_added": int(update.date_added.timestamp())
201                if update.date_added
202                else UNSET,
203            },
204        )
205        # update/set provider_mappings table
206        provider_mappings = (
207            update.provider_mappings
208            if overwrite
209            else {*update.provider_mappings, *cur_item.provider_mappings}
210        )
211        await self.set_provider_mappings(db_id, provider_mappings, overwrite)
212        self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
213
214    async def _get_provider_podcast_episodes(
215        self, item_id: str, provider_instance_id_or_domain: str
216    ) -> AsyncGenerator[PodcastEpisode, None]:
217        """Return podcast episodes for the given provider podcast id."""
218        prov = self.mass.get_provider(provider_instance_id_or_domain)
219        if not isinstance(prov, MusicProvider):
220            return
221
222        async def set_resume_position(episode: PodcastEpisode) -> None:
223            if episode.fully_played is not None or episode.resume_position_ms:
224                # provider supports resume info, we can skip
225                return
226            # for providers that do not natively support providing resume info,
227            # we fallback to the playlog db table
228            resume_info_db_row = await self.mass.music.database.get_row(
229                DB_TABLE_PLAYLOG,
230                {
231                    "item_id": episode.item_id,
232                    "provider": prov.instance_id,
233                    "media_type": MediaType.PODCAST_EPISODE,
234                },
235            )
236            if resume_info_db_row is None:
237                return
238            if resume_info_db_row["seconds_played"]:
239                episode.resume_position_ms = int(resume_info_db_row["seconds_played"] * 1000)
240            if resume_info_db_row["fully_played"] is not None:
241                episode.fully_played = bool(resume_info_db_row["fully_played"])
242
243        # grab the episodes from the provider
244        # note that we do not cache any of this because its
245        # always a rather small list and we want fresh resume info
246        async for item in prov.get_podcast_episodes(item_id):
247            await set_resume_position(item)
248            yield item
249
250    async def radio_mode_base_tracks(
251        self,
252        item: Podcast,
253        preferred_provider_instances: list[str] | None = None,
254    ) -> list[Track]:
255        """
256        Get the list of base tracks from the controller used to calculate the dynamic radio.
257
258        :param item: The Podcast to get base tracks for.
259        :param preferred_provider_instances: List of preferred provider instance IDs to use.
260        """
261        msg = "Dynamic tracks not supported for Podcast MediaItem"
262        raise NotImplementedError(msg)
263
264    async def match_provider(
265        self, db_podcast: Podcast, provider: MusicProvider, strict: bool = True
266    ) -> list[ProviderMapping]:
267        """
268        Try to find match on (streaming) provider for the provided (database) podcast.
269
270        This is used to link objects of different providers/qualities together.
271        """
272        self.logger.debug(
273            "Trying to match podcast %s on provider %s",
274            db_podcast.name,
275            provider.name,
276        )
277        matches: list[ProviderMapping] = []
278        search_str = db_podcast.name
279        search_result = await self.search(search_str, provider.instance_id)
280        for search_result_item in search_result:
281            if not search_result_item.available:
282                continue
283            if not compare_media_item(db_podcast, search_result_item, strict=strict):
284                continue
285            # we must fetch the full podcast version, search results can be simplified objects
286            prov_podcast = await self.get_provider_item(
287                search_result_item.item_id,
288                search_result_item.provider,
289                fallback=search_result_item,
290            )
291            if compare_podcast(db_podcast, prov_podcast, strict=strict):
292                # 100% match
293                matches.extend(prov_podcast.provider_mappings)
294        if not matches:
295            self.logger.debug(
296                "Could not find match for Podcast %s on provider %s",
297                db_podcast.name,
298                provider.name,
299            )
300        return matches
301
302    async def match_providers(self, db_podcast: Podcast) -> None:
303        """Try to find match on all (streaming) providers for the provided (database) podcast.
304
305        This is used to link objects of different providers/qualities together.
306        """
307        if db_podcast.provider != "library":
308            return  # Matching only supported for database items
309
310        # try to find match on all providers
311        cur_provider_domains = {x.provider_domain for x in db_podcast.provider_mappings}
312        for provider in self.mass.music.providers:
313            if provider.domain in cur_provider_domains:
314                continue
315            if ProviderFeature.SEARCH not in provider.supported_features:
316                continue
317            if not provider.library_supported(MediaType.PODCAST):
318                continue
319            if not provider.is_streaming_provider:
320                # matching on unique providers is pointless as they push (all) their content to MA
321                continue
322            if match := await self.match_provider(db_podcast, provider):
323                # 100% match, we update the db with the additional provider mapping(s)
324                await self.add_provider_mappings(db_podcast.item_id, match)
325                cur_provider_domains.add(provider.domain)
326