# litellm/proxy/guardrails/guardrail_hooks/pangea.py
import os
from typing import TYPE_CHECKING, Any, Optional, Protocol, Type

from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_guardrail import (
    CustomGuardrail,
    log_guardrail_information,
)
from litellm.llms.custom_httpx.http_handler import (
    get_async_httpx_client,
    httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.common_utils.callback_utils import (
    add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import LLMResponseTypes, ModelResponse, TextCompletionResponse

if TYPE_CHECKING:
    from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel


class PangeaGuardrailMissingSecrets(Exception):
    """Custom exception for missing Pangea secrets."""

    pass


class _Transformer(Protocol):
    def get_messages(self) -> list[dict]:  # noqa: E704
        ...

    def update_original_body(self, prompt_messages: list[dict]) -> Any:  # noqa: E704
        ...


class _TextCompletionRequest:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        return [{"role": "user", "content": self.body["prompt"]}]

    # This mutates the original dict, but we'll still return it anyways
    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == 1
        self.body["prompt"] = prompt_messages[0]["content"]
        return self.body


class _TextCompletionResponse:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []
        for choice in self.body["choices"]:
            messages.append({"role": "assistant", "content": choice["text"]})

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == len(self.body["choices"])

        for choice, prompt_message in zip(self.body["choices"], prompt_messages):
            choice["text"] = prompt_message["content"]

        return self.body


class _ChatCompletionRequest:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []

        for message in self.body["messages"]:
            role = message["role"]
            content = message["content"]
            if isinstance(content, str):
                messages.append({"role": role, "content": content})
            if isinstance(content, list):
                for content_part in content:
                    if content_part["type"] == "text":
                        messages.append({"role": role, "content": content_part["text"]})

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        count = 0

        for message in self.body["messages"]:
            content = message["content"]
            if isinstance(content, str):
                message["content"] = prompt_messages[count]["content"]
                count += 1
            if isinstance(content, list):
                for content_part in content:
                    if content_part["type"] == "text":
                        content_part["text"] = prompt_messages[count]["content"]
                        count += 1

        assert len(prompt_messages) == count
        return self.body


class _ChatCompletionResponse:
    def __init__(self, body):
        self.body = body

    def get_messages(self) -> list[dict]:
        messages = []

        for choice in self.body["choices"]:
            messages.append(
                {
                    "role": choice["message"]["role"],
                    "content": choice["message"]["content"],
                }
            )

        return messages

    def update_original_body(self, prompt_messages: list[dict]) -> Any:
        assert len(prompt_messages) == len(self.body["choices"])

        for choice, prompt_message in zip(self.body["choices"], prompt_messages):
            choice["message"]["content"] = prompt_message["content"]

        return self.body


def _get_transformer_for_request(body, call_type) -> Optional[_Transformer]:
    match call_type:
        case "text_completion" | "atext_completion":
            return _TextCompletionRequest(body)
        case "completion" | "acompletion":
            return _ChatCompletionRequest(body)

    return None


def _get_transformer_for_response(body) -> Optional[_Transformer]:
    match body:
        case TextCompletionResponse():
            return _TextCompletionResponse(body)
        case ModelResponse():
            return _ChatCompletionResponse(body)

    return None


class PangeaHandler(CustomGuardrail):
    """
    Pangea AI Guardrail handler to interact with the Pangea AI Guard service.

    This class implements the necessary hooks to call the Pangea AI Guard API
    for input and output scanning based on the configured recipe.
    """

    def __init__(
        self,
        guardrail_name: str,
        pangea_input_recipe: Optional[str] = None,
        pangea_output_recipe: Optional[str] = None,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        **kwargs,
    ):
        """
        Initializes the PangeaHandler.

        Args:
            guardrail_name (str): The name of the guardrail instance.
            pangea_recipe (str): The Pangea recipe key to use for scanning.
            api_key (Optional[str]): The Pangea API key. Reads from PANGEA_API_KEY env var if None.
            api_base (Optional[str]): The Pangea API base URL. Reads from PANGEA_API_BASE env var or uses default if None.
            **kwargs: Additional arguments passed to the CustomGuardrail base class.
        """
        self.async_handler = get_async_httpx_client(
            llm_provider=httpxSpecialProvider.GuardrailCallback
        )
        self.api_key = api_key or os.environ.get("PANGEA_API_KEY")
        if not self.api_key:
            raise PangeaGuardrailMissingSecrets(
                "Pangea API Key not found. Set PANGEA_API_KEY environment variable or pass it in litellm_params."
            )

        # Default Pangea base URL if not provided
        self.api_base = (
            api_base
            or os.environ.get("PANGEA_API_BASE")
            or "https://ai-guard.aws.us.pangea.cloud"
        )
        self.pangea_input_recipe = pangea_input_recipe
        self.pangea_output_recipe = pangea_output_recipe
        self.guardrail_endpoint = f"{self.api_base}/v1/text/guard"

        # Pass relevant kwargs to the parent class
        super().__init__(guardrail_name=guardrail_name, **kwargs)
        verbose_proxy_logger.info(
            f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_input_recipe}, api_base={self.api_base}"
        )

    async def _call_pangea_guard(self, payload: dict, hook_name: str) -> dict:
        """
        Makes the API call to the Pangea AI Guard endpoint.
        The function itself will raise an error in the case that a response
        should be blocked, but will return a list of redacted messages that the caller
        should act on.

        Args:
            payload (dict): The request payload.
            request_data (dict): Original request data (used for logging/headers).
            hook_name (str): Name of the hook calling this function (for logging).

        Raises:
            HTTPException: If the Pangea API returns a 'blocked: true' response.
            Exception: For other API call failures.

        Returns:
            list[dict]: The original response body
        """
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        try:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail ({hook_name}): Calling endpoint {self.guardrail_endpoint} with payload: {payload}"
            )
            response = await self.async_handler.post(
                url=self.guardrail_endpoint, json=payload, headers=headers
            )
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)

            result = response.json()
            verbose_proxy_logger.debug(
                f"Pangea Guardrail ({hook_name}): Received response: {result}"
            )

            # Check if the request was blocked
            if result.get("result", {}).get("blocked") is True:
                verbose_proxy_logger.warning(
                    f"Pangea Guardrail ({hook_name}): Request blocked. Response: {result}"
                )
                raise HTTPException(
                    status_code=400,  # Bad Request, indicating violation
                    detail={
                        "error": "Violated Pangea guardrail policy",
                        "guardrail_name": self.guardrail_name,
                        "pangea_response": result.get("result"),
                    },
                )
            else:
                verbose_proxy_logger.info(
                    f"Pangea Guardrail ({hook_name}): Request passed. Response: {result.get('result', {}).get('detectors')}"
                )

            return result

        except HTTPException as e:
            # Re-raise HTTPException if it's the one we raised for blocking
            raise e
        except Exception as e:
            verbose_proxy_logger.error(
                f"Pangea Guardrail ({hook_name}): Error calling API: {e}. Response text: {getattr(e, 'response', None) and getattr(e.response, 'text', None)}"  # type: ignore
            )
            # Decide if you want to block by default on error, or allow through
            # Raising an exception here will block the request.
            # To allow through on error, you might just log and return.
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Error communicating with Pangea Guardrail",
                    "guardrail_name": self.guardrail_name,
                    "exception": str(e),
                },
            ) from e

    @log_guardrail_information
    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,
    ):
        event_type = GuardrailEventHooks.pre_call
        if self.should_run_guardrail(data=data, event_type=event_type) is not True:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}."
            )
            return data

        transformer = _get_transformer_for_request(data, call_type)
        if not transformer:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}"
                f" because we cannot determine type of request: call_type '{call_type}'"
            )
            return

        messages = transformer.get_messages()
        if not messages:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}"
                " because messages is empty."
            )
            return

        ai_guard_payload = {
            "debug": False,  # Or make this configurable if needed
            "messages": messages,
        }
        if self.pangea_input_recipe:
            ai_guard_payload["recipe"] = self.pangea_input_recipe

        ai_guard_response = await self._call_pangea_guard(
            ai_guard_payload, "async_pre_call_hook"
        )
        # Add guardrail name to header if passed
        add_guardrail_to_applied_guardrails_header(
            request_data=data, guardrail_name=self.guardrail_name
        )
        prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", [])

        try:
            return transformer.update_original_body(prompt_messages)
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Failed to update original request body",
                    "guardrail_name": self.guardrail_name,
                    "exceptions": str(e),
                },
            ) from e

    @log_guardrail_information
    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: UserAPIKeyAuth,
        # This union isn't actually correct -- it can get other response types depending on the API called
        response: LLMResponseTypes,
    ):
        """
        Guardrail hook run after a successful LLM call (scans output).

        Args:
            data (dict): The original request data.
            user_api_key_dict (UserAPIKeyAuth): User API key details.
            response (LLMResponseTypes): The response object from the LLM call.
        """
        event_type = GuardrailEventHooks.post_call
        if self.should_run_guardrail(data=data, event_type=event_type) is not True:
            verbose_proxy_logger.debug(
                f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}."
            )
            return data

        transformer = _get_transformer_for_response(response)
        if not transformer:
            verbose_proxy_logger.warning(
                f"Pangea Guardrail (async_post_call_success_hook): Skipping guardrail {self.guardrail_name}"
                " because we cannot determine type of request"
            )
            return

        messages = transformer.get_messages()
        verbose_proxy_logger.warning(f"GOT MESSAGES: {messages}")
        ai_guard_payload = {
            "debug": False,  # Or make this configurable if needed
            "messages": messages,
        }
        if self.pangea_output_recipe:
            ai_guard_payload["recipe"] = self.pangea_output_recipe

        ai_guard_response = await self._call_pangea_guard(
            ai_guard_payload, "post_call_success_hook"
        )
        prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", [])

        try:
            return transformer.update_original_body(prompt_messages)
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail={
                    "error": "Failed to update original response body",
                    "guardrail_name": self.guardrail_name,
                    "exceptions": str(e),
                },
            ) from e

    @staticmethod
    def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
        from litellm.types.proxy.guardrails.guardrail_hooks.pangea import (
            PangeaGuardrailConfigModel,
        )

        return PangeaGuardrailConfigModel
