/
/
/
1"""Unit tests for YandexMusicClient (api_client.py)."""
2
3from __future__ import annotations
4
5import base64
6import hashlib
7import hmac
8import re
9from unittest import mock
10
11import pytest
12from music_assistant_models.errors import ResourceTemporarilyUnavailable
13from yandex_music.exceptions import NetworkError
14from yandex_music.rotor.dashboard import Dashboard
15from yandex_music.rotor.station_result import StationResult
16from yandex_music.utils.sign_request import DEFAULT_SIGN_KEY
17
18from music_assistant.providers.yandex_music.api_client import (
19 GET_FILE_INFO_CODECS,
20 YandexMusicClient,
21)
22
23
24def _make_client() -> tuple[YandexMusicClient, mock.AsyncMock]:
25 """Create a YandexMusicClient with a mocked underlying ClientAsync.
26
27 Also mocks connect() so that _reconnect() restores the mock client
28 instead of trying to create a real connection.
29
30 :return: Tuple of (YandexMusicClient, mock_underlying_client).
31 """
32 client = YandexMusicClient(token="fake_token")
33 mock_underlying = mock.AsyncMock()
34 client._client = mock_underlying
35 client._user_id = 12345
36
37 async def _fake_connect() -> bool:
38 client._client = mock_underlying
39 client._user_id = 12345
40 return True
41
42 client.connect = _fake_connect # type: ignore[method-assign]
43 return client, mock_underlying
44
45
46# -- get_liked_albums: batching -------------------------------------------------
47
48
49async def test_get_liked_albums_batching() -> None:
50 """Albums are fetched in batch via client.albums() for full metadata."""
51 client, underlying = _make_client()
52
53 # Build 3 minimal "like" objects with album stubs (no cover_uri)
54 likes = []
55 for album_id in (1, 2, 3):
56 album_stub = type("Album", (), {"id": album_id, "cover_uri": None})()
57 like = type("Like", (), {"album": album_stub})()
58 likes.append(like)
59
60 # Full album objects returned by client.albums()
61 full_albums = [
62 type("Album", (), {"id": aid, "cover_uri": f"cover_{aid}"})() for aid in (1, 2, 3)
63 ]
64
65 underlying.users_likes_albums = mock.AsyncMock(return_value=likes)
66 underlying.albums = mock.AsyncMock(return_value=full_albums)
67
68 result = await client.get_liked_albums()
69
70 underlying.albums.assert_awaited_once_with(["1", "2", "3"])
71 assert result == full_albums
72 assert all(a.cover_uri is not None for a in result)
73
74
75async def test_get_liked_albums_batch_fallback_on_network_error() -> None:
76 """When client.albums() fails, fallback returns minimal album data from likes."""
77 client, underlying = _make_client()
78
79 album_stub_1 = type("Album", (), {"id": 10, "cover_uri": None})()
80 album_stub_2 = type("Album", (), {"id": 20, "cover_uri": None})()
81 likes = [
82 type("Like", (), {"album": album_stub_1})(),
83 type("Like", (), {"album": album_stub_2})(),
84 ]
85
86 underlying.users_likes_albums = mock.AsyncMock(return_value=likes)
87 underlying.albums = mock.AsyncMock(side_effect=NetworkError("timeout"))
88
89 result = await client.get_liked_albums()
90
91 # Should fall back to the minimal album objects from likes
92 assert len(result) == 2
93 assert {a.id for a in result} == {10, 20}
94
95
96# -- get_tracks: retry on NetworkError -------------------------------------------
97
98
99async def test_get_tracks_retry_on_network_error_then_success() -> None:
100 """First call fails with NetworkError; retry succeeds."""
101 client, underlying = _make_client()
102
103 track = type("Track", (), {"id": 400, "title": "Test Track"})()
104 underlying.tracks = mock.AsyncMock(side_effect=[NetworkError("timeout"), [track]])
105
106 result = await client.get_tracks(["400"])
107
108 assert result == [track]
109 assert underlying.tracks.await_count == 2
110
111
112async def test_get_tracks_retry_on_network_error_both_fail() -> None:
113 """Both attempts fail with NetworkError â ResourceTemporarilyUnavailable."""
114 client, underlying = _make_client()
115
116 underlying.tracks = mock.AsyncMock(
117 side_effect=[NetworkError("timeout"), NetworkError("timeout again")]
118 )
119
120 with pytest.raises(ResourceTemporarilyUnavailable):
121 await client.get_tracks(["400"])
122
123 assert underlying.tracks.await_count == 2
124
125
126# -- get_my_wave_tracks --------------------------------------------------------
127
128
129async def test_get_my_wave_tracks_returns_tracks_and_batch_id() -> None:
130 """get_my_wave_tracks calls rotor_station_tracks and returns ordered tracks and batch_id."""
131 client, underlying = _make_client()
132
133 seq_track = type("TrackShort", (), {"id": 100, "track_id": 100})()
134 sequence_item = type("SequenceItem", (), {"track": seq_track})()
135 result_obj = type(
136 "StationTracksResult",
137 (),
138 {"sequence": [sequence_item], "batch_id": "batch_abc"},
139 )()
140 underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj)
141
142 full_track = type("Track", (), {"id": 100, "title": "My Wave Track"})()
143 underlying.tracks = mock.AsyncMock(return_value=[full_track])
144
145 tracks, batch_id = await client.get_my_wave_tracks()
146
147 underlying.rotor_station_tracks.assert_awaited_once()
148 assert batch_id == "batch_abc"
149 assert len(tracks) == 1
150 assert tracks[0].id == 100
151
152
153async def test_get_my_wave_tracks_empty_sequence_returns_empty() -> None:
154 """When rotor returns no sequence, get_my_wave_tracks returns ([], batch_id or None)."""
155 client, underlying = _make_client()
156
157 result_obj = type("StationTracksResult", (), {"sequence": [], "batch_id": None})()
158 underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj)
159
160 tracks, batch_id = await client.get_my_wave_tracks()
161
162 assert tracks == []
163 assert batch_id is None
164 underlying.tracks.assert_not_awaited()
165
166
167async def test_send_rotor_station_feedback_posts() -> None:
168 """send_rotor_station_feedback POSTs to rotor feedback endpoint."""
169 client, underlying = _make_client()
170
171 underlying._request = mock.AsyncMock()
172 underlying.base_url = "https://api.music.yandex.net"
173
174 result = await client.send_rotor_station_feedback(
175 "user:onyourwave",
176 "trackStarted",
177 track_id="12345",
178 batch_id="batch_xyz",
179 )
180
181 assert result is True
182 underlying._request.post.assert_awaited_once()
183 call_args = underlying._request.post.await_args
184 assert "rotor/station/user:onyourwave/feedback" in call_args[0][0]
185 body = call_args[0][1]
186 assert body["type"] == "trackStarted"
187 assert body["trackId"] == "12345"
188 assert body["batchId"] == "batch_xyz"
189
190
191# -- LRC regex tests ---------------------------------------------------------
192
193
194def test_lrc_regex_matches_valid_synced_lyrics() -> None:
195 """LRC regex matches valid synced lyrics with proper format [mm:ss.xx].
196
197 Uses re.search (no ^ anchor) matching the implementation in api_client.py,
198 which intentionally allows timestamps anywhere in the text so that LRC
199 metadata lines like [ar:Artist] before the first timestamp don't prevent
200 detection.
201 """
202 pattern = r"\[\d{2}:\d{2}(?:\.\d{2,3})?\]"
203
204 # Valid LRC formats that should match
205 valid_cases = [
206 "[00:12]", # Basic format (no fractional part)
207 "[00:12.34]", # With centiseconds (2-digit fractional part â lower bound of \d{2,3})
208 "[00:12.345]", # With milliseconds (3-digit fractional part â upper bound of \d{2,3})
209 "[12:34]", # Another basic format
210 "[99:59.99]", # Edge case
211 "Some [00:12] text", # Timestamp embedded in text â re.search finds it
212 ]
213
214 for case in valid_cases:
215 assert re.search(pattern, case), f"Should match: {case}"
216
217
218def test_lrc_regex_rejects_invalid_formats() -> None:
219 """LRC regex rejects invalid formats (no closing bracket, wrong format)."""
220 pattern = r"\[\d{2}:\d{2}(?:\.\d{2,3})?\]"
221
222 # Invalid formats that should NOT match
223 invalid_cases = [
224 "[00:12", # Missing closing bracket
225 "00:12]", # Missing opening bracket
226 "[0:12]", # Single digit minute
227 "[00:1]", # Single digit second
228 "[00:12.1]", # Single digit centiseconds (should be 2-3 digits)
229 "[00:12.1234]", # Four digit milliseconds
230 ]
231
232 for case in invalid_cases:
233 assert not re.search(pattern, case), f"Should NOT match: {case}"
234
235
236# -- HMAC sign construction tests --------------------------------------------
237
238
239def test_hmac_sign_construction_explicit() -> None:
240 """HMAC sign is constructed explicitly with commas stripped from codecs."""
241 # Simulate the parameters
242 timestamp = 1234567890
243 track_id = "12345"
244
245 # The correct way (explicit construction)
246 codecs_for_sign = GET_FILE_INFO_CODECS.replace(",", "")
247 param_string = f"{timestamp}{track_id}lossless{codecs_for_sign}encraw"
248
249 # Verify codecs_for_sign has no commas
250 assert "," not in codecs_for_sign
251
252 # Verify the construction is correct
253 expected = f"1234567890{track_id}lossless{codecs_for_sign}encraw"
254 assert param_string == expected
255
256 # Verify HMAC can be constructed
257 hmac_sign = hmac.new(
258 DEFAULT_SIGN_KEY.encode(),
259 param_string.encode(),
260 hashlib.sha256,
261 )
262 sign = base64.b64encode(hmac_sign.digest()).decode()[:-1]
263
264 # Verify sign is 43 characters (SHA-256 base64 with one "=" removed)
265 assert len(sign) == 43
266 assert not sign.endswith("=")
267
268
269# -- get_dashboard_stations --------------------------------------------------
270
271
272async def test_get_dashboard_stations_returns_personalized_stations() -> None:
273 """get_dashboard_stations() returns stations from rotor/stations/dashboard."""
274 client, underlying = _make_client()
275
276 _de_client = type("C", (), {"report_unknown_fields": False})()
277
278 station_result = StationResult.de_json(
279 {
280 "station": {
281 "id": {"type": "mood", "tag": "sad"},
282 "name": "ÐÑÑÑÑное",
283 "restrictions": {},
284 "restrictions2": {},
285 "full_image_url": None,
286 "id_for_from": "mood-sad",
287 "icon": None,
288 },
289 "settings": None,
290 "settings2": None,
291 "ad_params": None,
292 "rup_title": "Sad Songs",
293 "rup_description": "",
294 },
295 _de_client,
296 )
297
298 dashboard = mock.MagicMock(spec=Dashboard)
299 dashboard.stations = [station_result]
300 underlying.rotor_stations_dashboard.return_value = dashboard
301
302 stations = await client.get_dashboard_stations()
303
304 assert len(stations) == 1
305 station_id, name, _image_url = stations[0]
306 assert station_id == "mood:sad"
307 assert name == "ÐÑÑÑÑное" # station.name takes priority over rup_title
308 underlying.rotor_stations_dashboard.assert_called_once()
309
310
311async def test_get_dashboard_stations_empty_on_error() -> None:
312 """get_dashboard_stations() returns empty list on network error."""
313 client, underlying = _make_client()
314 underlying.rotor_stations_dashboard.side_effect = NetworkError("timeout")
315
316 stations = await client.get_dashboard_stations()
317
318 assert stations == []
319
320
321async def test_get_dashboard_stations_skips_user_type() -> None:
322 """get_dashboard_stations() filters out personal 'user' type stations."""
323 client, underlying = _make_client()
324
325 _de_client = type("C", (), {"report_unknown_fields": False})()
326
327 personal_station = StationResult.de_json(
328 {
329 "station": {
330 "id": {"type": "user", "tag": "onyourwave"},
331 "name": "My Wave",
332 "restrictions": {},
333 "restrictions2": {},
334 "full_image_url": None,
335 "id_for_from": "user-onyourwave",
336 "icon": None,
337 },
338 "settings": None,
339 "settings2": None,
340 "ad_params": None,
341 "rup_title": "My Wave",
342 "rup_description": "",
343 },
344 _de_client,
345 )
346
347 dashboard = mock.MagicMock(spec=Dashboard)
348 dashboard.stations = [personal_station]
349 underlying.rotor_stations_dashboard.return_value = dashboard
350
351 stations = await client.get_dashboard_stations()
352
353 assert stations == []
354