from rest_framework.views import exception_handler
from rest_framework.response import Response
from rest_framework import status
from rest_framework_simplejwt.exceptions import InvalidToken
from rest_framework.exceptions import ValidationError
import logging
import traceback
from django.conf import settings


logger = logging.getLogger(__name__)

def custom_exception_handler(exc, context):
    """
    Custom exception handler for DRF that returns uniform error responses.
    Includes traceback logging and optional detailed error in DEBUG mode.
    """

    # Try DRF's default handler first
    response = exception_handler(exc, context)

    # Customize SimpleJWT error responses
    if isinstance(exc, InvalidToken):
        response.data = {
            "code": "token_not_valid",
            "success": False,
            "message": "Your session has expired. Please login again.",
            "details": exc.detail
        }

    # Flatten DRF ValidationError
    if isinstance(exc, ValidationError):
        message = "Validation failed."
        if isinstance(response.data, dict):
            for field, errors in response.data.items():
                if isinstance(errors, list):
                    message = f"{field}: {errors[0]}"
                    break
        return Response({'success': False, 'message': message}, status=response.status_code)

    if response is not None:
        # If DRF handled it, add consistent fields
        response.data['success'] = False
        if 'detail' in response.data:
            response.data['message'] = response.data.pop('detail')
        return response

    # Unhandled exceptions (500)
    view = context.get('view', None)
    request = context.get('request', None)

    # Get user info
    
    # Log traceback
    tb_formatted = ''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))
    
    # Prepare response
    error_detail = {
        "success": False,
        "message": str(exc) if not settings.DEBUG else "Internal Server Error"
    }

    if settings.DEBUG:
        error_detail["traceback"] = tb_formatted

    print(error_detail)
    return Response(error_detail, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
