/
/
/
1"""Tests for remote access feature."""
2
3from unittest.mock import AsyncMock, Mock, patch
4
5import pytest
6from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection
7from aiortc.rtcdtlstransport import RTCCertificate
8
9from music_assistant.controllers.webserver.remote_access import RemoteAccessInfo
10from music_assistant.controllers.webserver.remote_access.gateway import (
11 WebRTCGateway,
12 WebRTCSession,
13)
14from music_assistant.helpers.webrtc_certificate import (
15 _generate_certificate,
16 create_peer_connection_with_certificate,
17 get_remote_id_from_certificate,
18)
19
20
21@pytest.fixture
22def mock_certificate() -> Mock:
23 """Create a mock RTCCertificate for testing."""
24 cert = Mock()
25 mock_fingerprint = Mock()
26 mock_fingerprint.algorithm = "sha-256"
27 mock_fingerprint.value = (
28 "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:"
29 "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99"
30 )
31 cert.getFingerprints.return_value = [mock_fingerprint]
32 return cert
33
34
35async def test_get_remote_id_from_certificate(mock_certificate: Mock) -> None:
36 """Test remote ID generation from certificate fingerprint."""
37 remote_id = get_remote_id_from_certificate(mock_certificate)
38
39 # Should be base32 encoded, uppercase, no padding
40 assert remote_id.isalnum()
41 assert remote_id == remote_id.upper()
42 # 128 bits = 16 bytes -> 26 base32 chars (without padding)
43 assert len(remote_id) == 26
44
45
46async def test_remote_access_info_dataclass() -> None:
47 """Test RemoteAccessInfo dataclass."""
48 info = RemoteAccessInfo(
49 enabled=True,
50 running=True,
51 connected=False,
52 remote_id="VVPN3TLP34YMGIZDINCEKQKSIR",
53 using_ha_cloud=False,
54 signaling_url="wss://signaling.music-assistant.io/ws",
55 )
56
57 assert info.enabled is True
58 assert info.running is True
59 assert info.connected is False
60 assert info.remote_id == "VVPN3TLP34YMGIZDINCEKQKSIR"
61 assert info.using_ha_cloud is False
62 assert info.signaling_url == "wss://signaling.music-assistant.io/ws"
63
64
65async def test_webrtc_gateway_initialization(mock_certificate: Mock) -> None:
66 """Test WebRTCGateway initializes correctly."""
67 mock_session = Mock()
68 gateway = WebRTCGateway(
69 http_session=mock_session,
70 remote_id="TEST-REMOTE-ID",
71 certificate=mock_certificate,
72 signaling_url="wss://test.example.com/ws",
73 local_ws_url="ws://localhost:8095/ws",
74 )
75
76 assert gateway._remote_id == "TEST-REMOTE-ID"
77 assert gateway.signaling_url == "wss://test.example.com/ws"
78 assert gateway.local_ws_url == "ws://localhost:8095/ws"
79 assert gateway.is_running is False
80 assert gateway.is_connected is False
81 assert len(gateway.ice_servers) > 0
82
83
84async def test_webrtc_gateway_custom_ice_servers(mock_certificate: Mock) -> None:
85 """Test WebRTCGateway accepts custom ICE servers."""
86 mock_session = Mock()
87 custom_ice_servers = [
88 {"urls": "stun:custom.stun.server:3478"},
89 {"urls": "turn:custom.turn.server:3478", "username": "user", "credential": "pass"},
90 ]
91
92 gateway = WebRTCGateway(
93 http_session=mock_session,
94 remote_id="TEST-REMOTE-ID",
95 certificate=mock_certificate,
96 ice_servers=custom_ice_servers,
97 )
98
99 assert gateway.ice_servers == custom_ice_servers
100
101
102async def test_webrtc_gateway_start_stop(mock_certificate: Mock) -> None:
103 """Test WebRTCGateway start and stop."""
104 mock_session = Mock()
105 gateway = WebRTCGateway(
106 http_session=mock_session,
107 remote_id="TEST-REMOTE-ID",
108 certificate=mock_certificate,
109 )
110
111 # Mock the _run method to avoid actual connection
112 with patch.object(gateway, "_run", new_callable=AsyncMock):
113 await gateway.start()
114 assert gateway.is_running is True
115 assert gateway._run_task is not None
116
117 await gateway.stop()
118 assert gateway.is_running is False
119
120
121async def test_webrtc_gateway_handle_registration_message(mock_certificate: Mock) -> None:
122 """Test WebRTCGateway handles registration confirmation."""
123 mock_session = Mock()
124 gateway = WebRTCGateway(
125 http_session=mock_session,
126 remote_id="TEST-REMOTE-ID",
127 certificate=mock_certificate,
128 )
129
130 # Mock signaling WebSocket
131 gateway._signaling_ws = Mock()
132
133 message = {"type": "registered", "remoteId": "TEST-REMOTE-ID"}
134 await gateway._handle_signaling_message(message)
135
136 # Should log but not crash
137
138
139async def test_webrtc_gateway_handle_error_message(mock_certificate: Mock) -> None:
140 """Test WebRTCGateway handles error messages."""
141 mock_session = Mock()
142 gateway = WebRTCGateway(
143 http_session=mock_session,
144 remote_id="TEST-REMOTE-ID",
145 certificate=mock_certificate,
146 )
147
148 message = {"type": "error", "message": "Test error"}
149 # Should log error but not crash
150 await gateway._handle_signaling_message(message)
151
152
153async def test_webrtc_gateway_create_session(mock_certificate: Mock) -> None:
154 """Test WebRTCGateway creates sessions for clients."""
155 mock_session = Mock()
156 gateway = WebRTCGateway(
157 http_session=mock_session,
158 remote_id="TEST-REMOTE-ID",
159 certificate=mock_certificate,
160 )
161
162 session_id = "test-session-123"
163 await gateway._create_session(session_id)
164
165 assert session_id in gateway.sessions
166 assert gateway.sessions[session_id].session_id == session_id
167 assert gateway.sessions[session_id].peer_connection is not None
168
169 # Cleanup
170 await gateway._close_session(session_id)
171
172
173async def test_webrtc_gateway_close_session(mock_certificate: Mock) -> None:
174 """Test WebRTCGateway closes sessions properly."""
175 mock_session = Mock()
176 gateway = WebRTCGateway(
177 http_session=mock_session,
178 remote_id="TEST-REMOTE-ID",
179 certificate=mock_certificate,
180 )
181
182 session_id = "test-session-456"
183 await gateway._create_session(session_id)
184 assert session_id in gateway.sessions
185
186 await gateway._close_session(session_id)
187 assert session_id not in gateway.sessions
188
189
190async def test_webrtc_gateway_close_nonexistent_session(mock_certificate: Mock) -> None:
191 """Test WebRTCGateway handles closing non-existent session gracefully."""
192 mock_session = Mock()
193 gateway = WebRTCGateway(
194 http_session=mock_session,
195 remote_id="TEST-REMOTE-ID",
196 certificate=mock_certificate,
197 )
198
199 # Should not raise an error
200 await gateway._close_session("nonexistent-session")
201
202
203async def test_webrtc_gateway_default_ice_servers(mock_certificate: Mock) -> None:
204 """Test WebRTCGateway uses default ICE servers."""
205 mock_session = Mock()
206 gateway = WebRTCGateway(
207 http_session=mock_session,
208 remote_id="TEST-REMOTE-ID",
209 certificate=mock_certificate,
210 )
211
212 assert len(gateway.ice_servers) > 0
213 # Should have at least one STUN server
214 assert any("stun:" in server["urls"] for server in gateway.ice_servers)
215
216
217async def test_webrtc_gateway_handle_client_connected(mock_certificate: Mock) -> None:
218 """Test WebRTCGateway handles client-connected message."""
219 mock_session = Mock()
220 gateway = WebRTCGateway(
221 http_session=mock_session,
222 remote_id="TEST-REMOTE-ID",
223 certificate=mock_certificate,
224 )
225
226 message = {"type": "client-connected", "sessionId": "test-session"}
227 await gateway._handle_signaling_message(message)
228
229 # Session should be created
230 assert "test-session" in gateway.sessions
231
232 # Cleanup
233 await gateway._close_session("test-session")
234
235
236async def test_webrtc_gateway_handle_client_disconnected(mock_certificate: Mock) -> None:
237 """Test WebRTCGateway handles client-disconnected message."""
238 mock_session = Mock()
239 gateway = WebRTCGateway(
240 http_session=mock_session,
241 remote_id="TEST-REMOTE-ID",
242 certificate=mock_certificate,
243 )
244
245 # Create a session first
246 session_id = "test-disconnect-session"
247 await gateway._create_session(session_id)
248 assert session_id in gateway.sessions
249
250 # Handle disconnect
251 message = {"type": "client-disconnected", "sessionId": session_id}
252 await gateway._handle_signaling_message(message)
253
254 # Session should be removed
255 assert session_id not in gateway.sessions
256
257
258async def test_webrtc_gateway_reconnection_logic(mock_certificate: Mock) -> None:
259 """Test WebRTCGateway has proper reconnection backoff."""
260 mock_session = Mock()
261 gateway = WebRTCGateway(
262 http_session=mock_session,
263 remote_id="TEST-REMOTE-ID",
264 certificate=mock_certificate,
265 )
266
267 # Check initial reconnect delay
268 assert gateway._current_reconnect_delay == 10
269
270 # Simulate multiple failed connections (without actually connecting)
271 initial_delay = gateway._current_reconnect_delay
272 gateway._current_reconnect_delay = min(
273 gateway._current_reconnect_delay * 2, gateway._max_reconnect_delay
274 )
275
276 assert gateway._current_reconnect_delay == initial_delay * 2
277
278 # Should not exceed max
279 for _ in range(10):
280 gateway._current_reconnect_delay = min(
281 gateway._current_reconnect_delay * 2, gateway._max_reconnect_delay
282 )
283
284 assert gateway._current_reconnect_delay <= gateway._max_reconnect_delay
285
286
287async def test_webrtc_gateway_session_data_structures() -> None:
288 """Test WebRTCSession data structure."""
289 config = RTCConfiguration()
290 pc = RTCPeerConnection(configuration=config)
291
292 session = WebRTCSession(session_id="test-123", peer_connection=pc)
293
294 assert session.session_id == "test-123"
295 assert session.peer_connection is pc
296 assert session.data_channel is None
297 assert session.local_ws is None
298 assert session.message_queue is not None
299 assert session.forward_to_local_task is None
300 assert session.forward_from_local_task is None
301
302 # Cleanup
303 await pc.close()
304
305
306async def test_webrtc_gateway_handle_offer_without_session(mock_certificate: Mock) -> None:
307 """Test WebRTCGateway handles offer for non-existent session gracefully."""
308 mock_session = Mock()
309 gateway = WebRTCGateway(
310 http_session=mock_session,
311 remote_id="TEST-REMOTE-ID",
312 certificate=mock_certificate,
313 )
314
315 # Try to handle offer for non-existent session
316 offer_data = {"sdp": "test-sdp", "type": "offer"}
317 await gateway._handle_offer("nonexistent-session", offer_data)
318
319 # Should not crash
320
321
322async def test_webrtc_gateway_handle_ice_candidate_without_session(mock_certificate: Mock) -> None:
323 """Test WebRTCGateway handles ICE candidate for non-existent session gracefully."""
324 mock_session = Mock()
325 gateway = WebRTCGateway(
326 http_session=mock_session,
327 remote_id="TEST-REMOTE-ID",
328 certificate=mock_certificate,
329 )
330
331 # Try to handle ICE candidate for non-existent session
332 candidate_data = {
333 "candidate": "candidate:1 1 UDP 1234 192.168.1.1 12345 typ host",
334 "sdpMid": "0",
335 "sdpMLineIndex": 0,
336 }
337 await gateway._handle_ice_candidate("nonexistent-session", candidate_data)
338
339 # Should not crash
340
341
342async def test_create_peer_connection_with_certificate() -> None:
343 """Test that create_peer_connection_with_certificate correctly sets the custom certificate.
344
345 This verifies the fragile name-mangled private attribute access works correctly
346 and that our custom certificate fully replaces the auto-generated one, which is
347 critical for DTLS pinning.
348 """
349 # First verify the name-mangled attribute exists on RTCPeerConnection.
350 # If aiortc changes its internals, this will fail and alert us to update our code.
351 pc = RTCPeerConnection()
352 try:
353 assert hasattr(pc, "_RTCPeerConnection__certificates")
354 finally:
355 await pc.close()
356
357 # Now test our function correctly sets the certificate
358 private_key, cert = _generate_certificate()
359 certificate = RTCCertificate(key=private_key, cert=cert)
360 config = RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.example.com:3478")])
361
362 pc = create_peer_connection_with_certificate(certificate, configuration=config)
363
364 try:
365 certificates = pc._RTCPeerConnection__certificates # type: ignore[attr-defined]
366 assert len(certificates) == 1
367 assert certificates[0] is certificate
368 finally:
369 await pc.close()
370