import time

import httpx
import pytest
from requests.models import PreparedRequest

from acmadauth.auth.httpxauth import HttpxDeviceFlowAuth
from acmadauth.auth.requestsauth import RequestsDeviceFlowAuth
from acmadauth.errors import DeviceAuthError
from acmadauth.handlers import KeycloakOIDCDeviceFlowHandler
from acmadauth.tokencaches.memory import MemoryTokenCache
from acmadauth.tokenset import JWTToken, TokenSet


class _FakeFlow:
    def __init__(self, tokens: TokenSet, refreshed: TokenSet | None = None) -> None:
        self._tokens = tokens
        self._refreshed = refreshed or tokens
        self.refreshed_with: str | None = None
        self.ensure_calls = 0

    def ensure_tokens(self, cache, *, open_browser=False, show_qr=True) -> TokenSet:
        self.ensure_calls += 1
        return self._tokens

    def refresh(self, refresh_token: str) -> TokenSet:
        self.refreshed_with = refresh_token
        return self._refreshed


def _make_tokens(access: str, refresh: str | None = "refresh") -> TokenSet:
    now = time.time()
    refresh_token = (
        JWTToken(token=refresh, expires_in=3600, time_obtained=now)
        if refresh
        else None
    )
    return TokenSet(
        access_token=JWTToken(token=access, expires_in=3600, time_obtained=now),
        refresh_token=refresh_token,
        token_type="Bearer",
    )


def test_httpxauth_injects_header():
    tokens = _make_tokens("access")
    flow = _FakeFlow(tokens)
    cache = MemoryTokenCache()
    cache.save(tokens)

    auth = HttpxDeviceFlowAuth(flow, cache)
    req = httpx.Request("GET", "https://example.test")
    gen = auth.auth_flow(req)
    req_out = next(gen)

    assert req_out.headers["Authorization"] == "Bearer access"

    resp = httpx.Response(200, request=req_out, content=b"ok")
    with pytest.raises(StopIteration):
        gen.send(resp)


def test_httpxauth_401_refreshes_and_retries():
    tokens = _make_tokens("old-access", refresh="refresh-token")
    refreshed = _make_tokens("new-access", refresh="new-refresh")
    flow = _FakeFlow(tokens, refreshed)
    cache = MemoryTokenCache()
    cache.save(tokens)

    auth = HttpxDeviceFlowAuth(flow, cache)
    req = httpx.Request("GET", "https://example.test")
    gen = auth.auth_flow(req)
    first_req = next(gen)

    assert first_req.headers["Authorization"] == "Bearer old-access"

    resp = httpx.Response(401, request=first_req, content=b"nope")
    second_req = gen.send(resp)

    assert flow.refreshed_with == "refresh-token"
    assert second_req.headers["Authorization"] == "Bearer new-access"

    final = httpx.Response(200, request=second_req, content=b"ok")
    with pytest.raises(StopIteration):
        gen.send(final)


def test_requestsauth_injects_header():
    tokens = _make_tokens("access")
    flow = _FakeFlow(tokens)
    cache = MemoryTokenCache()
    cache.save(tokens)
    auth = RequestsDeviceFlowAuth(flow, cache)

    req = PreparedRequest()
    req.prepare(method="GET", url="https://example.test")
    out = auth(req)

    assert out.headers["Authorization"] == "Bearer access"


def test_device_authorization_missing_endpoint_raises():
    handler = KeycloakOIDCDeviceFlowHandler(
        realm_url="https://example.test",
        client_id="client",
    )
    handler._get_discovery = lambda: {}

    with pytest.raises(DeviceAuthError):
        handler.start_device_authorization()
