"""Keycloak OIDC Device Flow Handler"""


import time
from typing import Any, Optional

import httpx
import pyqrcode  # type: ignore[import-untyped]

from .errors import DeviceAuthError, OIDCDiscoveryError
from .tokencaches.base import TokenCache
from .tokenset import TokenSet


class KeycloakOIDCDeviceFlowHandler:
    """
    Handler for Keycloak OIDC Device Flow.

    Args:
        realm_url: The Keycloak realm URL.
        client_id: The client ID.
        client_secret: The client secret, if applicable.
        scopes: The scopes to request. Default is "openid profile email offline_access".
        verify_tls: Whether to verify TLS certificates. Default is True.
        timeout: The HTTP request timeout in seconds. Default is 10.
    """

    def __init__(
        self,
        realm_url: str,
        client_id: str,
        client_secret: Optional[str] = None,
        scopes: str = "openid profile email offline_access",
        verify_tls: bool = True,
        timeout: int = 10,
    ) -> None:
        self.realm_url = realm_url
        self.client_id = client_id
        self.client_secret = client_secret
        self.scopes = scopes
        self.verify_tls = verify_tls
        self.timeout = timeout
        self._discovery_document: dict[str, Any] = {}

    def _get_discovery(self) -> dict[str, Any]:
        """
        Fetch and cache the OIDC discovery document.
        Returns:
            The OIDC discovery document.
        """
        if not self._discovery_document:
            url = f"{self.realm_url.rstrip('/')}/.well-known/openid-configuration"
            try:
                with httpx.Client(verify=self.verify_tls, timeout=self.timeout) as client:
                    response = client.get(url)
                    response.raise_for_status()
                    self._discovery_document = response.json()
            except httpx.HTTPError as e:
                raise OIDCDiscoveryError(
                    f"Failed to fetch OIDC discovery document: {str(e)}")
        return self._discovery_document

    def _token_auth(self) -> httpx.Auth | httpx._client.UseClientDefault:
        """
        Keycloak supports client_secret_basic and client_secret_post (per your config).
        If client_secret is set, we'll use HTTP Basic, which is widely supported.

        Returns:
            The HTTP BasicAuth object if client_secret is set, None otherwise.
        """
        if not self.client_secret:
            return httpx.USE_CLIENT_DEFAULT
        return httpx.BasicAuth(self.client_id, self.client_secret)

    def __get_error_from_response(self, r: httpx.Response) -> Optional[str]:
        """
        Extract error from HTTP response if present.

        Args:
            r: The HTTP response.

        Returns:
            The error string if present, None otherwise.
        """
        try:
            data = r.json()
            return data.get("error", None)
        except Exception:
            return None

    def poll_token_endpoint(
        self,
        device_code: str,
        interval: int,
        expires_in: int,
        client_secret_post: bool = False,
    ) -> TokenSet:
        """
        Poll the token endpoint for tokens using the device code.

        Args:
            device_code: The device code obtained from the device authorization response.
            interval: The polling interval in seconds.
            expires_in: The expiration time of the device code in seconds.
            client_secret_post: Whether to send client_secret in the POST body instead of using HTTP Basic Auth.

        Returns:
            The token set obtained from the token endpoint.
        """
        token_url = self._get_discovery()["token_endpoint"]

        deadline = time.time() + float(expires_in)

        with httpx.Client(verify=self.verify_tls, timeout=self.timeout) as client:
            while True:
                if time.time() >= deadline:
                    raise DeviceAuthError("Device code expired before authorization completed.")

                r = self.__make_token_request(
                    device_code, client_secret_post, token_url, client)

                if r.status_code == 200:
                    return TokenSet.from_token_response(r.json())

                # Device flow “pending/slow_down/denied/expired_token” errors
                # are 400s with JSON body
                if r.status_code != 400:
                    raise DeviceAuthError(f"Token polling failed: {r.status_code} {r.text}")

                error = self.__get_error_from_response(r)
                err = error
                if error == "authorization_pending":
                    time.sleep(interval)
                elif error == "slow_down":
                    interval = interval + 2
                    time.sleep(interval)
                elif error == "access_denied":
                    raise DeviceAuthError("The user denied the request.")
                elif error == "expired_token":
                    raise DeviceAuthError("Device code expired.")
                else:
                    raise DeviceAuthError(f"Device flow error: {err or r.text}")

    def __make_token_request(
            self,
            device_code,
            client_secret_post,
            token_url,
            client):
        """
        Make the token request to the token endpoint.

        Args:
            device_code: The device code.
            client_secret_post: Whether to send client_secret in the POST body instead of using HTTP Basic
            token_url: The token endpoint URL.
            client: The HTTPX client.

        Returns:
            The HTTP response from the token endpoint.
        """
        data = {
            "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
            "device_code": device_code,
            "client_id": self.client_id,
        }
        auth = self._token_auth()

        if client_secret_post and self.client_secret:
            data["client_secret"] = self.client_secret
            auth = httpx.USE_CLIENT_DEFAULT

        r = client.post(token_url, data=data, auth=auth)
        return r

    def refresh(self, refresh_token: str) -> TokenSet:
        """
        Refresh tokens using the refresh token.

        Args:
            refresh_token: The refresh token.
        Returns:
            The new token set obtained from the token endpoint.
        Raises:
            DeviceAuthError: If the refresh operation fails.
        """

        disc = self._get_discovery()
        token_url = disc["token_endpoint"]

        data = {
            "grant_type": "refresh_token",
            "refresh_token": refresh_token,
            "client_id": self.client_id,
        }
        auth = self._token_auth()

        # If using client_secret_post instead of basic, uncomment:
        # if self.client_secret:
        #     data["client_secret"] = self.client_secret
        #     auth = None

        with httpx.Client(verify=self.verify_tls, timeout=self.timeout) as client:
            r = client.post(token_url, data=data, auth=auth)
            if r.status_code != 200:
                raise DeviceAuthError(
                    f"Refresh failed: {r.status_code} {r.text}"
                )
            return TokenSet.from_token_response(r.json())

    def start_device_authorization(self) -> dict[str, Any]:
        """
        Start the device authorization flow.

        Returns:
            The device authorization response.
        """
        disc = self._get_discovery()
        url = disc.get("device_authorization_endpoint")
        if not url:
            raise DeviceAuthError(
                "Device Authorization Endpoint not found in discovery document."
            )

        data = {
            "client_id": self.client_id,
            "scope": self.scopes,
        }
        with httpx.Client(verify=self.verify_tls, timeout=self.timeout) as client:
            r = client.post(url, data=data, auth=self._token_auth())
            if r.status_code != 200:
                raise DeviceAuthError(
                    f"Device authorization failed: {r.status_code} {r.text}"
                )
            return r.json()

    def ensure_tokens(
        self,
        cache: TokenCache,
        *,
        open_browser: bool = False,
        show_qr: bool = True,
    ) -> TokenSet:
        """
        Load cached tokens, refresh if needed, otherwise run device flow.

        Args:
            cache: The token cache to use.
            open_browser: Whether to open the browser automatically during device flow.
            show_qr: Whether to display a QR code for the device flow URL.
        Returns:
            The valid token set.
        Raises:
            DeviceAuthError: If device authorization or token refresh fails.
        """
        tokens = cache.load()
        if tokens and not tokens.access_token.is_expired():
            return tokens

        if tokens and tokens.refresh_token and not tokens.refresh_token.is_expired():
            try:
                new_tokens: TokenSet = self.refresh(tokens.refresh_token.token)
                cache.save(new_tokens)
                return new_tokens
            except (DeviceAuthError, httpx.HTTPError):
                # fall back to re-auth
                cache.clear()

        # Start device flow
        device = self.start_device_authorization()

        verification_uri = device.get(
            "verification_uri_complete") or device.get("verification_uri")
        user_code = device.get("user_code")
        device_code = device["device_code"]
        interval = int(device.get("interval", 5))
        expires_in = int(device.get("expires_in", 600))

        # Keycloak usually provides both verification_uri and a "complete" one
        complete = device.get("verification_uri_complete")

        print("\n=== Device Login Required ===")
        print(f"Open this URL:\n  {verification_uri}\n")
        if user_code and complete is None:
            print(f"Enter code:\n  {user_code}\n")
        if show_qr and verification_uri:
            qr = pyqrcode.create(verification_uri)
            print(qr.terminal(quiet_zone=1))

        if open_browser and isinstance(verification_uri, str):
            import webbrowser
            try:
                webbrowser.open(verification_uri)
            except webbrowser.Error as e:
                print(f"Failed to open browser: {str(e)}")

        new_tokens = self.poll_token_endpoint(
            device_code=device_code,
            interval=interval,
            expires_in=expires_in)
        cache.save(new_tokens)
        return new_tokens
