/
/
/
1"""Authentication provider base classes and implementations."""
2
3from __future__ import annotations
4
5import asyncio
6import hashlib
7import logging
8import secrets
9from abc import ABC, abstractmethod
10from dataclasses import dataclass
11from datetime import datetime, timedelta
12from typing import TYPE_CHECKING, Any, TypedDict, cast
13from urllib.parse import urlparse
14
15from hass_client import HomeAssistantClient
16from hass_client.exceptions import BaseHassClientError
17from hass_client.utils import base_url, get_auth_url, get_token, get_websocket_url
18from music_assistant_models.auth import AuthProviderType, User, UserRole
19from music_assistant_models.errors import AuthenticationFailed
20
21from music_assistant.constants import CONF_AUTH_ALLOW_SELF_REGISTRATION, MASS_LOGGER_NAME
22from music_assistant.helpers.datetime import utc
23
24if TYPE_CHECKING:
25 from music_assistant import MusicAssistant
26 from music_assistant.controllers.webserver.auth import AuthenticationManager
27 from music_assistant.providers.hass import HomeAssistantProvider
28
29
30LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth")
31
32
33def normalize_username(username: str) -> str:
34 """
35 Normalize username to lowercase for case-insensitive comparison.
36
37 :param username: The username to normalize.
38 :return: Normalized username (lowercase, stripped).
39 """
40 return username.strip().lower()
41
42
43async def get_ha_user_details(
44 mass: MusicAssistant, ha_user_id: str, wait_timeout: float = 30.0
45) -> tuple[str | None, str | None, str | None]:
46 """
47 Get user username, display name and avatar URL from Home Assistant.
48
49 Uses the existing HA provider connection (which has admin access) to fetch
50 user details from config/auth/list and the person entity.
51
52 :param mass: MusicAssistant instance.
53 :param ha_user_id: Home Assistant user ID.
54 :param wait_timeout: Maximum time to wait for HA provider to become available (default 30s).
55 :return: Tuple of (username, display_name, avatar_url) or all None if not found.
56 """
57 # Wait for the HA provider to become available (handles race condition at startup)
58 hass_prov = None
59 wait_interval = 0.5
60 elapsed = 0.0
61 while elapsed < wait_timeout:
62 hass_prov = mass.get_provider("hass")
63 if hass_prov is not None and hass_prov.available:
64 break
65 await asyncio.sleep(wait_interval)
66 elapsed += wait_interval
67 hass_prov = None # Reset to None for the final check
68
69 if hass_prov is None or not hass_prov.available:
70 LOGGER.debug("HA provider not available after %.1fs, cannot fetch user details", elapsed)
71 return None, None, None
72
73 hass_prov = cast("HomeAssistantProvider", hass_prov)
74 return await hass_prov.get_user_details(ha_user_id)
75
76
77async def get_ha_user_role(
78 mass: MusicAssistant, ha_user_id: str, wait_timeout: float = 30.0
79) -> UserRole:
80 """
81 Get user role based on Home Assistant admin status.
82
83 :param mass: MusicAssistant instance.
84 :param ha_user_id: The Home Assistant user ID to check.
85 :param wait_timeout: Maximum time to wait for HA provider to become available (default 30s).
86 """
87 try:
88 # Wait for the HA provider to become available (handles race condition at startup)
89 hass_prov = None
90 wait_interval = 0.5
91 elapsed = 0.0
92 while elapsed < wait_timeout:
93 hass_prov = mass.get_provider("hass")
94 if hass_prov is not None and hass_prov.available:
95 break
96 await asyncio.sleep(wait_interval)
97 elapsed += wait_interval
98 hass_prov = None # Reset to None for the final check
99
100 if hass_prov is None or not hass_prov.available:
101 raise RuntimeError("Home Assistant provider not available")
102
103 if TYPE_CHECKING:
104 hass_prov = cast("HomeAssistantProvider", hass_prov)
105 # Query HA for user list to check admin status
106 result = await hass_prov.hass.send_command("config/auth/list")
107 if not result:
108 raise RuntimeError("Failed to retrieve user list from Home Assistant")
109 for ha_user in result:
110 if ha_user.get("id") == ha_user_id:
111 # User is admin if they have "system-admin" in their group_ids
112 group_ids = ha_user.get("group_ids", [])
113 if "system-admin" in group_ids:
114 LOGGER.debug("HA user %s is admin, granting ADMIN role", ha_user_id)
115 return UserRole.ADMIN
116 return UserRole.USER
117 raise RuntimeError(f"HA user ID {ha_user_id} not found in user list")
118 except Exception as err:
119 msg = f"Failed to check HA admin status: {err}"
120 raise AuthenticationFailed(msg) from err
121
122
123class LoginRateLimiter:
124 """Rate limiter for login attempts to prevent brute force attacks."""
125
126 def __init__(self) -> None:
127 """Initialize the rate limiter."""
128 # Track failed attempts per username: {username: [timestamp1, timestamp2, ...]}
129 self._failed_attempts: dict[str, list[datetime]] = {}
130 # Time window for tracking attempts (30 minutes)
131 self._tracking_window = timedelta(minutes=30)
132 # Lock for thread-safe access to _failed_attempts
133 self._lock = asyncio.Lock()
134
135 def _cleanup_old_attempts(self, username: str) -> None:
136 """
137 Remove failed attempts outside the tracking window.
138
139 :param username: The username to clean up.
140 """
141 if username not in self._failed_attempts:
142 return
143
144 cutoff_time = utc() - self._tracking_window
145 self._failed_attempts[username] = [
146 timestamp for timestamp in self._failed_attempts[username] if timestamp > cutoff_time
147 ]
148
149 # Remove username if no attempts left
150 if not self._failed_attempts[username]:
151 del self._failed_attempts[username]
152
153 def get_delay(self, username: str) -> int:
154 """
155 Get the delay in seconds before next login attempt is allowed.
156
157 Progressive delays based on failed attempts:
158 - 1-2 attempts: no delay
159 - 3-5 attempts: 30 seconds
160 - 6-9 attempts: 60 seconds
161 - 10-14 attempts: 120 seconds
162 - 15+ attempts: 300 seconds (5 minutes)
163
164 :param username: The username attempting to log in.
165 :return: Delay in seconds (0 if no delay needed).
166 """
167 self._cleanup_old_attempts(username)
168
169 if username not in self._failed_attempts:
170 return 0
171
172 attempt_count = len(self._failed_attempts[username])
173
174 if attempt_count < 3:
175 return 0
176 if attempt_count < 6:
177 return 30
178 if attempt_count < 10:
179 return 60
180 if attempt_count < 15:
181 return 120
182 return 300 # 5 minutes max delay
183
184 async def check_rate_limit(self, username: str) -> tuple[bool, int]:
185 """
186 Check if login attempt is allowed and apply delay if needed.
187
188 :param username: The username attempting to log in.
189 :return: Tuple of (allowed, delay_seconds). If not allowed, includes remaining delay.
190 """
191 async with self._lock:
192 self._cleanup_old_attempts(username)
193
194 if username not in self._failed_attempts or not self._failed_attempts[username]:
195 return True, 0
196
197 # Get the most recent failed attempt
198 last_attempt = self._failed_attempts[username][-1]
199 required_delay = self.get_delay(username)
200
201 if required_delay == 0:
202 return True, 0
203
204 # Calculate how much time has passed since last attempt
205 time_since_last = (utc() - last_attempt).total_seconds()
206
207 if time_since_last < required_delay:
208 # Still in cooldown period
209 remaining_delay = int(required_delay - time_since_last)
210 return False, remaining_delay
211
212 return True, 0
213
214 async def record_failed_attempt(self, username: str) -> None:
215 """
216 Record a failed login attempt.
217
218 :param username: The username that failed to log in.
219 """
220 async with self._lock:
221 self._cleanup_old_attempts(username)
222
223 if username not in self._failed_attempts:
224 self._failed_attempts[username] = []
225
226 self._failed_attempts[username].append(utc())
227
228 # Log warning for suspicious activity
229 attempt_count = len(self._failed_attempts[username])
230 if attempt_count == 10:
231 LOGGER.warning(
232 "Suspicious login activity: 10 failed attempts for username '%s'", username
233 )
234 elif attempt_count == 20:
235 LOGGER.warning(
236 "High suspicious login activity: 20 failed attempts for username '%s'. "
237 "Consider manually disabling this account.",
238 username,
239 )
240
241 async def clear_attempts(self, username: str) -> None:
242 """
243 Clear failed attempts for a username (called after successful login).
244
245 :param username: The username to clear.
246 """
247 async with self._lock:
248 if username in self._failed_attempts:
249 del self._failed_attempts[username]
250
251
252class LoginProviderConfig(TypedDict, total=False):
253 """Base configuration for login providers."""
254
255
256class HomeAssistantProviderConfig(LoginProviderConfig):
257 """Configuration for Home Assistant OAuth provider."""
258
259 ha_url: str
260
261
262@dataclass
263class AuthResult:
264 """Result of an authentication attempt."""
265
266 success: bool
267 user: User | None = None
268 error: str | None = None
269 access_token: str | None = None
270 return_url: str | None = None
271
272
273class LoginProvider(ABC):
274 """Base class for login providers."""
275
276 def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
277 """
278 Initialize login provider.
279
280 :param mass: MusicAssistant instance.
281 :param provider_id: Unique identifier for this provider instance.
282 :param config: Provider-specific configuration.
283 """
284 self.mass = mass
285 self.provider_id = provider_id
286 self.config = config
287 self.logger = LOGGER
288
289 @property
290 def allow_self_registration(self) -> bool:
291 """Return whether self-registration is allowed for this provider."""
292 return False
293
294 @property
295 def auth_manager(self) -> AuthenticationManager:
296 """Get auth manager from webserver."""
297 return self.mass.webserver.auth
298
299 @property
300 @abstractmethod
301 def provider_type(self) -> AuthProviderType:
302 """Return the provider type."""
303
304 @property
305 @abstractmethod
306 def requires_redirect(self) -> bool:
307 """Return True if this provider requires OAuth redirect."""
308
309 @abstractmethod
310 async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
311 """
312 Authenticate user with provided credentials.
313
314 :param credentials: Provider-specific credentials (username/password, OAuth code, etc).
315 """
316
317 async def get_authorization_url(
318 self, redirect_uri: str, return_url: str | None = None
319 ) -> str | None:
320 """
321 Get OAuth authorization URL if applicable.
322
323 :param redirect_uri: The callback URL for OAuth flow.
324 :param return_url: Optional URL to redirect to after successful login.
325 """
326 return None
327
328 async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult:
329 """
330 Handle OAuth callback if applicable.
331
332 :param code: OAuth authorization code.
333 :param state: OAuth state parameter for CSRF protection.
334 :param redirect_uri: The callback URL.
335 """
336 return AuthResult(success=False, error="OAuth not supported by this provider")
337
338
339class BuiltinLoginProvider(LoginProvider):
340 """Built-in username/password login provider."""
341
342 def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
343 """
344 Initialize built-in login provider.
345
346 :param mass: MusicAssistant instance.
347 :param provider_id: Unique identifier for this provider instance.
348 :param config: Provider-specific configuration.
349 """
350 super().__init__(mass, provider_id, config)
351 self._rate_limiter = LoginRateLimiter()
352
353 @property
354 def provider_type(self) -> AuthProviderType:
355 """Return the provider type."""
356 return AuthProviderType.BUILTIN
357
358 @property
359 def requires_redirect(self) -> bool:
360 """Return False - built-in provider doesn't need redirect."""
361 return False
362
363 async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
364 """
365 Authenticate user with username and password.
366
367 :param credentials: Dict containing 'username' and 'password'.
368 """
369 username = credentials.get("username")
370 password = credentials.get("password")
371
372 if not username or not password:
373 return AuthResult(success=False, error="Username and password required")
374
375 username = normalize_username(username)
376
377 # Check rate limit before attempting authentication
378 allowed, remaining_delay = await self._rate_limiter.check_rate_limit(username)
379 if not allowed:
380 self.logger.warning(
381 "Rate limit exceeded for username '%s'. %d seconds remaining.",
382 username,
383 remaining_delay,
384 )
385 return AuthResult(
386 success=False,
387 error=f"Too many failed attempts. Please try again in {remaining_delay} seconds.",
388 )
389
390 # First, look up user by username to get user_id
391 # This is needed to create the password hash with user_id in the salt
392 user_row = await self.auth_manager.database.get_row("users", {"username": username})
393 if not user_row:
394 # Record failed attempt even if username doesn't exist
395 # This prevents username enumeration timing attacks
396 await self._rate_limiter.record_failed_attempt(username)
397 return AuthResult(success=False, error="Invalid username or password")
398
399 user_id = user_row["user_id"]
400
401 # Hash the password using user_id for enhanced security
402 password_hash = self._hash_password(password, user_id)
403
404 # Verify the password by checking if provider link exists
405 user = await self.auth_manager.get_user_by_provider_link(
406 AuthProviderType.BUILTIN, password_hash
407 )
408
409 if not user:
410 # Record failed attempt
411 await self._rate_limiter.record_failed_attempt(username)
412 return AuthResult(success=False, error="Invalid username or password")
413
414 # Check if user is enabled
415 if not user.enabled:
416 # Record failed attempt for disabled accounts too
417 await self._rate_limiter.record_failed_attempt(username)
418 return AuthResult(success=False, error="User account is disabled")
419
420 # Successful login - clear any failed attempts
421 await self._rate_limiter.clear_attempts(username)
422 return AuthResult(success=True, user=user)
423
424 async def create_user_with_password(
425 self,
426 username: str,
427 password: str,
428 role: UserRole = UserRole.USER,
429 display_name: str | None = None,
430 player_filter: list[str] | None = None,
431 provider_filter: list[str] | None = None,
432 ) -> User:
433 """
434 Create a new built-in user with password.
435
436 :param username: The username.
437 :param password: The password (will be hashed).
438 :param role: The user role (default: USER).
439 :param display_name: Optional display name.
440 :param player_filter: Optional list of player IDs user has access to.
441 :param provider_filter: Optional list of provider instance IDs user has access to.
442 """
443 # Create the user
444 user = await self.auth_manager.create_user(
445 username=username,
446 role=role,
447 display_name=display_name,
448 player_filter=player_filter,
449 provider_filter=provider_filter,
450 )
451
452 # Hash password using user_id for enhanced security
453 password_hash = self._hash_password(password, user.user_id)
454 await self.auth_manager.link_user_to_provider(user, AuthProviderType.BUILTIN, password_hash)
455
456 return user
457
458 async def change_password(self, user: User, old_password: str, new_password: str) -> bool:
459 """
460 Change user password.
461
462 :param user: The user.
463 :param old_password: Current password for verification.
464 :param new_password: The new password.
465 """
466 # Verify old password first using user_id
467 old_password_hash = self._hash_password(old_password, user.user_id)
468 existing_user = await self.auth_manager.get_user_by_provider_link(
469 AuthProviderType.BUILTIN, old_password_hash
470 )
471
472 if not existing_user or existing_user.user_id != user.user_id:
473 return False
474
475 # Update password link with new hash using user_id
476 new_password_hash = self._hash_password(new_password, user.user_id)
477 await self.auth_manager.update_provider_link(
478 user, AuthProviderType.BUILTIN, new_password_hash
479 )
480
481 return True
482
483 async def reset_password(self, user: User, new_password: str) -> None:
484 """
485 Reset user password (admin only - no old password verification).
486
487 :param user: The user whose password to reset.
488 :param new_password: The new password.
489 """
490 # Hash new password using user_id and update provider link
491 new_password_hash = self._hash_password(new_password, user.user_id)
492 await self.auth_manager.update_provider_link(
493 user, AuthProviderType.BUILTIN, new_password_hash
494 )
495
496 def _hash_password(self, password: str, user_id: str) -> str:
497 """
498 Hash password with salt combining user ID and server ID.
499
500 :param password: Plain text password.
501 :param user_id: User ID to include in salt (random token for high entropy).
502 """
503 # Combine user_id (random) and server_id for maximum security
504 salt = f"{user_id}:{self.mass.server_id}"
505 return hashlib.pbkdf2_hmac(
506 "sha256", password.encode(), salt.encode(), iterations=100000
507 ).hex()
508
509
510class HomeAssistantOAuthProvider(LoginProvider):
511 """Home Assistant OAuth login provider."""
512
513 def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
514 """
515 Initialize Home Assistant OAuth provider.
516
517 :param mass: MusicAssistant instance.
518 :param provider_id: Unique identifier for this provider instance.
519 :param config: Provider-specific configuration.
520 """
521 super().__init__(mass, provider_id, config)
522 # Store OAuth state -> return_url mapping to support concurrent sessions
523 self._oauth_sessions: dict[str, str | None] = {}
524
525 @property
526 def allow_self_registration(self) -> bool:
527 """Return whether self-registration is allowed, read dynamically from config."""
528 return bool(self.mass.webserver.config.get_value(CONF_AUTH_ALLOW_SELF_REGISTRATION))
529
530 @property
531 def provider_type(self) -> AuthProviderType:
532 """Return the provider type."""
533 return AuthProviderType.HOME_ASSISTANT
534
535 @property
536 def requires_redirect(self) -> bool:
537 """Return True - Home Assistant OAuth requires redirect."""
538 return True
539
540 async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
541 """
542 Not used for OAuth providers - use handle_oauth_callback instead.
543
544 :param credentials: Not used.
545 """
546 return AuthResult(success=False, error="Use OAuth flow for Home Assistant authentication")
547
548 async def _get_external_ha_url(self) -> str | None:
549 """
550 Get the external URL for Home Assistant from the config API.
551
552 This is needed when MA runs as HA add-on and connects via internal docker network
553 (http://supervisor/api) but needs the external URL for OAuth redirects.
554
555 :return: External URL if available, otherwise None.
556 """
557 ha_url = cast("str", self.config.get("ha_url")) if self.config.get("ha_url") else None
558 if not ha_url:
559 return None
560
561 # Check if we're using the internal supervisor URL
562 if "supervisor" not in ha_url.lower():
563 # Not using internal URL, return as-is
564 return ha_url
565
566 # We're using internal URL - try to get external URL from HA provider
567 ha_provider = self.mass.get_provider("hass")
568 if not ha_provider:
569 # No HA provider available, use configured URL
570 return ha_url
571
572 ha_provider = cast("HomeAssistantProvider", ha_provider)
573
574 try:
575 # Access the hass client from the provider
576 hass_client = ha_provider.hass
577 if not hass_client or not hass_client.connected:
578 return ha_url
579
580 # Get network URLs from Home Assistant using WebSocket API
581 # This command returns internal, external, and cloud URLs
582 network_urls = await hass_client.send_command("network/url")
583
584 if network_urls:
585 # Priority: external > cloud > internal
586 # External is the manually configured external URL
587 # Cloud is the Nabu Casa cloud URL
588 # Internal is the local network URL
589 external_url = network_urls.get("external")
590 cloud_url = network_urls.get("cloud")
591 internal_url = network_urls.get("internal")
592
593 # Use external URL first, then cloud, then internal
594 final_url = cast("str", external_url or cloud_url or internal_url)
595 if final_url:
596 self.logger.debug(
597 "Using HA URL for OAuth: %s (from network/url, configured: %s)",
598 final_url,
599 ha_url,
600 )
601 return final_url
602 except Exception as err:
603 self.logger.warning("Failed to fetch HA network URLs: %s", err, exc_info=True)
604
605 # Fallback to configured URL
606 return ha_url
607
608 async def get_authorization_url(
609 self, redirect_uri: str, return_url: str | None = None
610 ) -> str | None:
611 """
612 Get Home Assistant OAuth authorization URL using hass_client.
613
614 :param redirect_uri: The callback URL.
615 :param return_url: Optional URL to redirect to after successful login.
616 """
617 # Get the correct HA URL (external URL if running as add-on)
618 ha_url = await self._get_external_ha_url()
619 if not ha_url:
620 return None
621
622 # If HA URL is still the internal supervisor URL (no external_url in HA config),
623 # infer from redirect_uri (the URL user is accessing MA from)
624 if "supervisor" in ha_url.lower():
625 # Extract scheme and host from redirect_uri to build external HA URL
626 parsed = urlparse(redirect_uri)
627 # HA typically runs on port 8123, but use default ports for HTTPS (443) or HTTP (80)
628 if parsed.scheme == "https":
629 # HTTPS - use default port 443 (no port in URL)
630 inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}"
631 else:
632 # HTTP - assume HA runs on default port 8123
633 inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}:8123"
634
635 self.logger.debug(
636 "HA external_url not configured, inferring from callback URL: %s",
637 inferred_ha_url,
638 )
639 ha_url = inferred_ha_url
640
641 state = secrets.token_urlsafe(32)
642 # Store return_url keyed by state to support concurrent OAuth sessions
643 # This prevents race conditions when multiple users/sessions login simultaneously
644 self._oauth_sessions[state] = return_url
645
646 # Use base_url of callback as client_id (same as HA provider does)
647 client_id = base_url(redirect_uri)
648
649 # Use hass_client's get_auth_url utility
650 return cast(
651 "str",
652 get_auth_url(
653 ha_url,
654 redirect_uri,
655 client_id=client_id,
656 state=state,
657 ),
658 )
659
660 async def _fetch_ha_user_id_via_websocket(self, ha_url: str, access_token: str) -> str | None:
661 """
662 Fetch the HA user ID from Home Assistant via WebSocket using OAuth token.
663
664 :param ha_url: Home Assistant URL.
665 :param access_token: Access token for WebSocket authentication.
666 :return: The HA user ID or None if fetch fails.
667 """
668 ws_url = get_websocket_url(ha_url)
669
670 try:
671 # Use context manager to automatically handle connect/disconnect
672 async with HomeAssistantClient(ws_url, access_token, self.mass.http_session) as client:
673 # Use the auth/current_user command to get user ID
674 result = await client.send_command("auth/current_user")
675 if result and (user_id := result.get("id")):
676 return str(user_id)
677 self.logger.warning("auth/current_user returned no user data or missing id")
678 return None
679 except BaseHassClientError as ws_error:
680 self.logger.error("Failed to fetch HA user via WebSocket: %s", ws_error)
681 return None
682
683 async def _get_or_create_user(
684 self,
685 username: str,
686 display_name: str | None,
687 ha_user_id: str,
688 avatar_url: str | None = None,
689 ) -> User | None:
690 """
691 Get or create a user for Home Assistant OAuth authentication.
692
693 Updates existing users with display_name and avatar_url from HA on each OAuth login
694 (HA is considered the source of truth for these fields).
695
696 :param username: Username from Home Assistant.
697 :param display_name: Display name from Home Assistant.
698 :param ha_user_id: Home Assistant user ID.
699 :param avatar_url: Avatar URL from Home Assistant person entity.
700 :return: User object or None if creation failed.
701 """
702 # Check if user already linked to HA
703 user = await self.auth_manager.get_user_by_provider_link(
704 AuthProviderType.HOME_ASSISTANT, ha_user_id
705 )
706 if user:
707 # Update user with HA details if available (HA is source of truth)
708 if display_name or avatar_url:
709 user = await self.auth_manager.update_user(
710 user,
711 display_name=display_name,
712 avatar_url=avatar_url,
713 )
714 return user
715
716 username = normalize_username(username)
717
718 # Check if a user with this username already exists (from built-in provider)
719 user_row = await self.auth_manager.database.get_row("users", {"username": username})
720 if user_row:
721 # User exists with this username - link them to HA provider
722 user_dict = dict(user_row)
723 existing_user = User(
724 user_id=user_dict["user_id"],
725 username=user_dict["username"],
726 role=UserRole(user_dict["role"]),
727 enabled=bool(user_dict["enabled"]),
728 created_at=datetime.fromisoformat(user_dict["created_at"]),
729 display_name=user_dict["display_name"],
730 avatar_url=user_dict["avatar_url"],
731 )
732
733 # Link existing user to Home Assistant
734 await self.auth_manager.link_user_to_provider(
735 existing_user, AuthProviderType.HOME_ASSISTANT, ha_user_id
736 )
737
738 # Update user with HA details if available (HA is source of truth)
739 if display_name or avatar_url:
740 existing_user = await self.auth_manager.update_user(
741 existing_user,
742 display_name=display_name,
743 avatar_url=avatar_url,
744 )
745
746 return existing_user
747
748 # New HA user - check if self-registration allowed
749 if not self.allow_self_registration:
750 return None
751
752 # Determine role based on HA admin status
753 role = await get_ha_user_role(self.mass, ha_user_id)
754
755 # Create new user
756 user = await self.auth_manager.create_user(
757 username=username,
758 role=role,
759 display_name=display_name or username,
760 avatar_url=avatar_url,
761 )
762
763 # Link to Home Assistant
764 await self.auth_manager.link_user_to_provider(
765 user, AuthProviderType.HOME_ASSISTANT, ha_user_id
766 )
767
768 return user
769
770 async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult:
771 """
772 Handle Home Assistant OAuth callback using hass_client.
773
774 :param code: OAuth authorization code.
775 :param state: OAuth state parameter.
776 :param redirect_uri: The callback URL.
777 """
778 # Verify state and retrieve return_url from session
779 if state not in self._oauth_sessions:
780 return AuthResult(success=False, error="Invalid or expired state parameter")
781
782 # Retrieve and remove the return_url for this session (cleanup)
783 return_url = self._oauth_sessions.pop(state)
784
785 # Get the correct HA URL (external URL if running as add-on)
786 # This must be the same URL used in get_authorization_url
787 ha_url = await self._get_external_ha_url()
788 if not ha_url:
789 return AuthResult(success=False, error="Home Assistant URL not configured")
790
791 try:
792 # Use base_url of callback as client_id (same as HA provider does)
793 client_id = base_url(redirect_uri)
794
795 # Use hass_client's get_token utility - no client_secret needed!
796 try:
797 token_details = await get_token(ha_url, code, client_id=client_id)
798 except Exception as token_error:
799 self.logger.error(
800 "Failed to get token from HA: %s (client_id: %s, ha_url: %s)",
801 token_error,
802 client_id,
803 ha_url,
804 )
805 return AuthResult(
806 success=False, error=f"Failed to exchange OAuth code: {token_error}"
807 )
808
809 access_token = token_details.get("access_token")
810 if not access_token:
811 return AuthResult(success=False, error="No access token received from HA")
812
813 # Get the HA user ID from the OAuth token via WebSocket
814 ha_user_id = await self._fetch_ha_user_id_via_websocket(ha_url, access_token)
815 if not ha_user_id:
816 return AuthResult(
817 success=False,
818 error="Failed to get user ID from Home Assistant",
819 )
820
821 # Get username, display name and avatar from HA provider (has admin access)
822 username, display_name, avatar_url = await get_ha_user_details(self.mass, ha_user_id)
823
824 # Fall back to HA user ID as username if not found
825 if not username:
826 self.logger.warning("Could not get username from HA, using user ID as fallback")
827 username = ha_user_id
828
829 # Get or create user
830 user = await self._get_or_create_user(username, display_name, ha_user_id, avatar_url)
831
832 if not user:
833 return AuthResult(
834 success=False,
835 error="Self-registration is disabled. Please contact an administrator.",
836 )
837
838 return AuthResult(success=True, user=user, return_url=return_url)
839
840 except Exception as e:
841 self.logger.exception("Error during Home Assistant OAuth callback")
842 return AuthResult(success=False, error=str(e))
843