import os
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union

import httpx

import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import (
    OptionalRerankParams,
    RerankBilledUnits,
    RerankResponse,
    RerankResponseDocument,
    RerankResponseMeta,
    RerankResponseResult,
    RerankTokens,
)
from litellm.utils import token_counter

from ..common_utils import HuggingFaceError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

    LoggingClass = LiteLLMLoggingObj
else:
    LoggingClass = Any


class HuggingFaceRerankResponseItem(TypedDict):
    """Type definition for HuggingFace rerank API response items."""

    index: int
    score: float
    text: Optional[str]  # Optional, included when return_text=True


class HuggingFaceRerankResponse(TypedDict):
    """Type definition for HuggingFace rerank API complete response."""

    # The response is a list of HuggingFaceRerankResponseItem
    pass


# Type alias for the actual response structure
HuggingFaceRerankResponseList = List[HuggingFaceRerankResponseItem]


class HuggingFaceRerankConfig(BaseRerankConfig):
    def get_api_base(self, model: str, api_base: Optional[str]) -> str:
        if api_base is not None:
            return api_base
        elif os.getenv("HF_API_BASE") is not None:
            return os.getenv("HF_API_BASE", "")
        elif os.getenv("HUGGINGFACE_API_BASE") is not None:
            return os.getenv("HUGGINGFACE_API_BASE", "")
        else:
            return "https://api-inference.huggingface.co"

    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
        """
        Get the complete URL for the API call, including the /rerank suffix if necessary.
        """
        # Get base URL from api_base or default
        base_url = self.get_api_base(model=model, api_base=api_base)

        # Remove trailing slashes and ensure we have the /rerank endpoint
        base_url = base_url.rstrip("/")
        if not base_url.endswith("/rerank"):
            base_url = f"{base_url}/rerank"

        return base_url

    def get_supported_cohere_rerank_params(self, model: str) -> list:
        return [
            "query",
            "documents",
            "top_n",
            "return_documents",
        ]

    def map_cohere_rerank_params(
        self,
        non_default_params: Optional[dict],
        model: str,
        drop_params: bool,
        query: str,
        documents: List[Union[str, Dict[str, Any]]],
        custom_llm_provider: Optional[str] = None,
        top_n: Optional[int] = None,
        rank_fields: Optional[List[str]] = None,
        return_documents: Optional[bool] = True,
        max_chunks_per_doc: Optional[int] = None,
        max_tokens_per_doc: Optional[int] = None,
    ) -> OptionalRerankParams:
        optional_rerank_params = {}
        if non_default_params is not None:
            for k, v in non_default_params.items():
                if k == "documents" and v is not None:
                    optional_rerank_params["texts"] = v
                elif k == "return_documents" and v is not None and isinstance(v, bool):
                    optional_rerank_params["return_text"] = v
                elif k == "top_n" and v is not None:
                    optional_rerank_params["top_n"] = v
                elif k == "documents" and v is not None:
                    optional_rerank_params["texts"] = v
                elif k == "query" and v is not None:
                    optional_rerank_params["query"] = v

        return OptionalRerankParams(**optional_rerank_params)  # type: ignore

    def validate_environment(
        self,
        headers: dict,
        model: str,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        # Get API credentials
        api_key, api_base = self.get_api_credentials(api_key=api_key, api_base=api_base)

        default_headers = {
            "accept": "application/json",
            "content-type": "application/json",
        }

        if api_key:
            default_headers["Authorization"] = f"Bearer {api_key}"

        if "Authorization" in headers:
            default_headers["Authorization"] = headers["Authorization"]

        return {**default_headers, **headers}

    def transform_rerank_request(
        self,
        model: str,
        optional_rerank_params: Union[OptionalRerankParams, dict],
        headers: dict,
    ) -> dict:
        if "query" not in optional_rerank_params:
            raise ValueError("query is required for HuggingFace rerank")
        if "texts" not in optional_rerank_params:
            raise ValueError(
                "Cohere 'documents' param is required for HuggingFace rerank"
            )
        # Ensure return_text is a boolean value
        # HuggingFace API expects return_text parameter, corresponding to our return_documents parameter
        request_body = {
            "raw_scores": False,
            "truncate": False,
            "truncation_direction": "Right",
        }

        request_body.update(optional_rerank_params)

        return request_body

    def transform_rerank_response(
        self,
        model: str,
        raw_response: httpx.Response,
        model_response: RerankResponse,
        logging_obj: LoggingClass,
        api_key: Optional[str] = None,
        request_data: dict = {},
        optional_params: dict = {},
        litellm_params: dict = {},
    ) -> RerankResponse:
        try:
            raw_response_json: HuggingFaceRerankResponseList = raw_response.json()
        except Exception:
            raise HuggingFaceError(
                message=getattr(raw_response, "text", str(raw_response)),
                status_code=getattr(raw_response, "status_code", 500),
            )

        # Use standard litellm token counter for proper token estimation
        input_text = request_data.get("query", "")
        try:
            # Calculate tokens for the raw response JSON string
            response_text = str(raw_response_json)
            estimated_output_tokens = token_counter(model=model, text=response_text)

            # Calculate input tokens from query and documents
            query = request_data.get("query", "")
            documents = request_data.get("texts", [])

            # Convert documents to string if they're not already
            documents_text = ""
            for doc in documents:
                if isinstance(doc, str):
                    documents_text += doc + " "
                elif isinstance(doc, dict) and "text" in doc:
                    documents_text += doc["text"] + " "

            # Calculate input tokens using the same model
            input_text = query + " " + documents_text
            estimated_input_tokens = token_counter(model=model, text=input_text)
        except Exception:
            # Fallback to reasonable estimates if token counting fails
            estimated_output_tokens = (
                len(raw_response_json) * 10 if raw_response_json else 10
            )
            estimated_input_tokens = (
                len(input_text) * 4 if "input_text" in locals() else 0
            )

        _billed_units = RerankBilledUnits(search_units=1)
        _tokens = RerankTokens(
            input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens
        )
        rerank_meta = RerankResponseMeta(
            api_version={"version": "1.0"}, billed_units=_billed_units, tokens=_tokens
        )

        # Check if documents should be returned based on request parameters
        should_return_documents = request_data.get(
            "return_text", False
        ) or request_data.get("return_documents", False)
        original_documents = request_data.get("texts", [])

        results = []
        for item in raw_response_json:
            # Extract required fields with defaults to handle None values
            index = item.get("index")
            score = item.get("score")

            # Skip items that don't have required fields
            if index is None or score is None:
                continue

            # Create RerankResponseResult with required fields
            result = RerankResponseResult(index=index, relevance_score=score)

            # Add optional document field if needed
            if should_return_documents:
                text_content = item.get("text", "")

                # 1. First try to use text returned directly from API if available
                if text_content:
                    result["document"] = RerankResponseDocument(text=text_content)
                # 2. If no text in API response but original documents are available, use those
                elif original_documents and 0 <= item.get("index", -1) < len(
                    original_documents
                ):
                    doc = original_documents[item.get("index")]
                    if isinstance(doc, str):
                        result["document"] = RerankResponseDocument(text=doc)
                    elif isinstance(doc, dict) and "text" in doc:
                        result["document"] = RerankResponseDocument(text=doc["text"])

            results.append(result)

        return RerankResponse(
            id=str(uuid.uuid4()),
            results=results,
            meta=rerank_meta,
        )

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
    ) -> BaseLLMException:
        return HuggingFaceError(message=error_message, status_code=status_code)

    def get_api_credentials(
        self,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> Tuple[Optional[str], Optional[str]]:
        """
        Get API key and base URL from multiple sources.
        Returns tuple of (api_key, api_base).

        Parameters:
            api_key: API key provided directly to this function, takes precedence over all other sources
            api_base: API base provided directly to this function, takes precedence over all other sources
        """
        # Get API key from multiple sources
        final_api_key = (
            api_key or litellm.huggingface_key or get_secret_str("HUGGINGFACE_API_KEY")
        )

        # Get API base from multiple sources
        final_api_base = (
            api_base
            or litellm.api_base
            or get_secret_str("HF_API_BASE")
            or get_secret_str("HUGGINGFACE_API_BASE")
        )

        return final_api_key, final_api_base
