/
/
/
1"""WebSocket client handler for Music Assistant API."""
2
3from __future__ import annotations
4
5import asyncio
6import logging
7from concurrent import futures
8from contextlib import suppress
9from typing import TYPE_CHECKING, Any, Final
10
11from aiohttp import WSMsgType, web
12from music_assistant_models.api import (
13 CommandMessage,
14 ErrorResultMessage,
15 MessageType,
16 SuccessResultMessage,
17)
18from music_assistant_models.auth import AuthProviderType, User, UserRole
19from music_assistant_models.enums import EventType
20from music_assistant_models.errors import (
21 AuthenticationRequired,
22 InsufficientPermissions,
23 InvalidCommand,
24 InvalidToken,
25)
26
27from music_assistant.constants import HOMEASSISTANT_SYSTEM_USER, VERBOSE_LOG_LEVEL
28from music_assistant.helpers.api import APICommandHandler, parse_arguments
29
30from .helpers.auth_middleware import (
31 is_request_from_ingress,
32 set_current_token,
33 set_current_user,
34 set_sendspin_player_id,
35)
36from .helpers.auth_providers import get_ha_user_details, get_ha_user_role
37
38if TYPE_CHECKING:
39 from music_assistant_models.event import MassEvent
40
41 from music_assistant.controllers.webserver import WebserverController
42
43MAX_PENDING_MSG = 512
44CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
45
46
47class WebsocketClientHandler:
48 """Handle an active websocket client connection."""
49
50 def __init__(self, webserver: WebserverController, request: web.Request) -> None:
51 """Initialize an active connection."""
52 self.webserver = webserver
53 self.mass = webserver.mass
54 self.request = request
55 self.wsock = web.WebSocketResponse(heartbeat=30)
56 self._to_write: asyncio.Queue[str | None] = asyncio.Queue(maxsize=MAX_PENDING_MSG)
57 self._handle_task: asyncio.Task[Any] | None = None
58 self._writer_task: asyncio.Task[None] | None = None
59 self._logger = webserver.logger
60 self._authenticated_user: User | None = (
61 None # Will be set after auth command or from Ingress
62 )
63 self._current_token: str | None = None # Will be set after auth command
64 self._token_id: str | None = None # Will be set after auth for tracking revocation
65 self._sendspin_player_id: str | None = None # Set if client is a sendspin web player
66 self._is_ingress = is_request_from_ingress(request)
67 self._events_unsub_callback: Any = None # Will be set after authentication
68 # Track WebRTC session ID if this is a WebRTC gateway connection
69 self._webrtc_session_id: str | None = request.query.get("webrtc_session_id")
70 # try to dynamically detect the base_url of a client if proxied or behind Ingress
71 self.base_url: str | None = None
72 if forward_host := request.headers.get("X-Forwarded-Host"):
73 ingress_path = request.headers.get("X-Ingress-Path", "")
74 forward_proto = request.headers.get("X-Forwarded-Proto", request.protocol)
75 self.base_url = f"{forward_proto}://{forward_host}{ingress_path}"
76
77 async def disconnect(self) -> None:
78 """Disconnect client."""
79 self._cancel()
80 if self._writer_task is not None:
81 await self._writer_task
82
83 async def handle_client(self) -> web.WebSocketResponse:
84 """Handle a websocket response."""
85 # ruff: noqa: PLR0915
86 request = self.request
87 wsock = self.wsock
88 try:
89 async with asyncio.timeout(10):
90 await wsock.prepare(request)
91 except TimeoutError:
92 self._logger.warning("Timeout preparing request from %s", request.remote)
93 return wsock
94
95 self._logger.log(VERBOSE_LOG_LEVEL, "Connection from %s", request.remote)
96 self._handle_task = asyncio.current_task()
97 self._writer_task = self.mass.create_task(self._writer())
98
99 # send server(version) info when client connects
100 server_info = self.mass.get_server_info()
101 await self._send_message(server_info)
102
103 # Block until onboarding is complete
104 if not self.webserver.auth.has_users and not self._is_ingress:
105 await self._send_message(ErrorResultMessage("connection", 503, "Setup required"))
106 await wsock.close()
107 return wsock
108
109 # For Ingress connections, auto-create/link user and subscribe to events immediately
110 # For regular connections, events will be subscribed after successful authentication
111 if self._is_ingress:
112 await self._handle_ingress_auth()
113 self._subscribe_to_events()
114
115 disconnect_warn = None
116
117 try:
118 while not wsock.closed:
119 msg = await wsock.receive()
120
121 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
122 break
123
124 if msg.type != WSMsgType.TEXT:
125 continue
126
127 self._logger.log(VERBOSE_LOG_LEVEL, "Received: %s", msg.data)
128
129 try:
130 command_msg = CommandMessage.from_json(msg.data)
131 except ValueError:
132 disconnect_warn = f"Received invalid JSON: {msg.data}"
133 break
134
135 await self._handle_command(command_msg)
136
137 except asyncio.CancelledError:
138 self._logger.debug("Connection closed by client")
139
140 except Exception:
141 self._logger.exception("Unexpected error inside websocket API")
142
143 finally:
144 # Handle connection shutting down.
145 if self._events_unsub_callback:
146 self._events_unsub_callback()
147 self._logger.log(VERBOSE_LOG_LEVEL, "Unsubscribed from events")
148
149 # Unregister from webserver tracking
150 self.webserver.unregister_websocket_client(self)
151
152 try:
153 self._to_write.put_nowait(None)
154 # Make sure all error messages are written before closing
155 await self._writer_task
156 await wsock.close()
157 except asyncio.QueueFull: # can be raised by put_nowait
158 self._writer_task.cancel()
159
160 finally:
161 if disconnect_warn is None:
162 self._logger.log(VERBOSE_LOG_LEVEL, "Disconnected")
163 else:
164 self._logger.warning("Disconnected: %s", disconnect_warn)
165
166 return wsock
167
168 async def _handle_command(self, msg: CommandMessage) -> None:
169 """Handle an incoming command from the client."""
170 self._logger.debug("Handling command %s", msg.command)
171
172 # Handle special "auth" command
173 if msg.command == "auth":
174 await self._handle_auth_command(msg)
175 return
176
177 # work out handler for the given path/command
178 handler = self.mass.command_handlers.get(msg.command)
179
180 if handler is None:
181 await self._send_message(
182 ErrorResultMessage(
183 msg.message_id,
184 InvalidCommand.error_code,
185 f"Invalid command: {msg.command}",
186 )
187 )
188 self._logger.warning("Invalid command: %s", msg.command)
189 return
190
191 # Check authentication if required
192 if handler.authenticated or handler.required_role:
193 # For Ingress, user should already be set from _handle_ingress_auth
194 # For regular connections, user must be set via auth command
195 if self._authenticated_user is None:
196 await self._send_message(
197 ErrorResultMessage(
198 msg.message_id,
199 AuthenticationRequired.error_code,
200 "Authentication required. Please send auth command first.",
201 )
202 )
203 return
204
205 # Set user, token, and sendspin player in context for API methods
206 set_current_user(self._authenticated_user)
207 set_current_token(self._current_token)
208 set_sendspin_player_id(self._sendspin_player_id)
209
210 # Check role if required
211 if handler.required_role == "admin":
212 if self._authenticated_user.role != UserRole.ADMIN:
213 await self._send_message(
214 ErrorResultMessage(
215 msg.message_id,
216 InsufficientPermissions.error_code,
217 "Admin access required",
218 )
219 )
220 return
221
222 # schedule task to handle the command
223 self.mass.create_task(self._run_handler(handler, msg))
224
225 async def _run_handler(self, handler: APICommandHandler, msg: CommandMessage) -> None:
226 """Run command handler and send response."""
227 try:
228 args = parse_arguments(handler.signature, handler.type_hints, msg.args)
229 result: Any = handler.target(**args)
230 if hasattr(result, "__anext__"):
231 # handle async generator (for really large listings)
232 items: list[Any] = []
233 async for item in result:
234 items.append(item)
235 if len(items) >= 500:
236 await self._send_message(
237 SuccessResultMessage(msg.message_id, items, partial=True)
238 )
239 items = []
240 result = items
241 elif asyncio.iscoroutine(result):
242 result = await result
243 await self._send_message(SuccessResultMessage(msg.message_id, result))
244 except Exception as err:
245 if self._logger.isEnabledFor(logging.DEBUG):
246 self._logger.exception("Error handling message: %s", msg)
247 else:
248 self._logger.error("Error handling message: %s: %s", msg.command, str(err))
249 err_msg = str(err) or err.__class__.__name__
250 await self._send_message(
251 ErrorResultMessage(msg.message_id, getattr(err, "error_code", 999), err_msg)
252 )
253
254 async def _writer(self) -> None:
255 """Write outgoing messages."""
256 # Exceptions if Socket disconnected or cancelled by connection handler
257 with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
258 while not self.wsock.closed:
259 if (process := await self._to_write.get()) is None:
260 break
261
262 if callable(process):
263 message: str = process()
264 else:
265 message = process
266 self._logger.log(VERBOSE_LOG_LEVEL, "Writing: %s", message)
267 await self.wsock.send_str(message)
268
269 async def _send_message(self, message: MessageType) -> None:
270 """Send a message to the client (for large response messages).
271
272 Runs JSON serialization in executor to avoid blocking for large messages.
273 Closes connection if the client is not reading the messages.
274
275 Async friendly.
276 """
277 # Run JSON serialization in executor to avoid blocking for large messages
278 loop = asyncio.get_running_loop()
279 _message = await loop.run_in_executor(None, message.to_json)
280
281 try:
282 self._to_write.put_nowait(_message)
283 except asyncio.QueueFull:
284 self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG)
285
286 self._cancel()
287
288 def _send_message_sync(self, message: MessageType) -> None:
289 """Send a message from a sync context (for small messages like events).
290
291 Serializes inline without executor overhead since events are typically small.
292 """
293 _message = message.to_json()
294
295 try:
296 self._to_write.put_nowait(_message)
297 except asyncio.QueueFull:
298 self._logger.error("Client exceeded max pending messages: %s", MAX_PENDING_MSG)
299
300 self._cancel()
301
302 async def _handle_auth_command(self, msg: CommandMessage) -> None:
303 """Handle WebSocket authentication command.
304
305 :param msg: The auth command message with access token.
306 """
307 # Extract token from args (support both 'token' and 'access_token' for backward compat)
308 token = msg.args.get("token") if msg.args else None
309 if not token:
310 token = msg.args.get("access_token") if msg.args else None
311 if not token:
312 await self._send_message(
313 ErrorResultMessage(
314 msg.message_id,
315 AuthenticationRequired.error_code,
316 "token required in args",
317 )
318 )
319 return
320
321 # Authenticate with token
322 user = await self.webserver.auth.authenticate_with_token(token)
323 if not user:
324 await self._send_message(
325 ErrorResultMessage(
326 msg.message_id,
327 InvalidToken.error_code,
328 "Invalid or expired token",
329 )
330 )
331 return
332
333 # Security: Deny homeassistant system user on regular (non-Ingress) webserver
334 if not self._is_ingress and user.username == HOMEASSISTANT_SYSTEM_USER:
335 await self._send_message(
336 ErrorResultMessage(
337 msg.message_id,
338 InvalidToken.error_code,
339 "Home Assistant system user not allowed on regular webserver",
340 )
341 )
342 return
343
344 # Get token_id for tracking revocation events
345 token_id = await self.webserver.auth.get_token_id_from_token(token)
346
347 # Store authenticated user, token, and token_id
348 self._authenticated_user = user
349 self._current_token = token
350 self._token_id = token_id
351 self._logger.info("WebSocket client authenticated as %s", user.username)
352
353 # Send success response
354 await self._send_message(
355 SuccessResultMessage(
356 msg.message_id,
357 {"authenticated": True, "user": user.to_dict()},
358 )
359 )
360
361 # Subscribe to events after successful authentication
362 self._subscribe_to_events()
363
364 # Register with webserver for tracking
365 self.webserver.register_websocket_client(self)
366
367 async def _handle_ingress_auth(self) -> None:
368 """Handle authentication for Ingress connections (auto-create/link user)."""
369 ingress_user_id = self.request.headers.get("X-Remote-User-ID")
370 ingress_username = self.request.headers.get("X-Remote-User-Name")
371 ingress_display_name = self.request.headers.get("X-Remote-User-Display-Name")
372
373 if ingress_user_id and ingress_username:
374 # Try to find existing user linked to this HA user ID
375 user = await self.webserver.auth.get_user_by_provider_link(
376 AuthProviderType.HOME_ASSISTANT, ingress_user_id
377 )
378
379 if not user:
380 # Check if a user with this username already exists
381 user = await self.webserver.auth.get_user_by_username(ingress_username)
382
383 if not user:
384 # New user - fetch details from HA
385 ha_username, ha_display_name, avatar_url = await get_ha_user_details(
386 self.mass, ingress_user_id
387 )
388 # Auto-create user for Ingress (they're already authenticated by HA)
389 role = await get_ha_user_role(self.mass, ingress_user_id)
390 user = await self.webserver.auth.create_user(
391 username=ha_username or ingress_username,
392 role=role,
393 display_name=ha_display_name or ingress_display_name,
394 avatar_url=avatar_url,
395 )
396
397 # Link to Home Assistant provider (or create the link if user already existed)
398 await self.webserver.auth.link_user_to_provider(
399 user, AuthProviderType.HOME_ASSISTANT, ingress_user_id
400 )
401
402 # Update user with HA details if available (HA is source of truth)
403 # Fall back to ingress headers if API lookup doesn't return values
404 _, ha_display_name, avatar_url = await get_ha_user_details(self.mass, ingress_user_id)
405 final_display_name = ha_display_name or ingress_display_name
406 if final_display_name or avatar_url:
407 user = await self.webserver.auth.update_user(
408 user,
409 display_name=final_display_name,
410 avatar_url=avatar_url,
411 )
412
413 self._authenticated_user = user
414 self._logger.debug("Ingress user authenticated: %s", user.username)
415 else:
416 # No HA user headers - allow homeassistant system user to connect with token
417 # This allows the Home Assistant integration to connect via the internal network
418 # The token authentication happens in _handle_auth_message
419 self._logger.debug("Ingress connection without user headers, expecting token auth")
420
421 def _subscribe_to_events(self) -> None:
422 """Subscribe to Mass events and forward them to the client."""
423 if self._events_unsub_callback is not None:
424 # Already subscribed
425 return
426
427 def handle_event(event: MassEvent) -> None:
428 # filter events for objects the user has no access to
429 if (
430 self._authenticated_user
431 and self._authenticated_user.player_filter
432 and event.event
433 in (
434 EventType.PLAYER_ADDED,
435 EventType.PLAYER_REMOVED,
436 EventType.PLAYER_UPDATED,
437 EventType.QUEUE_ADDED,
438 EventType.QUEUE_ITEMS_UPDATED,
439 EventType.QUEUE_TIME_UPDATED,
440 EventType.QUEUE_UPDATED,
441 )
442 and event.object_id
443 and event.object_id not in self._authenticated_user.player_filter
444 and event.object_id != self._sendspin_player_id
445 ):
446 return
447
448 self._send_message_sync(event)
449
450 self._events_unsub_callback = self.mass.subscribe(handle_event)
451 self._logger.debug("Subscribed to events")
452
453 def _cancel(self) -> None:
454 """Cancel the connection."""
455 if self._handle_task is not None:
456 self._handle_task.cancel()
457 if self._writer_task is not None:
458 self._writer_task.cancel()
459