music-assistant-server

31.6 KBPY
auth_providers.py
31.6 KB843 lines • python
1"""Authentication provider base classes and implementations."""
2
3from __future__ import annotations
4
5import asyncio
6import hashlib
7import logging
8import secrets
9from abc import ABC, abstractmethod
10from dataclasses import dataclass
11from datetime import datetime, timedelta
12from typing import TYPE_CHECKING, Any, TypedDict, cast
13from urllib.parse import urlparse
14
15from hass_client import HomeAssistantClient
16from hass_client.exceptions import BaseHassClientError
17from hass_client.utils import base_url, get_auth_url, get_token, get_websocket_url
18from music_assistant_models.auth import AuthProviderType, User, UserRole
19from music_assistant_models.errors import AuthenticationFailed
20
21from music_assistant.constants import CONF_AUTH_ALLOW_SELF_REGISTRATION, MASS_LOGGER_NAME
22from music_assistant.helpers.datetime import utc
23
24if TYPE_CHECKING:
25    from music_assistant import MusicAssistant
26    from music_assistant.controllers.webserver.auth import AuthenticationManager
27    from music_assistant.providers.hass import HomeAssistantProvider
28
29
30LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth")
31
32
33def normalize_username(username: str) -> str:
34    """
35    Normalize username to lowercase for case-insensitive comparison.
36
37    :param username: The username to normalize.
38    :return: Normalized username (lowercase, stripped).
39    """
40    return username.strip().lower()
41
42
43async def get_ha_user_details(
44    mass: MusicAssistant, ha_user_id: str, wait_timeout: float = 30.0
45) -> tuple[str | None, str | None, str | None]:
46    """
47    Get user username, display name and avatar URL from Home Assistant.
48
49    Uses the existing HA provider connection (which has admin access) to fetch
50    user details from config/auth/list and the person entity.
51
52    :param mass: MusicAssistant instance.
53    :param ha_user_id: Home Assistant user ID.
54    :param wait_timeout: Maximum time to wait for HA provider to become available (default 30s).
55    :return: Tuple of (username, display_name, avatar_url) or all None if not found.
56    """
57    # Wait for the HA provider to become available (handles race condition at startup)
58    hass_prov = None
59    wait_interval = 0.5
60    elapsed = 0.0
61    while elapsed < wait_timeout:
62        hass_prov = mass.get_provider("hass")
63        if hass_prov is not None and hass_prov.available:
64            break
65        await asyncio.sleep(wait_interval)
66        elapsed += wait_interval
67        hass_prov = None  # Reset to None for the final check
68
69    if hass_prov is None or not hass_prov.available:
70        LOGGER.debug("HA provider not available after %.1fs, cannot fetch user details", elapsed)
71        return None, None, None
72
73    hass_prov = cast("HomeAssistantProvider", hass_prov)
74    return await hass_prov.get_user_details(ha_user_id)
75
76
77async def get_ha_user_role(
78    mass: MusicAssistant, ha_user_id: str, wait_timeout: float = 30.0
79) -> UserRole:
80    """
81    Get user role based on Home Assistant admin status.
82
83    :param mass: MusicAssistant instance.
84    :param ha_user_id: The Home Assistant user ID to check.
85    :param wait_timeout: Maximum time to wait for HA provider to become available (default 30s).
86    """
87    try:
88        # Wait for the HA provider to become available (handles race condition at startup)
89        hass_prov = None
90        wait_interval = 0.5
91        elapsed = 0.0
92        while elapsed < wait_timeout:
93            hass_prov = mass.get_provider("hass")
94            if hass_prov is not None and hass_prov.available:
95                break
96            await asyncio.sleep(wait_interval)
97            elapsed += wait_interval
98            hass_prov = None  # Reset to None for the final check
99
100        if hass_prov is None or not hass_prov.available:
101            raise RuntimeError("Home Assistant provider not available")
102
103        if TYPE_CHECKING:
104            hass_prov = cast("HomeAssistantProvider", hass_prov)
105        # Query HA for user list to check admin status
106        result = await hass_prov.hass.send_command("config/auth/list")
107        if not result:
108            raise RuntimeError("Failed to retrieve user list from Home Assistant")
109        for ha_user in result:
110            if ha_user.get("id") == ha_user_id:
111                # User is admin if they have "system-admin" in their group_ids
112                group_ids = ha_user.get("group_ids", [])
113                if "system-admin" in group_ids:
114                    LOGGER.debug("HA user %s is admin, granting ADMIN role", ha_user_id)
115                    return UserRole.ADMIN
116                return UserRole.USER
117        raise RuntimeError(f"HA user ID {ha_user_id} not found in user list")
118    except Exception as err:
119        msg = f"Failed to check HA admin status: {err}"
120        raise AuthenticationFailed(msg) from err
121
122
123class LoginRateLimiter:
124    """Rate limiter for login attempts to prevent brute force attacks."""
125
126    def __init__(self) -> None:
127        """Initialize the rate limiter."""
128        # Track failed attempts per username: {username: [timestamp1, timestamp2, ...]}
129        self._failed_attempts: dict[str, list[datetime]] = {}
130        # Time window for tracking attempts (30 minutes)
131        self._tracking_window = timedelta(minutes=30)
132        # Lock for thread-safe access to _failed_attempts
133        self._lock = asyncio.Lock()
134
135    def _cleanup_old_attempts(self, username: str) -> None:
136        """
137        Remove failed attempts outside the tracking window.
138
139        :param username: The username to clean up.
140        """
141        if username not in self._failed_attempts:
142            return
143
144        cutoff_time = utc() - self._tracking_window
145        self._failed_attempts[username] = [
146            timestamp for timestamp in self._failed_attempts[username] if timestamp > cutoff_time
147        ]
148
149        # Remove username if no attempts left
150        if not self._failed_attempts[username]:
151            del self._failed_attempts[username]
152
153    def get_delay(self, username: str) -> int:
154        """
155        Get the delay in seconds before next login attempt is allowed.
156
157        Progressive delays based on failed attempts:
158        - 1-2 attempts: no delay
159        - 3-5 attempts: 30 seconds
160        - 6-9 attempts: 60 seconds
161        - 10-14 attempts: 120 seconds
162        - 15+ attempts: 300 seconds (5 minutes)
163
164        :param username: The username attempting to log in.
165        :return: Delay in seconds (0 if no delay needed).
166        """
167        self._cleanup_old_attempts(username)
168
169        if username not in self._failed_attempts:
170            return 0
171
172        attempt_count = len(self._failed_attempts[username])
173
174        if attempt_count < 3:
175            return 0
176        if attempt_count < 6:
177            return 30
178        if attempt_count < 10:
179            return 60
180        if attempt_count < 15:
181            return 120
182        return 300  # 5 minutes max delay
183
184    async def check_rate_limit(self, username: str) -> tuple[bool, int]:
185        """
186        Check if login attempt is allowed and apply delay if needed.
187
188        :param username: The username attempting to log in.
189        :return: Tuple of (allowed, delay_seconds). If not allowed, includes remaining delay.
190        """
191        async with self._lock:
192            self._cleanup_old_attempts(username)
193
194            if username not in self._failed_attempts or not self._failed_attempts[username]:
195                return True, 0
196
197            # Get the most recent failed attempt
198            last_attempt = self._failed_attempts[username][-1]
199            required_delay = self.get_delay(username)
200
201            if required_delay == 0:
202                return True, 0
203
204            # Calculate how much time has passed since last attempt
205            time_since_last = (utc() - last_attempt).total_seconds()
206
207            if time_since_last < required_delay:
208                # Still in cooldown period
209                remaining_delay = int(required_delay - time_since_last)
210                return False, remaining_delay
211
212            return True, 0
213
214    async def record_failed_attempt(self, username: str) -> None:
215        """
216        Record a failed login attempt.
217
218        :param username: The username that failed to log in.
219        """
220        async with self._lock:
221            self._cleanup_old_attempts(username)
222
223            if username not in self._failed_attempts:
224                self._failed_attempts[username] = []
225
226            self._failed_attempts[username].append(utc())
227
228            # Log warning for suspicious activity
229            attempt_count = len(self._failed_attempts[username])
230            if attempt_count == 10:
231                LOGGER.warning(
232                    "Suspicious login activity: 10 failed attempts for username '%s'", username
233                )
234            elif attempt_count == 20:
235                LOGGER.warning(
236                    "High suspicious login activity: 20 failed attempts for username '%s'. "
237                    "Consider manually disabling this account.",
238                    username,
239                )
240
241    async def clear_attempts(self, username: str) -> None:
242        """
243        Clear failed attempts for a username (called after successful login).
244
245        :param username: The username to clear.
246        """
247        async with self._lock:
248            if username in self._failed_attempts:
249                del self._failed_attempts[username]
250
251
252class LoginProviderConfig(TypedDict, total=False):
253    """Base configuration for login providers."""
254
255
256class HomeAssistantProviderConfig(LoginProviderConfig):
257    """Configuration for Home Assistant OAuth provider."""
258
259    ha_url: str
260
261
262@dataclass
263class AuthResult:
264    """Result of an authentication attempt."""
265
266    success: bool
267    user: User | None = None
268    error: str | None = None
269    access_token: str | None = None
270    return_url: str | None = None
271
272
273class LoginProvider(ABC):
274    """Base class for login providers."""
275
276    def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
277        """
278        Initialize login provider.
279
280        :param mass: MusicAssistant instance.
281        :param provider_id: Unique identifier for this provider instance.
282        :param config: Provider-specific configuration.
283        """
284        self.mass = mass
285        self.provider_id = provider_id
286        self.config = config
287        self.logger = LOGGER
288
289    @property
290    def allow_self_registration(self) -> bool:
291        """Return whether self-registration is allowed for this provider."""
292        return False
293
294    @property
295    def auth_manager(self) -> AuthenticationManager:
296        """Get auth manager from webserver."""
297        return self.mass.webserver.auth
298
299    @property
300    @abstractmethod
301    def provider_type(self) -> AuthProviderType:
302        """Return the provider type."""
303
304    @property
305    @abstractmethod
306    def requires_redirect(self) -> bool:
307        """Return True if this provider requires OAuth redirect."""
308
309    @abstractmethod
310    async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
311        """
312        Authenticate user with provided credentials.
313
314        :param credentials: Provider-specific credentials (username/password, OAuth code, etc).
315        """
316
317    async def get_authorization_url(
318        self, redirect_uri: str, return_url: str | None = None
319    ) -> str | None:
320        """
321        Get OAuth authorization URL if applicable.
322
323        :param redirect_uri: The callback URL for OAuth flow.
324        :param return_url: Optional URL to redirect to after successful login.
325        """
326        return None
327
328    async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult:
329        """
330        Handle OAuth callback if applicable.
331
332        :param code: OAuth authorization code.
333        :param state: OAuth state parameter for CSRF protection.
334        :param redirect_uri: The callback URL.
335        """
336        return AuthResult(success=False, error="OAuth not supported by this provider")
337
338
339class BuiltinLoginProvider(LoginProvider):
340    """Built-in username/password login provider."""
341
342    def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
343        """
344        Initialize built-in login provider.
345
346        :param mass: MusicAssistant instance.
347        :param provider_id: Unique identifier for this provider instance.
348        :param config: Provider-specific configuration.
349        """
350        super().__init__(mass, provider_id, config)
351        self._rate_limiter = LoginRateLimiter()
352
353    @property
354    def provider_type(self) -> AuthProviderType:
355        """Return the provider type."""
356        return AuthProviderType.BUILTIN
357
358    @property
359    def requires_redirect(self) -> bool:
360        """Return False - built-in provider doesn't need redirect."""
361        return False
362
363    async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
364        """
365        Authenticate user with username and password.
366
367        :param credentials: Dict containing 'username' and 'password'.
368        """
369        username = credentials.get("username")
370        password = credentials.get("password")
371
372        if not username or not password:
373            return AuthResult(success=False, error="Username and password required")
374
375        username = normalize_username(username)
376
377        # Check rate limit before attempting authentication
378        allowed, remaining_delay = await self._rate_limiter.check_rate_limit(username)
379        if not allowed:
380            self.logger.warning(
381                "Rate limit exceeded for username '%s'. %d seconds remaining.",
382                username,
383                remaining_delay,
384            )
385            return AuthResult(
386                success=False,
387                error=f"Too many failed attempts. Please try again in {remaining_delay} seconds.",
388            )
389
390        # First, look up user by username to get user_id
391        # This is needed to create the password hash with user_id in the salt
392        user_row = await self.auth_manager.database.get_row("users", {"username": username})
393        if not user_row:
394            # Record failed attempt even if username doesn't exist
395            # This prevents username enumeration timing attacks
396            await self._rate_limiter.record_failed_attempt(username)
397            return AuthResult(success=False, error="Invalid username or password")
398
399        user_id = user_row["user_id"]
400
401        # Hash the password using user_id for enhanced security
402        password_hash = self._hash_password(password, user_id)
403
404        # Verify the password by checking if provider link exists
405        user = await self.auth_manager.get_user_by_provider_link(
406            AuthProviderType.BUILTIN, password_hash
407        )
408
409        if not user:
410            # Record failed attempt
411            await self._rate_limiter.record_failed_attempt(username)
412            return AuthResult(success=False, error="Invalid username or password")
413
414        # Check if user is enabled
415        if not user.enabled:
416            # Record failed attempt for disabled accounts too
417            await self._rate_limiter.record_failed_attempt(username)
418            return AuthResult(success=False, error="User account is disabled")
419
420        # Successful login - clear any failed attempts
421        await self._rate_limiter.clear_attempts(username)
422        return AuthResult(success=True, user=user)
423
424    async def create_user_with_password(
425        self,
426        username: str,
427        password: str,
428        role: UserRole = UserRole.USER,
429        display_name: str | None = None,
430        player_filter: list[str] | None = None,
431        provider_filter: list[str] | None = None,
432    ) -> User:
433        """
434        Create a new built-in user with password.
435
436        :param username: The username.
437        :param password: The password (will be hashed).
438        :param role: The user role (default: USER).
439        :param display_name: Optional display name.
440        :param player_filter: Optional list of player IDs user has access to.
441        :param provider_filter: Optional list of provider instance IDs user has access to.
442        """
443        # Create the user
444        user = await self.auth_manager.create_user(
445            username=username,
446            role=role,
447            display_name=display_name,
448            player_filter=player_filter,
449            provider_filter=provider_filter,
450        )
451
452        # Hash password using user_id for enhanced security
453        password_hash = self._hash_password(password, user.user_id)
454        await self.auth_manager.link_user_to_provider(user, AuthProviderType.BUILTIN, password_hash)
455
456        return user
457
458    async def change_password(self, user: User, old_password: str, new_password: str) -> bool:
459        """
460        Change user password.
461
462        :param user: The user.
463        :param old_password: Current password for verification.
464        :param new_password: The new password.
465        """
466        # Verify old password first using user_id
467        old_password_hash = self._hash_password(old_password, user.user_id)
468        existing_user = await self.auth_manager.get_user_by_provider_link(
469            AuthProviderType.BUILTIN, old_password_hash
470        )
471
472        if not existing_user or existing_user.user_id != user.user_id:
473            return False
474
475        # Update password link with new hash using user_id
476        new_password_hash = self._hash_password(new_password, user.user_id)
477        await self.auth_manager.update_provider_link(
478            user, AuthProviderType.BUILTIN, new_password_hash
479        )
480
481        return True
482
483    async def reset_password(self, user: User, new_password: str) -> None:
484        """
485        Reset user password (admin only - no old password verification).
486
487        :param user: The user whose password to reset.
488        :param new_password: The new password.
489        """
490        # Hash new password using user_id and update provider link
491        new_password_hash = self._hash_password(new_password, user.user_id)
492        await self.auth_manager.update_provider_link(
493            user, AuthProviderType.BUILTIN, new_password_hash
494        )
495
496    def _hash_password(self, password: str, user_id: str) -> str:
497        """
498        Hash password with salt combining user ID and server ID.
499
500        :param password: Plain text password.
501        :param user_id: User ID to include in salt (random token for high entropy).
502        """
503        # Combine user_id (random) and server_id for maximum security
504        salt = f"{user_id}:{self.mass.server_id}"
505        return hashlib.pbkdf2_hmac(
506            "sha256", password.encode(), salt.encode(), iterations=100000
507        ).hex()
508
509
510class HomeAssistantOAuthProvider(LoginProvider):
511    """Home Assistant OAuth login provider."""
512
513    def __init__(self, mass: MusicAssistant, provider_id: str, config: LoginProviderConfig) -> None:
514        """
515        Initialize Home Assistant OAuth provider.
516
517        :param mass: MusicAssistant instance.
518        :param provider_id: Unique identifier for this provider instance.
519        :param config: Provider-specific configuration.
520        """
521        super().__init__(mass, provider_id, config)
522        # Store OAuth state -> return_url mapping to support concurrent sessions
523        self._oauth_sessions: dict[str, str | None] = {}
524
525    @property
526    def allow_self_registration(self) -> bool:
527        """Return whether self-registration is allowed, read dynamically from config."""
528        return bool(self.mass.webserver.config.get_value(CONF_AUTH_ALLOW_SELF_REGISTRATION))
529
530    @property
531    def provider_type(self) -> AuthProviderType:
532        """Return the provider type."""
533        return AuthProviderType.HOME_ASSISTANT
534
535    @property
536    def requires_redirect(self) -> bool:
537        """Return True - Home Assistant OAuth requires redirect."""
538        return True
539
540    async def authenticate(self, credentials: dict[str, Any]) -> AuthResult:
541        """
542        Not used for OAuth providers - use handle_oauth_callback instead.
543
544        :param credentials: Not used.
545        """
546        return AuthResult(success=False, error="Use OAuth flow for Home Assistant authentication")
547
548    async def _get_external_ha_url(self) -> str | None:
549        """
550        Get the external URL for Home Assistant from the config API.
551
552        This is needed when MA runs as HA add-on and connects via internal docker network
553        (http://supervisor/api) but needs the external URL for OAuth redirects.
554
555        :return: External URL if available, otherwise None.
556        """
557        ha_url = cast("str", self.config.get("ha_url")) if self.config.get("ha_url") else None
558        if not ha_url:
559            return None
560
561        # Check if we're using the internal supervisor URL
562        if "supervisor" not in ha_url.lower():
563            # Not using internal URL, return as-is
564            return ha_url
565
566        # We're using internal URL - try to get external URL from HA provider
567        ha_provider = self.mass.get_provider("hass")
568        if not ha_provider:
569            # No HA provider available, use configured URL
570            return ha_url
571
572        ha_provider = cast("HomeAssistantProvider", ha_provider)
573
574        try:
575            # Access the hass client from the provider
576            hass_client = ha_provider.hass
577            if not hass_client or not hass_client.connected:
578                return ha_url
579
580            # Get network URLs from Home Assistant using WebSocket API
581            # This command returns internal, external, and cloud URLs
582            network_urls = await hass_client.send_command("network/url")
583
584            if network_urls:
585                # Priority: external > cloud > internal
586                # External is the manually configured external URL
587                # Cloud is the Nabu Casa cloud URL
588                # Internal is the local network URL
589                external_url = network_urls.get("external")
590                cloud_url = network_urls.get("cloud")
591                internal_url = network_urls.get("internal")
592
593                # Use external URL first, then cloud, then internal
594                final_url = cast("str", external_url or cloud_url or internal_url)
595                if final_url:
596                    self.logger.debug(
597                        "Using HA URL for OAuth: %s (from network/url, configured: %s)",
598                        final_url,
599                        ha_url,
600                    )
601                    return final_url
602        except Exception as err:
603            self.logger.warning("Failed to fetch HA network URLs: %s", err, exc_info=True)
604
605        # Fallback to configured URL
606        return ha_url
607
608    async def get_authorization_url(
609        self, redirect_uri: str, return_url: str | None = None
610    ) -> str | None:
611        """
612        Get Home Assistant OAuth authorization URL using hass_client.
613
614        :param redirect_uri: The callback URL.
615        :param return_url: Optional URL to redirect to after successful login.
616        """
617        # Get the correct HA URL (external URL if running as add-on)
618        ha_url = await self._get_external_ha_url()
619        if not ha_url:
620            return None
621
622        # If HA URL is still the internal supervisor URL (no external_url in HA config),
623        # infer from redirect_uri (the URL user is accessing MA from)
624        if "supervisor" in ha_url.lower():
625            # Extract scheme and host from redirect_uri to build external HA URL
626            parsed = urlparse(redirect_uri)
627            # HA typically runs on port 8123, but use default ports for HTTPS (443) or HTTP (80)
628            if parsed.scheme == "https":
629                # HTTPS - use default port 443 (no port in URL)
630                inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}"
631            else:
632                # HTTP - assume HA runs on default port 8123
633                inferred_ha_url = f"{parsed.scheme}://{parsed.hostname}:8123"
634
635            self.logger.debug(
636                "HA external_url not configured, inferring from callback URL: %s",
637                inferred_ha_url,
638            )
639            ha_url = inferred_ha_url
640
641        state = secrets.token_urlsafe(32)
642        # Store return_url keyed by state to support concurrent OAuth sessions
643        # This prevents race conditions when multiple users/sessions login simultaneously
644        self._oauth_sessions[state] = return_url
645
646        # Use base_url of callback as client_id (same as HA provider does)
647        client_id = base_url(redirect_uri)
648
649        # Use hass_client's get_auth_url utility
650        return cast(
651            "str",
652            get_auth_url(
653                ha_url,
654                redirect_uri,
655                client_id=client_id,
656                state=state,
657            ),
658        )
659
660    async def _fetch_ha_user_id_via_websocket(self, ha_url: str, access_token: str) -> str | None:
661        """
662        Fetch the HA user ID from Home Assistant via WebSocket using OAuth token.
663
664        :param ha_url: Home Assistant URL.
665        :param access_token: Access token for WebSocket authentication.
666        :return: The HA user ID or None if fetch fails.
667        """
668        ws_url = get_websocket_url(ha_url)
669
670        try:
671            # Use context manager to automatically handle connect/disconnect
672            async with HomeAssistantClient(ws_url, access_token, self.mass.http_session) as client:
673                # Use the auth/current_user command to get user ID
674                result = await client.send_command("auth/current_user")
675                if result and (user_id := result.get("id")):
676                    return str(user_id)
677                self.logger.warning("auth/current_user returned no user data or missing id")
678                return None
679        except BaseHassClientError as ws_error:
680            self.logger.error("Failed to fetch HA user via WebSocket: %s", ws_error)
681            return None
682
683    async def _get_or_create_user(
684        self,
685        username: str,
686        display_name: str | None,
687        ha_user_id: str,
688        avatar_url: str | None = None,
689    ) -> User | None:
690        """
691        Get or create a user for Home Assistant OAuth authentication.
692
693        Updates existing users with display_name and avatar_url from HA on each OAuth login
694        (HA is considered the source of truth for these fields).
695
696        :param username: Username from Home Assistant.
697        :param display_name: Display name from Home Assistant.
698        :param ha_user_id: Home Assistant user ID.
699        :param avatar_url: Avatar URL from Home Assistant person entity.
700        :return: User object or None if creation failed.
701        """
702        # Check if user already linked to HA
703        user = await self.auth_manager.get_user_by_provider_link(
704            AuthProviderType.HOME_ASSISTANT, ha_user_id
705        )
706        if user:
707            # Update user with HA details if available (HA is source of truth)
708            if display_name or avatar_url:
709                user = await self.auth_manager.update_user(
710                    user,
711                    display_name=display_name,
712                    avatar_url=avatar_url,
713                )
714            return user
715
716        username = normalize_username(username)
717
718        # Check if a user with this username already exists (from built-in provider)
719        user_row = await self.auth_manager.database.get_row("users", {"username": username})
720        if user_row:
721            # User exists with this username - link them to HA provider
722            user_dict = dict(user_row)
723            existing_user = User(
724                user_id=user_dict["user_id"],
725                username=user_dict["username"],
726                role=UserRole(user_dict["role"]),
727                enabled=bool(user_dict["enabled"]),
728                created_at=datetime.fromisoformat(user_dict["created_at"]),
729                display_name=user_dict["display_name"],
730                avatar_url=user_dict["avatar_url"],
731            )
732
733            # Link existing user to Home Assistant
734            await self.auth_manager.link_user_to_provider(
735                existing_user, AuthProviderType.HOME_ASSISTANT, ha_user_id
736            )
737
738            # Update user with HA details if available (HA is source of truth)
739            if display_name or avatar_url:
740                existing_user = await self.auth_manager.update_user(
741                    existing_user,
742                    display_name=display_name,
743                    avatar_url=avatar_url,
744                )
745
746            return existing_user
747
748        # New HA user - check if self-registration allowed
749        if not self.allow_self_registration:
750            return None
751
752        # Determine role based on HA admin status
753        role = await get_ha_user_role(self.mass, ha_user_id)
754
755        # Create new user
756        user = await self.auth_manager.create_user(
757            username=username,
758            role=role,
759            display_name=display_name or username,
760            avatar_url=avatar_url,
761        )
762
763        # Link to Home Assistant
764        await self.auth_manager.link_user_to_provider(
765            user, AuthProviderType.HOME_ASSISTANT, ha_user_id
766        )
767
768        return user
769
770    async def handle_oauth_callback(self, code: str, state: str, redirect_uri: str) -> AuthResult:
771        """
772        Handle Home Assistant OAuth callback using hass_client.
773
774        :param code: OAuth authorization code.
775        :param state: OAuth state parameter.
776        :param redirect_uri: The callback URL.
777        """
778        # Verify state and retrieve return_url from session
779        if state not in self._oauth_sessions:
780            return AuthResult(success=False, error="Invalid or expired state parameter")
781
782        # Retrieve and remove the return_url for this session (cleanup)
783        return_url = self._oauth_sessions.pop(state)
784
785        # Get the correct HA URL (external URL if running as add-on)
786        # This must be the same URL used in get_authorization_url
787        ha_url = await self._get_external_ha_url()
788        if not ha_url:
789            return AuthResult(success=False, error="Home Assistant URL not configured")
790
791        try:
792            # Use base_url of callback as client_id (same as HA provider does)
793            client_id = base_url(redirect_uri)
794
795            # Use hass_client's get_token utility - no client_secret needed!
796            try:
797                token_details = await get_token(ha_url, code, client_id=client_id)
798            except Exception as token_error:
799                self.logger.error(
800                    "Failed to get token from HA: %s (client_id: %s, ha_url: %s)",
801                    token_error,
802                    client_id,
803                    ha_url,
804                )
805                return AuthResult(
806                    success=False, error=f"Failed to exchange OAuth code: {token_error}"
807                )
808
809            access_token = token_details.get("access_token")
810            if not access_token:
811                return AuthResult(success=False, error="No access token received from HA")
812
813            # Get the HA user ID from the OAuth token via WebSocket
814            ha_user_id = await self._fetch_ha_user_id_via_websocket(ha_url, access_token)
815            if not ha_user_id:
816                return AuthResult(
817                    success=False,
818                    error="Failed to get user ID from Home Assistant",
819                )
820
821            # Get username, display name and avatar from HA provider (has admin access)
822            username, display_name, avatar_url = await get_ha_user_details(self.mass, ha_user_id)
823
824            # Fall back to HA user ID as username if not found
825            if not username:
826                self.logger.warning("Could not get username from HA, using user ID as fallback")
827                username = ha_user_id
828
829            # Get or create user
830            user = await self._get_or_create_user(username, display_name, ha_user_id, avatar_url)
831
832            if not user:
833                return AuthResult(
834                    success=False,
835                    error="Self-registration is disabled. Please contact an administrator.",
836                )
837
838            return AuthResult(success=True, user=user, return_url=return_url)
839
840        except Exception as e:
841            self.logger.exception("Error during Home Assistant OAuth callback")
842            return AuthResult(success=False, error=str(e))
843