music-assistant-server

18.8 KBPY
provider.py
18.8 KB510 lines • python
1"""Pandora music provider for Music Assistant."""
2
3from __future__ import annotations
4
5import time
6from typing import TYPE_CHECKING, Any
7
8import aiohttp
9from aiohttp import web
10from music_assistant_models.enums import (
11    ContentType,
12    ImageType,
13    MediaType,
14    StreamType,
15)
16from music_assistant_models.errors import (
17    InvalidDataError,
18    LoginFailed,
19    MediaNotFoundError,
20    ProviderUnavailableError,
21)
22from music_assistant_models.media_items import (
23    AudioFormat,
24    MediaItemImage,
25    MediaItemMetadata,
26    ProviderMapping,
27    Radio,
28    SearchResults,
29    UniqueList,
30)
31from music_assistant_models.streamdetails import MultiPartPath, StreamDetails, StreamMetadata
32
33from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
34from music_assistant.controllers.cache import use_cache
35from music_assistant.helpers.compare import compare_strings
36from music_assistant.models.music_provider import MusicProvider
37
38from .constants import (
39    LOGIN_ENDPOINT,
40    PLAYLIST_FRAGMENT_ENDPOINT,
41    STATIONS_ENDPOINT,
42)
43from .helpers import create_auth_headers, get_csrf_token, handle_pandora_error
44
45if TYPE_CHECKING:
46    from collections.abc import AsyncGenerator
47
48
49class PandoraStationSession:
50    """Manages streaming state for a single Pandora station."""
51
52    def __init__(self, station_id: str):
53        """Initialize a new station streaming session.
54
55        Args:
56            station_id: The Pandora station ID.
57        """
58        self.station_id = station_id
59        self.fragments: list[dict[str, Any] | None] = []
60        self.track_map: list[tuple[int, int]] = []
61        self.cumulative_times: list[int] = []
62        self.last_accessed = time.time()
63
64    def get_track_duration(self, music_track_num: int) -> int:
65        """Calculate duration for a specific track index."""
66        if not (0 <= music_track_num < len(self.track_map)):
67            return 0
68        frag_idx, track_idx = self.track_map[music_track_num]
69        if frag_idx >= len(self.fragments) or not (frag := self.fragments[frag_idx]):
70            return 0
71        tracks = frag.get("tracks", [])
72        if track_idx >= len(tracks):
73            return 0
74        return int(tracks[track_idx].get("trackLength", 0))
75
76
77class PandoraProvider(MusicProvider):
78    """Pandora Music Provider."""
79
80    _auth_token: str | None = None
81    _user_id: str | None = None
82    _csrf_token: str | None = None
83    _sessions: dict[str, PandoraStationSession]
84
85    async def handle_async_init(self) -> None:
86        """Handle async initialization of the provider."""
87        self._on_unload_callbacks = []
88        self._sessions = {}
89
90        # Authenticate with Pandora
91        username = str(self.config.get_value(CONF_USERNAME))
92        password = str(self.config.get_value(CONF_PASSWORD))
93
94        await self._authenticate(username, password)
95
96        # Register dynamic stream route
97        self._on_unload_callbacks.append(
98            self.mass.streams.register_dynamic_route(
99                f"/{self.instance_id}_stream", self._handle_stream_request
100            )
101        )
102
103    async def unload(self, is_removed: bool = False) -> None:
104        """Handle unload/close of the provider."""
105        for callback in getattr(self, "_on_unload_callbacks", []):
106            callback()
107        await super().unload(is_removed)
108
109    async def _authenticate(self, username: str, password: str) -> None:
110        """Authenticate with Pandora and get auth token."""
111        try:
112            self._csrf_token = await get_csrf_token(self.mass.http_session)
113
114            login_data = {
115                "username": username,
116                "password": password,
117                "keepLoggedIn": True,
118                "existingAuthToken": None,
119            }
120
121            headers = create_auth_headers(self._csrf_token)
122
123            async with self.mass.http_session.post(
124                LOGIN_ENDPOINT,
125                headers=headers,
126                json=login_data,
127                timeout=aiohttp.ClientTimeout(total=30),
128            ) as response:
129                if response.status != 200:
130                    raise LoginFailed(f"Login request failed with status {response.status}")
131
132                response_data = await response.json()
133                handle_pandora_error(response_data)
134
135                self._auth_token = response_data.get("authToken")
136                if not self._auth_token:
137                    raise LoginFailed("No auth token received from Pandora")
138
139                self._user_id = response_data.get("listenerId")
140                self.logger.info("Successfully authenticated with Pandora")
141
142        except aiohttp.ClientError as err:
143            self.logger.exception("Network error during authentication")
144            raise ProviderUnavailableError(
145                "Unable to connect to Pandora for authentication"
146            ) from err
147
148    async def _api_request(
149        self, method: str, url: str, data: dict[str, Any] | None = None, retry: bool = True
150    ) -> dict[str, Any]:
151        """Make an API request to Pandora.
152
153        :param method: HTTP method (GET, POST, etc.)
154        :param url: API endpoint URL
155        :param data: Optional JSON data to send
156        :param retry: Whether to retry once on 401 authentication errors
157        """
158        if not self._csrf_token or not self._auth_token:
159            raise LoginFailed("Not authenticated with Pandora")
160
161        headers = create_auth_headers(self._csrf_token, self._auth_token)
162
163        try:
164            async with self.mass.http_session.request(
165                method, url, json=data, headers=headers
166            ) as response:
167                # Check status BEFORE parsing JSON
168                if response.status == 401:
169                    if retry:
170                        # Auth token expired, re-authenticate and retry once
171                        username = str(self.config.get_value(CONF_USERNAME))
172                        password = str(self.config.get_value(CONF_PASSWORD))
173                        await self._authenticate(username, password)
174                        return await self._api_request(method, url, data, retry=False)
175                    raise LoginFailed("Pandora authentication failed after retry")
176
177                if response.status == 404:
178                    raise MediaNotFoundError("Resource not found")
179                if response.status >= 500:
180                    raise ProviderUnavailableError("Pandora server error")
181                if response.status >= 400:
182                    raise InvalidDataError(f"Pandora API error: HTTP {response.status}")
183
184                result: dict[str, Any] = await response.json()
185                handle_pandora_error(result)
186                return result
187
188        except aiohttp.ClientError as err:
189            raise ProviderUnavailableError("Unable to connect to Pandora") from err
190        except (ValueError, KeyError) as err:
191            raise InvalidDataError("Invalid response from Pandora") from err
192
193    @use_cache(3600)
194    async def get_radio(self, prov_radio_id: str) -> Radio:
195        """Get single radio station details."""
196        return Radio(
197            item_id=prov_radio_id,
198            provider=self.domain,
199            name=f"Pandora Station {prov_radio_id}",
200            provider_mappings={
201                ProviderMapping(
202                    item_id=prov_radio_id,
203                    provider_domain=self.domain,
204                    provider_instance=self.instance_id,
205                )
206            },
207        )
208
209    async def get_library_radios(self) -> AsyncGenerator[Radio, None]:
210        """Retrieve library/subscribed radio stations from the provider."""
211        response = await self._api_request(
212            "POST",
213            STATIONS_ENDPOINT,
214            data={
215                "pageSize": 250,
216            },
217        )
218
219        stations = response.get("stations", [])
220        self.logger.debug("Retrieved %d stations from Pandora", len(stations))
221
222        for station in stations:
223            station_image = None
224            if art := station.get("art"):
225                art_url = next(
226                    (item["url"] for item in art if item.get("size") == 500),
227                    art[-1]["url"] if art else None,
228                )
229                if art_url:
230                    station_image = MediaItemImage(
231                        type=ImageType.THUMB,
232                        path=art_url,
233                        provider=self.instance_id,
234                        remotely_accessible=True,
235                    )
236            yield Radio(
237                item_id=station["stationId"],
238                provider=self.instance_id,
239                name=station["name"],
240                metadata=MediaItemMetadata(
241                    images=UniqueList([station_image]) if station_image else None,
242                ),
243                provider_mappings={
244                    ProviderMapping(
245                        item_id=station["stationId"],
246                        provider_domain=self.domain,
247                        provider_instance=self.instance_id,
248                    )
249                },
250            )
251
252    async def get_stream_details(self, item_id: str, media_type: MediaType) -> StreamDetails:
253        """Get streamdetails for a radio station."""
254        if media_type != MediaType.RADIO:
255            raise MediaNotFoundError(f"Unsupported media type: {media_type}")
256
257        # Create playlist with 1000 track placeholders for continuous streaming
258        parts = [
259            MultiPartPath(
260                path=f"{self.mass.streams.base_url}/{self.instance_id}_stream?"
261                f"station_id={item_id}&track_num={i}"
262            )
263            for i in range(1000)
264        ]
265        return StreamDetails(
266            provider=self.instance_id,
267            item_id=item_id,
268            audio_format=AudioFormat(
269                content_type=ContentType.AAC,
270            ),
271            media_type=MediaType.RADIO,
272            stream_type=StreamType.HTTP,
273            path=parts,
274            can_seek=False,
275            allow_seek=False,
276            stream_metadata=StreamMetadata(
277                title="Pandora Radio",
278            ),
279            stream_metadata_update_callback=self._update_stream_metadata,
280            stream_metadata_update_interval=5,  # Check every 5 seconds
281        )
282
283    async def _get_fragment_data(
284        self, session: PandoraStationSession, fragment_index: int
285    ) -> dict[str, Any]:
286        """Fetch fragment data from Pandora API."""
287        # Check if already cached in session
288        if fragment_index < len(session.fragments):
289            cached = session.fragments[fragment_index]
290            if cached is not None:
291                return cached
292
293        fragment_data = {
294            "stationId": session.station_id,
295            "isStationStart": fragment_index == 0,
296            "fragmentRequestReason": "Normal",
297            "audioFormat": "aacplus",
298            "startingAtTrackId": None,
299            "onDemandArtistMessageArtistUidHex": None,
300            "onDemandArtistMessageIdHex": None,
301        }
302
303        try:
304            result: dict[str, Any] = await self._api_request(
305                "POST",
306                PLAYLIST_FRAGMENT_ENDPOINT,
307                data=fragment_data,
308            )
309
310            # Store in session cache
311            while len(session.fragments) <= fragment_index:
312                session.fragments.append(None)
313            session.fragments[fragment_index] = result
314
315            tracks = result.get("tracks", [])
316
317            # Calculate starting cumulative time for this fragment
318            if session.cumulative_times:
319                # Get the last music track's end time
320                last_music_track_num = len(session.track_map) - 1
321                last_start = session.cumulative_times[-1]
322                last_duration = session.get_track_duration(last_music_track_num)
323                current_cumulative = last_start + last_duration
324            else:
325                current_cumulative = 0
326
327            for track_idx, track in enumerate(tracks):
328                title = track.get("songTitle", "")
329                # Skip curator messages from the mapping
330                if "Curator Message" not in title and "curator message" not in title.lower():
331                    session.track_map.append((fragment_index, track_idx))
332                    session.cumulative_times.append(current_cumulative)
333
334                    duration = track.get("trackLength", 0)
335                    current_cumulative += duration
336
337            return result
338
339        except MediaNotFoundError:
340            raise
341        except InvalidDataError as err:
342            self.logger.error("Invalid fragment data for station %s: %s", session.station_id, err)
343            raise
344
345    async def _handle_stream_request(self, request: web.Request) -> web.Response:
346        """Handle dynamic stream request.
347
348        Map track numbers to Pandora fragments and redirect to audio URLs.
349        """
350        if not (station_id := request.query.get("station_id")):
351            return web.Response(status=400, text="Missing station_id")
352        if not (track_num_str := request.query.get("track_num")):
353            return web.Response(status=400, text="Missing track_num")
354
355        try:
356            music_track_num = int(track_num_str)
357        except ValueError:
358            return web.Response(status=400, text="Invalid track_num")
359
360        # Get or create session with LRU eviction
361        session = self._get_or_create_session(station_id)
362
363        try:
364            # If we don't have this music track yet, fetch more fragments
365            while music_track_num >= len(session.track_map):
366                next_fragment_idx = len(session.fragments)
367                await self._get_fragment_data(session, next_fragment_idx)
368
369            # Look up the actual fragment/track position
370            fragment_idx, track_idx = session.track_map[music_track_num]
371
372            # Ensure fragment is loaded
373            if fragment_idx >= len(session.fragments) or not session.fragments[fragment_idx]:
374                await self._get_fragment_data(session, fragment_idx)
375
376            fragment = session.fragments[fragment_idx]
377            if not fragment:
378                return web.Response(status=404, text="Track unavailable")
379
380            # Get the track
381            tracks = fragment.get("tracks", [])
382            if track_idx >= len(tracks):
383                self.logger.error(
384                    "Track index %d out of range (fragment has %d tracks)",
385                    track_idx,
386                    len(tracks),
387                )
388                return web.Response(status=404, text="Track unavailable")
389
390            track = tracks[track_idx]
391            audio_url = track.get("audioURL")
392
393            if not audio_url:
394                self.logger.error("No audio URL in track data")
395                return web.Response(status=404, text="Track unavailable")
396
397            # Redirect to the actual audio URL
398            return web.Response(status=302, headers={"Location": audio_url})
399
400        except (MediaNotFoundError, InvalidDataError) as err:
401            self.logger.error("Stream error: %s", err)
402            return web.Response(status=404, text="Stream unavailable")
403        except ProviderUnavailableError as err:
404            self.logger.error("Pandora service unavailable: %s", err)
405            return web.Response(status=503, text="Service temporarily unavailable")
406
407    def _get_or_create_session(self, station_id: str) -> PandoraStationSession:
408        """Get or create a session, with LRU eviction if needed."""
409        # Simple LRU: limit to 10 active sessions
410        if station_id not in self._sessions and len(self._sessions) >= 10:
411            # Remove oldest session
412            oldest = min(self._sessions.values(), key=lambda s: s.last_accessed)
413            self.logger.debug("Evicting session for station %s", oldest.station_id)
414            del self._sessions[oldest.station_id]
415
416        if station_id not in self._sessions:
417            self._sessions[station_id] = PandoraStationSession(station_id)
418
419        session = self._sessions[station_id]
420        session.last_accessed = time.time()
421        return session
422
423    async def search(
424        self,
425        search_query: str,
426        media_types: list[MediaType],
427        limit: int = 25,
428    ) -> SearchResults:
429        """Search library radio stations by name."""
430        # Search limited to library stations (API search requires legacy endpoints)
431        if MediaType.RADIO not in media_types:
432            return SearchResults()
433
434        results: list[Radio] = []
435
436        async for station in self.get_library_radios():
437            if compare_strings(station.name, search_query):
438                results.append(station)
439                if len(results) >= limit:
440                    break
441
442        return SearchResults(radio=results)
443
444    async def _update_stream_metadata(
445        self, streamdetails: StreamDetails, elapsed_time: int
446    ) -> None:
447        """Update stream metadata based on elapsed playback time."""
448        station_id = streamdetails.item_id
449
450        # Get session if it exists
451        if station_id not in self._sessions:
452            return
453
454        session = self._sessions[station_id]
455        session.last_accessed = time.time()
456
457        if not session.track_map or not session.cumulative_times:
458            return
459
460        # Find the current track based on elapsed time
461        current_track_idx = None
462        for i, start_time in enumerate(session.cumulative_times):
463            # Calculate when this track ends
464            if i + 1 < len(session.cumulative_times):
465                end_time = session.cumulative_times[i + 1]
466            else:
467                end_time = start_time + session.get_track_duration(i)
468
469            if start_time <= elapsed_time < end_time:
470                current_track_idx = i
471                break
472
473        if current_track_idx is None:
474            return
475
476        # Get track data
477        frag_idx, track_idx = session.track_map[current_track_idx]
478        if frag_idx >= len(session.fragments):
479            return
480        fragment = session.fragments[frag_idx]
481        if not fragment:
482            return
483
484        tracks = fragment.get("tracks", [])
485        if track_idx >= len(tracks):
486            return
487
488        track = tracks[track_idx]
489
490        # Update metadata if title changed
491        if not streamdetails.stream_metadata or streamdetails.stream_metadata.title == track.get(
492            "songTitle"
493        ):
494            return
495
496        # Get album art
497        album_art_url = None
498        if album_art := track.get("albumArt"):
499            album_art_url = next(
500                (art["url"] for art in album_art if art.get("size") == 500),
501                album_art[-1]["url"] if album_art else None,
502            )
503
504        streamdetails.stream_metadata.title = track.get("songTitle", "Unknown Song")
505        streamdetails.stream_metadata.artist = track.get("artistName", "Unknown Artist")
506        streamdetails.stream_metadata.album = track.get("albumTitle")
507        streamdetails.stream_metadata.image_url = album_art_url
508        streamdetails.stream_metadata.duration = track.get("trackLength")
509        streamdetails.stream_metadata.uri = track.get("songDetailURL")
510