music-assistant-server

13.7 KBPY
podcasts.py
13.7 KB319 lines • python
1"""Manage MediaItems of type Podcast."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, Any
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
10from music_assistant_models.media_items import Podcast, PodcastEpisode, ProviderMapping, UniqueList
11
12from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
13from music_assistant.controllers.media.base import MediaControllerBase
14from music_assistant.helpers.compare import (
15    compare_media_item,
16    compare_podcast,
17    create_safe_string,
18    loose_compare_strings,
19)
20from music_assistant.helpers.database import UNSET
21from music_assistant.helpers.json import serialize_to_json
22from music_assistant.models.music_provider import MusicProvider
23
24if TYPE_CHECKING:
25    from music_assistant_models.media_items import Track
26
27    from music_assistant import MusicAssistant
28
29
30class PodcastsController(MediaControllerBase[Podcast]):
31    """Controller managing MediaItems of type Podcast."""
32
33    db_table = DB_TABLE_PODCASTS
34    media_type = MediaType.PODCAST
35    item_cls = Podcast
36
37    def __init__(self, mass: MusicAssistant) -> None:
38        """Initialize class."""
39        super().__init__(mass)
40        # register (extra) api handlers
41        api_base = self.api_base
42        self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
43        self.mass.register_api_command(f"music/{api_base}/podcast_episode", self.episode)
44        self.mass.register_api_command(f"music/{api_base}/podcast_versions", self.versions)
45
46    async def library_items(
47        self,
48        favorite: bool | None = None,
49        search: str | None = None,
50        limit: int = 500,
51        offset: int = 0,
52        order_by: str = "sort_name",
53        provider: str | list[str] | None = None,
54    ) -> list[Podcast]:
55        """Get in-database podcasts.
56
57        :param favorite: Filter by favorite status.
58        :param search: Filter by search query.
59        :param limit: Maximum number of items to return.
60        :param offset: Number of items to skip.
61        :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
62        :param provider: Filter by provider instance ID (single string or list).
63        """
64        result = await self.get_library_items_by_query(
65            favorite=favorite,
66            search=search,
67            limit=limit,
68            offset=offset,
69            order_by=order_by,
70            provider_filter=self._ensure_provider_filter(provider),
71        )
72        if search and len(result) < 25 and not offset:
73            # append publisher items to result
74            extra_query_parts: list[str] = [
75                "WHERE podcasts.publisher LIKE :search",
76            ]
77            extra_query_params: dict[str, Any] = {
78                "search": f"%{search}%",
79            }
80            return result + await self.get_library_items_by_query(
81                favorite=favorite,
82                search=None,
83                limit=limit,
84                order_by=order_by,
85                provider_filter=self._ensure_provider_filter(provider),
86                extra_query_parts=extra_query_parts,
87                extra_query_params=extra_query_params,
88            )
89        return result
90
91    async def episodes(
92        self,
93        item_id: str,
94        provider_instance_id_or_domain: str,
95    ) -> AsyncGenerator[PodcastEpisode, None]:
96        """Return podcast episodes for the given provider podcast id."""
97        # always check if we have a library item for this podcast
98        if provider_instance_id_or_domain == "library":
99            library_podcast = await self.get_library_item(item_id)
100            if not library_podcast:
101                raise MediaNotFoundError(f"Podcast {item_id} not found in library")
102            provider_instance_id_or_domain, item_id = self._select_provider_id(library_podcast)
103        # podcast episodes are not stored in the db/library
104        # so we always need to fetch them from the provider
105        async for episode in self._get_provider_podcast_episodes(
106            item_id, provider_instance_id_or_domain
107        ):
108            yield episode
109
110    async def episode(
111        self,
112        item_id: str,
113        provider_instance_id_or_domain: str,
114    ) -> PodcastEpisode:
115        """Return single podcast episode by the given provider podcast id."""
116        prov = self.mass.get_provider(provider_instance_id_or_domain)
117        if not isinstance(prov, MusicProvider):
118            raise ProviderUnavailableError("Provider not found")
119        return await prov.get_podcast_episode(item_id)
120
121    async def versions(
122        self,
123        item_id: str,
124        provider_instance_id_or_domain: str,
125    ) -> UniqueList[Podcast]:
126        """Return all versions of an podcast we can find on all providers."""
127        podcast = await self.get_provider_item(item_id, provider_instance_id_or_domain)
128        search_query = podcast.name
129        result: UniqueList[Podcast] = UniqueList()
130        for provider_id in self.mass.music.get_unique_providers():
131            provider = self.mass.get_provider(provider_id)
132            if not isinstance(provider, MusicProvider):
133                continue
134            if not provider.library_supported(MediaType.PODCAST):
135                continue
136            result.extend(
137                prov_item
138                for prov_item in await self.search(search_query, provider_id)
139                if loose_compare_strings(podcast.name, prov_item.name)
140                # make sure that the 'base' version is NOT included
141                and not podcast.provider_mappings.intersection(prov_item.provider_mappings)
142            )
143        return result
144
145    async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
146        """Add a new record to the database."""
147        db_id = await self.mass.music.database.insert(
148            self.db_table,
149            {
150                "name": item.name,
151                "sort_name": item.sort_name,
152                "version": item.version,
153                "favorite": item.favorite,
154                "metadata": serialize_to_json(item.metadata),
155                "external_ids": serialize_to_json(item.external_ids),
156                "publisher": item.publisher,
157                "total_episodes": item.total_episodes or 0,
158                "search_name": create_safe_string(item.name, True, True),
159                "search_sort_name": create_safe_string(item.sort_name or "", True, True),
160                "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
161            },
162        )
163        # update/set provider_mappings table
164        await self.set_provider_mappings(db_id, item.provider_mappings)
165        self.logger.debug("added %s to database (id: %s)", item.name, db_id)
166        return db_id
167
168    async def _update_library_item(
169        self, item_id: str | int, update: Podcast, overwrite: bool = False
170    ) -> None:
171        """Update existing record in the database."""
172        db_id = int(item_id)  # ensure integer
173        cur_item = await self.get_library_item(db_id)
174        metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
175        cur_item.external_ids.update(update.external_ids)
176        name = update.name if overwrite else cur_item.name
177        sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
178        await self.mass.music.database.update(
179            self.db_table,
180            {"item_id": db_id},
181            {
182                "name": name,
183                "sort_name": sort_name,
184                "version": update.version if overwrite else cur_item.version or update.version,
185                "metadata": serialize_to_json(metadata),
186                "external_ids": serialize_to_json(
187                    update.external_ids if overwrite else cur_item.external_ids
188                ),
189                "publisher": cur_item.publisher or update.publisher,
190                "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
191                "search_name": create_safe_string(name, True, True),
192                "search_sort_name": create_safe_string(sort_name or "", True, True),
193                "timestamp_added": int(update.date_added.timestamp())
194                if update.date_added
195                else UNSET,
196            },
197        )
198        # update/set provider_mappings table
199        provider_mappings = (
200            update.provider_mappings
201            if overwrite
202            else {*update.provider_mappings, *cur_item.provider_mappings}
203        )
204        await self.set_provider_mappings(db_id, provider_mappings, overwrite)
205        self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
206
207    async def _get_provider_podcast_episodes(
208        self, item_id: str, provider_instance_id_or_domain: str
209    ) -> AsyncGenerator[PodcastEpisode, None]:
210        """Return podcast episodes for the given provider podcast id."""
211        prov = self.mass.get_provider(provider_instance_id_or_domain)
212        if not isinstance(prov, MusicProvider):
213            return
214
215        async def set_resume_position(episode: PodcastEpisode) -> None:
216            if episode.fully_played is not None or episode.resume_position_ms:
217                # provider supports resume info, we can skip
218                return
219            # for providers that do not natively support providing resume info,
220            # we fallback to the playlog db table
221            resume_info_db_row = await self.mass.music.database.get_row(
222                DB_TABLE_PLAYLOG,
223                {
224                    "item_id": episode.item_id,
225                    "provider": prov.instance_id,
226                    "media_type": MediaType.PODCAST_EPISODE,
227                },
228            )
229            if resume_info_db_row is None:
230                return
231            if resume_info_db_row["seconds_played"]:
232                episode.resume_position_ms = int(resume_info_db_row["seconds_played"] * 1000)
233            if resume_info_db_row["fully_played"] is not None:
234                episode.fully_played = resume_info_db_row["fully_played"]
235
236        # grab the episodes from the provider
237        # note that we do not cache any of this because its
238        # always a rather small list and we want fresh resume info
239        async for item in prov.get_podcast_episodes(item_id):
240            await set_resume_position(item)
241            yield item
242
243    async def radio_mode_base_tracks(
244        self,
245        item: Podcast,
246        preferred_provider_instances: list[str] | None = None,
247    ) -> list[Track]:
248        """
249        Get the list of base tracks from the controller used to calculate the dynamic radio.
250
251        :param item: The Podcast to get base tracks for.
252        :param preferred_provider_instances: List of preferred provider instance IDs to use.
253        """
254        msg = "Dynamic tracks not supported for Podcast MediaItem"
255        raise NotImplementedError(msg)
256
257    async def match_provider(
258        self, db_podcast: Podcast, provider: MusicProvider, strict: bool = True
259    ) -> list[ProviderMapping]:
260        """
261        Try to find match on (streaming) provider for the provided (database) podcast.
262
263        This is used to link objects of different providers/qualities together.
264        """
265        self.logger.debug(
266            "Trying to match podcast %s on provider %s",
267            db_podcast.name,
268            provider.name,
269        )
270        matches: list[ProviderMapping] = []
271        search_str = db_podcast.name
272        search_result = await self.search(search_str, provider.instance_id)
273        for search_result_item in search_result:
274            if not search_result_item.available:
275                continue
276            if not compare_media_item(db_podcast, search_result_item, strict=strict):
277                continue
278            # we must fetch the full podcast version, search results can be simplified objects
279            prov_podcast = await self.get_provider_item(
280                search_result_item.item_id,
281                search_result_item.provider,
282                fallback=search_result_item,
283            )
284            if compare_podcast(db_podcast, prov_podcast, strict=strict):
285                # 100% match
286                matches.extend(prov_podcast.provider_mappings)
287        if not matches:
288            self.logger.debug(
289                "Could not find match for Podcast %s on provider %s",
290                db_podcast.name,
291                provider.name,
292            )
293        return matches
294
295    async def match_providers(self, db_podcast: Podcast) -> None:
296        """Try to find match on all (streaming) providers for the provided (database) podcast.
297
298        This is used to link objects of different providers/qualities together.
299        """
300        if db_podcast.provider != "library":
301            return  # Matching only supported for database items
302
303        # try to find match on all providers
304        cur_provider_domains = {x.provider_domain for x in db_podcast.provider_mappings}
305        for provider in self.mass.music.providers:
306            if provider.domain in cur_provider_domains:
307                continue
308            if ProviderFeature.SEARCH not in provider.supported_features:
309                continue
310            if not provider.library_supported(MediaType.PODCAST):
311                continue
312            if not provider.is_streaming_provider:
313                # matching on unique providers is pointless as they push (all) their content to MA
314                continue
315            if match := await self.match_provider(db_podcast, provider):
316                # 100% match, we update the db with the additional provider mapping(s)
317                await self.add_provider_mappings(db_podcast.item_id, match)
318                cur_provider_domains.add(provider.domain)
319