/
/
/
1"""Manage MediaItems of type Podcast."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, Any
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
10from music_assistant_models.media_items import Podcast, PodcastEpisode, ProviderMapping, UniqueList
11
12from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
13from music_assistant.controllers.media.base import MediaControllerBase
14from music_assistant.helpers.compare import (
15 compare_media_item,
16 compare_podcast,
17 create_safe_string,
18 loose_compare_strings,
19)
20from music_assistant.helpers.database import UNSET
21from music_assistant.helpers.json import serialize_to_json
22from music_assistant.models.music_provider import MusicProvider
23
24if TYPE_CHECKING:
25 from music_assistant_models.media_items import Track
26
27 from music_assistant import MusicAssistant
28
29
30class PodcastsController(MediaControllerBase[Podcast]):
31 """Controller managing MediaItems of type Podcast."""
32
33 db_table = DB_TABLE_PODCASTS
34 media_type = MediaType.PODCAST
35 item_cls = Podcast
36
37 def __init__(self, mass: MusicAssistant) -> None:
38 """Initialize class."""
39 super().__init__(mass)
40 # register (extra) api handlers
41 api_base = self.api_base
42 self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
43 self.mass.register_api_command(f"music/{api_base}/podcast_episode", self.episode)
44 self.mass.register_api_command(f"music/{api_base}/podcast_versions", self.versions)
45
46 async def library_items(
47 self,
48 favorite: bool | None = None,
49 search: str | None = None,
50 limit: int = 500,
51 offset: int = 0,
52 order_by: str = "sort_name",
53 provider: str | list[str] | None = None,
54 genre: int | list[int] | None = None,
55 **kwargs: Any,
56 ) -> list[Podcast]:
57 """Get in-database podcasts.
58
59 :param favorite: Filter by favorite status.
60 :param search: Filter by search query.
61 :param limit: Maximum number of items to return.
62 :param offset: Number of items to skip.
63 :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
64 :param provider: Filter by provider instance ID (single string or list).
65 :param genre: Filter by genre id(s).
66 """
67 result = await self.get_library_items_by_query(
68 favorite=favorite,
69 search=search,
70 genre_ids=genre,
71 limit=limit,
72 offset=offset,
73 order_by=order_by,
74 provider_filter=self._ensure_provider_filter(provider),
75 in_library_only=True,
76 )
77 if search and len(result) < 25 and not offset:
78 # append publisher items to result
79 extra_query_parts: list[str] = [
80 "WHERE podcasts.publisher LIKE :search",
81 ]
82 extra_query_params: dict[str, Any] = {
83 "search": f"%{search}%",
84 }
85 return result + await self.get_library_items_by_query(
86 favorite=favorite,
87 search=None,
88 genre_ids=genre,
89 limit=limit,
90 order_by=order_by,
91 provider_filter=self._ensure_provider_filter(provider),
92 extra_query_parts=extra_query_parts,
93 extra_query_params=extra_query_params,
94 in_library_only=True,
95 )
96 return result
97
98 async def episodes(
99 self,
100 item_id: str,
101 provider_instance_id_or_domain: str,
102 ) -> AsyncGenerator[PodcastEpisode, None]:
103 """Return podcast episodes for the given provider podcast id."""
104 # always check if we have a library item for this podcast
105 if provider_instance_id_or_domain == "library":
106 library_podcast = await self.get_library_item(item_id)
107 if not library_podcast:
108 raise MediaNotFoundError(f"Podcast {item_id} not found in library")
109 provider_instance_id_or_domain, item_id = self._select_provider_id(library_podcast)
110 # podcast episodes are not stored in the db/library
111 # so we always need to fetch them from the provider
112 async for episode in self._get_provider_podcast_episodes(
113 item_id, provider_instance_id_or_domain
114 ):
115 yield episode
116
117 async def episode(
118 self,
119 item_id: str,
120 provider_instance_id_or_domain: str,
121 ) -> PodcastEpisode:
122 """Return single podcast episode by the given provider podcast id."""
123 prov = self.mass.get_provider(provider_instance_id_or_domain)
124 if not isinstance(prov, MusicProvider):
125 raise ProviderUnavailableError("Provider not found")
126 return await prov.get_podcast_episode(item_id)
127
128 async def versions(
129 self,
130 item_id: str,
131 provider_instance_id_or_domain: str,
132 ) -> UniqueList[Podcast]:
133 """Return all versions of an podcast we can find on all providers."""
134 podcast = await self.get_provider_item(item_id, provider_instance_id_or_domain)
135 search_query = podcast.name
136 result: UniqueList[Podcast] = UniqueList()
137 for provider_id in self.mass.music.get_unique_providers():
138 provider = self.mass.get_provider(provider_id)
139 if not isinstance(provider, MusicProvider):
140 continue
141 if not provider.library_supported(MediaType.PODCAST):
142 continue
143 result.extend(
144 prov_item
145 for prov_item in await self.search(search_query, provider_id)
146 if loose_compare_strings(podcast.name, prov_item.name)
147 # make sure that the 'base' version is NOT included
148 and not podcast.provider_mappings.intersection(prov_item.provider_mappings)
149 )
150 return result
151
152 async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
153 """Add a new record to the database."""
154 db_id = await self.mass.music.database.insert(
155 self.db_table,
156 {
157 "name": item.name,
158 "sort_name": item.sort_name,
159 "version": item.version,
160 "favorite": item.favorite,
161 "metadata": serialize_to_json(item.metadata),
162 "external_ids": serialize_to_json(item.external_ids),
163 "publisher": item.publisher,
164 "total_episodes": item.total_episodes or 0,
165 "search_name": create_safe_string(item.name, True, True),
166 "search_sort_name": create_safe_string(item.sort_name or "", True, True),
167 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
168 },
169 )
170 # update/set provider_mappings table
171 await self.set_provider_mappings(db_id, item.provider_mappings)
172 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
173 return db_id
174
175 async def _update_library_item(
176 self, item_id: str | int, update: Podcast, overwrite: bool = False
177 ) -> None:
178 """Update existing record in the database."""
179 db_id = int(item_id) # ensure integer
180 cur_item = await self.get_library_item(db_id)
181 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
182 cur_item.external_ids.update(update.external_ids)
183 name = update.name if overwrite else cur_item.name
184 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
185 await self.mass.music.database.update(
186 self.db_table,
187 {"item_id": db_id},
188 {
189 "name": name,
190 "sort_name": sort_name,
191 "version": update.version if overwrite else cur_item.version or update.version,
192 "metadata": serialize_to_json(metadata),
193 "external_ids": serialize_to_json(
194 update.external_ids if overwrite else cur_item.external_ids
195 ),
196 "publisher": cur_item.publisher or update.publisher,
197 "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
198 "search_name": create_safe_string(name, True, True),
199 "search_sort_name": create_safe_string(sort_name or "", True, True),
200 "timestamp_added": int(update.date_added.timestamp())
201 if update.date_added
202 else UNSET,
203 },
204 )
205 # update/set provider_mappings table
206 provider_mappings = (
207 update.provider_mappings
208 if overwrite
209 else {*update.provider_mappings, *cur_item.provider_mappings}
210 )
211 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
212 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
213
214 async def _get_provider_podcast_episodes(
215 self, item_id: str, provider_instance_id_or_domain: str
216 ) -> AsyncGenerator[PodcastEpisode, None]:
217 """Return podcast episodes for the given provider podcast id."""
218 prov = self.mass.get_provider(provider_instance_id_or_domain)
219 if not isinstance(prov, MusicProvider):
220 return
221
222 async def set_resume_position(episode: PodcastEpisode) -> None:
223 if episode.fully_played is not None or episode.resume_position_ms:
224 # provider supports resume info, we can skip
225 return
226 # for providers that do not natively support providing resume info,
227 # we fallback to the playlog db table
228 resume_info_db_row = await self.mass.music.database.get_row(
229 DB_TABLE_PLAYLOG,
230 {
231 "item_id": episode.item_id,
232 "provider": prov.instance_id,
233 "media_type": MediaType.PODCAST_EPISODE,
234 },
235 )
236 if resume_info_db_row is None:
237 return
238 if resume_info_db_row["seconds_played"]:
239 episode.resume_position_ms = int(resume_info_db_row["seconds_played"] * 1000)
240 if resume_info_db_row["fully_played"] is not None:
241 episode.fully_played = bool(resume_info_db_row["fully_played"])
242
243 # grab the episodes from the provider
244 # note that we do not cache any of this because its
245 # always a rather small list and we want fresh resume info
246 async for item in prov.get_podcast_episodes(item_id):
247 await set_resume_position(item)
248 yield item
249
250 async def radio_mode_base_tracks(
251 self,
252 item: Podcast,
253 preferred_provider_instances: list[str] | None = None,
254 ) -> list[Track]:
255 """
256 Get the list of base tracks from the controller used to calculate the dynamic radio.
257
258 :param item: The Podcast to get base tracks for.
259 :param preferred_provider_instances: List of preferred provider instance IDs to use.
260 """
261 msg = "Dynamic tracks not supported for Podcast MediaItem"
262 raise NotImplementedError(msg)
263
264 async def match_provider(
265 self, db_podcast: Podcast, provider: MusicProvider, strict: bool = True
266 ) -> list[ProviderMapping]:
267 """
268 Try to find match on (streaming) provider for the provided (database) podcast.
269
270 This is used to link objects of different providers/qualities together.
271 """
272 self.logger.debug(
273 "Trying to match podcast %s on provider %s",
274 db_podcast.name,
275 provider.name,
276 )
277 matches: list[ProviderMapping] = []
278 search_str = db_podcast.name
279 search_result = await self.search(search_str, provider.instance_id)
280 for search_result_item in search_result:
281 if not search_result_item.available:
282 continue
283 if not compare_media_item(db_podcast, search_result_item, strict=strict):
284 continue
285 # we must fetch the full podcast version, search results can be simplified objects
286 prov_podcast = await self.get_provider_item(
287 search_result_item.item_id,
288 search_result_item.provider,
289 fallback=search_result_item,
290 )
291 if compare_podcast(db_podcast, prov_podcast, strict=strict):
292 # 100% match
293 matches.extend(prov_podcast.provider_mappings)
294 if not matches:
295 self.logger.debug(
296 "Could not find match for Podcast %s on provider %s",
297 db_podcast.name,
298 provider.name,
299 )
300 return matches
301
302 async def match_providers(self, db_podcast: Podcast) -> None:
303 """Try to find match on all (streaming) providers for the provided (database) podcast.
304
305 This is used to link objects of different providers/qualities together.
306 """
307 if db_podcast.provider != "library":
308 return # Matching only supported for database items
309
310 # try to find match on all providers
311 cur_provider_domains = {x.provider_domain for x in db_podcast.provider_mappings}
312 for provider in self.mass.music.providers:
313 if provider.domain in cur_provider_domains:
314 continue
315 if ProviderFeature.SEARCH not in provider.supported_features:
316 continue
317 if not provider.library_supported(MediaType.PODCAST):
318 continue
319 if not provider.is_streaming_provider:
320 # matching on unique providers is pointless as they push (all) their content to MA
321 continue
322 if match := await self.match_provider(db_podcast, provider):
323 # 100% match, we update the db with the additional provider mapping(s)
324 await self.add_provider_mappings(db_podcast.item_id, match)
325 cur_provider_domains.add(provider.domain)
326