/
/
/
1"""WebSocket client handler for Music Assistant API."""
2
3from __future__ import annotations
4
5import asyncio
6import inspect
7import logging
8from concurrent import futures
9from contextlib import suppress
10from typing import TYPE_CHECKING, Any, Final
11
12from aiohttp import WSMsgType, web
13from music_assistant_models.api import (
14 CommandMessage,
15 ErrorResultMessage,
16 MessageType,
17 SuccessResultMessage,
18)
19from music_assistant_models.auth import AuthProviderType, User, UserRole
20from music_assistant_models.enums import EventType
21from music_assistant_models.errors import (
22 AuthenticationRequired,
23 InsufficientPermissions,
24 InvalidCommand,
25 InvalidToken,
26)
27
28from music_assistant.constants import HOMEASSISTANT_SYSTEM_USER, VERBOSE_LOG_LEVEL
29from music_assistant.helpers.api import APICommandHandler, parse_arguments
30
31from .helpers.auth_middleware import (
32 is_request_from_ingress,
33 set_current_token,
34 set_current_user,
35 set_sendspin_player_id,
36)
37from .helpers.auth_providers import get_ha_user_details, get_ha_user_role
38
39if TYPE_CHECKING:
40 from music_assistant_models.event import MassEvent
41
42 from music_assistant.controllers.webserver import WebserverController
43
44MAX_PENDING_MSG = 512
45CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
46
47
48class WebsocketClientHandler:
49 """Handle an active websocket client connection."""
50
51 def __init__(self, webserver: WebserverController, request: web.Request) -> None:
52 """Initialize an active connection."""
53 self.webserver = webserver
54 self.mass = webserver.mass
55 self.request = request
56 self.wsock = web.WebSocketResponse(heartbeat=30)
57 self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG)
58 self._handle_task: asyncio.Task[Any] | None = None
59 self._writer_task: asyncio.Task[None] | None = None
60 self._logger = webserver.logger
61 self._authenticated_user: User | None = (
62 None # Will be set after auth command or from Ingress
63 )
64 self._current_token: str | None = None # Will be set after auth command
65 self._token_id: str | None = None # Will be set after auth for tracking revocation
66 self._sendspin_player_id: str | None = None # Set if client is a sendspin web player
67 self._is_ingress = is_request_from_ingress(request)
68 self._events_unsub_callback: Any = None # Will be set after authentication
69 # Track WebRTC session ID if this is a WebRTC gateway connection
70 self._webrtc_session_id: str | None = request.query.get("webrtc_session_id")
71 # try to dynamically detect the base_url of a client if proxied or behind Ingress
72 self.base_url: str | None = None
73 if forward_host := request.headers.get("X-Forwarded-Host"):
74 ingress_path = request.headers.get("X-Ingress-Path", "")
75 forward_proto = request.headers.get("X-Forwarded-Proto", request.protocol)
76 self.base_url = f"{forward_proto}://{forward_host}{ingress_path}"
77
78 async def disconnect(self) -> None:
79 """Disconnect client."""
80 self._cancel()
81 if self._writer_task is not None:
82 await self._writer_task
83
84 async def handle_client(self) -> web.WebSocketResponse:
85 """Handle a websocket response."""
86 # ruff: noqa: PLR0915
87 request = self.request
88 wsock = self.wsock
89 try:
90 async with asyncio.timeout(10):
91 await wsock.prepare(request)
92 except TimeoutError:
93 self._logger.warning("Timeout preparing request from %s", request.remote)
94 return wsock
95
96 self._logger.log(VERBOSE_LOG_LEVEL, "Connection from %s", request.remote)
97 self._handle_task = asyncio.current_task()
98 self._writer_task = self.mass.create_task(self._writer())
99
100 # send server(version) info when client connects
101 server_info = self.mass.get_server_info()
102 await self._send_message(server_info)
103
104 # Block until onboarding is complete
105 if not self.webserver.auth.has_users and not self._is_ingress:
106 await self._send_message(ErrorResultMessage("connection", 503, "Setup required"))
107 await wsock.close()
108 return wsock
109
110 # For Ingress connections, auto-create/link user and subscribe to events immediately
111 # For regular connections, events will be subscribed after successful authentication
112 if self._is_ingress:
113 await self._handle_ingress_auth()
114 self._subscribe_to_events()
115
116 disconnect_warn = None
117
118 try:
119 while not wsock.closed:
120 msg = await wsock.receive()
121
122 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
123 break
124
125 if msg.type != WSMsgType.TEXT:
126 continue
127
128 self._logger.log(VERBOSE_LOG_LEVEL, "Received: %s", msg.data)
129
130 try:
131 command_msg = CommandMessage.from_json(msg.data)
132 except ValueError:
133 disconnect_warn = f"Received invalid JSON: {msg.data}"
134 break
135
136 await self._handle_command(command_msg)
137
138 except asyncio.CancelledError:
139 self._logger.debug("Connection closed by client")
140
141 except Exception:
142 self._logger.exception("Unexpected error inside websocket API")
143
144 finally:
145 # Handle connection shutting down.
146 if self._events_unsub_callback:
147 self._events_unsub_callback()
148 self._logger.log(VERBOSE_LOG_LEVEL, "Unsubscribed from events")
149
150 # Unregister from webserver tracking
151 self.webserver.unregister_websocket_client(self)
152
153 try:
154 self._to_write.put_nowait(None)
155 # Make sure all error messages are written before closing
156 await self._writer_task
157 await wsock.close()
158 except asyncio.QueueFull: # can be raised by put_nowait
159 self._writer_task.cancel()
160
161 finally:
162 if disconnect_warn is None:
163 self._logger.log(VERBOSE_LOG_LEVEL, "Disconnected")
164 else:
165 self._logger.warning("Disconnected: %s", disconnect_warn)
166
167 return wsock
168
169 async def _handle_command(self, msg: CommandMessage) -> None:
170 """Handle an incoming command from the client."""
171 self._logger.debug("Handling command %s", msg.command)
172
173 # Handle special "auth" command
174 if msg.command == "auth":
175 await self._handle_auth_command(msg)
176 return
177
178 # work out handler for the given path/command
179 handler = self.mass.command_handlers.get(msg.command)
180
181 if handler is None:
182 await self._send_message(
183 ErrorResultMessage(
184 msg.message_id,
185 InvalidCommand.error_code,
186 f"Invalid command: {msg.command}",
187 )
188 )
189 self._logger.warning("Invalid command: %s", msg.command)
190 return
191
192 # Check authentication if required
193 if handler.authenticated or handler.required_role:
194 # For Ingress, user should already be set from _handle_ingress_auth
195 # For regular connections, user must be set via auth command
196 if self._authenticated_user is None:
197 await self._send_message(
198 ErrorResultMessage(
199 msg.message_id,
200 AuthenticationRequired.error_code,
201 "Authentication required. Please send auth command first.",
202 )
203 )
204 return
205
206 # Set user, token, and sendspin player in context for API methods
207 set_current_user(self._authenticated_user)
208 set_current_token(self._current_token)
209 set_sendspin_player_id(self._sendspin_player_id)
210
211 # Check role if required
212 if handler.required_role == "admin":
213 if self._authenticated_user.role != UserRole.ADMIN:
214 await self._send_message(
215 ErrorResultMessage(
216 msg.message_id,
217 InsufficientPermissions.error_code,
218 "Admin access required",
219 )
220 )
221 return
222
223 # schedule task to handle the command
224 self.mass.create_task(self._run_handler(handler, msg))
225
226 async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
227 """Run command handler and send response."""
228 try:
229 args = parse_arguments(handler.signature, handler.type_hints, msg.args)
230 result: Any = handler.target(**args)
231 if hasattr(result, "__anext__"):
232 # handle async generator (for really large listings)
233 items: list[Any] = []
234 async for item in result:
235 items.append(item)
236 if len(items) >= 500:
237 await self._send_message(
238 SuccessResultMessage(msg.message_id, items, partial=True)
239 )
240 items = []
241 result = items
242 elif inspect.iscoroutine(result):
243 result = await result
244 await self._send_message(SuccessResultMessage(msg.message_id, result))
245 except Exception as err:
246 if self._logger.isEnabledFor(logging.DEBUG):
247 self._logger.exception("Error handling message: %s", msg)
248 else:
249 self._logger.error("Error handling message: %s: %s", msg.command, str(err))
250 err_msg = str(err) or err.__class__.__name__
251 await self._send_message(
252 ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg)
253 )
254
255 async def _writer(self) -> None:
256 """Write outgoing messages."""
257 # Exceptions if Socket disconnected or cancelled by connection handler
258 with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
259 while not self.wsock.closed:
260 if (process := await self._to_write.get()) is None:
261 break
262
263 if callable(process):
264 message: str = process()
265 else:
266 message = process
267 self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", message)
268 await self.wsock.send_str(message)
269
270 async def _send_message(self, message: MessageType) -> None:
271 """Send a message to the client (for large response messages).
272
273 Runs JSON serialization in executor to avoid blocking for large messages.
274 Closes connection if the client is not reading the messages.
275
276 Async friendly.
277 """
278 # Run JSON serialization in executor to avoid blocking for large messages
279 loop = asyncio.get_running_loop()
280 _message = await loop.run_in_executor(None, message.to_json)
281
282 try:
283 self._to_write.put_nowait(_message)
284 except asyncio.QueueFull:
285 self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG)
286
287 self._cancel()
288
289 def _send_message_sync(self, message: MessageType) -> None:
290 """Send a message from a sync context (for small messages like events).
291
292 Serializes inline without executor overhead since events are typically small.
293 """
294 _message = message.to_json()
295
296 try:
297 self._to_write.put_nowait(_message)
298 except asyncio.QueueFull:
299 self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG)
300
301 self._cancel()
302
303 async def _handle_auth_command(self, msg: CommandMessage) -> None:
304 """Handle WebSocket authentication command.
305
306 :param msg: The auth command message with access token.
307 """
308 # Extract token from args (support both 'token' and 'access_token' for backward compat)
309 token = msg.args.get("token") if msg.args else None
310 if not token:
311 token = msg.args.get("access_token") if msg.args else None
312 if not token:
313 await self._send_message(
314 ErrorResultMessage(
315 msg.message_id,
316 AuthenticationRequired.error_code,
317 "token required in args",
318 )
319 )
320 return
321
322 # Authenticate with token
323 user = await self.webserver.auth.authenticate_with_token(token)
324 if not user:
325 await self._send_message(
326 ErrorResultMessage(
327 msg.message_id,
328 InvalidToken.error_code,
329 "Invalid or expired token",
330 )
331 )
332 return
333
334 # Security: Deny homeassistant system user on regular (non-Ingress) webserver
335 if not self._is_ingress and user.username == HOMEASSISTANT_SYSTEM_USER:
336 await self._send_message(
337 ErrorResultMessage(
338 msg.message_id,
339 InvalidToken.error_code,
340 "Home Assistant system user not allowed on regular webserver",
341 )
342 )
343 return
344
345 # Get token_id for tracking revocation events
346 token_id = await self.webserver.auth.get_token_id_from_token(token)
347
348 # Store authenticated user, token, and token_id
349 self._authenticated_user = user
350 self._current_token = token
351 self._token_id = token_id
352 self._logger.info("WebSocket client authenticated as %s", user.username)
353
354 # Send success response
355 await self._send_message(
356 SuccessResultMessage(
357 msg.message_id,
358 {"authenticated": True, "user": user.to_dict()},
359 )
360 )
361
362 # Subscribe to events after successful authentication
363 self._subscribe_to_events()
364
365 # Register with webserver for tracking
366 self.webserver.register_websocket_client(self)
367
368 async def _handle_ingress_auth(self) -> None:
369 """Handle authentication for Ingress connections (auto-create/link user)."""
370 ingress_user_id = self.request.headers.get("X-Remote-User-ID")
371 ingress_username = self.request.headers.get("X-Remote-User-Name")
372 ingress_display_name = self.request.headers.get("X-Remote-User-Display-Name")
373
374 if ingress_user_id and ingress_username:
375 # Try to find existing user linked to this HA user ID
376 user = await self.webserver.auth.get_user_by_provider_link(
377 AuthProviderType.HOME_ASSISTANT, ingress_user_id
378 )
379
380 if not user:
381 # Check if a user with this username already exists
382 user = await self.webserver.auth.get_user_by_username(ingress_username)
383
384 if not user:
385 # New user - fetch details from HA
386 ha_username, ha_display_name, avatar_url = await get_ha_user_details(
387 self.mass, ingress_user_id
388 )
389 # Auto-create user for Ingress (they're already authenticated by HA)
390 role = await get_ha_user_role(self.mass, ingress_user_id)
391 user = await self.webserver.auth.create_user(
392 username=ha_username or ingress_username,
393 role=role,
394 display_name=ha_display_name or ingress_display_name,
395 avatar_url=avatar_url,
396 )
397
398 # Link to Home Assistant provider (or create the link if user already existed)
399 await self.webserver.auth.link_user_to_provider(
400 user, AuthProviderType.HOME_ASSISTANT, ingress_user_id
401 )
402
403 # Update user with HA details if available (HA is source of truth)
404 # Fall back to ingress headers if API lookup doesn't return values
405 _, ha_display_name, avatar_url = await get_ha_user_details(self.mass, ingress_user_id)
406 final_display_name = ha_display_name or ingress_display_name
407 if final_display_name or avatar_url:
408 user = await self.webserver.auth.update_user(
409 user,
410 display_name=final_display_name,
411 avatar_url=avatar_url,
412 )
413
414 self._authenticated_user = user
415 self._logger.debug("Ingress user authenticated: %s", user.username)
416 else:
417 # No HA user headers - allow homeassistant system user to connect with token
418 # This allows the Home Assistant integration to connect via the internal network
419 # The token authentication happens in _handle_auth_message
420 self._logger.debug("Ingress connection without user headers, expecting token auth")
421
422 def _subscribe_to_events(self) -> None:
423 """Subscribe to Mass events and forward them to the client."""
424 if self._events_unsub_callback is not None:
425 # Already subscribed
426 return
427
428 def handle_event(event: MassEvent) -> None:
429 # filter events for objects the user has no access to
430 if (
431 self._authenticated_user
432 and self._authenticated_user.player_filter
433 and event.event
434 in (
435 EventType.PLAYER_ADDED,
436 EventType.PLAYER_REMOVED,
437 EventType.PLAYER_UPDATED,
438 EventType.QUEUE_ADDED,
439 EventType.QUEUE_ITEMS_UPDATED,
440 EventType.QUEUE_TIME_UPDATED,
441 EventType.QUEUE_UPDATED,
442 )
443 and event.object_id
444 and event.object_id not in self._authenticated_user.player_filter
445 and event.object_id != self._sendspin_player_id
446 ):
447 return
448
449 self._send_message_sync(event)
450
451 self._events_unsub_callback = self.mass.subscribe(handle_event)
452 self._logger.debug("Subscribed to events")
453
454 def _cancel(self) -> None:
455 """Cancel the connection."""
456 if self._handle_task is not None:
457 self._handle_task.cancel()
458 if self._writer_task is not None:
459 self._writer_task.cancel()
460