/
/
/
1"""Manage MediaItems of type Radio."""
2
3from __future__ import annotations
4
5import asyncio
6from typing import TYPE_CHECKING
7
8from music_assistant_models.enums import MediaType, ProviderFeature
9from music_assistant_models.media_items import ProviderMapping, Radio, Track
10
11from music_assistant.constants import DB_TABLE_RADIOS
12from music_assistant.helpers.compare import (
13 compare_media_item,
14 compare_radio,
15 create_safe_string,
16 loose_compare_strings,
17)
18from music_assistant.helpers.database import UNSET
19from music_assistant.helpers.json import serialize_to_json
20from music_assistant.models.music_provider import MusicProvider
21
22from .base import MediaControllerBase
23
24if TYPE_CHECKING:
25 from music_assistant import MusicAssistant
26
27
28class RadioController(MediaControllerBase[Radio]):
29 """Controller managing MediaItems of type Radio."""
30
31 db_table = DB_TABLE_RADIOS
32 media_type = MediaType.RADIO
33 item_cls = Radio
34
35 def __init__(self, mass: MusicAssistant) -> None:
36 """Initialize class."""
37 super().__init__(mass)
38 # register (extra) api handlers
39 api_base = self.api_base
40 self.mass.register_api_command(f"music/{api_base}/radio_versions", self.versions)
41
42 async def versions(
43 self,
44 item_id: str,
45 provider_instance_id_or_domain: str,
46 ) -> list[Radio]:
47 """Return all versions of a radio station we can find on all providers."""
48 radio = await self.get(item_id, provider_instance_id_or_domain)
49 # perform a search on all provider(types) to collect all versions/variants
50 all_versions = {
51 prov_item.item_id: prov_item
52 for prov_items in await asyncio.gather(
53 *[
54 self.search(radio.name, provider_domain)
55 for provider_domain in self.mass.music.get_unique_providers()
56 ]
57 )
58 for prov_item in prov_items
59 if loose_compare_strings(radio.name, prov_item.name)
60 }
61 # make sure that the 'base' version is NOT included
62 for prov_version in radio.provider_mappings:
63 all_versions.pop(prov_version.item_id, None)
64
65 # return the aggregated result
66 return list(all_versions.values())
67
68 async def _add_library_item(self, item: Radio, overwrite_existing: bool = False) -> int:
69 """Add a new item record to the database."""
70 assert self.mass.music.database is not None # For type checking
71 db_id = await self.mass.music.database.insert(
72 self.db_table,
73 {
74 "name": item.name,
75 "sort_name": item.sort_name,
76 "favorite": item.favorite,
77 "metadata": serialize_to_json(item.metadata),
78 "external_ids": serialize_to_json(item.external_ids),
79 "search_name": create_safe_string(item.name, True, True),
80 "search_sort_name": create_safe_string(
81 item.sort_name if item.sort_name is not None else "", True, True
82 ),
83 "timestamp_added": int(item.date_added.timestamp()) if item.date_added else UNSET,
84 },
85 )
86 # update/set provider_mappings table
87 await self.set_provider_mappings(db_id, item.provider_mappings)
88 self.logger.debug("added %s to database (id: %s)", item.name, db_id)
89 return db_id
90
91 async def _update_library_item(
92 self, item_id: str | int, update: Radio, overwrite: bool = False
93 ) -> None:
94 """Update existing record in the database."""
95 db_id = int(item_id) # ensure integer
96 cur_item = await self.get_library_item(db_id)
97 metadata = update.metadata if overwrite else cur_item.metadata.update(update.metadata)
98 cur_item.external_ids.update(update.external_ids)
99 match = {"item_id": db_id}
100 name = update.name if overwrite else cur_item.name
101 sort_name = update.sort_name if overwrite else cur_item.sort_name or update.sort_name
102 assert self.mass.music.database is not None # For type checking
103 await self.mass.music.database.update(
104 self.db_table,
105 match,
106 {
107 # always prefer name from updated item here
108 "name": name,
109 "sort_name": sort_name,
110 "metadata": serialize_to_json(metadata),
111 "external_ids": serialize_to_json(
112 update.external_ids if overwrite else cur_item.external_ids
113 ),
114 "search_name": create_safe_string(name, True, True),
115 "search_sort_name": create_safe_string(sort_name or "", True, True),
116 "timestamp_added": int(update.date_added.timestamp())
117 if update.date_added
118 else UNSET,
119 },
120 )
121 # update/set provider_mappings table
122 provider_mappings = (
123 update.provider_mappings
124 if overwrite
125 else {*update.provider_mappings, *cur_item.provider_mappings}
126 )
127 await self.set_provider_mappings(db_id, provider_mappings, overwrite)
128 self.logger.debug("updated %s in database: (id %s)", update.name, db_id)
129
130 async def radio_mode_base_tracks(
131 self,
132 item: Radio,
133 preferred_provider_instances: list[str] | None = None,
134 ) -> list[Track]:
135 """
136 Get the list of base tracks from the controller used to calculate the dynamic radio.
137
138 :param item: The Radio to get base tracks for.
139 :param preferred_provider_instances: List of preferred provider instance IDs to use.
140 """
141 msg = "Dynamic tracks not supported for Radio MediaItem"
142 raise NotImplementedError(msg)
143
144 async def match_provider(
145 self, db_radio: Radio, provider: MusicProvider, strict: bool = True
146 ) -> list[ProviderMapping]:
147 """
148 Try to find match on (streaming) provider for the provided (database) radio.
149
150 This is used to link objects of different providers/qualities together.
151 """
152 self.logger.debug(
153 "Trying to match radio %s on provider %s",
154 db_radio.name,
155 provider.name,
156 )
157 matches: list[ProviderMapping] = []
158 search_str = db_radio.name
159 search_result = await self.search(search_str, provider.instance_id)
160 for search_result_item in search_result:
161 if not search_result_item.available:
162 continue
163 if not compare_media_item(db_radio, search_result_item, strict=strict):
164 continue
165 # we must fetch the full radio version, search results can be simplified objects
166 prov_radio = await self.get_provider_item(
167 search_result_item.item_id,
168 search_result_item.provider,
169 fallback=search_result_item,
170 )
171 if compare_radio(db_radio, prov_radio, strict=strict):
172 # 100% match
173 matches.extend(prov_radio.provider_mappings)
174 if not matches:
175 self.logger.debug(
176 "Could not find match for Radio %s on provider %s",
177 db_radio.name,
178 provider.name,
179 )
180 return matches
181
182 async def match_providers(self, db_radio: Radio) -> None:
183 """Try to find match on all (streaming) providers for the provided (database) radio.
184
185 This is used to link objects of different providers/qualities together.
186 """
187 if db_radio.provider != "library":
188 return # Matching only supported for database items
189
190 # try to find match on all providers
191 cur_provider_domains = {x.provider_domain for x in db_radio.provider_mappings}
192 for provider in self.mass.music.providers:
193 if provider.domain in cur_provider_domains:
194 continue
195 if ProviderFeature.SEARCH not in provider.supported_features:
196 continue
197 if not provider.library_supported(MediaType.RADIO):
198 continue
199 if not provider.is_streaming_provider:
200 # matching on unique providers is pointless as they push (all) their content to MA
201 continue
202 if match := await self.match_provider(db_radio, provider):
203 # 100% match, we update the db with the additional provider mapping(s)
204 await self.add_provider_mappings(db_radio.item_id, match)
205 cur_provider_domains.add(provider.domain)
206