/
/
/
1"""Authentication manager for Music Assistant webserver."""
2
3from __future__ import annotations
4
5import contextlib
6import hashlib
7import logging
8import secrets
9from datetime import datetime, timedelta
10from sqlite3 import OperationalError
11from typing import TYPE_CHECKING, Any
12
13import jwt as pyjwt
14from music_assistant_models.auth import (
15 AuthProviderType,
16 AuthToken,
17 User,
18 UserAuthProvider,
19 UserRole,
20)
21from music_assistant_models.errors import (
22 AuthenticationRequired,
23 InsufficientPermissions,
24 InvalidDataError,
25)
26
27from music_assistant.constants import (
28 DB_TABLE_PLAYLOG,
29 HOMEASSISTANT_SYSTEM_USER,
30 MASS_LOGGER_NAME,
31)
32from music_assistant.controllers.webserver.helpers.auth_middleware import (
33 get_current_token,
34 get_current_user,
35)
36from music_assistant.controllers.webserver.helpers.auth_providers import (
37 AuthResult,
38 BuiltinLoginProvider,
39 HomeAssistantOAuthProvider,
40 HomeAssistantProviderConfig,
41 LoginProvider,
42 normalize_username,
43)
44from music_assistant.helpers.api import api_command
45from music_assistant.helpers.database import DatabaseConnection
46from music_assistant.helpers.datetime import utc
47from music_assistant.helpers.json import json_dumps, json_loads
48from music_assistant.helpers.jwt_auth import JWTHelper
49
50if TYPE_CHECKING:
51 from music_assistant.controllers.webserver import WebserverController
52
53LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth")
54
55# Database schema version
56DB_SCHEMA_VERSION = 4
57
58# Token expiration constants (in days)
59TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use)
60TOKEN_LONG_LIVED_EXPIRATION = 3650 # Long-lived tokens (10 years, no auto-renewal)
61
62
63class AuthenticationManager:
64 """Manager for authentication and user management (part of webserver controller)."""
65
66 def __init__(self, webserver: WebserverController) -> None:
67 """
68 Initialize the authentication manager.
69
70 :param webserver: WebserverController instance.
71 """
72 self.webserver = webserver
73 self.mass = webserver.mass
74 self.database: DatabaseConnection = None # type: ignore[assignment]
75 self.login_providers: dict[str, LoginProvider] = {}
76 self.logger = LOGGER
77 self._has_users: bool = False
78 self.jwt_helper: JWTHelper = None # type: ignore[assignment]
79
80 async def setup(self) -> None:
81 """Initialize the authentication manager."""
82 # Setup database
83 db_path = self.mass.storage_path + "/auth.db"
84 self.database = DatabaseConnection(db_path)
85 await self.database.setup()
86
87 # Create database schema and handle migrations
88 await self._setup_database()
89
90 # Initialize JWT helper with secret key
91 jwt_secret = await self._get_or_create_jwt_secret()
92 self.jwt_helper = JWTHelper(jwt_secret)
93
94 # Setup login providers
95 await self._setup_login_providers()
96
97 self._has_users = await self._has_non_system_users()
98
99 self.logger.info(
100 "Authentication manager initialized (providers=%d)", len(self.login_providers)
101 )
102
103 async def close(self) -> None:
104 """Cleanup on exit."""
105 if self.database:
106 await self.database.close()
107
108 @property
109 def has_users(self) -> bool:
110 """Check if any users exist in the system."""
111 return self._has_users
112
113 async def _setup_database(self) -> None:
114 """Set up database schema and handle migrations."""
115 # Always create tables if they don't exist
116 await self._create_database_tables()
117
118 # Check current schema version
119 try:
120 if db_row := await self.database.get_row("settings", {"key": "schema_version"}):
121 prev_version = int(db_row["value"])
122 else:
123 prev_version = DB_SCHEMA_VERSION
124 except (KeyError, ValueError, Exception):
125 # settings table doesn't exist yet or other error
126 prev_version = 0
127
128 # Perform migration if needed
129 if prev_version < DB_SCHEMA_VERSION:
130 self.logger.warning(
131 "Performing database migration from schema version %s to %s",
132 prev_version,
133 DB_SCHEMA_VERSION,
134 )
135 await self._migrate_database(prev_version)
136
137 # Store current schema version
138 await self.database.insert_or_replace(
139 "settings",
140 {"key": "schema_version", "value": str(DB_SCHEMA_VERSION), "type": "int"},
141 )
142
143 # Create indexes
144 await self._create_database_indexes()
145 await self.database.commit()
146
147 async def _create_database_tables(self) -> None:
148 """Create database tables."""
149 # Settings table (for schema version and other settings)
150 await self.database.execute(
151 """
152 CREATE TABLE IF NOT EXISTS settings (
153 key TEXT PRIMARY KEY,
154 value TEXT,
155 type TEXT
156 )
157 """
158 )
159 # Users table
160 await self.database.execute(
161 """
162 CREATE TABLE IF NOT EXISTS users (
163 user_id TEXT PRIMARY KEY,
164 username TEXT NOT NULL UNIQUE,
165 role TEXT NOT NULL,
166 enabled INTEGER NOT NULL DEFAULT 1,
167 created_at TEXT NOT NULL,
168 display_name TEXT,
169 avatar_url TEXT,
170 preferences json NOT NULL DEFAULT '{}',
171 player_filter json NOT NULL DEFAULT '[]',
172 provider_filter json NOT NULL DEFAULT '[]'
173 )
174 """
175 )
176 # User auth provider links (many-to-many)
177 await self.database.execute(
178 """
179 CREATE TABLE IF NOT EXISTS user_auth_providers (
180 link_id TEXT PRIMARY KEY,
181 user_id TEXT NOT NULL,
182 provider_type TEXT NOT NULL,
183 provider_user_id TEXT NOT NULL,
184 created_at TEXT NOT NULL,
185 UNIQUE(provider_type, provider_user_id),
186 FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
187 )
188 """
189 )
190 # Auth tokens table
191 await self.database.execute(
192 """
193 CREATE TABLE IF NOT EXISTS auth_tokens (
194 token_id TEXT PRIMARY KEY,
195 user_id TEXT NOT NULL,
196 token_hash TEXT NOT NULL UNIQUE,
197 name TEXT NOT NULL,
198 created_at TEXT NOT NULL,
199 expires_at TEXT,
200 last_used_at TEXT,
201 is_long_lived INTEGER NOT NULL DEFAULT 0,
202 FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
203 )
204 """
205 )
206 await self.database.commit()
207
208 async def _create_database_indexes(self) -> None:
209 """Create database indexes."""
210 await self.database.execute(
211 "CREATE INDEX IF NOT EXISTS idx_user_auth_providers_user "
212 "ON user_auth_providers(user_id)"
213 )
214 await self.database.execute(
215 "CREATE INDEX IF NOT EXISTS idx_user_auth_providers_provider "
216 "ON user_auth_providers(provider_type, provider_user_id)"
217 )
218 await self.database.execute(
219 "CREATE INDEX IF NOT EXISTS idx_tokens_user ON auth_tokens(user_id)"
220 )
221 await self.database.execute(
222 "CREATE INDEX IF NOT EXISTS idx_tokens_hash ON auth_tokens(token_hash)"
223 )
224
225 async def _migrate_database(self, from_version: int) -> None:
226 """Perform database migration.
227
228 :param from_version: The schema version to migrate from.
229 """
230 self.logger.info(
231 "Migrating auth database from version %s to %s", from_version, DB_SCHEMA_VERSION
232 )
233 # Migration to version 2: Recreate tables due to password salt breaking change
234 if from_version < 2:
235 # Drop all auth-related tables
236 await self.database.execute("DROP TABLE IF EXISTS auth_tokens")
237 await self.database.execute("DROP TABLE IF EXISTS user_auth_providers")
238 await self.database.execute("DROP TABLE IF EXISTS users")
239 await self.database.commit()
240
241 # Recreate tables with current schema
242 await self._create_database_tables()
243
244 # Migration to version 3: Add player_filter and provider_filter columns
245 if from_version < 3:
246 with contextlib.suppress(OperationalError):
247 # Column(s) may already exist
248 await self.database.execute(
249 "ALTER TABLE users ADD COLUMN player_filter json NOT NULL DEFAULT '[]'"
250 )
251 await self.database.execute(
252 "ALTER TABLE users ADD COLUMN provider_filter json NOT NULL DEFAULT '[]'"
253 )
254 await self.database.commit()
255
256 # Migration to version 4: Make usernames case-insensitive by converting to lowercase
257 if from_version < 4:
258 await self.database.execute("UPDATE users SET username = LOWER(username)")
259 await self.database.commit()
260
261 async def _get_or_create_jwt_secret(self) -> str:
262 """Get or create JWT secret key from database.
263
264 :return: JWT secret key for signing tokens.
265 """
266 # Try to get existing secret
267 if secret_row := await self.database.get_row("settings", {"key": "jwt_secret"}):
268 return str(secret_row["value"])
269
270 # Generate new secret
271 jwt_secret = JWTHelper.generate_secret_key()
272
273 # Store in database
274 await self.database.insert_or_replace(
275 "settings",
276 {"key": "jwt_secret", "value": jwt_secret, "type": "string"},
277 )
278 await self.database.commit()
279
280 self.logger.info("Generated new JWT secret key")
281 return jwt_secret
282
283 async def _setup_login_providers(self) -> None:
284 """Set up available login providers based on configuration."""
285 # Always enable built-in provider
286 self.login_providers["builtin"] = BuiltinLoginProvider(self.mass, "builtin", {})
287
288 # Home Assistant OAuth provider
289 # Automatically enabled if HA provider (plugin) is configured
290 ha_provider = None
291 for provider in self.mass.providers:
292 if provider.domain == "hass" and provider.available:
293 ha_provider = provider
294 break
295
296 if ha_provider:
297 # Get URL from the HA provider config
298 ha_url = ha_provider.config.get_value("url")
299 assert isinstance(ha_url, str)
300 ha_config: HomeAssistantProviderConfig = {"ha_url": ha_url}
301 self.login_providers["homeassistant"] = HomeAssistantOAuthProvider(
302 self.mass, "homeassistant", ha_config
303 )
304 self.logger.info(
305 "Home Assistant OAuth provider enabled (using URL from HA provider: %s)",
306 ha_url,
307 )
308
309 async def _sync_ha_oauth_provider(self) -> None:
310 """
311 Sync HA OAuth provider with HA provider availability (dynamic check).
312
313 Adds the provider if HA is available, removes it if HA is not available.
314 """
315 # Find HA provider
316 ha_provider = None
317 for provider in self.mass.providers:
318 if provider.domain == "hass" and provider.available:
319 ha_provider = provider
320 break
321
322 if ha_provider:
323 # HA provider exists and is available - ensure OAuth provider is registered
324 if "homeassistant" not in self.login_providers:
325 # Get URL from the HA provider config
326 ha_url = ha_provider.config.get_value("url")
327 assert isinstance(ha_url, str)
328 ha_config: HomeAssistantProviderConfig = {"ha_url": ha_url}
329 self.login_providers["homeassistant"] = HomeAssistantOAuthProvider(
330 self.mass, "homeassistant", ha_config
331 )
332 self.logger.info(
333 "Home Assistant OAuth provider dynamically enabled (using URL: %s)",
334 ha_url,
335 )
336 # HA provider not available - remove OAuth provider if present
337 elif "homeassistant" in self.login_providers:
338 del self.login_providers["homeassistant"]
339 self.logger.info("Home Assistant OAuth provider removed (HA provider not available)")
340
341 async def authenticate_with_credentials(
342 self, provider_id: str, credentials: dict[str, Any]
343 ) -> AuthResult:
344 """
345 Authenticate a user with credentials.
346
347 :param provider_id: The login provider ID.
348 :param credentials: Provider-specific credentials.
349 """
350 provider = self.login_providers.get(provider_id)
351 if not provider:
352 return AuthResult(success=False, error="Invalid provider")
353
354 return await provider.authenticate(credentials)
355
356 async def authenticate_with_token(self, token: str) -> User | None:
357 """
358 Authenticate a user with an access token.
359
360 Supports both JWT tokens and legacy hash-based tokens for backward compatibility.
361
362 :param token: The access token (JWT or legacy hash token).
363 """
364 # Try to decode as JWT first
365 try:
366 payload = self.jwt_helper.decode_token(token, verify_exp=True)
367 token_id = payload.get("jti")
368 user_id = payload.get("sub")
369 is_long_lived = payload.get("is_long_lived", False)
370
371 if not token_id or not user_id:
372 return None
373
374 token_row = await self.database.get_row("auth_tokens", {"token_id": token_id})
375 if not token_row:
376 return None
377
378 # Database expiration is source of truth
379 if token_row["expires_at"]:
380 db_expires_at = datetime.fromisoformat(token_row["expires_at"])
381 if utc() > db_expires_at:
382 await self.database.delete("auth_tokens", {"token_id": token_id})
383 return None
384
385 # Update last used timestamp
386 now = utc()
387 updates = {"last_used_at": now.isoformat()}
388
389 if not is_long_lived:
390 # Short-lived token: extend expiration on each use (sliding window)
391 new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
392 updates["expires_at"] = new_expires_at.isoformat()
393
394 # Update database
395 await self.database.update(
396 "auth_tokens",
397 {"token_id": token_id},
398 updates,
399 )
400
401 return await self.get_user(user_id)
402
403 except pyjwt.ExpiredSignatureError:
404 if token_id := self.jwt_helper.get_token_id(token):
405 await self.database.delete("auth_tokens", {"token_id": token_id})
406 return None
407 except pyjwt.InvalidTokenError:
408 self.logger.debug("Token is not a valid JWT, trying legacy hash lookup")
409 except Exception as err:
410 self.logger.debug("Error decoding JWT token: %s, trying legacy hash lookup", err)
411
412 # Fallback to legacy hash-based token lookup
413 token_hash = hashlib.sha256(token.encode()).hexdigest()
414 token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
415 if not token_row:
416 return None
417
418 # Check if token is expired
419 if token_row["expires_at"]:
420 expires_at = datetime.fromisoformat(token_row["expires_at"])
421 if utc() > expires_at:
422 # Token expired, delete it
423 await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]})
424 return None
425
426 # Implement sliding expiration for short-lived tokens
427 is_long_lived = bool(token_row["is_long_lived"])
428 now = utc()
429 updates = {"last_used_at": now.isoformat()}
430
431 if not is_long_lived and token_row["expires_at"]:
432 # Short-lived token: extend expiration on each use (sliding window)
433 new_expires_at = now + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
434 updates["expires_at"] = new_expires_at.isoformat()
435
436 # Update last used timestamp and potentially expiration
437 await self.database.update(
438 "auth_tokens",
439 {"token_id": token_row["token_id"]},
440 updates,
441 )
442
443 # Get user
444 return await self.get_user(token_row["user_id"])
445
446 async def get_token_id_from_token(self, token: str) -> str | None:
447 """
448 Get token_id from a token string (for tracking revocation).
449
450 :param token: The access token (JWT or legacy hash token).
451 :return: The token_id or None if token not found.
452 """
453 # Try to extract from JWT first
454 if token_id := self.jwt_helper.get_token_id(token):
455 return token_id
456
457 # Fallback: Hash-based lookup for legacy tokens
458 token_hash = hashlib.sha256(token.encode()).hexdigest()
459 token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
460 if not token_row:
461 return None
462
463 return str(token_row["token_id"])
464
465 @api_command("auth/user", required_role="admin")
466 async def get_user(self, user_id: str) -> User | None:
467 """
468 Get user by ID (admin only).
469
470 :param user_id: The user ID.
471 :return: User object or None if not found.
472 """
473 user_row = await self.database.get_row("users", {"user_id": user_id})
474 if not user_row or not user_row["enabled"]:
475 return None
476
477 return User(
478 user_id=user_row["user_id"],
479 username=user_row["username"],
480 role=UserRole(user_row["role"]),
481 enabled=bool(user_row["enabled"]),
482 created_at=datetime.fromisoformat(user_row["created_at"]),
483 display_name=user_row["display_name"],
484 avatar_url=user_row["avatar_url"],
485 preferences=json_loads(user_row["preferences"]),
486 player_filter=json_loads(user_row["player_filter"]),
487 provider_filter=json_loads(user_row["provider_filter"]),
488 )
489
490 async def get_user_by_username(self, username: str) -> User | None:
491 """
492 Get user by username.
493
494 :param username: The username.
495 :return: User object or None if not found.
496 """
497 username = normalize_username(username)
498
499 user_row = await self.database.get_row("users", {"username": username})
500 if not user_row:
501 return None
502
503 return await self.get_user(user_row["user_id"])
504
505 async def get_user_by_provider_link(
506 self, provider_type: AuthProviderType, provider_user_id: str
507 ) -> User | None:
508 """
509 Get user by their provider link.
510
511 :param provider_type: The auth provider type.
512 :param provider_user_id: The user ID from the provider.
513 """
514 link_row = await self.database.get_row(
515 "user_auth_providers",
516 {
517 "provider_type": provider_type.value,
518 "provider_user_id": provider_user_id,
519 },
520 )
521 if not link_row:
522 return None
523
524 return await self.get_user(link_row["user_id"])
525
526 async def create_user(
527 self,
528 username: str,
529 role: UserRole = UserRole.USER,
530 display_name: str | None = None,
531 avatar_url: str | None = None,
532 preferences: dict[str, Any] | None = None,
533 player_filter: list[str] | None = None,
534 provider_filter: list[str] | None = None,
535 ) -> User:
536 """
537 Create a new user.
538
539 :param username: The username.
540 :param role: The user role (default: USER).
541 :param display_name: Optional display name.
542 :param avatar_url: Optional avatar URL.
543 :param preferences: Optional user preferences dict.
544 :param player_filter: Optional list of player IDs user has access to.
545 :param provider_filter: Optional list of provider instance IDs user has access to.
546 """
547 normalized_username = normalize_username(username)
548
549 # Check if this is the first non-system user
550 is_first_user = not await self._has_non_system_users()
551
552 user_id = secrets.token_urlsafe(32)
553 created_at = utc()
554 if preferences is None:
555 preferences = {}
556 if player_filter is None:
557 player_filter = []
558 if provider_filter is None:
559 provider_filter = []
560
561 user_data = {
562 "user_id": user_id,
563 "username": normalized_username,
564 "role": role.value,
565 "enabled": True,
566 "created_at": created_at.isoformat(),
567 "display_name": display_name,
568 "avatar_url": avatar_url,
569 "preferences": json_dumps(preferences),
570 "player_filter": json_dumps(player_filter),
571 "provider_filter": json_dumps(provider_filter),
572 }
573
574 await self.database.insert("users", user_data)
575
576 user = User(
577 user_id=user_id,
578 username=normalized_username,
579 role=role,
580 enabled=True,
581 created_at=created_at,
582 display_name=display_name,
583 avatar_url=avatar_url,
584 preferences=preferences,
585 player_filter=player_filter,
586 provider_filter=provider_filter,
587 )
588
589 # If this is the first non-system user, migrate playlog entries to them
590 if is_first_user and normalized_username != HOMEASSISTANT_SYSTEM_USER:
591 self._has_users = True
592 await self._migrate_playlog_to_first_user(user_id)
593
594 return user
595
596 async def _has_non_system_users(self) -> bool:
597 """Check if any non-system users exist."""
598 user_rows = await self.database.get_rows("users", limit=10)
599 return any(row["username"] != HOMEASSISTANT_SYSTEM_USER for row in user_rows)
600
601 async def _migrate_playlog_to_first_user(self, user_id: str) -> None:
602 """
603 Migrate all existing playlog entries to the first user.
604
605 This is called automatically when the first non-system user is created.
606 All existing playlog entries (which have NULL userid) will be updated
607 to belong to this first user.
608
609 :param user_id: The user ID of the first user.
610 """
611 try:
612 # Update all playlog entries with NULL userid to this user
613 await self.mass.music.database.execute(
614 f"UPDATE {DB_TABLE_PLAYLOG} SET userid = :userid WHERE userid IS NULL",
615 {"userid": user_id},
616 )
617 await self.mass.music.database.commit()
618 self.logger.info("Migrated existing playlog entries to first user: %s", user_id)
619 except Exception as err:
620 self.logger.warning("Failed to migrate playlog entries: %s", err)
621
622 async def get_homeassistant_system_user(self) -> User:
623 """
624 Get or create the Home Assistant system user.
625
626 This is a special system user created automatically for Home Assistant integration.
627 It bypasses normal authentication but is restricted to the ingress webserver.
628
629 :return: The Home Assistant system user.
630 """
631 username = HOMEASSISTANT_SYSTEM_USER
632 display_name = "Home Assistant Integration"
633 role = UserRole.USER
634
635 normalized_username = normalize_username(username)
636
637 # Try to find existing user by username
638 user_row = await self.database.get_row("users", {"username": normalized_username})
639 if user_row:
640 # Use get_user to ensure preferences are parsed correctly
641 user = await self.get_user(user_row["user_id"])
642 assert user is not None # User exists in DB, so get_user must return it
643 return user
644
645 # Create new system user
646 user = await self.create_user(
647 username=username,
648 role=role,
649 display_name=display_name,
650 )
651 self.logger.debug("Created Home Assistant system user: %s (role: %s)", username, role.value)
652 return user
653
654 async def get_homeassistant_system_user_token(self) -> str:
655 """
656 Get or create an auth token for the Home Assistant system user.
657
658 This method ensures only one active token exists for the HA integration.
659 If an old token exists, it is deleted and a new one is created.
660 The token auto-renews on use (expires after 30 days of inactivity).
661
662 :return: Authentication token for the Home Assistant system user.
663 """
664 token_name = "Home Assistant Integration"
665
666 # Get the system user
667 system_user = await self.get_homeassistant_system_user()
668
669 # Delete any existing tokens with this name to avoid accumulation
670 # We can't retrieve the plain token from the hash, so we always create a new one
671 existing_tokens = await self.database.get_rows(
672 "auth_tokens",
673 {"user_id": system_user.user_id, "name": token_name},
674 )
675 for token_row in existing_tokens:
676 await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]})
677
678 # Create a new token for the system user
679 return await self.create_token(
680 user=system_user,
681 name=token_name,
682 is_long_lived=False,
683 )
684
685 async def link_user_to_provider(
686 self,
687 user: User,
688 provider_type: AuthProviderType,
689 provider_user_id: str,
690 ) -> UserAuthProvider:
691 """
692 Link a user to an authentication provider.
693
694 If a link already exists for this provider/provider_user_id, returns the existing link.
695
696 :param user: The user to link.
697 :param provider_type: The provider type.
698 :param provider_user_id: The user ID from the provider (e.g., password hash, OAuth ID).
699 """
700 # Check if a link already exists for this provider/provider_user_id
701 existing_link = await self.database.get_row(
702 "user_auth_providers",
703 {
704 "provider_type": provider_type.value,
705 "provider_user_id": provider_user_id,
706 },
707 )
708
709 if existing_link:
710 # Link already exists - return it
711 return UserAuthProvider(
712 link_id=existing_link["link_id"],
713 user_id=existing_link["user_id"],
714 provider_type=AuthProviderType(existing_link["provider_type"]),
715 provider_user_id=existing_link["provider_user_id"],
716 created_at=datetime.fromisoformat(existing_link["created_at"]),
717 )
718
719 # Create new link
720 link_id = secrets.token_urlsafe(32)
721 created_at = utc()
722 link_data = {
723 "link_id": link_id,
724 "user_id": user.user_id,
725 "provider_type": provider_type.value,
726 "provider_user_id": provider_user_id,
727 "created_at": created_at.isoformat(),
728 }
729
730 await self.database.insert("user_auth_providers", link_data)
731
732 return UserAuthProvider(
733 link_id=link_id,
734 user_id=user.user_id,
735 provider_type=provider_type,
736 provider_user_id=provider_user_id,
737 created_at=created_at,
738 )
739
740 async def update_user(
741 self,
742 user: User,
743 username: str | None = None,
744 display_name: str | None = None,
745 avatar_url: str | None = None,
746 ) -> User:
747 """
748 Update a user's profile information.
749
750 :param user: The user to update.
751 :param username: New username (optional).
752 :param display_name: New display name (optional).
753 :param avatar_url: New avatar URL (optional).
754 """
755 updates = {}
756 if username is not None:
757 # Normalize username for case-insensitive authentication
758 updates["username"] = normalize_username(username)
759 if display_name is not None:
760 updates["display_name"] = display_name
761 if avatar_url is not None:
762 updates["avatar_url"] = avatar_url
763
764 if updates:
765 await self.database.update("users", {"user_id": user.user_id}, updates)
766
767 # Return updated user
768 updated_user = await self.get_user(user.user_id)
769 assert updated_user is not None # User exists, so get_user must return it
770 return updated_user
771
772 async def update_user_preferences(
773 self,
774 user: User,
775 preferences: dict[str, Any],
776 ) -> User:
777 """
778 Update a user's preferences.
779
780 :param user: The user to update.
781 :param preferences: New preferences dict (completely replaces existing preferences).
782 """
783 # Verify user exists
784 current_user = await self.get_user(user.user_id)
785 if not current_user:
786 raise ValueError(f"User {user.user_id} not found")
787
788 # Update database with new preferences (complete replacement)
789 await self.database.update(
790 "users",
791 {"user_id": user.user_id},
792 {"preferences": json_dumps(preferences)},
793 )
794
795 # Return updated user
796 updated_user = await self.get_user(user.user_id)
797 assert updated_user is not None # User exists, so get_user must return it
798 return updated_user
799
800 async def update_provider_link(
801 self,
802 user: User,
803 provider_type: AuthProviderType,
804 provider_user_id: str,
805 ) -> None:
806 """
807 Update a user's provider link (e.g., change password).
808
809 :param user: The user.
810 :param provider_type: The provider type.
811 :param provider_user_id: The new provider user ID (e.g., new password hash).
812 """
813 # Find existing link
814 link_row = await self.database.get_row(
815 "user_auth_providers",
816 {
817 "user_id": user.user_id,
818 "provider_type": provider_type.value,
819 },
820 )
821
822 if link_row:
823 # Update existing link
824 await self.database.update(
825 "user_auth_providers",
826 {"link_id": link_row["link_id"]},
827 {"provider_user_id": provider_user_id},
828 )
829 else:
830 # Create new link
831 await self.link_user_to_provider(user, provider_type, provider_user_id)
832
833 async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str:
834 """
835 Create a new access token for a user.
836
837 :param user: The user to create the token for.
838 :param name: A name/description for the token (e.g., device name).
839 :param is_long_lived: Whether this is a long-lived token (default: False).
840 Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity.
841 Long-lived tokens (True): No auto-renewal, expire after 10 years.
842 """
843 # Generate unique token ID
844 token_id = secrets.token_urlsafe(32)
845
846 # Calculate expiration based on token type
847 created_at = utc()
848 if is_long_lived:
849 # Long-lived tokens expire after 10 years (no auto-renewal)
850 expires_at = created_at + timedelta(days=TOKEN_LONG_LIVED_EXPIRATION)
851 else:
852 # Short-lived tokens expire after 30 days (with auto-renewal on use)
853 expires_at = created_at + timedelta(days=TOKEN_SHORT_LIVED_EXPIRATION)
854
855 # Generate JWT token
856 token = self.jwt_helper.encode_token(
857 user=user,
858 token_id=token_id,
859 token_name=name,
860 expires_at=expires_at,
861 is_long_lived=is_long_lived,
862 )
863
864 # Store token hash in database for revocation checking
865 token_hash = hashlib.sha256(token.encode()).hexdigest()
866 token_data = {
867 "token_id": token_id,
868 "user_id": user.user_id,
869 "token_hash": token_hash,
870 "name": name,
871 "created_at": created_at.isoformat(),
872 "expires_at": expires_at.isoformat(),
873 "is_long_lived": 1 if is_long_lived else 0,
874 }
875 await self.database.insert("auth_tokens", token_data)
876
877 return token
878
879 @api_command("auth/token/revoke")
880 async def revoke_token(self, token_id: str) -> None:
881 """
882 Revoke an auth token.
883
884 :param token_id: The token ID to revoke.
885 """
886 user = get_current_user()
887 if not user:
888 raise AuthenticationRequired("Not authenticated")
889
890 token_row = await self.database.get_row("auth_tokens", {"token_id": token_id})
891 if not token_row:
892 raise InvalidDataError("Token not found")
893
894 # Check permissions - users can only revoke their own tokens unless admin
895 if token_row["user_id"] != user.user_id and user.role != UserRole.ADMIN:
896 raise InsufficientPermissions("You can only revoke your own tokens")
897
898 await self.database.delete("auth_tokens", {"token_id": token_id})
899
900 # Disconnect any WebSocket connections using this token
901 self.webserver.disconnect_websockets_for_token(token_id)
902
903 @api_command("auth/tokens")
904 async def get_user_tokens(self, user_id: str | None = None) -> list[AuthToken]:
905 """
906 Get current user's auth tokens or another user's tokens (admin only).
907
908 :param user_id: Optional user ID to get tokens for (admin only).
909 :return: List of auth tokens.
910 """
911 current_user = get_current_user()
912 if not current_user:
913 return []
914
915 # If user_id is provided and different from current user, require admin
916 if user_id and user_id != current_user.user_id:
917 if current_user.role != UserRole.ADMIN:
918 return []
919 target_user = await self.get_user(user_id)
920 if not target_user:
921 return []
922 else:
923 target_user = current_user
924
925 token_rows = await self.database.get_rows(
926 "auth_tokens", {"user_id": target_user.user_id}, limit=100
927 )
928 return [AuthToken.from_dict(dict(row)) for row in token_rows]
929
930 @api_command("auth/users", required_role="admin")
931 async def list_users(self) -> list[User]:
932 """
933 Get all users (admin only).
934
935 System users are excluded from the list.
936
937 :return: List of user objects.
938 """
939 user_rows = await self.database.get_rows("users", limit=1000)
940 users = []
941 for row in user_rows:
942 # Skip system users
943 if row["username"] == HOMEASSISTANT_SYSTEM_USER:
944 continue
945 users.append(
946 User(
947 user_id=row["user_id"],
948 username=row["username"],
949 role=UserRole(row["role"]),
950 enabled=bool(row["enabled"]),
951 created_at=datetime.fromisoformat(row["created_at"]),
952 display_name=row["display_name"],
953 avatar_url=row["avatar_url"],
954 preferences=json_loads(row["preferences"]),
955 player_filter=json_loads(row["player_filter"]),
956 provider_filter=json_loads(row["provider_filter"]),
957 )
958 )
959 return users
960
961 async def update_user_role(self, user_id: str, new_role: UserRole, admin_user: User) -> bool:
962 """
963 Update a user's role (admin only).
964
965 :param user_id: The user ID to update.
966 :param new_role: The new role to assign.
967 :param admin_user: The admin user performing the action.
968 """
969 if admin_user.role != UserRole.ADMIN:
970 return False
971
972 user_row = await self.database.get_row("users", {"user_id": user_id})
973 if not user_row:
974 return False
975
976 await self.database.update(
977 "users",
978 {"user_id": user_id},
979 {"role": new_role.value},
980 )
981 return True
982
983 @api_command("auth/user/enable", required_role="admin")
984 async def enable_user(self, user_id: str) -> None:
985 """
986 Enable user account (admin only).
987
988 :param user_id: The user ID.
989 """
990 await self.database.update(
991 "users",
992 {"user_id": user_id},
993 {"enabled": 1},
994 )
995
996 @api_command("auth/user/disable", required_role="admin")
997 async def disable_user(self, user_id: str) -> None:
998 """
999 Disable user account (admin only).
1000
1001 :param user_id: The user ID.
1002 """
1003 admin_user = get_current_user()
1004 if not admin_user:
1005 raise AuthenticationRequired("Not authenticated")
1006
1007 # Cannot disable yourself
1008 if user_id == admin_user.user_id:
1009 raise InvalidDataError("Cannot disable your own account")
1010
1011 await self.database.update(
1012 "users",
1013 {"user_id": user_id},
1014 {"enabled": 0},
1015 )
1016
1017 # Disconnect all WebSocket connections for this user
1018 self.webserver.disconnect_websockets_for_user(user_id)
1019
1020 async def get_login_providers(self) -> list[dict[str, Any]]:
1021 """Get list of available login providers (dynamically checks for HA provider)."""
1022 # Sync HA OAuth provider with HA provider availability
1023 await self._sync_ha_oauth_provider()
1024
1025 providers = []
1026 for provider_id, provider in self.login_providers.items():
1027 providers.append(
1028 {
1029 "provider_id": provider_id,
1030 "provider_type": provider.provider_type.value,
1031 "requires_redirect": provider.requires_redirect,
1032 }
1033 )
1034 return providers
1035
1036 @api_command("auth/login", authenticated=False)
1037 async def login(
1038 self,
1039 username: str | None = None,
1040 password: str | None = None,
1041 provider_id: str = "builtin",
1042 device_name: str | None = None,
1043 **extra_credentials: Any,
1044 ) -> dict[str, Any]:
1045 """Authenticate user with credentials via WebSocket.
1046
1047 This command allows clients to authenticate over the WebSocket connection
1048 using username/password or other provider-specific credentials.
1049
1050 :param username: Username for authentication (for builtin provider).
1051 :param password: Password for authentication (for builtin provider).
1052 :param provider_id: The login provider ID (defaults to "builtin").
1053 :param device_name: Optional device name for the token (e.g., "iPhone 15", "Desktop PC").
1054 :param extra_credentials: Additional provider-specific credentials.
1055 :return: Authentication result with access token if successful.
1056 """
1057 # Build credentials dict from parameters
1058 credentials: dict[str, Any] = {}
1059 if username is not None:
1060 credentials["username"] = username
1061 if password is not None:
1062 credentials["password"] = password
1063 credentials.update(extra_credentials)
1064
1065 auth_result = await self.authenticate_with_credentials(provider_id, credentials)
1066
1067 if not auth_result.success:
1068 return {
1069 "success": False,
1070 "error": auth_result.error or "Authentication failed",
1071 }
1072
1073 if not auth_result.user:
1074 return {
1075 "success": False,
1076 "error": "Authentication failed: no user returned",
1077 }
1078
1079 # Create short-lived access token with device name if provided
1080 token_name = device_name or f"WebSocket Session - {auth_result.user.username}"
1081 token = await self.create_token(
1082 auth_result.user,
1083 is_long_lived=False,
1084 name=token_name,
1085 )
1086
1087 return {
1088 "success": True,
1089 "access_token": token,
1090 "user": {
1091 "user_id": auth_result.user.user_id,
1092 "username": auth_result.user.username,
1093 "display_name": auth_result.user.display_name,
1094 "role": auth_result.user.role.value,
1095 },
1096 }
1097
1098 @api_command("auth/providers", authenticated=False)
1099 async def get_providers(self) -> list[dict[str, Any]]:
1100 """Get list of available authentication providers.
1101
1102 Returns information about all available login providers including
1103 whether they require OAuth redirect flow.
1104 """
1105 return await self.get_login_providers()
1106
1107 @api_command("auth/authorization_url", authenticated=False)
1108 async def get_auth_url(
1109 self,
1110 provider_id: str,
1111 return_url: str | None = None,
1112 ) -> dict[str, str | None]:
1113 """Get OAuth authorization URL for authentication.
1114
1115 For OAuth providers (like Home Assistant), this returns the URL that
1116 the user should visit in their browser to authorize the application.
1117
1118 :param provider_id: The provider ID (e.g., "hass").
1119 :param return_url: URL to redirect to after OAuth completes.
1120 :return: Dictionary with authorization_url.
1121 """
1122 auth_url = await self.get_authorization_url(provider_id, return_url)
1123 if not auth_url:
1124 return {
1125 "authorization_url": None,
1126 "error": "Provider does not support OAuth or does not exist",
1127 }
1128
1129 return {
1130 "authorization_url": auth_url,
1131 }
1132
1133 async def get_authorization_url(
1134 self, provider_id: str, return_url: str | None = None
1135 ) -> str | None:
1136 """
1137 Get OAuth authorization URL for a provider.
1138
1139 :param provider_id: The provider ID.
1140 :param return_url: Optional URL to redirect to after successful login.
1141 """
1142 provider = self.login_providers.get(provider_id)
1143 if not provider or not provider.requires_redirect:
1144 return None
1145
1146 # Build callback redirect_uri
1147 redirect_uri = f"{self.webserver.base_url}/auth/callback?provider_id={provider_id}"
1148 return await provider.get_authorization_url(redirect_uri, return_url)
1149
1150 async def handle_oauth_callback(
1151 self, provider_id: str, code: str, state: str, redirect_uri: str
1152 ) -> AuthResult:
1153 """
1154 Handle OAuth callback.
1155
1156 :param provider_id: The provider ID.
1157 :param code: OAuth authorization code.
1158 :param state: OAuth state parameter.
1159 :param redirect_uri: The callback URL.
1160 """
1161 provider = self.login_providers.get(provider_id)
1162 if not provider:
1163 return AuthResult(success=False, error="Invalid provider")
1164
1165 return await provider.handle_oauth_callback(code, state, redirect_uri)
1166
1167 @api_command("auth/token/create")
1168 async def create_long_lived_token(self, name: str, user_id: str | None = None) -> str:
1169 """
1170 Create a new long-lived access token for current user or another user (admin only).
1171
1172 Long-lived tokens are intended for external integrations and API access.
1173 They expire after 10 years and do NOT auto-renew on use.
1174
1175 Short-lived tokens (for regular user sessions) are only created during login
1176 and auto-renew on each use (sliding 30-day expiration window).
1177
1178 :param name: The name/description for the token (e.g., "Home Assistant", "Mobile App").
1179 :param user_id: Optional user ID to create token for (admin only).
1180 :return: The created token string.
1181 """
1182 current_user = get_current_user()
1183 if not current_user:
1184 raise AuthenticationRequired("Not authenticated")
1185
1186 # If user_id is provided and different from current user, require admin
1187 if user_id and user_id != current_user.user_id:
1188 if current_user.role != UserRole.ADMIN:
1189 raise InsufficientPermissions(
1190 "Admin access required to create tokens for other users"
1191 )
1192 target_user = await self.get_user(user_id)
1193 if not target_user:
1194 raise InvalidDataError("User not found")
1195 else:
1196 target_user = current_user
1197
1198 # Create a long-lived token (only long-lived tokens can be created via this command)
1199 token = await self.create_token(target_user, name, is_long_lived=True)
1200 self.logger.info("Created long-lived token '%s' for user '%s'", name, target_user.username)
1201 return token
1202
1203 @api_command("auth/user/create", required_role="admin")
1204 async def create_user_with_api(
1205 self,
1206 username: str,
1207 password: str,
1208 role: str = "user",
1209 display_name: str | None = None,
1210 avatar_url: str | None = None,
1211 player_filter: list[str] | None = None,
1212 provider_filter: list[str] | None = None,
1213 ) -> User:
1214 """
1215 Create a new user with built-in authentication (admin only).
1216
1217 :param username: The username (minimum 2 characters).
1218 :param password: The password (minimum 8 characters).
1219 :param role: User role - "admin" or "user" (default: "user").
1220 :param display_name: Optional display name.
1221 :param avatar_url: Optional avatar URL.
1222 :param player_filter: Optional list of player IDs user has access to.
1223 :param provider_filter: Optional list of provider instance IDs user has access to.
1224 :return: Created user object.
1225 """
1226 # Validation
1227 if not username or len(username) < 2:
1228 raise InvalidDataError("Username must be at least 2 characters")
1229
1230 if not password or len(password) < 8:
1231 raise InvalidDataError("Password must be at least 8 characters")
1232
1233 # Validate role
1234 try:
1235 user_role = UserRole(role)
1236 except ValueError as err:
1237 raise InvalidDataError("Invalid role. Must be 'admin' or 'user'") from err
1238
1239 # Get built-in provider
1240 builtin_provider = self.login_providers.get("builtin")
1241 if not builtin_provider or not isinstance(builtin_provider, BuiltinLoginProvider):
1242 raise InvalidDataError("Built-in auth provider not available")
1243
1244 # Create user with password
1245 user = await builtin_provider.create_user_with_password(
1246 username,
1247 password,
1248 role=user_role,
1249 player_filter=player_filter,
1250 provider_filter=provider_filter,
1251 )
1252
1253 # Update optional fields if provided
1254 if display_name or avatar_url:
1255 updated_user = await self.update_user(
1256 user, display_name=display_name, avatar_url=avatar_url
1257 )
1258 if updated_user:
1259 user = updated_user
1260
1261 self.logger.info("User created by admin: %s (role: %s)", username, role)
1262 return user
1263
1264 @api_command("auth/user/delete", required_role="admin")
1265 async def delete_user(self, user_id: str) -> None:
1266 """
1267 Delete user account (admin only).
1268
1269 :param user_id: The user ID.
1270 """
1271 admin_user = get_current_user()
1272 if not admin_user:
1273 raise AuthenticationRequired("Not authenticated")
1274
1275 # Don't allow deleting yourself
1276 if user_id == admin_user.user_id:
1277 raise InvalidDataError("Cannot delete your own account")
1278
1279 # Delete user from database
1280 await self.database.delete("users", {"user_id": user_id})
1281 await self.database.commit()
1282
1283 # Disconnect all WebSocket connections for this user
1284 self.webserver.disconnect_websockets_for_user(user_id)
1285
1286 @api_command("auth/me")
1287 async def get_current_user_info(self) -> User:
1288 """Get current authenticated user information."""
1289 current_user_obj = get_current_user()
1290 if not current_user_obj:
1291 raise AuthenticationRequired("Not authenticated")
1292 return current_user_obj
1293
1294 async def _update_profile_password(
1295 self,
1296 target_user: User,
1297 password: str,
1298 is_admin_update: bool,
1299 current_user: User,
1300 ) -> None:
1301 """Update user password (helper method)."""
1302 if len(password) < 8:
1303 raise InvalidDataError("Password must be at least 8 characters")
1304
1305 builtin_provider = self.login_providers.get("builtin")
1306 if not builtin_provider or not isinstance(builtin_provider, BuiltinLoginProvider):
1307 raise InvalidDataError("Built-in auth not available")
1308
1309 # Update password (used for both admin resets and user password changes)
1310 await builtin_provider.reset_password(target_user, password)
1311
1312 if is_admin_update:
1313 self.logger.info(
1314 "Password reset for user %s by admin %s",
1315 target_user.username,
1316 current_user.username,
1317 )
1318 else:
1319 self.logger.info("Password changed for user %s", target_user.username)
1320
1321 async def update_user_filters(
1322 self,
1323 target_user: User,
1324 player_filter: list[str] | None,
1325 provider_filter: list[str] | None,
1326 ) -> User:
1327 """Update user player and provider filters (helper method)."""
1328 updates = {}
1329 if player_filter is not None:
1330 updates["player_filter"] = json_dumps(player_filter)
1331 if provider_filter is not None:
1332 updates["provider_filter"] = json_dumps(provider_filter)
1333
1334 if updates:
1335 await self.database.update("users", {"user_id": target_user.user_id}, updates)
1336 # Refresh target user to get updated filters
1337 refreshed_user = await self.get_user(target_user.user_id)
1338 if not refreshed_user:
1339 raise InvalidDataError("Failed to refresh user after filter update")
1340 return refreshed_user
1341 return target_user
1342
1343 @api_command("auth/user/update")
1344 async def update_user_profile(
1345 self,
1346 user_id: str | None = None,
1347 username: str | None = None,
1348 display_name: str | None = None,
1349 avatar_url: str | None = None,
1350 password: str | None = None,
1351 role: str | None = None,
1352 preferences: dict[str, Any] | None = None,
1353 player_filter: list[str] | None = None,
1354 provider_filter: list[str] | None = None,
1355 ) -> User:
1356 """
1357 Update user profile information.
1358
1359 Users can update their own profile. Admins can update any user including role and password.
1360
1361 :param user_id: User ID to update (optional, defaults to current user).
1362 :param username: New username (optional).
1363 :param display_name: New display name (optional).
1364 :param avatar_url: New avatar URL (optional).
1365 :param password: New password (optional, minimum 8 characters).
1366 :param role: New role - "admin" or "user" (optional, set by admin only).
1367 :param preferences: User preferences dict (completely replaces existing, optional).
1368 :param player_filter: List of player IDs user has access to (set by admin only, optional).
1369 :param provider_filter: List of provider instance IDs user has access to (set by admin only, optional).
1370 :return: Updated user object.
1371 """ # noqa: E501
1372 current_user_obj = get_current_user()
1373 if not current_user_obj:
1374 raise AuthenticationRequired("Not authenticated")
1375
1376 # Determine target user
1377 is_admin = current_user_obj.role == UserRole.ADMIN
1378 if user_id and user_id != current_user_obj.user_id:
1379 # Updating another user - requires admin
1380 if not is_admin:
1381 raise InsufficientPermissions("Admin access required")
1382 target_user = await self.get_user(user_id)
1383 if not target_user:
1384 raise InvalidDataError("User not found")
1385 else:
1386 # Updating own profile
1387 target_user = current_user_obj
1388
1389 # Update role (admin only)
1390 if role:
1391 if not is_admin:
1392 raise InsufficientPermissions("Only admins can update user roles")
1393
1394 try:
1395 new_role = UserRole(role)
1396 except ValueError as err:
1397 raise InvalidDataError("Invalid role. Must be 'admin' or 'user'") from err
1398
1399 success = await self.update_user_role(target_user.user_id, new_role, current_user_obj)
1400 if not success:
1401 raise InvalidDataError("Failed to update role")
1402
1403 # Refresh target user to get updated role
1404 refreshed_user = await self.get_user(target_user.user_id)
1405 if not refreshed_user:
1406 raise InvalidDataError("Failed to refresh user after role update")
1407 target_user = refreshed_user
1408
1409 # Update basic profile fields
1410 if username or display_name or avatar_url:
1411 updated_user = await self.update_user(
1412 target_user,
1413 username=username,
1414 display_name=display_name,
1415 avatar_url=avatar_url,
1416 )
1417 if not updated_user:
1418 raise InvalidDataError("Failed to update user profile")
1419 target_user = updated_user
1420
1421 # Update preferences if provided
1422 if preferences is not None:
1423 target_user = await self.update_user_preferences(target_user, preferences)
1424
1425 # Update player_filter and provider_filter (admin only)
1426 if player_filter is not None or provider_filter is not None:
1427 if not is_admin:
1428 raise InsufficientPermissions("Only admins can update player/provider filters")
1429 target_user = await self.update_user_filters(
1430 target_user, player_filter, provider_filter
1431 )
1432
1433 # Update password if provided
1434 if password:
1435 await self._update_profile_password(target_user, password, is_admin, current_user_obj)
1436
1437 return target_user
1438
1439 @api_command("auth/logout")
1440 async def logout(self) -> None:
1441 """Logout current user by revoking the current token."""
1442 user = get_current_user()
1443 if not user:
1444 raise AuthenticationRequired("Not authenticated")
1445
1446 # Get current token from context
1447 token = get_current_token()
1448 if not token:
1449 raise InvalidDataError("No token in context")
1450
1451 # Find and revoke the token
1452 token_hash = hashlib.sha256(token.encode()).hexdigest()
1453 token_row = await self.database.get_row("auth_tokens", {"token_hash": token_hash})
1454 if token_row:
1455 await self.database.delete("auth_tokens", {"token_id": token_row["token_id"]})
1456
1457 # Disconnect any WebSocket connections using this token
1458 self.webserver.disconnect_websockets_for_token(token_row["token_id"])
1459
1460 @api_command("auth/user/providers")
1461 async def get_my_providers(self) -> list[dict[str, Any]]:
1462 """
1463 Get current user's linked authentication providers.
1464
1465 :return: List of provider links.
1466 """
1467 user = get_current_user()
1468 if not user:
1469 return []
1470
1471 # Get provider links from database
1472 rows = await self.database.get_rows("user_auth_providers", {"user_id": user.user_id})
1473 providers = [UserAuthProvider.from_dict(dict(row)) for row in rows]
1474 return [p.to_dict() for p in providers]
1475
1476 @api_command("auth/user/unlink_provider", required_role="admin")
1477 async def unlink_provider(self, user_id: str, provider_type: str) -> bool:
1478 """
1479 Unlink authentication provider from user (admin only).
1480
1481 :param user_id: The user ID.
1482 :param provider_type: Provider type to unlink.
1483 :return: True if successful.
1484 """
1485 await self.database.delete(
1486 "user_auth_providers", {"user_id": user_id, "provider_type": provider_type}
1487 )
1488 await self.database.commit()
1489 return True
1490