"""
MCP Guardrail Handler for Unified Guardrails.

This handler works with the synthetic "messages" payload generated by
`ProxyLogging._convert_mcp_to_llm_format`, which always produces a single user
message whose `content` string encodes the MCP tool name and arguments. The
handler simply feeds that text through the configured guardrail and writes the
result back onto the message.
"""

from typing import TYPE_CHECKING, Any, Dict, Optional

from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs

if TYPE_CHECKING:
    from litellm.integrations.custom_guardrail import CustomGuardrail
    from mcp.types import CallToolResult


class MCPGuardrailTranslationHandler(BaseTranslation):
    """Guardrail translation handler for MCP tool calls."""

    async def process_input_messages(
        self,
        data: Dict[str, Any],
        guardrail_to_apply: "CustomGuardrail",
        litellm_logging_obj: Optional[Any] = None,
    ) -> Dict[str, Any]:
        messages = data.get("messages")
        if not isinstance(messages, list) or not messages:
            verbose_proxy_logger.debug("MCP Guardrail: No messages to process")
            return data

        first_message = messages[0]
        content: Optional[str] = None
        if isinstance(first_message, dict):
            content = first_message.get("content")
        else:
            content = getattr(first_message, "content", None)

        if not isinstance(content, str):
            verbose_proxy_logger.debug(
                "MCP Guardrail: Message content missing or not a string",
            )
            return data

        guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
            inputs=GenericGuardrailAPIInputs(texts=[content]),
            request_data=data,
            input_type="request",
            logging_obj=litellm_logging_obj,
        )
        guardrailed_texts = (
            guardrailed_inputs.get("texts", []) if guardrailed_inputs else []
        )

        if guardrailed_texts:
            new_content = guardrailed_texts[0]
            if isinstance(first_message, dict):
                first_message["content"] = new_content
            else:
                setattr(first_message, "content", new_content)

            verbose_proxy_logger.debug(
                "MCP Guardrail: Updated content for tool %s",
                data.get("mcp_tool_name"),
            )
        else:
            verbose_proxy_logger.debug(
                "MCP Guardrail: Guardrail returned no text updates for tool %s",
                data.get("mcp_tool_name"),
            )

        return data

    async def process_output_response(
        self,
        response: "CallToolResult",
        guardrail_to_apply: "CustomGuardrail",
        litellm_logging_obj: Optional[Any] = None,
        user_api_key_dict: Optional[Any] = None,
    ) -> Any:
        # Not implemented: MCP guardrail translation never calls this path today.
        verbose_proxy_logger.debug(
            "MCP Guardrail: Output processing not implemented for MCP tools",
        )
        return response
