music-assistant-server

5.8 KBPY
test_api_client.py
5.8 KB169 lines • python
1"""Unit tests for YandexMusicClient (api_client.py)."""
2
3from __future__ import annotations
4
5from unittest import mock
6
7import pytest
8from music_assistant_models.errors import ResourceTemporarilyUnavailable
9from yandex_music.exceptions import NetworkError
10
11from music_assistant.providers.yandex_music.api_client import YandexMusicClient
12
13
14def _make_client() -> tuple[YandexMusicClient, mock.AsyncMock]:
15    """Create a YandexMusicClient with a mocked underlying ClientAsync.
16
17    :return: Tuple of (YandexMusicClient, mock_underlying_client).
18    """
19    client = YandexMusicClient(token="fake_token")
20    mock_underlying = mock.AsyncMock()
21    client._client = mock_underlying
22    client._user_id = 12345
23    return client, mock_underlying
24
25
26# -- get_liked_albums: batching -------------------------------------------------
27
28
29async def test_get_liked_albums_batching() -> None:
30    """Albums are fetched in batch via client.albums() for full metadata."""
31    client, underlying = _make_client()
32
33    # Build 3 minimal "like" objects with album stubs (no cover_uri)
34    likes = []
35    for album_id in (1, 2, 3):
36        album_stub = type("Album", (), {"id": album_id, "cover_uri": None})()
37        like = type("Like", (), {"album": album_stub})()
38        likes.append(like)
39
40    # Full album objects returned by client.albums()
41    full_albums = [
42        type("Album", (), {"id": aid, "cover_uri": f"cover_{aid}"})() for aid in (1, 2, 3)
43    ]
44
45    underlying.users_likes_albums = mock.AsyncMock(return_value=likes)
46    underlying.albums = mock.AsyncMock(return_value=full_albums)
47
48    result = await client.get_liked_albums()
49
50    underlying.albums.assert_awaited_once_with(["1", "2", "3"])
51    assert result == full_albums
52    assert all(a.cover_uri is not None for a in result)
53
54
55async def test_get_liked_albums_batch_fallback_on_network_error() -> None:
56    """When client.albums() fails, fallback returns minimal album data from likes."""
57    client, underlying = _make_client()
58
59    album_stub_1 = type("Album", (), {"id": 10, "cover_uri": None})()
60    album_stub_2 = type("Album", (), {"id": 20, "cover_uri": None})()
61    likes = [
62        type("Like", (), {"album": album_stub_1})(),
63        type("Like", (), {"album": album_stub_2})(),
64    ]
65
66    underlying.users_likes_albums = mock.AsyncMock(return_value=likes)
67    underlying.albums = mock.AsyncMock(side_effect=NetworkError("timeout"))
68
69    result = await client.get_liked_albums()
70
71    # Should fall back to the minimal album objects from likes
72    assert len(result) == 2
73    assert {a.id for a in result} == {10, 20}
74
75
76# -- get_tracks: retry on NetworkError -------------------------------------------
77
78
79async def test_get_tracks_retry_on_network_error_then_success() -> None:
80    """First call fails with NetworkError; retry succeeds."""
81    client, underlying = _make_client()
82
83    track = type("Track", (), {"id": 400, "title": "Test Track"})()
84    underlying.tracks = mock.AsyncMock(side_effect=[NetworkError("timeout"), [track]])
85
86    result = await client.get_tracks(["400"])
87
88    assert result == [track]
89    assert underlying.tracks.await_count == 2
90
91
92async def test_get_tracks_retry_on_network_error_both_fail() -> None:
93    """Both attempts fail with NetworkError → ResourceTemporarilyUnavailable."""
94    client, underlying = _make_client()
95
96    underlying.tracks = mock.AsyncMock(
97        side_effect=[NetworkError("timeout"), NetworkError("timeout again")]
98    )
99
100    with pytest.raises(ResourceTemporarilyUnavailable):
101        await client.get_tracks(["400"])
102
103    assert underlying.tracks.await_count == 2
104
105
106# -- get_my_wave_tracks --------------------------------------------------------
107
108
109async def test_get_my_wave_tracks_returns_tracks_and_batch_id() -> None:
110    """get_my_wave_tracks calls rotor_station_tracks and returns ordered tracks and batch_id."""
111    client, underlying = _make_client()
112
113    seq_track = type("TrackShort", (), {"id": 100, "track_id": 100})()
114    sequence_item = type("SequenceItem", (), {"track": seq_track})()
115    result_obj = type(
116        "StationTracksResult",
117        (),
118        {"sequence": [sequence_item], "batch_id": "batch_abc"},
119    )()
120    underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj)
121
122    full_track = type("Track", (), {"id": 100, "title": "My Wave Track"})()
123    underlying.tracks = mock.AsyncMock(return_value=[full_track])
124
125    tracks, batch_id = await client.get_my_wave_tracks()
126
127    underlying.rotor_station_tracks.assert_awaited_once()
128    assert batch_id == "batch_abc"
129    assert len(tracks) == 1
130    assert tracks[0].id == 100
131
132
133async def test_get_my_wave_tracks_empty_sequence_returns_empty() -> None:
134    """When rotor returns no sequence, get_my_wave_tracks returns ([], batch_id or None)."""
135    client, underlying = _make_client()
136
137    result_obj = type("StationTracksResult", (), {"sequence": [], "batch_id": None})()
138    underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj)
139
140    tracks, batch_id = await client.get_my_wave_tracks()
141
142    assert tracks == []
143    assert batch_id is None
144    underlying.tracks.assert_not_awaited()
145
146
147async def test_send_rotor_station_feedback_posts() -> None:
148    """send_rotor_station_feedback POSTs to rotor feedback endpoint."""
149    client, underlying = _make_client()
150
151    underlying._request = mock.AsyncMock()
152    underlying.base_url = "https://api.music.yandex.net"
153
154    result = await client.send_rotor_station_feedback(
155        "user:onyourwave",
156        "trackStarted",
157        track_id="12345",
158        batch_id="batch_xyz",
159    )
160
161    assert result is True
162    underlying._request.post.assert_awaited_once()
163    call_args = underlying._request.post.await_args
164    assert "rotor/station/user:onyourwave/feedback" in call_args[0][0]
165    body = call_args[0][1]
166    assert body["type"] == "trackStarted"
167    assert body["trackId"] == "12345"
168    assert body["batchId"] == "batch_xyz"
169