"""
Module defining data structures for handling JWT token sets.

"""

import time
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional

import jwt
import marshmallow_dataclass
from marshmallow import ValidationError

from .errors import TokenSetError
from .schemas import TokenResponseSchema


@dataclass
class JWTToken:
    """
    Class representing a JWT token with expiration handling.

    Parameters:
        token: The JWT token string.
        expires_in: The number of seconds until the token expires, if provided.
        time_obtained: The timestamp when the token was obtained (default is current time).
    """

    token: str
    expires_in: Optional[int]
    time_obtained: float = field(
        default_factory=time.time,
        )

    def get_claims(self) -> Dict[str, Any]:
        """
        Decode the JWT token without verifying the signature to extract claims.

        Returns:
            Dict[str, Any]: The claims contained in the JWT token.
        """
        try:
            return jwt.decode(
                self.token,
                options={
                    "verify_signature": False,
                    "verify_exp": False})
        except Exception:
            return {}

    def is_expired(self, margin: int = 15) -> bool:
        """
        Check if the token is expired, considering an optional margin. Expiration is determined
        by the earliest of the 'exp' claim in the token and the 'expires_in' value (if provided).

        Parameters:
            margin: The number of seconds before actual expiration to consider the token as expired.

        Returns:
            True if the token is expired or will expire within the margin, False otherwise.

        Raises:
            ValueError: If the token is not set.
        """
        if not self.token:
            raise ValueError("Token is not set.")
        exp: Optional[int] = self.get_claims().get("exp")
        alt: Optional[float] = (self.time_obtained +
                                self.expires_in) if self.expires_in else None
        min_exp: list[float] = [float(t) for t in [exp, alt] if t is not None]
        if not min_exp:
            return True
        return time.time() >= (min(min_exp) - margin)

    def __str__(self) -> str:
        """
        Return the string representation of the JWT token.

        Returns:
            The JWT token string or an empty string if the token is not set.
        """
        return self.token or ""


@dataclass
class TokenSet:
    """
    Class representing a set of tokens including access and refresh tokens.

    Parameters:
        access_token: The access token (JWTToken).
        refresh_token: The refresh token (JWTToken), if available.
        token_type: The type of the token (e.g., "Bearer").
        scope: The scope of the token, if available.
    """

    access_token: JWTToken
    refresh_token: Optional[JWTToken]
    token_type: str
    scope: Optional[str] = None

    def to_json(self) -> Dict[str, Any]:
        """
        Serialize the TokenSet to a JSON-compatible dictionary.

        Returns:
            The serialized TokenSet.
        """
        return asdict(self)

    @staticmethod
    def from_token_response(data: Dict[str, Any]) -> "TokenSet":
        """
        Create a TokenSet from a token response dictionary.

        Parameters:
            data: The dictionary containing the token response data.

        Returns:
            The created TokenSet object.

        Raises:
            KeyError: If required fields are missing in the data.
        """

        schema = TokenResponseSchema()
        try:
            validated_data = schema.load(data)
        except ValidationError as err:
            raise TokenSetError(f"Invalid token response data: {err.messages}")

        return TokenSet(
            access_token=JWTToken(
                token=validated_data["access_token"],
                expires_in=validated_data.get(
                    "expires_in",
                    None)),
            refresh_token=JWTToken(
                token=validated_data.get("refresh_token"),
                expires_in=validated_data.get(
                    "refresh_expires_in",
                    None)) if validated_data.get("refresh_token") else None,
            token_type=validated_data.get("token_type") or "Bearer",
            scope=validated_data.get(
                "scope",
                None),
        )

    @staticmethod
    def from_json(data: Dict[str, Any]) -> "TokenSet":
        """
        Deserialize a TokenSet from a JSON-compatible dictionary.

        Parameters:
            data: The dictionary containing the token set data.
        Returns:
            The deserialized TokenSet object.
        Raises:
            marshmallow.ValidationError: If the data is invalid.
        """
        schema = marshmallow_dataclass.class_schema(TokenSet)()
        return schema.load(data)
