/
/
/
1"""Manage MediaItems of type Podcast."""
2
3from __future__ import annotations
4
5from collections.abc import AsyncGenerator
6from typing import TYPE_CHECKING, Any
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.errors import MediaNotFoundError, ProviderUnavailableError
10from music_assistant_models.media_items import Podcast, PodcastEpisode, ProviderMapping, UniqueList
11
12from music_assistant.constants import DB_TABLE_PLAYLOG, DB_TABLE_PODCASTS
13from music_assistant.controllers.media.base import MediaControllerBase
14from music_assistant.helpers.compare import (
15 compare_media_item,
16 compare_podcast,
17 create_safe_string,
18 loose_compare_strings,
19)
20from music_assistant.helpers.database import UNSET
21from music_assistant.helpers.json import serialize_to_json
22from music_assistant.models.music_provider import MusicProvider
23
24if TYPE_CHECKING:
25 from music_assistant_models.media_items import Track
26
27 from music_assistant import MusicAssistant
28
29
30class PodcastsController(MediaControllerBase[Podcast]):
31 """Controller managing MediaItems of type Podcast."""
32
33 db_table = DB_TABLE_PODCASTS
34 media_type = MediaType.PODCAST
35 item_cls = Podcast
36
37 def __init__(self, mass: MusicAssistant) -> None:
38 """Initialize class."""
39 super().__init__(mass)
40 # register (extra) api handlers
41 api_base = self.api_base
42 self.mass.register_api_command(f"music/{api_base}/podcast_episodes", self.episodes)
43 self.mass.register_api_command(f"music/{api_base}/podcast_episode", self.episode)
44 self.mass.register_api_command(f"music/{api_base}/podcast_versions", self.versions)
45
46 async def library_items(
47 self,
48 favorite: bool | None = None,
49 search: str | None = None,
50 limit: int = 500,
51 offset: int = 0,
52 order_by: str = "sort_name",
53 provider: str | list[str] | None = None,
54 **kwargs: Any,
55 ) -> list[Podcast]:
56 """Get in-database podcasts.
57
58 :param favorite: Filter by favorite status.
59 :param search: Filter by search query.
60 :param limit: Maximum number of items to return.
61 :param offset: Number of items to skip.
62 :param order_by: Order by field (e.g. 'sort_name', 'timestamp_added').
63 :param provider: Filter by provider instance ID (single string or list).
64 """
65 result = await self.get_library_items_by_query(
66 favorite=favorite,
67 search=search,
68 limit=limit,
69 offset=offset,
70 order_by=order_by,
71 provider_filter=self._ensure_provider_filter(provider),
72 in_library_only=True,
73 )
74 if search and len(result) < 25 and not offset:
75 # append publisher items to result
76 extra_query_parts: list[str] = [
77 "WHERE podcasts.publisher LIKE :search",
78 ]
79 extra_query_params: dict[str, Any] = {
80 "search": f"%{search}%",
81 }
82 return result + await self.get_library_items_by_query(
83 favorite=favorite,
84 search=None,
85 limit=limit,
86 order_by=order_by,
87 provider_filter=self._ensure_provider_filter(provider),
88 extra_query_parts=extra_query_parts,
89 extra_query_params=extra_query_params,
90 in_library_only=True,
91 )
92 return result
93
94 async def episodes(
95 self,
96 item_id: str,
97 provider_instance_id_or_domain: str,
98 ) -> AsyncGenerator[PodcastEpisode, None]:
99 """Return podcast episodes for the given provider podcast id."""
100 # always check if we have a library item for this podcast
101 if provider_instance_id_or_domain == "library":
102 library_podcast = await self.get_library_item(item_id)
103 if not library_podcast:
104 raise MediaNotFoundError(f"Podcast {item_id} not found in library")
105 provider_instance_id_or_domain, item_id = self._select_provider_id(library_podcast)
106 # podcast episodes are not stored in the db/library
107 # so we always need to fetch them from the provider
108 async for episode in self._get_provider_podcast_episodes(
109 item_id, provider_instance_id_or_domain
110 ):
111 yield episode
112
113 async def episode(
114 self,
115 item_id: str,
116 provider_instance_id_or_domain: str,
117 ) -> PodcastEpisode:
118 """Return single podcast episode by the given provider podcast id."""
119 prov = self.mass.get_provider(provider_instance_id_or_domain)
120 if not isinstance(prov, MusicProvider):
121 raise ProviderUnavailableError("Provider not found")
122 return await prov.get_podcast_episode(item_id)
123
124 async def versions(
125 self,
126 item_id: str,
127 provider_instance_id_or_domain: str,
128 ) -> UniqueList[Podcast]:
129 """Return all versions of an podcast we can find on all providers."""
130 podcast = await self.get_provider_item(item_id, provider_instance_id_or_domain)
131 search_query = podcast.name
132 result: UniqueList[Podcast] = UniqueList()
133 for provider_id in self.mass.music.get_unique_providers():
134 provider = self.mass.get_provider(provider_id)
135 if not isinstance(provider, MusicProvider):
136 continue
137 if not provider.library_supported(MediaType.PODCAST):
138 continue
139 result.extend(
140 prov_item
141 for prov_item in await self.search(search_query, provider_id)
142 if loose_compare_strings(podcast.name, prov_item.name)
143 # make sure that the 'base' version is NOT included
144 and not podcast.provider_mappings.intersection(prov_item.provider_mappings)
145 )
146 return result
147
148 async def _add_library_item(self, item: Podcast, overwrite_existing: bool = False) -> int:
149 """Add a new record to the database."""
150 db_id = await self.mass.music.database.insert(
151 self.db_table,
152 {
153 "name": item.name,
154 "sort_name": item.sort_name,
155 "version": item.version,
156 "favorite": item.favorite,
157 "metadata": serialize_to_json(item.metadata),
158 "external_ids": serialize_to_json(item.external_ids),
159 "publisher": item.publisher,
160 "total_episodes": item.total_episodes or 0,
161 "search_name": create_safe_string(item.name, True, True),
162 "search_sort_name": create_safe_string(item.sort_name or "", True, True),
163 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
164 },
165 )
166 # update/set provider_mappings table
167 await self.set_provider_mappings(db_id, item.provider_mappings)
168 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
169 return db_id
170
171 async def _update_library_item(
172 self, item_id: str | int, update: Podcast, overwrite: bool = False
173 ) -> None:
174 """Update existing record in the database."""
175 db_id = int(item_id) # ensure integer
176 cur_item = await self.get_library_item(db_id)
177 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
178 cur_item.external_ids.update(update.external_ids)
179 name = update.name if overwrite else cur_item.name
180 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
181 await self.mass.music.database.update(
182 self.db_table,
183 {"item_id": db_id},
184 {
185 "name": name,
186 "sort_name": sort_name,
187 "version": update.version if overwrite else cur_item.version or update.version,
188 "metadata": serialize_to_json(metadata),
189 "external_ids": serialize_to_json(
190 update.external_ids if overwrite else cur_item.external_ids
191 ),
192 "publisher": cur_item.publisher or update.publisher,
193 "total_episodes": cur_item.total_episodes or update.total_episodes or 0,
194 "search_name": create_safe_string(name, True, True),
195 "search_sort_name": create_safe_string(sort_name or "", True, True),
196 "timestamp_added": int(update.date_added.timestamp())
197 if update.date_added
198 else UNSET,
199 },
200 )
201 # update/set provider_mappings table
202 provider_mappings = (
203 update.provider_mappings
204 if overwrite
205 else {*update.provider_mappings, *cur_item.provider_mappings}
206 )
207 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
208 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
209
210 async def _get_provider_podcast_episodes(
211 self, item_id: str, provider_instance_id_or_domain: str
212 ) -> AsyncGenerator[PodcastEpisode, None]:
213 """Return podcast episodes for the given provider podcast id."""
214 prov = self.mass.get_provider(provider_instance_id_or_domain)
215 if not isinstance(prov, MusicProvider):
216 return
217
218 async def set_resume_position(episode: PodcastEpisode) -> None:
219 if episode.fully_played is not None or episode.resume_position_ms:
220 # provider supports resume info, we can skip
221 return
222 # for providers that do not natively support providing resume info,
223 # we fallback to the playlog db table
224 resume_info_db_row = await self.mass.music.database.get_row(
225 DB_TABLE_PLAYLOG,
226 {
227 "item_id": episode.item_id,
228 "provider": prov.instance_id,
229 "media_type": MediaType.PODCAST_EPISODE,
230 },
231 )
232 if resume_info_db_row is None:
233 return
234 if resume_info_db_row["seconds_played"]:
235 episode.resume_position_ms = int(resume_info_db_row["seconds_played"] * 1000)
236 if resume_info_db_row["fully_played"] is not None:
237 episode.fully_played = resume_info_db_row["fully_played"]
238
239 # grab the episodes from the provider
240 # note that we do not cache any of this because its
241 # always a rather small list and we want fresh resume info
242 async for item in prov.get_podcast_episodes(item_id):
243 await set_resume_position(item)
244 yield item
245
246 async def radio_mode_base_tracks(
247 self,
248 item: Podcast,
249 preferred_provider_instances: list[str] | None = None,
250 ) -> list[Track]:
251 """
252 Get the list of base tracks from the controller used to calculate the dynamic radio.
253
254 :param item: The Podcast to get base tracks for.
255 :param preferred_provider_instances: List of preferred provider instance IDs to use.
256 """
257 msg = "Dynamic tracks not supported for Podcast MediaItem"
258 raise NotImplementedError(msg)
259
260 async def match_provider(
261 self, db_podcast: Podcast, provider: MusicProvider, strict: bool = True
262 ) -> list[ProviderMapping]:
263 """
264 Try to find match on (streaming) provider for the provided (database) podcast.
265
266 This is used to link objects of different providers/qualities together.
267 """
268 self.logger.debug(
269 "Trying to match podcast %s on provider %s",
270 db_podcast.name,
271 provider.name,
272 )
273 matches: list[ProviderMapping] = []
274 search_str = db_podcast.name
275 search_result = await self.search(search_str, provider.instance_id)
276 for search_result_item in search_result:
277 if not search_result_item.available:
278 continue
279 if not compare_media_item(db_podcast, search_result_item, strict=strict):
280 continue
281 # we must fetch the full podcast version, search results can be simplified objects
282 prov_podcast = await self.get_provider_item(
283 search_result_item.item_id,
284 search_result_item.provider,
285 fallback=search_result_item,
286 )
287 if compare_podcast(db_podcast, prov_podcast, strict=strict):
288 # 100% match
289 matches.extend(prov_podcast.provider_mappings)
290 if not matches:
291 self.logger.debug(
292 "Could not find match for Podcast %s on provider %s",
293 db_podcast.name,
294 provider.name,
295 )
296 return matches
297
298 async def match_providers(self, db_podcast: Podcast) -> None:
299 """Try to find match on all (streaming) providers for the provided (database) podcast.
300
301 This is used to link objects of different providers/qualities together.
302 """
303 if db_podcast.provider != "library":
304 return # Matching only supported for database items
305
306 # try to find match on all providers
307 cur_provider_domains = {x.provider_domain for x in db_podcast.provider_mappings}
308 for provider in self.mass.music.providers:
309 if provider.domain in cur_provider_domains:
310 continue
311 if ProviderFeature.SEARCH not in provider.supported_features:
312 continue
313 if not provider.library_supported(MediaType.PODCAST):
314 continue
315 if not provider.is_streaming_provider:
316 # matching on unique providers is pointless as they push (all) their content to MA
317 continue
318 if match := await self.match_provider(db_podcast, provider):
319 # 100% match, we update the db with the additional provider mapping(s)
320 await self.add_provider_mappings(db_podcast.item_id, match)
321 cur_provider_domains.add(provider.domain)
322