import time
import warnings

import keyring
import keyring.errors
import pytest

from acmadauth.tokencaches.jsoncache import JsonFileTokenCache
from acmadauth.tokencaches.keyringcache import KeyringTokenCache
from acmadauth.tokenset import JWTToken, TokenSet


def _make_tokens() -> TokenSet:
    now = time.time()
    return TokenSet(
        access_token=JWTToken(token="access", expires_in=3600, time_obtained=now),
        refresh_token=JWTToken(token="refresh", expires_in=3600, time_obtained=now),
        token_type="Bearer",
    )


def test_json_cache_roundtrip_and_clear(tmp_path):
    path = tmp_path / "tokens.json"
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")    
        cache = JsonFileTokenCache(str(path))
        assert len(w) == 1
        assert issubclass(w[-1].category, UserWarning)
    tokens = _make_tokens()

    cache.save(tokens)
    loaded = cache.load()

    assert loaded == tokens

    cache.clear()
    assert not path.exists()


def test_keyring_cache_roundtrip_and_clear(monkeypatch):
    store: dict[tuple[str, str], str] = {}

    def fake_get(service, username):
        return store.get((service, username))

    def fake_set(service, username, password):
        store[(service, username)] = password

    def fake_delete(service, username):
        if (service, username) not in store:
            raise keyring.errors.PasswordDeleteError("missing")
        del store[(service, username)]

    monkeypatch.setattr(keyring, "get_password", fake_get)
    monkeypatch.setattr(keyring, "set_password", fake_set)
    monkeypatch.setattr(keyring, "delete_password", fake_delete)

    cache = KeyringTokenCache("acmad:test")
    tokens = _make_tokens()

    cache.save(tokens)
    loaded = cache.load()

    assert loaded == tokens

    cache.clear()
    assert store == {}


def test_json_cache_load_invalid_json(tmp_path):
    path = tmp_path / "tokens.json"
    path.write_text("{invalid json")

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")    
        cache = JsonFileTokenCache(str(path))
        assert len(w) == 1
        assert issubclass(w[-1].category, UserWarning)        
    with pytest.raises(Exception) as excinfo:
        cache.load()
    assert "Failed to load token cache" in str(excinfo.value)
