import asyncio
import contextvars
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast, overload

import httpx

import litellm
from litellm import Logging, client, exception_type, get_litellm_params
from litellm.constants import DEFAULT_IMAGE_ENDPOINT_MODEL
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
from litellm.exceptions import LiteLLMUnknownProvider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.mock_functions import mock_image_generation
from litellm.llms.base_llm import BaseImageEditConfig, BaseImageGenerationConfig
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.custom_llm import CustomLLM

#################### Initialize provider clients ####################
llm_http_handler: BaseLLMHTTPHandler = BaseLLMHTTPHandler()
from litellm.main import (
    azure_chat_completions,
    base_llm_aiohttp_handler,
    base_llm_http_handler,
    bedrock_image_generation,
    openai_chat_completions,
    openai_image_variations,
    vertex_image_generation,
)

###########################################
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.llms.openai import ImageGenerationRequestQuality
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import (
    LITELLM_IMAGE_VARIATION_PROVIDERS,
    FileTypes,
    LlmProviders,
    all_litellm_params,
)
from litellm.utils import (
    ImageResponse,
    ProviderConfigManager,
    get_llm_provider,
    get_optional_params_image_gen,
)

from .utils import ImageEditRequestUtils


##### Image Generation #######################
@client
async def aimage_generation(*args, **kwargs) -> ImageResponse:
    """
    Asynchronously calls the `image_generation` function with the given arguments and keyword arguments.

    Parameters:
    - `args` (tuple): Positional arguments to be passed to the `image_generation` function.
    - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function.

    Returns:
    - `response` (Any): The response returned by the `image_generation` function.
    """
    loop = asyncio.get_event_loop()
    model = args[0] if len(args) > 0 else kwargs["model"]
    ### PASS ARGS TO Image Generation ###
    kwargs["aimg_generation"] = True
    custom_llm_provider = None
    try:
        # Use a partial function to pass your keyword arguments
        func = partial(image_generation, *args, **kwargs)

        # Add the context to the function
        ctx = contextvars.copy_context()
        func_with_context = partial(ctx.run, func)

        _, custom_llm_provider, _, _ = get_llm_provider(
            model=model, api_base=kwargs.get("api_base", None)
        )

        # Await normally
        init_response = await loop.run_in_executor(None, func_with_context)

        response: Optional[ImageResponse] = None
        if isinstance(init_response, dict):
            response = ImageResponse(**init_response)
        elif isinstance(init_response, ImageResponse):  ## CACHING SCENARIO
            response = init_response
        elif asyncio.iscoroutine(init_response):
            response = await init_response  # type: ignore
        
        if response is None:
            raise ValueError(
                "Unable to get Image Response. Please pass a valid llm_provider."
            )
        
        return response
    except Exception as e:
        custom_llm_provider = custom_llm_provider or "openai"
        raise exception_type(
            model=model,
            custom_llm_provider=custom_llm_provider,
            original_exception=e,
            completion_kwargs=args,
            extra_kwargs=kwargs,
        )


# Overload for when aimg_generation=True (returns Coroutine)
@overload
def image_generation(
    prompt: str,
    model: Optional[str] = None,
    n: Optional[int] = None,
    quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
    response_format: Optional[str] = None,
    size: Optional[str] = None,
    style: Optional[str] = None,
    user: Optional[str] = None,
    input_fidelity: Optional[str] = None,
    timeout=600,  # default to 10 minutes
    api_key: Optional[str] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    custom_llm_provider=None,
    *,
    aimg_generation: Literal[True],
    **kwargs,
) -> Coroutine[Any, Any, ImageResponse]:
    ...


# Overload for when aimg_generation=False or not specified (returns ImageResponse)
@overload
def image_generation(
    prompt: str,
    model: Optional[str] = None,
    n: Optional[int] = None,
    quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
    response_format: Optional[str] = None,
    size: Optional[str] = None,
    style: Optional[str] = None,
    user: Optional[str] = None,
    input_fidelity: Optional[str] = None,
    timeout=600,  # default to 10 minutes
    api_key: Optional[str] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    custom_llm_provider=None,
    *,
    aimg_generation: Literal[False] = False,
    **kwargs,
) -> ImageResponse:
    ...


@client
def image_generation(  # noqa: PLR0915
    prompt: str,
    model: Optional[str] = None,
    n: Optional[int] = None,
    quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
    response_format: Optional[str] = None,
    size: Optional[str] = None,
    style: Optional[str] = None,
    user: Optional[str] = None,
    input_fidelity: Optional[str] = None,
    timeout=600,  # default to 10 minutes
    api_key: Optional[str] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    custom_llm_provider=None,
    **kwargs,
) -> Union[
        ImageResponse,
        Coroutine[Any, Any, ImageResponse],
    ]:
    """
    Maps the https://api.openai.com/v1/images/generations endpoint.

    Currently supports just Azure + OpenAI.
    """
    try:
        args = locals()
        aimg_generation = kwargs.get("aimg_generation", False)
        litellm_call_id = kwargs.get("litellm_call_id", None)
        logger_fn = kwargs.get("logger_fn", None)
        mock_response: Optional[str] = kwargs.get("mock_response", None)  # type: ignore
        proxy_server_request = kwargs.get("proxy_server_request", None)
        azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
        model_info = kwargs.get("model_info", None)
        metadata = kwargs.get("metadata", {})
        litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj")  # type: ignore
        client = kwargs.get("client", None)
        extra_headers = kwargs.get("extra_headers", None)
        headers: dict = kwargs.get("headers", None) or {}
        base_model = kwargs.get("base_model", None)
        if extra_headers is not None:
            headers.update(extra_headers)
        model_response: ImageResponse = litellm.utils.ImageResponse()
        dynamic_api_key: Optional[str] = None
        if model is not None or custom_llm_provider is not None:
            model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
                model=model,  # type: ignore
                custom_llm_provider=custom_llm_provider,
                api_base=api_base,
            )
        else:
            model = "dall-e-2"
            custom_llm_provider = "openai"  # default to dall-e-2 on openai
        model_response._hidden_params["model"] = model
        openai_params = [
            "user",
            "request_timeout",
            "api_base",
            "api_version",
            "api_key",
            "deployment_id",
            "organization",
            "base_url",
            "default_headers",
            "timeout",
            "max_retries",
            "n",
            "quality",
            "size",
            "style",
            "input_fidelity",
        ]
        litellm_params = all_litellm_params
        default_params = openai_params + litellm_params
        non_default_params = {
            k: v for k, v in kwargs.items() if k not in default_params
        }  # model-specific params - pass them straight to the model/provider

        image_generation_config: Optional[BaseImageGenerationConfig] = None
        if (
            custom_llm_provider is not None
            and custom_llm_provider in LlmProviders._member_map_.values()
        ):
            image_generation_config = (
                ProviderConfigManager.get_provider_image_generation_config(
                    model=base_model or model,
                    provider=LlmProviders(custom_llm_provider),
                )
            )

        optional_params = get_optional_params_image_gen(
            model=base_model or model,
            n=n,
            quality=quality,
            response_format=response_format,
            size=size,
            style=style,
            user=user,
            input_fidelity=input_fidelity,
            custom_llm_provider=custom_llm_provider,
            provider_config=image_generation_config,
            **non_default_params,
        )

        litellm_params_dict = get_litellm_params(**kwargs)

        logging: Logging = litellm_logging_obj
        logging.update_environment_variables(
            model=model,
            user=user,
            optional_params=optional_params,
            litellm_params={
                "timeout": timeout,
                "azure": False,
                "litellm_call_id": litellm_call_id,
                "logger_fn": logger_fn,
                "proxy_server_request": proxy_server_request,
                "model_info": model_info,
                "metadata": metadata,
                "preset_cache_key": None,
                "stream_response": {},
            },
            custom_llm_provider=custom_llm_provider,
        )
        if "custom_llm_provider" not in logging.model_call_details:
            logging.model_call_details["custom_llm_provider"] = custom_llm_provider
        if mock_response is not None:
            return mock_image_generation(model=model, mock_response=mock_response)

        if custom_llm_provider == "azure":
            # azure configs
            api_type = get_secret_str("AZURE_API_TYPE") or "azure"

            api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")

            api_version = (
                api_version
                or litellm.api_version
                or get_secret_str("AZURE_API_VERSION")
            )

            api_key = (
                api_key
                or litellm.api_key
                or litellm.azure_key
                or get_secret_str("AZURE_OPENAI_API_KEY")
                or get_secret_str("AZURE_API_KEY")
            )

            azure_ad_token = optional_params.pop(
                "azure_ad_token", None
            ) or get_secret_str("AZURE_AD_TOKEN")

            default_headers = {
                "Content-Type": "application/json;",
                "api-key": api_key,
            }
            for k, v in default_headers.items():
                if k not in headers:
                    headers[k] = v

            model_response = azure_chat_completions.image_generation(
                model=model,
                prompt=prompt,
                timeout=timeout,
                api_key=api_key,
                api_base=api_base,
                azure_ad_token=azure_ad_token,
                azure_ad_token_provider=azure_ad_token_provider,
                logging_obj=litellm_logging_obj,
                optional_params=optional_params,
                model_response=model_response,
                api_version=api_version,
                aimg_generation=aimg_generation,
                client=client,
                headers=headers,
                litellm_params=litellm_params_dict,
            )
        elif (
            custom_llm_provider == "openai"
            or custom_llm_provider in litellm.openai_compatible_providers
        ):
            model_response = openai_chat_completions.image_generation(
                model=model,
                prompt=prompt,
                timeout=timeout,
                api_key=api_key or dynamic_api_key,
                api_base=api_base,
                logging_obj=litellm_logging_obj,
                optional_params=optional_params,
                model_response=model_response,
                aimg_generation=aimg_generation,
                client=client,
            )
        elif custom_llm_provider == "bedrock":
            if model is None:
                raise Exception("Model needs to be set for bedrock")
            model_response = bedrock_image_generation.image_generation(  # type: ignore
                model=model,
                prompt=prompt,
                timeout=timeout,
                logging_obj=litellm_logging_obj,
                optional_params=optional_params,
                model_response=model_response,
                aimg_generation=aimg_generation,
                client=client,
                api_base=api_base,
                api_key=api_key
            )
        elif custom_llm_provider == "vertex_ai":
            vertex_ai_project = (
                optional_params.pop("vertex_project", None)
                or optional_params.pop("vertex_ai_project", None)
                or litellm.vertex_project
                or get_secret_str("VERTEXAI_PROJECT")
            )
            vertex_ai_location = (
                optional_params.pop("vertex_location", None)
                or optional_params.pop("vertex_ai_location", None)
                or litellm.vertex_location
                or get_secret_str("VERTEXAI_LOCATION")
            )
            vertex_credentials = (
                optional_params.pop("vertex_credentials", None)
                or optional_params.pop("vertex_ai_credentials", None)
                or get_secret_str("VERTEXAI_CREDENTIALS")
            )

            api_base = (
                api_base
                or litellm.api_base
                or get_secret_str("VERTEXAI_API_BASE")
                or get_secret_str("VERTEX_API_BASE")
            )

            model_response = vertex_image_generation.image_generation(
                model=model,
                prompt=prompt,
                timeout=timeout,
                logging_obj=litellm_logging_obj,
                optional_params=optional_params,
                model_response=model_response,
                vertex_project=vertex_ai_project,
                vertex_location=vertex_ai_location,
                vertex_credentials=vertex_credentials,
                aimg_generation=aimg_generation,
                api_base=api_base,
                client=client,
            )
        #########################################################
        # Providers using llm_http_handler
        #########################################################
        elif custom_llm_provider in (
            litellm.LlmProviders.RECRAFT,
        ):
            if image_generation_config is None:
                raise ValueError(f"image generation config is not supported for {custom_llm_provider}")
            
            return llm_http_handler.image_generation_handler(
                model=model,
                prompt=prompt,
                image_generation_provider_config=image_generation_config,
                image_generation_optional_request_params=optional_params,
                custom_llm_provider=custom_llm_provider,
                litellm_params=litellm_params_dict,
                logging_obj=litellm_logging_obj,
                timeout=timeout,
                client=client,
            )
        elif (
            custom_llm_provider in litellm._custom_providers
        ):  # Assume custom LLM provider
            # Get the Custom Handler
            custom_handler: Optional[CustomLLM] = None
            for item in litellm.custom_provider_map:
                if item["provider"] == custom_llm_provider:
                    custom_handler = item["custom_handler"]

            if custom_handler is None:
                raise LiteLLMUnknownProvider(
                    model=model, custom_llm_provider=custom_llm_provider
                )

            ## ROUTE LLM CALL ##
            if aimg_generation is True:
                async_custom_client: Optional[AsyncHTTPHandler] = None
                if client is not None and isinstance(client, AsyncHTTPHandler):
                    async_custom_client = client

                ## CALL FUNCTION
                model_response = custom_handler.aimage_generation(  # type: ignore
                    model=model,
                    prompt=prompt,
                    api_key=api_key,
                    api_base=api_base,
                    model_response=model_response,
                    optional_params=optional_params,
                    logging_obj=litellm_logging_obj,
                    timeout=timeout,
                    client=async_custom_client,
                )
            else:
                custom_client: Optional[HTTPHandler] = None
                if client is not None and isinstance(client, HTTPHandler):
                    custom_client = client

                ## CALL FUNCTION
                model_response = custom_handler.image_generation(
                    model=model,
                    prompt=prompt,
                    api_key=api_key,
                    api_base=api_base,
                    model_response=model_response,
                    optional_params=optional_params,
                    logging_obj=litellm_logging_obj,
                    timeout=timeout,
                    client=custom_client,
                )

        return model_response
    except Exception as e:
        ## Map to OpenAI Exception
        raise exception_type(
            model=model,
            custom_llm_provider=custom_llm_provider,
            original_exception=e,
            completion_kwargs=locals(),
            extra_kwargs=kwargs,
        )


@client
async def aimage_variation(*args, **kwargs) -> ImageResponse:
    """
    Asynchronously calls the `image_variation` function with the given arguments and keyword arguments.

    Parameters:
    - `args` (tuple): Positional arguments to be passed to the `image_variation` function.
    - `kwargs` (dict): Keyword arguments to be passed to the `image_variation` function.

    Returns:
    - `response` (Any): The response returned by the `image_variation` function.
    """
    loop = asyncio.get_event_loop()
    model = kwargs.get("model", None)
    custom_llm_provider = kwargs.get("custom_llm_provider", None)
    ### PASS ARGS TO Image Generation ###
    kwargs["async_call"] = True
    try:
        # Use a partial function to pass your keyword arguments
        func = partial(image_variation, *args, **kwargs)

        # Add the context to the function
        ctx = contextvars.copy_context()
        func_with_context = partial(ctx.run, func)

        if custom_llm_provider is None and model is not None:
            _, custom_llm_provider, _, _ = get_llm_provider(
                model=model, api_base=kwargs.get("api_base", None)
            )

        # Await normally
        init_response = await loop.run_in_executor(None, func_with_context)
        if isinstance(init_response, dict) or isinstance(
            init_response, ImageResponse
        ):  ## CACHING SCENARIO
            if isinstance(init_response, dict):
                init_response = ImageResponse(**init_response)
            response = init_response
        elif asyncio.iscoroutine(init_response):
            response = await init_response  # type: ignore
        else:
            # Call the synchronous function using run_in_executor
            response = await loop.run_in_executor(None, func_with_context)
        return response
    except Exception as e:
        custom_llm_provider = custom_llm_provider or "openai"
        raise exception_type(
            model=model,
            custom_llm_provider=custom_llm_provider,
            original_exception=e,
            completion_kwargs=args,
            extra_kwargs=kwargs,
        )


@client
def image_variation(
    image: FileTypes,
    model: str = "dall-e-2",  # set to dall-e-2 by default - like OpenAI.
    n: int = 1,
    response_format: Literal["url", "b64_json"] = "url",
    size: Optional[str] = None,
    user: Optional[str] = None,
    **kwargs,
) -> ImageResponse:
    # get non-default params
    client = kwargs.get("client", None)
    # get logging object
    litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))

    # get the litellm params
    litellm_params = get_litellm_params(**kwargs)
    # get the custom llm provider
    model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
        model=model,
        custom_llm_provider=litellm_params.get("custom_llm_provider", None),
        api_base=litellm_params.get("api_base", None),
        api_key=litellm_params.get("api_key", None),
    )

    # route to the correct provider w/ the params
    try:
        llm_provider = LlmProviders(custom_llm_provider)
        image_variation_provider = LITELLM_IMAGE_VARIATION_PROVIDERS(llm_provider)
    except ValueError:
        raise ValueError(
            f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
        )
    model_response = ImageResponse()

    response: Optional[ImageResponse] = None

    provider_config = ProviderConfigManager.get_provider_model_info(
        model=model or "",  # openai defaults to dall-e-2
        provider=llm_provider,
    )

    if provider_config is None:
        raise ValueError(
            f"image variation provider has no known model info config - required for getting api keys, etc.: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
        )

    api_key = provider_config.get_api_key(litellm_params.get("api_key", None))
    api_base = provider_config.get_api_base(litellm_params.get("api_base", None))

    if image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.OPENAI:
        if api_key is None:
            raise ValueError("API key is required for OpenAI image variations")
        if api_base is None:
            raise ValueError("API base is required for OpenAI image variations")

        response = openai_image_variations.image_variations(
            model_response=model_response,
            api_key=api_key,
            api_base=api_base,
            model=model,
            image=image,
            timeout=litellm_params.get("timeout", None),
            custom_llm_provider=custom_llm_provider,
            logging_obj=litellm_logging_obj,
            optional_params={},
            litellm_params=litellm_params,
        )
    elif image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.TOPAZ:
        if api_key is None:
            raise ValueError("API key is required for Topaz image variations")
        if api_base is None:
            raise ValueError("API base is required for Topaz image variations")

        response = base_llm_aiohttp_handler.image_variations(
            model_response=model_response,
            api_key=api_key,
            api_base=api_base,
            model=model,
            image=image,
            timeout=litellm_params.get("timeout", None) or DEFAULT_REQUEST_TIMEOUT,
            custom_llm_provider=custom_llm_provider,
            logging_obj=litellm_logging_obj,
            optional_params={},
            litellm_params=litellm_params,
            client=client,
        )

    # return the response
    if response is None:
        raise ValueError(
            f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
        )
    return response


@client
def image_edit(
    image: FileTypes,
    prompt: str,
    model: Optional[str] = None,
    mask: Optional[str] = None,
    n: Optional[int] = None,
    quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
    response_format: Optional[str] = None,
    size: Optional[str] = None,
    user: Optional[str] = None,
    # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
    # The extra values given here take precedence over values defined on the client or passed to this method.
    extra_headers: Optional[Dict[str, Any]] = None,
    extra_query: Optional[Dict[str, Any]] = None,
    extra_body: Optional[Dict[str, Any]] = None,
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    # LiteLLM specific params,
    custom_llm_provider: Optional[str] = None,
    **kwargs,
) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse]]:
    """
    Maps the image edit functionality, similar to OpenAI's images/edits endpoint.
    """
    local_vars = locals()
    try:
        litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj")  # type: ignore
        litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
        _is_async = kwargs.pop("async_call", False) is True

        # get llm provider logic
        litellm_params = GenericLiteLLMParams(**kwargs)
        model, custom_llm_provider, _, _ = get_llm_provider(
            model=model or DEFAULT_IMAGE_ENDPOINT_MODEL,
            custom_llm_provider=custom_llm_provider,
        )

        # get provider config
        image_edit_provider_config: Optional[
            BaseImageEditConfig
        ] = ProviderConfigManager.get_provider_image_edit_config(
            model=model,
            provider=litellm.LlmProviders(custom_llm_provider),
        )

        if image_edit_provider_config is None:
            raise ValueError(f"image edit is not supported for {custom_llm_provider}")

        local_vars.update(kwargs)
        # Get ImageEditOptionalRequestParams with only valid parameters
        image_edit_optional_params: ImageEditOptionalRequestParams = (
            ImageEditRequestUtils.get_requested_image_edit_optional_param(local_vars)
        )

        # Get optional parameters for the responses API
        image_edit_request_params: Dict = (
            ImageEditRequestUtils.get_optional_params_image_edit(
                model=model,
                image_edit_provider_config=image_edit_provider_config,
                image_edit_optional_params=image_edit_optional_params,
            )
        )

        # Pre Call logging
        litellm_logging_obj.update_environment_variables(
            model=model,
            user=user,
            optional_params=dict(image_edit_request_params),
            litellm_params={
                "litellm_call_id": litellm_call_id,
                **image_edit_request_params,
            },
            custom_llm_provider=custom_llm_provider,
        )

        # Call the handler with _is_async flag instead of directly calling the async handler
        return base_llm_http_handler.image_edit_handler(
            model=model,
            image=image,
            prompt=prompt,
            image_edit_provider_config=image_edit_provider_config,
            image_edit_optional_request_params=image_edit_request_params,
            custom_llm_provider=custom_llm_provider,
            litellm_params=litellm_params,
            logging_obj=litellm_logging_obj,
            extra_headers=extra_headers,
            extra_body=extra_body,
            timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
            _is_async=_is_async,
            client=kwargs.get("client"),
        )

    except Exception as e:
        raise litellm.exception_type(
            model=model,
            custom_llm_provider=custom_llm_provider,
            original_exception=e,
            completion_kwargs=local_vars,
            extra_kwargs=kwargs,
        )


@client
async def aimage_edit(
    image: FileTypes,
    model: str,
    prompt: str,
    mask: Optional[str] = None,
    n: Optional[int] = None,
    quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
    response_format: Optional[str] = None,
    size: Optional[str] = None,
    user: Optional[str] = None,
    # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
    # The extra values given here take precedence over values defined on the client or passed to this method.
    extra_headers: Optional[Dict[str, Any]] = None,
    extra_query: Optional[Dict[str, Any]] = None,
    extra_body: Optional[Dict[str, Any]] = None,
    timeout: Optional[Union[float, httpx.Timeout]] = None,
    # LiteLLM specific params,
    custom_llm_provider: Optional[str] = None,
    **kwargs,
) -> ImageResponse:
    """
    Asynchronously calls the `image_edit` function with the given arguments and keyword arguments.

    Parameters:
    - `args` (tuple): Positional arguments to be passed to the `image_edit` function.
    - `kwargs` (dict): Keyword arguments to be passed to the `image_edit` function.

    Returns:
    - `response` (Any): The response returned by the `image_edit` function.
    """
    local_vars = locals()
    try:
        loop = asyncio.get_event_loop()
        kwargs["async_call"] = True

        # get custom llm provider so we can use this for mapping exceptions
        if custom_llm_provider is None:
            _, custom_llm_provider, _, _ = litellm.get_llm_provider(
                model=model, api_base=local_vars.get("base_url", None)
            )

        func = partial(
            image_edit,
            image=image,
            prompt=prompt,
            mask=mask,
            model=model,
            n=n,
            quality=quality,
            response_format=response_format,
            size=size,
            user=user,
            timeout=timeout,
            custom_llm_provider=custom_llm_provider,
            **kwargs,
        )

        ctx = contextvars.copy_context()
        func_with_context = partial(ctx.run, func)
        init_response = await loop.run_in_executor(None, func_with_context)

        if asyncio.iscoroutine(init_response):
            response = await init_response
        else:
            response = init_response

        return response
    except Exception as e:
        raise litellm.exception_type(
            model=model,
            custom_llm_provider=custom_llm_provider,
            original_exception=e,
            completion_kwargs=local_vars,
            extra_kwargs=kwargs,
        )
