import asyncio
from typing import Dict, List, Optional

from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
    DBSpendUpdateTransactions,
    Litellm_EntityType,
    SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
    BaseUpdateQueue,
    service_logger_obj,
)
from litellm.types.services import ServiceTypes


class SpendUpdateQueue(BaseUpdateQueue):
    """
    In memory buffer for spend updates that should be committed to the database
    """

    def __init__(self):
        super().__init__()
        self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()

    async def flush_and_get_aggregated_db_spend_update_transactions(
        self,
    ) -> DBSpendUpdateTransactions:
        """Flush all updates from the queue and return all updates aggregated by entity type."""
        updates = await self.flush_all_updates_from_in_memory_queue()
        verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
        return self.get_aggregated_db_spend_update_transactions(updates)

    async def add_update(self, update: SpendUpdateQueueItem):
        """Enqueue an update to the spend update queue"""
        verbose_proxy_logger.debug("Adding update to queue: %s", update)
        await self.update_queue.put(update)

        # if the queue is full, aggregate the updates
        if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
            verbose_proxy_logger.warning(
                "Spend update queue is full. Aggregating all entries in queue to concatenate entries."
            )
            await self.aggregate_queue_updates()

    async def aggregate_queue_updates(self):
        """Concatenate all updates in the queue to reduce the size of in-memory queue"""
        updates: List[
            SpendUpdateQueueItem
        ] = await self.flush_all_updates_from_in_memory_queue()
        aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
        for update in aggregated_updates:
            await self.update_queue.put(update)
        return

    def _get_aggregated_spend_update_queue_item(
        self, updates: List[SpendUpdateQueueItem]
    ) -> List[SpendUpdateQueueItem]:
        """
        This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id


        Aggregate updates by entity type + id

        eg.

        ```
        [
            {
                "entity_type": "user",
                "entity_id": "123",
                "response_cost": 100
            },
            {
                "entity_type": "user",
                "entity_id": "123",
                "response_cost": 200
            }
        ]

        ```

        becomes

        ```

        [
            {
                "entity_type": "user",
                "entity_id": "123",
                "response_cost": 300
            }
        ]

        ```
        """
        verbose_proxy_logger.debug(
            "Aggregating spend updates, current queue size: %s",
            self.update_queue.qsize(),
        )
        aggregated_spend_updates: List[SpendUpdateQueueItem] = []

        _in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
        """
        Used for combining several updates into a single update
        Key=entity_type:entity_id
        Value=SpendUpdateQueueItem
        """
        for update in updates:
            _key = f"{update.get('entity_type')}:{update.get('entity_id')}"
            if _key not in _in_memory_map:
                _in_memory_map[_key] = update
            else:
                current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
                update_cost = update.get("response_cost", 0) or 0
                _in_memory_map[_key]["response_cost"] = current_cost + update_cost

        for _key, update in _in_memory_map.items():
            aggregated_spend_updates.append(update)

        verbose_proxy_logger.debug(
            "Aggregated spend updates: %s", aggregated_spend_updates
        )
        return aggregated_spend_updates

    def get_aggregated_db_spend_update_transactions(
        self, updates: List[SpendUpdateQueueItem]
    ) -> DBSpendUpdateTransactions:
        """Aggregate updates by entity type."""
        # Initialize all transaction lists as empty dicts
        db_spend_update_transactions = DBSpendUpdateTransactions(
            user_list_transactions={},
            end_user_list_transactions={},
            key_list_transactions={},
            team_list_transactions={},
            team_member_list_transactions={},
            org_list_transactions={},
        )

        # Map entity types to their corresponding transaction dictionary keys
        entity_type_to_dict_key = {
            Litellm_EntityType.USER: "user_list_transactions",
            Litellm_EntityType.END_USER: "end_user_list_transactions",
            Litellm_EntityType.KEY: "key_list_transactions",
            Litellm_EntityType.TEAM: "team_list_transactions",
            Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
            Litellm_EntityType.ORGANIZATION: "org_list_transactions",
        }

        for update in updates:
            entity_type = update.get("entity_type")
            entity_id = update.get("entity_id") or ""
            response_cost = update.get("response_cost") or 0

            if entity_type is None:
                verbose_proxy_logger.debug(
                    "Skipping update spend for update: %s, because entity_type is None",
                    update,
                )
                continue

            dict_key = entity_type_to_dict_key.get(entity_type)
            if dict_key is None:
                verbose_proxy_logger.debug(
                    "Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
                    update,
                )
                continue  # Skip unknown entity types

            # Type-safe access using if/elif statements
            if dict_key == "user_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "user_list_transactions"
                ]
            elif dict_key == "end_user_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "end_user_list_transactions"
                ]
            elif dict_key == "key_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "key_list_transactions"
                ]
            elif dict_key == "team_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "team_list_transactions"
                ]
            elif dict_key == "team_member_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "team_member_list_transactions"
                ]
            elif dict_key == "org_list_transactions":
                transactions_dict = db_spend_update_transactions[
                    "org_list_transactions"
                ]
            else:
                continue

            if transactions_dict is None:
                transactions_dict = {}

                # type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
                db_spend_update_transactions[dict_key] = transactions_dict  # type: ignore

            if entity_id not in transactions_dict:
                transactions_dict[entity_id] = 0

            transactions_dict[entity_id] += response_cost or 0

        return db_spend_update_transactions

    async def _emit_new_item_added_to_queue_event(
        self,
        queue_size: Optional[int] = None,
    ):
        asyncio.create_task(
            service_logger_obj.async_service_success_hook(
                service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
                duration=0,
                call_type="_emit_new_item_added_to_queue_event",
                event_metadata={
                    "gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
                    "gauge_value": queue_size,
                },
            )
        )
