/
/
/
1"""Manage MediaItems of type Podcast."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, Any
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
10from music_assistant_models.media_items import Podcast, PodcastEpisode, ProviderMapping, UniqueList
11
12from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
13from music_assistant.controllers.media.base import MediaControllerBase
14from music_assistant.helpers.compare import (
15 compare_media_item,
16 compare_podcast,
17 create_safe_string,
18 loose_compare_strings,
19)
20from music_assistant.helpers.database import UNSET
21from music_assistant.helpers.json import serialize_to_json
22from music_assistant.models.music_provider import MusicProvider
23
24if TYPE_CHECKING:
25 from music_assistant_models.media_items import Track
26
27 from music_assistant import MusicAssistant
28
29
30class PodcastsController(MediaControllerBase[Podcast]):
31 """Controller managing MediaItems of type Podcast."""
32
33 db_table = DB_TABLE_PODCASTS
34 media_type = MediaType.PODCAST
35 item_cls = Podcast
36
37 def __init__(self, mass: MusicAssistant) -> None:
38 """Initialize class."""
39 super().__init__(mass)
40 # register (extra) api handlers
41 api_base = self.api_base
42 self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
43 self.mass.register_api_command(f"music/{api_base}/podcast_episode", self.episode)
44 self.mass.register_api_command(f"music/{api_base}/podcast_versions", self.versions)
45
46 async def library_items(
47 self,
48 favorite: bool | None = None,
49 search: str | None = None,
50 limit: int = 500,
51 offset: int = 0,
52 order_by: str = "sort_name",
53 provider: str | list[str] | None = None,
54 ) -> list[Podcast]:
55 """Get in-database podcasts.
56
57 :param favorite: Filter by favorite status.
58 :param search: Filter by search query.
59 :param limit: Maximum number of items to return.
60 :param offset: Number of items to skip.
61 :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
62 :param provider: Filter by provider instance ID (single string or list).
63 """
64 result = await self.get_library_items_by_query(
65 favorite=favorite,
66 search=search,
67 limit=limit,
68 offset=offset,
69 order_by=order_by,
70 provider_filter=self._ensure_provider_filter(provider),
71 )
72 if search and len(result) < 25 and not offset:
73 # append publisher items to result
74 extra_query_parts: list[str] = [
75 "WHERE podcasts.publisher LIKE :search",
76 ]
77 extra_query_params: dict[str, Any] = {
78 "search": f"%{search}%",
79 }
80 return result + await self.get_library_items_by_query(
81 favorite=favorite,
82 search=None,
83 limit=limit,
84 order_by=order_by,
85 provider_filter=self._ensure_provider_filter(provider),
86 extra_query_parts=extra_query_parts,
87 extra_query_params=extra_query_params,
88 )
89 return result
90
91 async def episodes(
92 self,
93 item_id: str,
94 provider_instance_id_or_domain: str,
95 ) -> AsyncGenerator[PodcastEpisode, None]:
96 """Return podcast episodes for the given provider podcast id."""
97 # always check if we have a library item for this podcast
98 if provider_instance_id_or_domain == "library":
99 library_podcast = await self.get_library_item(item_id)
100 if not library_podcast:
101 raise MediaNotFoundError(f"Podcast {item_id} not found in library")
102 provider_instance_id_or_domain, item_id = self._select_provider_id(library_podcast)
103 # podcast episodes are not stored in the db/library
104 # so we always need to fetch them from the provider
105 async for episode in self._get_provider_podcast_episodes(
106 item_id, provider_instance_id_or_domain
107 ):
108 yield episode
109
110 async def episode(
111 self,
112 item_id: str,
113 provider_instance_id_or_domain: str,
114 ) -> PodcastEpisode:
115 """Return single podcast episode by the given provider podcast id."""
116 prov = self.mass.get_provider(provider_instance_id_or_domain)
117 if not isinstance(prov, MusicProvider):
118 raise ProviderUnavailableError("Provider not found")
119 return await prov.get_podcast_episode(item_id)
120
121 async def versions(
122 self,
123 item_id: str,
124 provider_instance_id_or_domain: str,
125 ) -> UniqueList[Podcast]:
126 """Return all versions of an podcast we can find on all providers."""
127 podcast = await self.get_provider_item(item_id, provider_instance_id_or_domain)
128 search_query = podcast.name
129 result: UniqueList[Podcast] = UniqueList()
130 for provider_id in self.mass.music.get_unique_providers():
131 provider = self.mass.get_provider(provider_id)
132 if not isinstance(provider, MusicProvider):
133 continue
134 if not provider.library_supported(MediaType.PODCAST):
135 continue
136 result.extend(
137 prov_item
138 for prov_item in await self.search(search_query, provider_id)
139 if loose_compare_strings(podcast.name, prov_item.name)
140 # make sure that the 'base' version is NOT included
141 and not podcast.provider_mappings.intersection(prov_item.provider_mappings)
142 )
143 return result
144
145 async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
146 """Add a new record to the database."""
147 db_id = await self.mass.music.database.insert(
148 self.db_table,
149 {
150 "name": item.name,
151 "sort_name": item.sort_name,
152 "version": item.version,
153 "favorite": item.favorite,
154 "metadata": serialize_to_json(item.metadata),
155 "external_ids": serialize_to_json(item.external_ids),
156 "publisher": item.publisher,
157 "total_episodes": item.total_episodes or 0,
158 "search_name": create_safe_string(item.name, True, True),
159 "search_sort_name": create_safe_string(item.sort_name or "", True, True),
160 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
161 },
162 )
163 # update/set provider_mappings table
164 await self.set_provider_mappings(db_id, item.provider_mappings)
165 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
166 return db_id
167
168 async def _update_library_item(
169 self, item_id: str | int, update: Podcast, overwrite: bool = False
170 ) -> None:
171 """Update existing record in the database."""
172 db_id = int(item_id) # ensure integer
173 cur_item = await self.get_library_item(db_id)
174 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
175 cur_item.external_ids.update(update.external_ids)
176 name = update.name if overwrite else cur_item.name
177 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
178 await self.mass.music.database.update(
179 self.db_table,
180 {"item_id": db_id},
181 {
182 "name": name,
183 "sort_name": sort_name,
184 "version": update.version if overwrite else cur_item.version or update.version,
185 "metadata": serialize_to_json(metadata),
186 "external_ids": serialize_to_json(
187 update.external_ids if overwrite else cur_item.external_ids
188 ),
189 "publisher": cur_item.publisher or update.publisher,
190 "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
191 "search_name": create_safe_string(name, True, True),
192 "search_sort_name": create_safe_string(sort_name or "", True, True),
193 "timestamp_added": int(update.date_added.timestamp())
194 if update.date_added
195 else UNSET,
196 },
197 )
198 # update/set provider_mappings table
199 provider_mappings = (
200 update.provider_mappings
201 if overwrite
202 else {*update.provider_mappings, *cur_item.provider_mappings}
203 )
204 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
205 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
206
207 async def _get_provider_podcast_episodes(
208 self, item_id: str, provider_instance_id_or_domain: str
209 ) -> AsyncGenerator[PodcastEpisode, None]:
210 """Return podcast episodes for the given provider podcast id."""
211 prov = self.mass.get_provider(provider_instance_id_or_domain)
212 if not isinstance(prov, MusicProvider):
213 return
214
215 async def set_resume_position(episode: PodcastEpisode) -> None:
216 if episode.fully_played is not None or episode.resume_position_ms:
217 # provider supports resume info, we can skip
218 return
219 # for providers that do not natively support providing resume info,
220 # we fallback to the playlog db table
221 resume_info_db_row = await self.mass.music.database.get_row(
222 DB_TABLE_PLAYLOG,
223 {
224 "item_id": episode.item_id,
225 "provider": prov.instance_id,
226 "media_type": MediaType.PODCAST_EPISODE,
227 },
228 )
229 if resume_info_db_row is None:
230 return
231 if resume_info_db_row["seconds_played"]:
232 episode.resume_position_ms = int(resume_info_db_row["seconds_played"] * 1000)
233 if resume_info_db_row["fully_played"] is not None:
234 episode.fully_played = resume_info_db_row["fully_played"]
235
236 # grab the episodes from the provider
237 # note that we do not cache any of this because its
238 # always a rather small list and we want fresh resume info
239 async for item in prov.get_podcast_episodes(item_id):
240 await set_resume_position(item)
241 yield item
242
243 async def radio_mode_base_tracks(
244 self,
245 item: Podcast,
246 preferred_provider_instances: list[str] | None = None,
247 ) -> list[Track]:
248 """
249 Get the list of base tracks from the controller used to calculate the dynamic radio.
250
251 :param item: The Podcast to get base tracks for.
252 :param preferred_provider_instances: List of preferred provider instance IDs to use.
253 """
254 msg = "Dynamic tracks not supported for Podcast MediaItem"
255 raise NotImplementedError(msg)
256
257 async def match_provider(
258 self, db_podcast: Podcast, provider: MusicProvider, strict: bool = True
259 ) -> list[ProviderMapping]:
260 """
261 Try to find match on (streaming) provider for the provided (database) podcast.
262
263 This is used to link objects of different providers/qualities together.
264 """
265 self.logger.debug(
266 "Trying to match podcast %s on provider %s",
267 db_podcast.name,
268 provider.name,
269 )
270 matches: list[ProviderMapping] = []
271 search_str = db_podcast.name
272 search_result = await self.search(search_str, provider.instance_id)
273 for search_result_item in search_result:
274 if not search_result_item.available:
275 continue
276 if not compare_media_item(db_podcast, search_result_item, strict=strict):
277 continue
278 # we must fetch the full podcast version, search results can be simplified objects
279 prov_podcast = await self.get_provider_item(
280 search_result_item.item_id,
281 search_result_item.provider,
282 fallback=search_result_item,
283 )
284 if compare_podcast(db_podcast, prov_podcast, strict=strict):
285 # 100% match
286 matches.extend(prov_podcast.provider_mappings)
287 if not matches:
288 self.logger.debug(
289 "Could not find match for Podcast %s on provider %s",
290 db_podcast.name,
291 provider.name,
292 )
293 return matches
294
295 async def match_providers(self, db_podcast: Podcast) -> None:
296 """Try to find match on all (streaming) providers for the provided (database) podcast.
297
298 This is used to link objects of different providers/qualities together.
299 """
300 if db_podcast.provider != "library":
301 return # Matching only supported for database items
302
303 # try to find match on all providers
304 cur_provider_domains = {x.provider_domain for x in db_podcast.provider_mappings}
305 for provider in self.mass.music.providers:
306 if provider.domain in cur_provider_domains:
307 continue
308 if ProviderFeature.SEARCH not in provider.supported_features:
309 continue
310 if not provider.library_supported(MediaType.PODCAST):
311 continue
312 if not provider.is_streaming_provider:
313 # matching on unique providers is pointless as they push (all) their content to MA
314 continue
315 if match := await self.match_provider(db_podcast, provider):
316 # 100% match, we update the db with the additional provider mapping(s)
317 await self.add_provider_mappings(db_podcast.item_id, match)
318 cur_provider_domains.add(provider.domain)
319