import re
from datetime import datetime, timedelta
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse

import numpy as np
import pandas as pd
from django.conf import settings
from rest_framework import status
from rest_framework.exceptions import ValidationError
from rest_framework.pagination import PageNumberPagination
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.tokens import RefreshToken
from bot.utils.scrape_details import extract_product_info

from .models import ProductScrapeTask, User
from .serializers import ProductScrapeTaskSerializer, UserSerializer
from .tasks import (run_1688_scrapper, run_alibaba_scraper,
                    run_made_in_china_scraper, scrape_images)


class StandardPagination(PageNumberPagination):
    """
    Global pagination class
    """
    page_size = 25
    page_size_query_param = 'limit'
    max_page_size = 100
    page_query_param = 'current_page'

    def get_paginated_response(self, data):
        if not hasattr(self, 'request'):
            return Response({
                'success':False,
                'message':'Internal Server Error',
                'data':[]
            }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
        
        return Response({
            'success':True,
            'message':'Data fetched successfully',
            'data':data,
            'pagination':{
                'current_page':self.page.number,
                'limit':self.get_page_size(self.request),
                'count':self.page.paginator.count,
                'total_pages':self.page.paginator.num_pages
            }
        }, status=status.HTTP_200_OK)

class ScraperView(APIView):
    permission_classes = [IsAuthenticated]
    authentication_classes = [JWTAuthentication]

    def post(self, request):
        try:
            source = request.data.get('source')
            if not source:
                return Response({"success": False, "message": "Source is required."}, status=400)

            if source.lower() == "alibaba":
                return self.scrape_alibaba(request)
            
            source=source.replace(" ", "")
            if source.lower() == "madeinchina":
                return self.scrape_made_in_china(request)
            
            return Response({"success": False, "message": "Unsupported source."}, status=400)

        except Exception as e:
            return Response({"success": False, "message": str(e)}, status=500)
        
    def scrape_alibaba(self, request):
        try:
            # Build query params
            params = {}
            filters = request.data  # get the full request body for additional params
            search_name = filters.get("query")
            if not search_name:
                return Response({"success":False, "message": "Missing query in request body."}, status=status.HTTP_400_BAD_REQUEST)

            source = filters.get('source')
            url = "https://www.alibaba.com/trade/search?"

            # Keyword Search
            if filters.get("query"):
                params["SearchText"] = filters["query"]
            
            if filters.get("num_products"):
                num_products = int(filters["num_products"])
            else:
                num_products = 48

            if filters.get('skip'):
                skip_products = int(filters['skip'])
            else:
                skip_products = 0

            if int(num_products) >= 1000:
                return Response({
                    'success':False,
                    'message':'Maximum products length is 1000'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            num_pages=num_products//48
            num_pages+=1
            # remove_products=(num_products%48)-48
            products_per_page = 48
            # Calculate which page to start from
            start_page = (skip_products // products_per_page) + 1
            # How many products to skip inside the first page
            skip_in_page = skip_products % products_per_page
            # (We might fetch an extra page if skip_in_page > 0)
            num_pages = ((num_products + skip_in_page) + products_per_page - 1) // products_per_page

            # Trade assurance
            if filters.get("trade_assurance"):
                params["ta"] = "y"

            # Verified Supplier
            if filters.get("verified_supplier"):
                params["assessmentCompany"] = "true"

            # Verified Pro
            if filters.get("verified_pro"):
                params["verifiedPro"] = "1"

            # Merge by Supplier
            if filters.get("merge_by_supplier"):
                params["mergeResult"] = "true"

            # Review score
            review_score = filters.get("review_score")
            if review_score in ["4", "4.5", "5"]:
                params["reviewScore"] = str(review_score)

            # Paid Sample
            if filters.get("paid_sample"):
                params["freeSample"] = "1"

            # Price Range
            min_price = filters.get("min_price")
            if min_price:
                params["pricef"] = str(min_price)

            max_price = filters.get("max_price")
            if max_price:
                params["pricet"] = str(max_price)

            # MOQ
            moq = filters.get("moq")
            if moq:
                params["moqt"] = f"MOQT{moq}"

            # Supplier Country/Region
            country = filters.get("supplier_country_region")
            if country:
                params["country"] = country

            # Export Country
            export_countries = filters.get("exported_to")
            if export_countries:
                if isinstance(export_countries, list):
                    joined = ",".join(export_countries)  # e.g., "US,UK"
                    params["exportCountry"] = quote(joined, safe=',')  # will give "US%2CUK"
                else:
                    params["exportCountry"] = quote(export_countries, safe=',')

            # Management Certificate
            mgmt_cert = filters.get("management_certificate")
            if mgmt_cert:
                if isinstance(mgmt_cert, list):
                    joined = ",".join(mgmt_cert)
                    params["companyAuthTag"] = quote(joined, safe=',') 
                else:
                    params["companyAuthTag"] = quote(mgmt_cert, safe=',')

            # Product Certificate
            prod_cert = filters.get("product_certificate")
            if prod_cert:
                if isinstance(prod_cert, list):
                    joined = ",".join(prod_cert)
                    params["productAuthTag"] = quote(joined,safe=",")
                else:
                    params["productAuthTag"] = quote(prod_cert,safe=',')

            # Add new params to original URL
            parsed = urlparse(url)
            original_query = parse_qs(parsed.query)
            original_query.update(params)

            # Flatten and encode query
            flat_query = {k: v if isinstance(v, str) else v[0] for k, v in original_query.items()}
            new_query = urlencode(flat_query, doseq=True)
            final_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', new_query, ''))
            
            scrape_task = ProductScrapeTask.objects.create(
                user=request.user,
                search_params=filters, 
                source=source,
                status="IN_PROGRESS"
            )
            
            # Launch background task
            run_alibaba_scraper.delay(final_url, search_name, scrape_task.id, request.user.id, num_pages, start_page, num_products, skip_in_page)
            # run_alibaba_scraper(final_url, search_name, scrape_task.id, request.user.id, num_pages, remove_products, num_products)
# b0184f54-3e30-4d8b-bdd9-d2ffdbe15183

            return Response({
                "success":True, 
                "message": f"Scrapping Started Successfully",
                "data":ProductScrapeTaskSerializer(scrape_task).data
            }, status=status.HTTP_200_OK)
        
        except Exception as e:
            import traceback
            print("Exception occurred:      ")
            traceback.print_exc()  # This prints the full traceback to the console
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
        
    def scrape_made_in_china(self, request):
        try:
            # Build query params
            filters = request.data  # get the full request body for additional params
            search_name = filters.get("query")
            if not search_name:
                return Response({"success":False, "message": "Missing query in request body."}, status=status.HTTP_400_BAD_REQUEST)

            source = filters.get('source')
            
            if filters.get("num_products"):
                num_products = int(filters["num_products"])
            else:
                num_products = 31

            if filters.get('skip'):
                skip_products = int(filters['skip'])
            else:
                skip_products = 0

            if int(num_products) >= 1000:
                return Response({
                    'success':False,
                    'message':'Maximum products length is 1000'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            # Calculate pages needed (MadeInChina shows about 20 products per page)
            # products_per_page = 31
            # num_pages = ((num_products + skip_products) + products_per_page - 1) // products_per_page
            # start_page = (skip_products // products_per_page) + 1
            # skip_in_page = skip_products % products_per_page
            skip_in_page = skip_products
            # Create scrape task
            scrape_task = ProductScrapeTask.objects.create(
                user=request.user,
                search_params=filters, 
                source=source,
                status="IN_PROGRESS"
            )
            
            # # Launch background task
            # print("----------number of products",num_products)
            # print("----------number of num_pages",num_pages)
            # print("----------number of start_page",start_page)
            # print("----------number of skip_in_page",skip_in_page)
           
            run_made_in_china_scraper.delay(
                search_name, 
                scrape_task.id, 
                request.user.id, 
                # num_pages, 
                # start_page, 
                num_products, 
                skip_in_page,
                filters
            )

            return Response({
                "success":True, 
                "message": f"Scrapping Started Successfully",
                "data":ProductScrapeTaskSerializer(scrape_task).data
            }, status=status.HTTP_200_OK)
            
        except Exception as e:
            import traceback
            print("Exception occurred:")
            traceback.print_exc()
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
        
    def scrape_1688(self, request):
        try:
            # Build query params
            filters = request.data  # get the full request body for additional params
            search_name = filters.get("query")
            if not search_name:
                return Response({"success":False, "message": "Missing query in request body."}, status=status.HTTP_400_BAD_REQUEST)

            source = filters.get('source')
            
            if filters.get("num_products"):
                num_products = int(filters["num_products"])
            else:
                num_products = 31

            if filters.get('skip'):
                skip_products = int(filters['skip'])
            else:
                skip_products = 0

            if int(num_products) >= 1000:
                return Response({
                    'success':False,
                    'message':'Maximum products length is 1000'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            # Calculate pages needed (MadeInChina shows about 20 products per page)
            # products_per_page = 31
            # num_pages = ((num_products + skip_products) + products_per_page - 1) // products_per_page
            # start_page = (skip_products // products_per_page) + 1
            # skip_in_page = skip_products % products_per_page
            skip_in_page = skip_products
            # Create scrape task
            scrape_task = ProductScrapeTask.objects.create(
                user=request.user,
                search_params=filters, 
                source=source,
                status="IN_PROGRESS"
            )
            
            # # Launch background task
            # print("----------number of products",num_products)
            # print("----------number of num_pages",num_pages)
            # print("----------number of start_page",start_page)
            # print("----------number of skip_in_page",skip_in_page)
           
            run_1688_scrapper.delay(
                search_name, 
                scrape_task.id, 
                request.user.id,
                # num_pages, 
                # start_page, 
                num_products, 
                skip_in_page,
                filters
            )

            return Response({
                "success":True, 
                "message": f"Scrapping Started Successfully",
                "data":ProductScrapeTaskSerializer(scrape_task).data
            }, status=status.HTTP_200_OK)
            
        except Exception as e:
            import traceback
            print("Exception occurred:")
            traceback.print_exc()
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)


class AlibabaScrapperView(APIView):
    permission_classes = [IsAuthenticated]
    authentication_classes = [JWTAuthentication]
    
    def post(self, request):
        try:
            # Build query params
            params = {}
            filters = request.data  # get the full request body for additional params
            
            search_name = filters.get("query")
            print(filters.get("source"))
            if not search_name:
                return Response({"success":False, "message": "Missing query in request body."}, status=status.HTTP_400_BAD_REQUEST)

            source = filters.get('source')
            if source == '1688':
                url = "https://s.1688.com/selloffer/offer_search.htm?keywords="
            if source == 'MADE IN CHINA':
                url = "https://s.1688.com/selloffer/offer_search.htm?keywords="
            else:
                url = "https://www.alibaba.com/trade/search?"

            # Keyword Search
            if filters.get("query"):
                params["SearchText"] = filters["query"]
            
            if filters.get("num_products"):
                num_products = int(filters["num_products"])
            else:
                num_products = 48

            if filters.get('skip'):
                skip_products = int(filters['skip'])
            else:
                skip_products = 0

            if int(num_products) >= 1000:
                return Response({
                    'success':False,
                    'message':'Maximum products length is 1000'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            num_pages=num_products//48
            num_pages+=1
            # remove_products=(num_products%48)-48
            products_per_page = 48
            # Calculate which page to start from
            start_page = (skip_products // products_per_page) + 1
            # How many products to skip inside the first page
            skip_in_page = skip_products % products_per_page
            # (We might fetch an extra page if skip_in_page > 0)
            num_pages = ((num_products + skip_in_page) + products_per_page - 1) // products_per_page

            # Trade assurance
            if filters.get("trade_assurance"):
                params["ta"] = "y"

            # Verified Supplier
            if filters.get("verified_supplier"):
                params["assessmentCompany"] = "true"

            # Verified Pro
            if filters.get("verified_pro"):
                params["verifiedPro"] = "1"

            # Merge by Supplier
            if filters.get("merge_by_supplier"):
                params["mergeResult"] = "true"

            # Review score
            review_score = filters.get("review_score")
            if review_score in ["4", "4.5", "5"]:
                params["reviewScore"] = str(review_score)

            # Paid Sample
            if filters.get("paid_sample"):
                params["freeSample"] = "1"

            # Price Range
            min_price = filters.get("min_price")
            if min_price:
                params["pricef"] = str(min_price)

            max_price = filters.get("max_price")
            if max_price:
                params["pricet"] = str(max_price)

            # MOQ
            moq = filters.get("moq")
            if moq:
                params["moqt"] = f"MOQT{moq}"

            # Supplier Country/Region
            country = filters.get("supplier_country_region")
            if country:
                params["country"] = country

            # Export Country
            export_countries = filters.get("exported_to")
            if export_countries:
                if isinstance(export_countries, list):
                    joined = ",".join(export_countries)  # e.g., "US,UK"
                    params["exportCountry"] = quote(joined, safe=',')  # will give "US%2CUK"
                else:
                    params["exportCountry"] = quote(export_countries, safe=',')

            # Management Certificate
            mgmt_cert = filters.get("management_certificate")
            if mgmt_cert:
                if isinstance(mgmt_cert, list):
                    joined = ",".join(mgmt_cert)
                    params["companyAuthTag"] = quote(joined, safe=',') 
                else:
                    params["companyAuthTag"] = quote(mgmt_cert, safe=',')

            # Product Certificate
            prod_cert = filters.get("product_certificate")
            if prod_cert:
                if isinstance(prod_cert, list):
                    joined = ",".join(prod_cert)
                    params["productAuthTag"] = quote(joined,safe=",")
                else:
                    params["productAuthTag"] = quote(prod_cert,safe=',')

            # Add new params to original URL
            parsed = urlparse(url)
            original_query = parse_qs(parsed.query)
            original_query.update(params)

            # Flatten and encode query
            flat_query = {k: v if isinstance(v, str) else v[0] for k, v in original_query.items()}
            new_query = urlencode(flat_query, doseq=True)
            final_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', new_query, ''))
            
            scrape_task = ProductScrapeTask.objects.create(
                user=request.user,
                search_params=filters, 
                source=source,
                status="IN_PROGRESS"
            )
            
            # Launch background task
            run_alibaba_scraper.delay(final_url, search_name, scrape_task.id, request.user.id, num_pages, start_page, num_products, skip_in_page)
            # run_alibaba_scraper(final_url, search_name, scrape_task.id, request.user.id, num_pages, remove_products, num_products)
# b0184f54-3e30-4d8b-bdd9-d2ffdbe15183

            return Response({
                "success":True, 
                "message": f"Scrapping Started Successfully",
                "data":ProductScrapeTaskSerializer(scrape_task).data
            }, status=status.HTTP_200_OK)
        
        except Exception as e:
            import traceback
            print("Exception occurred:")
            traceback.print_exc()  # This prints the full traceback to the console
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

class LoginAPIView(APIView):
    permission_classes = [AllowAny]
    authentication_classes = []

    def post(self, request):
        email = request.data.get("email")
        password = request.data.get("password")

        if not email or not password:
            return Response({"success":False, "message": "Email and password are required."}, status=status.HTTP_400_BAD_REQUEST)

        try:
            user = User.objects.get(email=email)
            if user.check_password(password):
                # Generate JWT tokens
                refresh = RefreshToken.for_user(user)
                response = Response({
                    "success": True,
                    "message": "Login successful.",
                    "data":UserSerializer(user).data
                }, status=status.HTTP_200_OK)
                # set token in cookie
                expiry = datetime.utcnow() + timedelta(days=2)

                response.set_cookie(
                    key='access_token',
                    value=str(refresh.access_token),
                    expires=expiry,
                    httponly=False,
                    secure=False,
                    samesite='Lax'
                )

                # For refresh token
                refresh_expiry = datetime.utcnow() + timedelta(days=7)

                response.set_cookie(
                    key='refresh_token',
                    value=str(refresh),
                    expires=refresh_expiry,
                    httponly=False,
                    secure=False,
                    samesite='Lax'
                )
                return response
            else:
                return Response({"success":False, "message": "Invalid credentials."}, status=status.HTTP_401_UNAUTHORIZED)
        except User.DoesNotExist:
            return Response({"success":False, "message": "User not found."}, status=status.HTTP_404_NOT_FOUND)
        except Exception as e:
            import traceback
            print("Exception occurred during login:")
            traceback.print_exc()
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

class VerifyUserAPIView(APIView):
    authentication_classes = [JWTAuthentication]
    permission_classes = [IsAuthenticated]

    def get(self, request):
        user = request.user
        if user.is_authenticated and user.is_staff:
            return Response({
                'success':True,
                'message':'User is authenticated',
                'data':UserSerializer(user).data
            }, status=status.HTTP_200_OK)
        return Response({
            'success':False,
            'message':'User is not authenticated'
        }, status=status.HTTP_401_UNAUTHORIZED)
        

class ProductScrapeTaskViewSet(ViewSet):
    authentication_classes = [JWTAuthentication]
    permission_classes = [IsAuthenticated]
    pagination_class = StandardPagination

    def list(self, request):
        user = request.user
        data = ProductScrapeTask.objects.filter(user=user).order_by('-started_at')
        paginator = self.pagination_class()
        page = None
        try:
            page = paginator.paginate_queryset(data, request)
        except Exception:
            return Response({
                'success':False,
                'message':'Page not found',
                'data':[]
            }, status=status.HTTP_200_OK)
        
        if page is not None:
            serialized_data = ProductScrapeTaskSerializer(page, many=True).data
            return paginator.get_paginated_response(serialized_data)
        
        serialized_data = ProductScrapeTaskSerializer(data, many=True).data
        return paginator.get_paginated_response(serialized_data)

        # return Response({
        #     'success':True,
        #     'message':'Data fetched successfully',
        #     'data':[]
        # }, status=status.HTTP_200_OK)

    def retrieve(self, request, pk=None):
        try:
            if not pk:
                return Response({
                    'success':False,
                    'message':'Task id is Required'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            data = ProductScrapeTask.objects.get(pk=pk)
            products = []
            pagination = {}
            if data.result_file and data.result_file.url:
                file_path = data.result_file.path
                df = pd.read_csv(file_path)
                df = df.replace({np.nan: None, np.inf: None, -np.inf: None})
                df.columns = [self.format_header(col) for col in df.columns]
                if 'image_url' in df.columns:
                    df['image_url'] = df['image_url'].apply(
                        lambda x: x.split(',') if isinstance(x, str) else x
                    )

                # --- Search filter ---
                search_query = request.query_params.get('query')
                if search_query:
                    search_query = search_query.strip().lower()
                    # Assuming 'title' and 'description' columns exist (adjust to your columns)
                    mask_title = df['title'].fillna('').str.lower().str.contains(search_query)
                    mask_description = df['description'].fillna('').str.lower().str.contains(search_query)
                    df = df[mask_title | mask_description]

                # --- Filter by Price ----
                min_price = request.query_params.get('min_price')
                max_price = request.query_params.get('max_price')

                if 'price' in df.columns:
                    if min_price is not None:
                        try:
                            min_price_val = float(min_price)
                            df = df[df['price'].astype(float) >= min_price_val]
                        except ValueError:
                            raise ValidationError({'min_price': 'min_price must be a valid number'})
                    if max_price is not None:
                        try:
                            max_price_val = float(max_price)
                            df = df[df['price'].astype(float) <= max_price_val]
                        except ValueError:
                            raise ValidationError({'max_price': 'max_price must be a valid number'})

                # --- Pagination ---
                page_number = request.query_params.get('current_page', 1)
                page_size = request.query_params.get('limit', 20)

                try:
                    page_number = int(page_number)
                    page_size = int(page_size)
                    assert page_number > 0 and page_size > 0
                except (ValueError, AssertionError):
                    raise ValidationError({'success':False, 'message': 'page and limit must be positive integers'})

                total_products = len(df)
                start_idx = (page_number - 1) * page_size
                end_idx = start_idx + page_size

                paged_df = df.iloc[start_idx:end_idx]
                products = paged_df.to_dict(orient='records')

                # Metadata for pagination
                pagination = {
                    'current_page': page_number,
                    'limit': page_size,
                    'count': total_products,
                    'total_pages': (total_products + page_size - 1) // page_size,
                }


            return Response({
                'success':True,
                'message':'Date fetched successfully',
                'data':{
                    'object': ProductScrapeTaskSerializer(data).data,
                    'products': products,
                    'pagination': pagination
                }
            }, status=status.HTTP_200_OK)

        except ProductScrapeTask.DoesNotExist:
            return Response({
                'success':False,
                'message':'Task not found'
            })
    
    def destroy(self, request, pk=None):
        try:
            if not request.user.is_superuser:
                return Response({
                    'success':False,
                    'message':'Permission Denied'
                }, status=status.HTTP_403_FORBIDDEN)

            if not pk:
                return Response({
                    'success':False,
                    'message':'Task id is Required'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            task = ProductScrapeTask.objects.get(pk=pk, user=request.user)
            task.delete()
            return Response({
                'success':True,
                'message':'Task deleted successfully'
            }, status=status.HTTP_200_OK)
        
        except ProductScrapeTask.DoesNotExist:
            return Response({
                'success':False,
                'message':'Task not found'
            }, status=status.HTTP_404_NOT_FOUND)
        except Exception as e:
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

    def format_header(self, header):
        # Convert header to lowercase, replace non-alphanumeric with underscore
        return re.sub(r'\W+', '_', header.strip().lower())

class LogoutAPIView(APIView):
    permission_classes = [AllowAny]
    authentication_classes = []

    def get(self, request):
        try:
            response = Response({"success":True, "message": "Logout successful."}, status=status.HTTP_200_OK)
            response.delete_cookie('access_token')
            response.delete_cookie('refresh_token')
            return response
        except Exception as e:
            import traceback
            print("Exception occurred during logout:")
            traceback.print_exc()
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)


        
class ProductScrapperDeleteAPIView(APIView):
    permission_classes = [IsAuthenticated]
    authentication_classes = [JWTAuthentication]

    def delete(self, request, pk=None):
        try:
            if not pk:
                return Response({
                    'success':False,
                    'message':'Task id is Required'
                }, status=status.HTTP_400_BAD_REQUEST)
            
            task = ProductScrapeTask.objects.get(pk=pk, user=request.user)
            task.delete()
            return Response({
                'success':True,
                'message':'Task deleted successfully'
            }, status=status.HTTP_200_OK)
        
        except ProductScrapeTask.DoesNotExist:
            return Response({
                'success':False,
                'message':'Task not found'
            }, status=status.HTTP_404_NOT_FOUND)
        except Exception as e:
            return Response({"success":False, "message": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

import uuid

def scrape_images_func(task_id):
        # task_id = "cc53dd55-54b1-49b0-b2ba-1d1c288d0055"
        try:
            task = ProductScrapeTask.objects.get(id=task_id)
        except ProductScrapeTask.DoesNotExist:
            print("Task not found")
            return
        df = pd.read_csv(task.result_file)
        prod_urls = []
        images_tasks = []
        for row in df.itertuples(index=False):
            url = row.URL
            if url:
                prod_urls.append(url)
                result = extract_product_info(url)
                images = result.get('product', {}).get('images')
                product_id = row.ID
                images_tasks.append({
                    'image_urls':images,
                    'product_id':product_id
                })
        scrape_images(task_id, images_tasks)

# scrape_images_func()
from bot.china.made_in_china import MadeInChinaScraper

EXECUTABLE_PATH=settings.EXECUTABLE_PATH
from playwright.async_api import async_playwright
from asgiref.sync import sync_to_async

async def scrape_china_images(task_id):
    try:
        task = await sync_to_async(ProductScrapeTask.objects.get)(id=task_id)
    except ProductScrapeTask.DoesNotExist:
        print("Task not found")
        return
    df = pd.read_csv(task.result_file)
    prod_urls = []
    images_tasks = []

    china = MadeInChinaScraper()

    async with async_playwright() as p:
        browser = await p.chromium.launch(
            headless=True,
            args=[
                "--disable-blink-features=AutomationControlled",
                "--start-maximized"
            ],
            slow_mo=100,
            executable_path=EXECUTABLE_PATH
        )
        
        context = await browser.new_context(
            user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
                    "AppleWebKit/537.36 (KHTML, like Gecko) "
                    "Chrome/115.0.0.0 Safari/537.36",
            viewport={"width": 1366, "height": 768},
            locale="en-US"
        )
        
        page = await context.new_page()

        for row in df.itertuples(index=False):
            url = row.URL
            if url:
                product_id = row.ID
                product_details = await china.scrape_product_details(page, url)
                images = product_details.get('images', [])
                
                images_tasks.append({
                    'image_urls': images,
                    'product_id': product_id
                })
        scrape_images(task_id, images_tasks)


import csv
import os
import tempfile
import shutil
import time


def fix_image_url_header(task_id):
    """
    Updates the CSV header for a ProductScrapeTask.
    If 'Image URLs' exists, rename it to 'Image URL'.
    """
    try:
        scrape_task = ProductScrapeTask.objects.get(id=task_id)
    except ProductScrapeTask.DoesNotExist:
        print(f"Error: Task with id {task_id} not found")
        return

    if not scrape_task.result_file:
        print(f"No file found for task {task_id}")
        time.sleep(10)
        try:
            scrape_task = ProductScrapeTask.objects.get(id=task_id)
        except ProductScrapeTask.DoesNotExist:
            print(f"Error: Task with id {task_id} not found after waiting")
            return

    csv_path = scrape_task.result_file.path

    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV path does not exist: {csv_path}")

    # Create a temporary file for safe editing
    temp_fd, temp_path = tempfile.mkstemp()
    os.close(temp_fd)

    try:
        with open(csv_path, "r", encoding="utf-8", newline="") as infile, \
             open(temp_path, "w", encoding="utf-8", newline="") as outfile:

            reader = csv.reader(infile)
            writer = csv.writer(outfile)

            # Read and update header
            header = next(reader)
            updated_header = ["Image URL" if h.strip().lower() == "image urls" else h for h in header]
            writer.writerow(updated_header)

            # Write the remaining rows unchanged
            for row in reader:
                writer.writerow(row)

        # Replace the original file with updated one
        shutil.move(temp_path, csv_path)
        print(f"Task {task_id}: Header updated successfully!")

    except Exception as e:
        print(f"Error while updating header for task {task_id}: {e}")
        if os.path.exists(temp_path):
            os.remove(temp_path)
    finally:
        if os.path.exists(temp_path):
            try:
                os.remove(temp_path)
            except:
                pass
