"""httpx Auth implementation for Keycloak OIDC Device Flow."""

from typing import Any, Callable, Generator, Optional

import httpx

from acmadauth.tokenset import TokenSet

from ..handlers import KeycloakOIDCDeviceFlowHandler
from ..tokencaches.base import TokenCache


class HttpxDeviceFlowAuth(httpx.Auth):
    """
    httpx.Auth that injects Bearer tokens and refreshes / device-logins as needed.

    Args:
        flow: The device flow handler to use.
        cache: The token cache to use.
        open_browser: Whether to open the browser automatically during device flow.
        show_qr: Whether to display a QR code for the device flow URL.
        show_url: Whether to display the device flow URL.
        login_url_callback: Optional callback function to receive the login URL.

    """

    requires_request_body = False

    def __init__(
        self,
        flow: KeycloakOIDCDeviceFlowHandler,
        cache: TokenCache,
        *,
        open_browser: bool = False,
        show_qr: bool = True,
        show_url: bool = True,
        login_url_callback: Optional[Callable[[str], None]] = None,
    ) -> None:
        self.flow = flow
        self.cache = cache
        self.open_browser = open_browser
        self.show_qr = show_qr
        self.show_url = show_url
        self.login_url_callback = login_url_callback

    def auth_flow(
        self, request: httpx.Request
    ) -> Generator[httpx.Request, Any, None]:
        """
        The auth flow that injects tokens into requests. Automatically refreshes or runs device flow as needed.

        Args:
            request: The outgoing HTTP request.

        Returns:
            The modified request with Authorization header.
        """
        tokens: TokenSet = self.flow.ensure_tokens(
            self.cache, open_browser=self.open_browser, show_qr=self.show_qr, show_url=self.show_url, login_url_callback=self.login_url_callback
        )
        request.headers["Authorization"] = (
            f"{tokens.token_type} {tokens.access_token}"
        )
        response = yield request
        if response.status_code != 401:
            return

        # Retry once on 401 using a refresh token if available.
        response.read()
        cached = self.cache.load()
        new_tokens: TokenSet
        if (
            cached
            and cached.refresh_token
            and not cached.refresh_token.is_expired()
        ):
            try:
                new_tokens = self.flow.refresh(cached.refresh_token.token)
                self.cache.save(new_tokens)
            except Exception:
                self.cache.clear()
                new_tokens = self.flow.ensure_tokens(
                    self.cache, open_browser=self.open_browser, show_qr=self.show_qr
                )
        else:
            self.cache.clear()
            new_tokens = self.flow.ensure_tokens(
                self.cache, open_browser=self.open_browser, show_qr=self.show_qr
            )

        request.headers["Authorization"] = (
            f"{new_tokens.token_type} {new_tokens.access_token}"
        )
        yield request
