from fastapi import FastAPI, APIRouter, HTTPException, Depends, status, UploadFile, File
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from dotenv import load_dotenv
from starlette.middleware.cors import CORSMiddleware
from motor.motor_asyncio import AsyncIOMotorClient
import os
import logging
import io
import csv
import re
from pathlib import Path
from collections import defaultdict
from difflib import SequenceMatcher
from pydantic import BaseModel, Field, ConfigDict, EmailStr
from typing import Any, Dict, List, Optional
import uuid
from datetime import datetime, timezone, timedelta
from decimal import Decimal, ROUND_HALF_UP
import bcrypt
import jwt
from contextlib import asynccontextmanager
from pypdf import PdfReader
from fastapi import FastAPI
from fastapi.security import HTTPBearer
from fastapi.routing import APIRouter
from fastapi.responses import JSONResponse, Response

# Load .env ONLY if it exists (local dev)
ROOT_DIR = Path(__file__).parent
load_dotenv(ROOT_DIR / ".env")

# MongoDB connection
MONGO_URL = os.getenv("MONGO_URL")
DB_NAME = os.getenv("DB_NAME")

if not MONGO_URL:
    raise RuntimeError("MONGO_URL is not set")

if not DB_NAME:
    raise RuntimeError("DB_NAME is not set")

client = AsyncIOMotorClient(MONGO_URL)
db = client[DB_NAME]

# JWT Configuration
SECRET_KEY = os.getenv("JWT_SECRET", "CHANGE_THIS_IN_PRODUCTION")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7  # 7 days

app = FastAPI()
api_router = APIRouter(prefix="/api")
security = HTTPBearer()

DEFAULT_INVOICE_TERMS = [
    "Thanks for doing business with us!",
    "No Discount Request Will be entertained during Final Payment.",
    "Delayed payment beyond 21 days may attract recovery action as per agreement terms.",
    "Subsidy, if applicable, is credited directly to the customer's bank account.",
    "Our responsibility is limited to the agreed work scope only.",
    "Workmanship warranty is valid for 2 years from commissioning.",
    "Component warranties are subject to manufacturer warranty terms.",
]


# ============= MODELS =============
class UserRegister(BaseModel):
    email: EmailStr
    password: str
    name: str
    company_name: Optional[str] = None

class UserLogin(BaseModel):
    email: EmailStr
    password: str

class UserResponse(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str
    email: str
    name: str
    company_name: Optional[str] = None

class TokenResponse(BaseModel):
    access_token: str
    token_type: str
    user: UserResponse

class Customer(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str
    name: str
    email: Optional[str] = None
    phone: str
    address: Optional[str] = None
    billing_address: Optional[str] = None
    shipping_address: Optional[str] = None
    gst_number: Optional[str] = None
    customer_type: str = "B2C"
    state: Optional[str] = None
    state_code: Optional[str] = None
    credit_balance: float = 0.0
    created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())

class CustomerCreate(BaseModel):
    name: str
    email: Optional[str] = None
    phone: str
    address: Optional[str] = None
    billing_address: Optional[str] = None
    shipping_address: Optional[str] = None
    gst_number: Optional[str] = None
    customer_type: str = "B2C"
    state: Optional[str] = None
    state_code: Optional[str] = None

class Product(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str
    name: str
    description: Optional[str] = None
    price: float
    stock: int = 0
    unit: str = "pcs"
    uqc: str = "PCS"
    product_type: str = "goods"
    hsn_code: Optional[str] = None
    gst_rate: float = 18.0
    cess_rate: float = 0.0
    tax_category: str = "taxable"
    low_stock_threshold: int = 5
    created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())

class ProductCreate(BaseModel):
    name: str
    description: Optional[str] = None
    price: float
    stock: int = 0
    unit: str = "pcs"
    uqc: str = "PCS"
    product_type: str = "goods"
    hsn_code: Optional[str] = None
    gst_rate: float = 18.0
    cess_rate: float = 0.0
    tax_category: str = "taxable"
    low_stock_threshold: int = 5

class InvoiceItem(BaseModel):
    product_id: str
    product_name: str
    quantity: float
    price: float
    gst_rate: float
    total: float
    hsn_code: Optional[str] = None
    unit: str = "pcs"
    uqc: str = "PCS"
    taxable_value: float = 0.0
    cgst_rate: float = 0.0
    cgst_amount: float = 0.0
    sgst_rate: float = 0.0
    sgst_amount: float = 0.0
    igst_rate: float = 0.0
    igst_amount: float = 0.0
    cess_rate: float = 0.0
    cess_amount: float = 0.0
    total_amount: float = 0.0

class Invoice(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str
    invoice_number: str
    customer_id: str
    customer_name: str
    items: List[InvoiceItem]
    invoice_type: str = "tax_invoice"
    place_of_supply: Optional[str] = None
    place_state_code: Optional[str] = None
    reverse_charge: bool = False
    tax_type: str = "intra_state"
    subtotal: float
    gst_amount: float
    cgst_amount: float = 0.0
    sgst_amount: float = 0.0
    igst_amount: float = 0.0
    cess_amount: float = 0.0
    round_off_amount: float = 0.0
    total_amount: float
    payment_status: str = "pending"  # pending, paid, partial, proforma
    payment_method: Optional[str] = None
    notes: Optional[str] = None
    created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
    paid_amount: float = 0.0
    balance_amount: float = 0.0

class InvoiceCreate(BaseModel):
    customer_id: str
    invoice_date: Optional[str] = None
    items: List[InvoiceItem]
    invoice_type: str = "tax_invoice"
    place_of_supply: Optional[str] = None
    place_state_code: Optional[str] = None
    reverse_charge: bool = False
    payment_method: Optional[str] = None
    notes: Optional[str] = None

class ProformaConvertRequest(BaseModel):
    invoice_type: str = "tax_invoice"
    invoice_date: Optional[str] = None
    payment_method: Optional[str] = None

class Payment(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str
    invoice_id: str
    invoice_number: str
    customer_id: str
    customer_name: str
    amount: float
    payment_method: str
    payment_date: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
    notes: Optional[str] = None

class PaymentCreate(BaseModel):
    invoice_id: str
    amount: float
    payment_method: str
    notes: Optional[str] = None

class CompanySettings(BaseModel):
    model_config = ConfigDict(extra="ignore")
    user_id: str
    company_name: str
    legal_name: Optional[str] = None
    trade_name: Optional[str] = None
    email: Optional[str] = None
    phone: Optional[str] = None
    address: Optional[str] = None
    gst_number: Optional[str] = None
    pan: Optional[str] = None
    state: Optional[str] = None
    state_code: Optional[str] = None
    gst_registration_type: str = "regular"
    bank_name: Optional[str] = None
    bank_account_number: Optional[str] = None
    bank_ifsc: Optional[str] = None
    bank_branch: Optional[str] = None
    signature_url: Optional[str] = None
    stamp_url: Optional[str] = None
    logo_url: Optional[str] = None
    invoice_terms: List[str] = Field(default_factory=lambda: list(DEFAULT_INVOICE_TERMS))
    invoice_footer: str = "This is a Computer Generated Invoice"
    show_bank_details: bool = True
    show_hsn_summary: bool = True
    show_declaration: bool = True
    show_tax_amount_words: bool = True
    show_terms_and_conditions: bool = True

class CompanySettingsUpdate(BaseModel):
    company_name: Optional[str] = None
    legal_name: Optional[str] = None
    trade_name: Optional[str] = None
    email: Optional[str] = None
    phone: Optional[str] = None
    address: Optional[str] = None
    gst_number: Optional[str] = None
    pan: Optional[str] = None
    state: Optional[str] = None
    state_code: Optional[str] = None
    gst_registration_type: Optional[str] = None
    bank_name: Optional[str] = None
    bank_account_number: Optional[str] = None
    bank_ifsc: Optional[str] = None
    bank_branch: Optional[str] = None
    signature_url: Optional[str] = None
    stamp_url: Optional[str] = None
    logo_url: Optional[str] = None
    invoice_terms: Optional[List[str]] = None
    invoice_footer: Optional[str] = None
    show_bank_details: Optional[bool] = None
    show_hsn_summary: Optional[bool] = None
    show_declaration: Optional[bool] = None
    show_tax_amount_words: Optional[bool] = None
    show_terms_and_conditions: Optional[bool] = None

class InvoiceItemUpdate(BaseModel):
    product_id: str
    product_name: str
    quantity: float
    price: float
    gst_rate: float
    total: float
    hsn_code: Optional[str] = None
    unit: str = "pcs"
    uqc: str = "PCS"
    cess_rate: float = 0.0


class InvoiceUpdate(BaseModel):
    invoice_date: Optional[str] = None
    items: List[InvoiceItemUpdate]
    invoice_type: str = "tax_invoice"
    place_of_supply: Optional[str] = None
    place_state_code: Optional[str] = None
    reverse_charge: bool = False
    notes: Optional[str] = None


class PurchaseBillItem(BaseModel):
    product_id: str
    product_name: str
    quantity: int
    purchase_price: float
    total: float
    hsn_code: Optional[str] = None
    unit: str = "pcs"
    uqc: str = "PCS"


class PurchaseBillItemCreate(BaseModel):
    product_id: str
    quantity: int
    purchase_price: float


class PurchaseBill(BaseModel):
    model_config = ConfigDict(extra="ignore")
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str
    bill_number: str
    supplier_name: str
    supplier_invoice_number: Optional[str] = None
    items: List[PurchaseBillItem]
    subtotal: float
    total_amount: float
    paid_amount: float = 0.0
    balance_amount: float = 0.0
    payment_status: str = "pending"
    payment_method: Optional[str] = None
    notes: Optional[str] = None
    created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())


class PurchaseBillCreate(BaseModel):
    supplier_name: str
    supplier_invoice_number: Optional[str] = None
    bill_date: Optional[str] = None
    items: List[PurchaseBillItemCreate]
    paid_amount: float = 0.0
    payment_method: Optional[str] = None
    notes: Optional[str] = None


class PurchaseBillUpdate(BaseModel):
    supplier_name: str
    supplier_invoice_number: Optional[str] = None
    bill_date: Optional[str] = None
    items: List[PurchaseBillItemCreate]
    paid_amount: float = 0.0
    payment_method: Optional[str] = None
    notes: Optional[str] = None


# ============= AUTH HELPERS =============
def hash_password(password: str) -> str:
    return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')

def verify_password(password: str, hashed: str) -> bool:
    return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8'))

 # ============= Next Invoice No. =============
PROFORMA_INVOICE_TYPE = "proforma_invoice"


def is_proforma_invoice(invoice_type: Optional[str]) -> bool:
    return invoice_type == PROFORMA_INVOICE_TYPE


async def get_next_document_number(user_id: str, invoice_type: str = "tax_invoice") -> str:
    counter_type = "proforma" if is_proforma_invoice(invoice_type) else "invoice"
    prefix = "PI" if is_proforma_invoice(invoice_type) else "INV"
    counter = await db.counters.find_one_and_update(
        {"user_id": user_id, "type": counter_type},
        {"$inc": {"seq": 1}},
        upsert=True,
        return_document=True
    )

    return f"{prefix}-{counter['seq']:05d}"


async def get_next_purchase_bill_number(user_id: str) -> str:
    counter = await db.counters.find_one_and_update(
        {"user_id": user_id, "type": "purchase_bill"},
        {"$inc": {"seq": 1}},
        upsert=True,
        return_document=True
    )
    return f"PB-{counter['seq']:05d}"

def round_money(value: float) -> float:
    return float(Decimal(str(value or 0)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))

def round_to_rupee(value: float) -> float:
    return float(Decimal(str(value or 0)).quantize(Decimal("1"), rounding=ROUND_HALF_UP))

def build_default_company_settings(user: dict) -> Dict[str, Any]:
    company_name = user.get("company_name", user["name"])
    return {
        "user_id": user["id"],
        "company_name": company_name,
        "legal_name": company_name,
        "trade_name": company_name,
        "email": user["email"],
        "phone": None,
        "address": None,
        "gst_number": None,
        "pan": None,
        "state": None,
        "state_code": None,
        "gst_registration_type": "regular",
        "bank_name": None,
        "bank_account_number": None,
        "bank_ifsc": None,
        "bank_branch": None,
        "signature_url": None,
        "stamp_url": None,
        "logo_url": None,
        "invoice_terms": list(DEFAULT_INVOICE_TERMS),
        "invoice_footer": "This is a Computer Generated Invoice",
        "show_bank_details": True,
        "show_hsn_summary": True,
        "show_declaration": True,
        "show_tax_amount_words": True,
        "show_terms_and_conditions": True,
    }

def ensure_company_settings_defaults(settings: Dict[str, Any]) -> Dict[str, Any]:
    merged = {
        "invoice_terms": list(DEFAULT_INVOICE_TERMS),
        "invoice_footer": "This is a Computer Generated Invoice",
        "show_bank_details": True,
        "show_hsn_summary": True,
        "show_declaration": True,
        "show_tax_amount_words": True,
        "show_terms_and_conditions": True,
    }
    merged.update(settings)
    if not merged.get("invoice_terms"):
        merged["invoice_terms"] = list(DEFAULT_INVOICE_TERMS)
    return merged

def collection_to_csv(rows: List[Dict[str, Any]], resource: str) -> str:
    if resource == "customers":
        columns = [
            "id", "name", "email", "phone", "customer_type",
            "gst_number", "state", "state_code", "address",
            "billing_address", "shipping_address", "created_at",
        ]
    elif resource == "products":
        columns = [
            "id", "name", "description", "price", "stock", "low_stock_threshold",
            "unit", "uqc", "product_type", "hsn_code", "gst_rate",
            "cess_rate", "tax_category", "created_at",
        ]
    elif resource == "payments":
        columns = [
            "id", "invoice_number", "customer_name", "amount",
            "payment_method", "payment_date", "notes",
        ]
    elif resource == "purchase_bills":
        columns = [
            "id", "bill_number", "supplier_name", "supplier_invoice_number",
            "created_at", "subtotal", "total_amount", "paid_amount",
            "balance_amount", "payment_status", "payment_method", "notes",
        ]
    else:
        columns = [
            "id", "invoice_number", "invoice_type", "customer_name", "created_at",
            "subtotal", "gst_amount", "cess_amount", "round_off_amount",
            "total_amount", "paid_amount", "balance_amount",
            "payment_status", "payment_method", "tax_type",
            "place_of_supply", "place_state_code",
        ]

    buffer = io.StringIO()
    writer = csv.DictWriter(buffer, fieldnames=columns)
    writer.writeheader()
    for row in rows:
        flat_row = {key: row.get(key, "") for key in columns}
        writer.writerow(flat_row)
    return buffer.getvalue()

def parse_report_date(value: Optional[str], end_of_day: bool = False) -> Optional[str]:
    if not value:
        return None
    parsed = datetime.fromisoformat(value)
    if parsed.tzinfo is None:
        parsed = parsed.replace(tzinfo=timezone.utc)
    if end_of_day:
        parsed = parsed.replace(hour=23, minute=59, second=59, microsecond=999999)
    else:
        parsed = parsed.replace(hour=0, minute=0, second=0, microsecond=0)
    return parsed.isoformat()

def get_tax_type(company_settings: Optional[dict], customer: dict, invoice_data: InvoiceCreate) -> str:
    seller_state_code = (company_settings or {}).get("state_code")
    buyer_state_code = invoice_data.place_state_code or customer.get("state_code")
    if seller_state_code and buyer_state_code and seller_state_code == buyer_state_code:
        return "intra_state"
    return "inter_state"

async def build_gst_invoice_totals(
    user_id: str,
    customer: dict,
    invoice_data: InvoiceCreate
) -> Dict[str, Any]:
    company_settings = await db.company_settings.find_one({"user_id": user_id}, {"_id": 0})
    tax_type = get_tax_type(company_settings, customer, invoice_data)

    calculated_items = []
    subtotal = cgst_amount = sgst_amount = igst_amount = cess_amount = 0.0

    for item in invoice_data.items:
        product = await db.products.find_one(
            {"id": item.product_id, "user_id": user_id},
            {"_id": 0}
        )
        taxable_value = round_money(item.price * item.quantity)
        gst_rate = float(item.gst_rate if item.gst_rate is not None else (product or {}).get("gst_rate", 0))
        cess_rate = float(getattr(item, "cess_rate", 0) or (product or {}).get("cess_rate", 0))

        if tax_type == "intra_state":
            cgst_rate = sgst_rate = gst_rate / 2
            igst_rate = 0.0
            item_cgst = round_money(taxable_value * cgst_rate / 100)
            item_sgst = round_money(taxable_value * sgst_rate / 100)
            item_igst = 0.0
        else:
            cgst_rate = sgst_rate = 0.0
            igst_rate = gst_rate
            item_cgst = item_sgst = 0.0
            item_igst = round_money(taxable_value * igst_rate / 100)

        item_cess = round_money(taxable_value * cess_rate / 100)
        line_total = round_money(taxable_value + item_cgst + item_sgst + item_igst + item_cess)

        calculated_items.append(InvoiceItem(
            product_id=item.product_id,
            product_name=item.product_name,
            quantity=item.quantity,
            price=item.price,
            gst_rate=gst_rate,
            total=taxable_value,
            hsn_code=getattr(item, "hsn_code", None) or (product or {}).get("hsn_code"),
            unit=getattr(item, "unit", None) or (product or {}).get("unit", "pcs"),
            uqc=getattr(item, "uqc", None) or (product or {}).get("uqc", "PCS"),
            taxable_value=taxable_value,
            cgst_rate=cgst_rate,
            cgst_amount=item_cgst,
            sgst_rate=sgst_rate,
            sgst_amount=item_sgst,
            igst_rate=igst_rate,
            igst_amount=item_igst,
            cess_rate=cess_rate,
            cess_amount=item_cess,
            total_amount=line_total
        ))

        subtotal += taxable_value
        cgst_amount += item_cgst
        sgst_amount += item_sgst
        igst_amount += item_igst
        cess_amount += item_cess

    subtotal = round_money(subtotal)
    cgst_amount = round_money(cgst_amount)
    sgst_amount = round_money(sgst_amount)
    igst_amount = round_money(igst_amount)
    cess_amount = round_money(cess_amount)
    gst_amount = round_money(cgst_amount + sgst_amount + igst_amount)
    base_total = round_money(subtotal + gst_amount + cess_amount)
    total_amount = round_money(round_to_rupee(base_total))
    round_off_amount = round_money(total_amount - base_total)

    return {
        "items": calculated_items,
        "tax_type": tax_type,
        "subtotal": subtotal,
        "gst_amount": gst_amount,
        "cgst_amount": cgst_amount,
        "sgst_amount": sgst_amount,
        "igst_amount": igst_amount,
        "cess_amount": cess_amount,
        "round_off_amount": round_off_amount,
        "total_amount": total_amount,
        "place_of_supply": invoice_data.place_of_supply or customer.get("state"),
        "place_state_code": invoice_data.place_state_code or customer.get("state_code"),
    }


async def build_purchase_bill_totals(
    user_id: str,
    purchase_bill_data: PurchaseBillCreate | PurchaseBillUpdate
) -> Dict[str, Any]:
    calculated_items: List[PurchaseBillItem] = []
    subtotal = 0.0

    if not purchase_bill_data.items:
        raise HTTPException(status_code=400, detail="Add at least one product")

    for item in purchase_bill_data.items:
        if item.quantity <= 0:
            raise HTTPException(status_code=400, detail="Quantity must be greater than zero")
        if item.purchase_price < 0:
            raise HTTPException(status_code=400, detail="Purchase price cannot be negative")

        product = await db.products.find_one(
            {"id": item.product_id, "user_id": user_id},
            {"_id": 0}
        )
        if not product:
            raise HTTPException(status_code=404, detail="Product not found")

        line_total = round_money(item.quantity * item.purchase_price)
        subtotal += line_total
        calculated_items.append(PurchaseBillItem(
            product_id=item.product_id,
            product_name=product.get("name", ""),
            quantity=item.quantity,
            purchase_price=round_money(item.purchase_price),
            total=line_total,
            hsn_code=product.get("hsn_code"),
            unit=product.get("unit", "pcs"),
            uqc=product.get("uqc", "PCS"),
        ))

    total_amount = round_money(subtotal)
    paid_amount = round_money(purchase_bill_data.paid_amount or 0.0)
    if paid_amount < 0:
        raise HTTPException(status_code=400, detail="Paid amount cannot be negative")
    if paid_amount > total_amount:
        raise HTTPException(status_code=400, detail="Paid amount cannot exceed total amount")
    if paid_amount > 0 and not purchase_bill_data.payment_method:
        raise HTTPException(status_code=400, detail="Select a payment method for paid amount")

    balance_amount = round_money(total_amount - paid_amount)
    if balance_amount <= 0:
        payment_status = "paid"
        balance_amount = 0.0
    elif paid_amount > 0:
        payment_status = "partial"
    else:
        payment_status = "pending"

    return {
        "items": calculated_items,
        "subtotal": round_money(subtotal),
        "total_amount": total_amount,
        "paid_amount": paid_amount,
        "balance_amount": balance_amount,
        "payment_status": payment_status,
        "payment_method": purchase_bill_data.payment_method if paid_amount > 0 else None,
    }


def normalize_lookup_text(value: Optional[str]) -> str:
    return re.sub(r"[^a-z0-9]+", " ", str(value or "").lower()).strip()


def clean_pdf_lines(text: str) -> List[str]:
    lines = []
    for raw_line in str(text or "").splitlines():
        line = re.sub(r"\s+", " ", raw_line).strip()
        if line:
            lines.append(line)
    return lines


def parse_amount_token(token: str) -> float:
    cleaned = str(token or "").replace(",", "").strip()
    return float(cleaned) if cleaned else 0.0


def find_purchase_bill_field(lines: List[str], labels: List[str]) -> str:
    patterns = [
        rf"(?:{'|'.join(labels)})\s*[:\-]?\s*([A-Za-z0-9\/\-. ]{{2,}})"
        for _ in [0]
    ]
    compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in patterns]

    for line in lines:
        for pattern in compiled_patterns:
            match = pattern.search(line)
            if match:
                return match.group(1).strip(" .:-")
    return ""


def infer_supplier_name(lines: List[str]) -> str:
    candidate_patterns = [
        re.compile(r"^(?:supplier|vendor|party|sold by|from)\s*[:\-]?\s*(.+)$", re.IGNORECASE),
        re.compile(r"^(?:m/s\.?|m/s)\s+(.+)$", re.IGNORECASE),
    ]

    for line in lines[:20]:
        for pattern in candidate_patterns:
            match = pattern.search(line)
            if match:
                return match.group(1).strip()

    ignored_keywords = (
        "invoice", "bill", "gst", "tax", "phone", "mobile", "email", "address",
        "qty", "quantity", "rate", "amount", "total", "subtotal", "hsn", "cgst",
        "sgst", "igst", "balance", "purchase", "date"
    )
    for line in lines[:12]:
        normalized = normalize_lookup_text(line)
        if len(normalized) < 4:
            continue
        if any(keyword in normalized for keyword in ignored_keywords):
            continue
        if not re.search(r"[A-Za-z]", line):
            continue
        return line.strip()

    return ""


def normalize_purchase_bill_date(value: Optional[str]) -> Optional[str]:
    if not value:
        return None

    text = str(value).strip()
    if not text:
        return None

    normalized = text.replace(".", "/").replace("-", "/")
    date_formats = [
        "%Y/%m/%d",
        "%d/%m/%Y",
        "%m/%d/%Y",
        "%d/%m/%y",
        "%m/%d/%y",
        "%d %b %Y",
        "%d %B %Y",
        "%b %d %Y",
        "%B %d %Y",
    ]

    for date_format in date_formats:
        try:
            return datetime.strptime(normalized, date_format).strftime("%Y-%m-%d")
        except ValueError:
            continue

    try:
        return datetime.fromisoformat(text).strftime("%Y-%m-%d")
    except ValueError:
        return None


def infer_line_item_from_line(line: str) -> Optional[Dict[str, Any]]:
    if not re.search(r"[A-Za-z]", line):
        return None

    lowered = normalize_lookup_text(line)
    skip_keywords = (
        "invoice", "bill", "subtotal", "total", "amount in words", "grand total",
        "round off", "cgst", "sgst", "igst", "cess", "discount", "tax", "balance",
        "payment", "supplier", "vendor", "phone", "email", "address"
    )
    if any(keyword in lowered for keyword in skip_keywords):
        return None

    numeric_matches = list(re.finditer(r"\d[\d,]*(?:\.\d+)?", line))
    if len(numeric_matches) < 2:
        return None

    hsn_match = re.search(r"\b\d{4,8}\b", line)
    hsn_code = hsn_match.group(0) if hsn_match else None
    numeric_values = [parse_amount_token(match.group(0)) for match in numeric_matches]

    if hsn_code and len(numeric_values) > 3:
        hsn_value = parse_amount_token(hsn_code)
        removed = False
        filtered_values = []
        for value in numeric_values:
            if not removed and value == hsn_value:
                removed = True
                continue
            filtered_values.append(value)
        numeric_values = filtered_values or numeric_values

    if len(numeric_values) >= 4 and numeric_values[0].is_integer() and 0 < numeric_values[0] <= 100:
        numeric_values = numeric_values[1:]

    if len(numeric_values) >= 3:
        quantity, purchase_price, total = numeric_values[-3], numeric_values[-2], numeric_values[-1]
    elif len(numeric_values) == 2:
        quantity, purchase_price = numeric_values[0], numeric_values[1]
        total = round_money(quantity * purchase_price)
    else:
        quantity = 1.0
        purchase_price = numeric_values[0]
        total = round_money(quantity * purchase_price)

    first_numeric_start = numeric_matches[0].start()
    product_name = line[:first_numeric_start].strip(" |:-")
    product_name = re.sub(r"^\d+\s+", "", product_name).strip()
    if not product_name:
        return None

    return {
        "product_name": product_name,
        "hsn_code": hsn_code,
        "quantity": max(int(round(quantity)), 0),
        "purchase_price": round_money(purchase_price),
        "total": round_money(total),
    }


def extract_purchase_items_from_lines(lines: List[str], products: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    extracted_items: List[Dict[str, Any]] = []
    seen_keys = set()

    for line in lines:
        item = infer_line_item_from_line(line)
        if not item or item["quantity"] <= 0:
            continue

        match = find_best_matching_product(item["product_name"], item["hsn_code"], products)
        if match:
            item["product_name"] = match["product"].get("name", item["product_name"])
            item["hsn_code"] = item["hsn_code"] or match["product"].get("hsn_code")

        key = (
            normalize_lookup_text(item["product_name"]),
            item["quantity"],
            item["purchase_price"],
            item["total"],
        )
        if key in seen_keys:
            continue
        seen_keys.add(key)
        extracted_items.append(item)

    return extracted_items


def find_best_matching_product(extracted_name: str, extracted_hsn_code: Optional[str], products: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    normalized_name = normalize_lookup_text(extracted_name)
    normalized_hsn = re.sub(r"\s+", "", str(extracted_hsn_code or ""))
    best_match = None
    best_score = 0.0

    for product in products:
        product_name = product.get("name", "")
        candidate_name = normalize_lookup_text(product_name)
        product_hsn = re.sub(r"\s+", "", str(product.get("hsn_code") or ""))

        score = 0.0
        if normalized_name and candidate_name:
            score = SequenceMatcher(None, normalized_name, candidate_name).ratio()
            if normalized_name == candidate_name:
                score = max(score, 1.0)
            elif normalized_name in candidate_name or candidate_name in normalized_name:
                score = max(score, 0.9)

        if normalized_hsn and product_hsn and normalized_hsn == product_hsn:
            score = max(score, 0.96)

        if score > best_score:
            best_score = score
            best_match = product

    if not best_match or best_score < 0.72:
        return None

    return {
        "product": best_match,
        "score": round(best_score, 3),
    }


async def extract_purchase_bill_pdf_locally(file_bytes: bytes, products: List[Dict[str, Any]]) -> Dict[str, Any]:
    try:
        reader = PdfReader(io.BytesIO(file_bytes))
    except Exception as exc:
        raise HTTPException(status_code=400, detail="Could not open the uploaded PDF") from exc

    pages_text = []
    for page in reader.pages:
        try:
            pages_text.append(page.extract_text() or "")
        except Exception:
            pages_text.append("")

    full_text = "\n".join(pages_text)
    lines = clean_pdf_lines(full_text)
    if not lines:
        raise HTTPException(
            status_code=400,
            detail="This PDF has no readable text. Free auto-load currently works for text PDFs, not scanned image PDFs."
        )

    supplier_invoice_number = find_purchase_bill_field(
        lines,
        ["invoice no", "invoice number", "bill no", "bill number", "voucher no", "voucher number"]
    )
    raw_bill_date = find_purchase_bill_field(lines, ["invoice date", "bill date", "date", "dated"])
    bill_date = normalize_purchase_bill_date(raw_bill_date) or raw_bill_date or None

    return {
        "supplier_name": infer_supplier_name(lines),
        "supplier_invoice_number": supplier_invoice_number,
        "bill_date": bill_date,
        "paid_amount": 0.0,
        "payment_method": "",
        "notes": "",
        "items": extract_purchase_items_from_lines(lines, products),
    }

def create_access_token(data: dict) -> str:
    to_encode = data.copy()
    expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
    try:
        payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
        user_id = payload.get("sub")
        if user_id is None:
            raise HTTPException(status_code=401, detail="Invalid token")
        user = await db.users.find_one({"id": user_id}, {"_id": 0})
        if user is None:
            raise HTTPException(status_code=401, detail="User not found")
        return user
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="Token expired")
    except jwt.JWTError:
        raise HTTPException(status_code=401, detail="Invalid token")
    
   
# ============= BASIC ROUTES =============
@api_router.get("/")
async def root():
    return {"message": "BillFlow API - Billing Software Backend"}

# ============= AUTH ROUTES =============
@api_router.post("/auth/register", response_model=TokenResponse)
async def register(user_data: UserRegister):
    existing = await db.users.find_one({"email": user_data.email}, {"_id": 0})
    if existing:
        raise HTTPException(status_code=400, detail="Email already registered")
    
    user_id = str(uuid.uuid4())
    user = {
    "id": user_id,
    "email": user_data.email,
    "password": hash_password(user_data.password),
    "name": user_data.name,
    "company_name": user_data.company_name,
    "created_at": datetime.now(timezone.utc).isoformat()
}

    await db.users.insert_one(user)
    
    # Create default company settings
    company_settings = build_default_company_settings({
        "id": user_id,
        "name": user_data.name,
        "email": user_data.email,
        "company_name": user_data.company_name,
    })
    await db.company_settings.insert_one(company_settings)
    
    token = create_access_token({"sub": user_id})
    user_response = UserResponse(id=user_id, email=user["email"], name=user["name"], company_name=user.get("company_name"))
    return TokenResponse(access_token=token, token_type="bearer", user=user_response)

@api_router.post("/auth/login", response_model=TokenResponse)
async def login(credentials: UserLogin):
    user = await db.users.find_one({"email": credentials.email}, {"_id": 0})
    if not user or not verify_password(credentials.password, user["password"]):
        raise HTTPException(status_code=401, detail="Invalid email or password")
    
    token = create_access_token({"sub": user["id"]})
    user_response = UserResponse(id=user["id"], email=user["email"], name=user["name"], company_name=user.get("company_name"))
    return TokenResponse(access_token=token, token_type="bearer", user=user_response)

@api_router.get("/auth/me", response_model=UserResponse)
async def get_me(current_user: dict = Depends(get_current_user)):
    return UserResponse(**current_user)

# ============= CUSTOMER ROUTES =============
@api_router.post("/customers", response_model=Customer)
async def create_customer(customer_data: CustomerCreate, current_user: dict = Depends(get_current_user)):
    customer = Customer(user_id=current_user["id"], **customer_data.model_dump())
    await db.customers.insert_one(customer.model_dump())
    return customer

@api_router.get("/customers", response_model=List[Customer])
async def get_customers(current_user: dict = Depends(get_current_user)):
    customers = await db.customers.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(1000)
    return customers

@api_router.get("/customers/{customer_id}", response_model=Customer)
async def get_customer(customer_id: str, current_user: dict = Depends(get_current_user)):
    customer = await db.customers.find_one({"id": customer_id, "user_id": current_user["id"]}, {"_id": 0})
    if not customer:
        raise HTTPException(status_code=404, detail="Customer not found")
    return customer

@api_router.put("/customers/{customer_id}", response_model=Customer)
async def update_customer(customer_id: str, customer_data: CustomerCreate, current_user: dict = Depends(get_current_user)):
    result = await db.customers.update_one(
        {"id": customer_id, "user_id": current_user["id"]},
        {"$set": customer_data.model_dump()}
    )
    if result.matched_count == 0:
        raise HTTPException(status_code=404, detail="Customer not found")
    customer = await db.customers.find_one({"id": customer_id}, {"_id": 0})
    return customer

@api_router.delete("/customers/{customer_id}")
async def delete_customer(customer_id: str, current_user: dict = Depends(get_current_user)):
    result = await db.customers.delete_one({"id": customer_id, "user_id": current_user["id"]})
    if result.deleted_count == 0:
        raise HTTPException(status_code=404, detail="Customer not found")
    return {"message": "Customer deleted successfully"}

# ============= PRODUCT ROUTES =============
@api_router.post("/products", response_model=Product)
async def create_product(product_data: ProductCreate, current_user: dict = Depends(get_current_user)):
    product = Product(user_id=current_user["id"], **product_data.model_dump())
    await db.products.insert_one(product.model_dump())
    return product

@api_router.get("/products", response_model=List[Product])
async def get_products(current_user: dict = Depends(get_current_user)):
    products = await db.products.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(1000)
    return products

@api_router.get("/products/low-stock", response_model=List[Product])
async def get_low_stock_products(current_user: dict = Depends(get_current_user)):
    products = await db.products.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(1000)
    return [
        Product(**product)
        for product in products
        if int(product.get("stock", 0)) <= int(product.get("low_stock_threshold", 0))
    ]

@api_router.get("/products/{product_id}", response_model=Product)
async def get_product(product_id: str, current_user: dict = Depends(get_current_user)):
    product = await db.products.find_one({"id": product_id, "user_id": current_user["id"]}, {"_id": 0})
    if not product:
        raise HTTPException(status_code=404, detail="Product not found")
    return product

@api_router.put("/products/{product_id}", response_model=Product)
async def update_product(product_id: str, product_data: ProductCreate, current_user: dict = Depends(get_current_user)):
    result = await db.products.update_one(
        {"id": product_id, "user_id": current_user["id"]},
        {"$set": product_data.model_dump()}
    )
    if result.matched_count == 0:
        raise HTTPException(status_code=404, detail="Product not found")
    product = await db.products.find_one({"id": product_id}, {"_id": 0})
    return product

@api_router.delete("/products/{product_id}")
async def delete_product(product_id: str, current_user: dict = Depends(get_current_user)):
    result = await db.products.delete_one({"id": product_id, "user_id": current_user["id"]})
    if result.deleted_count == 0:
        raise HTTPException(status_code=404, detail="Product not found")
    return {"message": "Product deleted successfully"}


# ============= PURCHASE BILL ROUTES =============
@api_router.post("/purchase-bills/extract-pdf")
async def extract_purchase_bill_pdf(
    file: UploadFile = File(...),
    current_user: dict = Depends(get_current_user)
):
    if not file.filename or not file.filename.lower().endswith(".pdf"):
        raise HTTPException(status_code=400, detail="Upload a PDF purchase bill")

    file_bytes = await file.read()
    if not file_bytes:
        raise HTTPException(status_code=400, detail="Uploaded PDF is empty")
    if len(file_bytes) > 10 * 1024 * 1024:
        raise HTTPException(status_code=400, detail="PDF must be 10 MB or smaller")

    products = await db.products.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(1000)
    extracted = await extract_purchase_bill_pdf_locally(file_bytes, products)

    extracted_items = []
    unmatched_items = []
    for item in extracted.get("items", []) or []:
        extracted_name = str(item.get("product_name") or "").strip()
        quantity = int(float(item.get("quantity") or 0))
        purchase_price = round_money(float(item.get("purchase_price") or 0))
        total = round_money(float(item.get("total") or (quantity * purchase_price)))
        hsn_code = str(item.get("hsn_code") or "").strip() or None

        if not extracted_name or quantity <= 0:
            continue

        match = find_best_matching_product(extracted_name, hsn_code, products)
        extracted_item = {
            "product_name": extracted_name,
            "hsn_code": hsn_code,
            "quantity": quantity,
            "purchase_price": purchase_price,
            "total": total,
            "matched_product_id": match["product"]["id"] if match else "",
            "matched_product_name": match["product"]["name"] if match else "",
            "matched_unit": match["product"].get("unit", "pcs") if match else "pcs",
            "matched_hsn_code": match["product"].get("hsn_code", "") if match else "",
            "match_score": match["score"] if match else 0.0,
        }
        extracted_items.append(extracted_item)
        if not match:
            unmatched_items.append(extracted_name)

    if not extracted_items:
        raise HTTPException(
            status_code=400,
            detail="Could not find any purchase items in that PDF"
        )

    bill_date = normalize_purchase_bill_date(extracted.get("bill_date"))
    paid_amount = round_money(float(extracted.get("paid_amount") or 0))
    payment_method = str(extracted.get("payment_method") or "").strip().lower()
    allowed_payment_methods = {"cash", "upi", "card", "bank_transfer", "cheque"}
    if payment_method not in allowed_payment_methods:
        payment_method = ""

    return {
        "supplier_name": str(extracted.get("supplier_name") or "").strip(),
        "supplier_invoice_number": str(extracted.get("supplier_invoice_number") or "").strip(),
        "bill_date": bill_date,
        "paid_amount": paid_amount,
        "payment_method": payment_method,
        "notes": str(extracted.get("notes") or "").strip(),
        "items": extracted_items,
        "unmatched_items": unmatched_items,
    }


@api_router.post("/purchase-bills", response_model=PurchaseBill)
async def create_purchase_bill(
    purchase_bill_data: PurchaseBillCreate,
    current_user: dict = Depends(get_current_user)
):
    totals = await build_purchase_bill_totals(current_user["id"], purchase_bill_data)
    bill_number = await get_next_purchase_bill_number(current_user["id"])
    purchase_bill = PurchaseBill(
        user_id=current_user["id"],
        bill_number=bill_number,
        supplier_name=purchase_bill_data.supplier_name,
        supplier_invoice_number=purchase_bill_data.supplier_invoice_number,
        items=totals["items"],
        subtotal=totals["subtotal"],
        total_amount=totals["total_amount"],
        paid_amount=totals["paid_amount"],
        balance_amount=totals["balance_amount"],
        payment_status=totals["payment_status"],
        payment_method=totals["payment_method"],
        notes=purchase_bill_data.notes,
        created_at=(
            datetime.fromisoformat(purchase_bill_data.bill_date).isoformat()
            if purchase_bill_data.bill_date
            else datetime.now(timezone.utc).isoformat()
        )
    )
    await db.purchase_bills.insert_one(purchase_bill.model_dump())

    for item in totals["items"]:
        await db.products.update_one(
            {"id": item.product_id, "user_id": current_user["id"]},
            {"$inc": {"stock": item.quantity}}
        )

    return purchase_bill


@api_router.get("/purchase-bills", response_model=List[PurchaseBill])
async def get_purchase_bills(current_user: dict = Depends(get_current_user)):
    purchase_bills = await db.purchase_bills.find(
        {"user_id": current_user["id"]},
        {"_id": 0}
    ).sort("created_at", -1).to_list(1000)
    return purchase_bills


@api_router.get("/purchase-bills/{purchase_bill_id}", response_model=PurchaseBill)
async def get_purchase_bill(purchase_bill_id: str, current_user: dict = Depends(get_current_user)):
    purchase_bill = await db.purchase_bills.find_one(
        {"id": purchase_bill_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not purchase_bill:
        raise HTTPException(status_code=404, detail="Purchase bill not found")
    return purchase_bill


@api_router.put("/purchase-bills/{purchase_bill_id}", response_model=PurchaseBill)
async def update_purchase_bill(
    purchase_bill_id: str,
    purchase_bill_data: PurchaseBillUpdate,
    current_user: dict = Depends(get_current_user)
):
    existing_bill = await db.purchase_bills.find_one(
        {"id": purchase_bill_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not existing_bill:
        raise HTTPException(status_code=404, detail="Purchase bill not found")

    for item in existing_bill.get("items", []):
        await db.products.update_one(
            {"id": item["product_id"], "user_id": current_user["id"]},
            {"$inc": {"stock": -item["quantity"]}}
        )

    totals = await build_purchase_bill_totals(current_user["id"], purchase_bill_data)
    update_data = {
        "supplier_name": purchase_bill_data.supplier_name,
        "supplier_invoice_number": purchase_bill_data.supplier_invoice_number,
        "items": [item.model_dump() for item in totals["items"]],
        "subtotal": totals["subtotal"],
        "total_amount": totals["total_amount"],
        "paid_amount": totals["paid_amount"],
        "balance_amount": totals["balance_amount"],
        "payment_status": totals["payment_status"],
        "payment_method": totals["payment_method"],
        "notes": purchase_bill_data.notes,
    }

    if purchase_bill_data.bill_date:
        update_data["created_at"] = (
            datetime.strptime(purchase_bill_data.bill_date, "%Y-%m-%d")
            .replace(tzinfo=timezone.utc)
            .isoformat()
        )

    await db.purchase_bills.update_one(
        {"id": purchase_bill_id, "user_id": current_user["id"]},
        {"$set": update_data}
    )

    for item in totals["items"]:
        await db.products.update_one(
            {"id": item.product_id, "user_id": current_user["id"]},
            {"$inc": {"stock": item.quantity}}
        )

    updated_bill = await db.purchase_bills.find_one(
        {"id": purchase_bill_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    return updated_bill


@api_router.delete("/purchase-bills/{purchase_bill_id}")
async def delete_purchase_bill(purchase_bill_id: str, current_user: dict = Depends(get_current_user)):
    purchase_bill = await db.purchase_bills.find_one(
        {"id": purchase_bill_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not purchase_bill:
        raise HTTPException(status_code=404, detail="Purchase bill not found")

    for item in purchase_bill.get("items", []):
        await db.products.update_one(
            {"id": item["product_id"], "user_id": current_user["id"]},
            {"$inc": {"stock": -item["quantity"]}}
        )

    await db.purchase_bills.delete_one({"id": purchase_bill_id, "user_id": current_user["id"]})
    return {"message": "Purchase bill deleted successfully"}

# ============= INVOICE ROUTES =============
@api_router.post("/invoices", response_model=Invoice)
async def create_invoice(invoice_data: InvoiceCreate, current_user: dict = Depends(get_current_user)):
    customer = await db.customers.find_one(
        {"id": invoice_data.customer_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not customer:
        raise HTTPException(status_code=404, detail="Customer not found")

    is_proforma = is_proforma_invoice(invoice_data.invoice_type)
    if is_proforma and invoice_data.payment_method:
        raise HTTPException(
            status_code=400,
            detail="Proforma invoices cannot be marked paid during creation"
        )

    # ✅ ATOMIC INVOICE / PI NUMBER
    invoice_number = await get_next_document_number(current_user["id"], invoice_data.invoice_type)

    totals = await build_gst_invoice_totals(current_user["id"], customer, invoice_data)
    total_amount = totals["total_amount"]

    if is_proforma:
        paid_amount = 0.0
        balance_amount = total_amount
        payment_status = "proforma"
    elif invoice_data.payment_method:
        paid_amount = total_amount
        balance_amount = 0.0
        payment_status = "paid"
    else:
        paid_amount = 0.0
        balance_amount = total_amount
        payment_status = "pending"

    invoice = Invoice(
        user_id=current_user["id"],
        invoice_number=invoice_number,
        customer_id=invoice_data.customer_id,
        customer_name=customer["name"],
        items=totals["items"],
        invoice_type=invoice_data.invoice_type,
        place_of_supply=totals["place_of_supply"],
        place_state_code=totals["place_state_code"],
        reverse_charge=invoice_data.reverse_charge,
        tax_type=totals["tax_type"],
        subtotal=totals["subtotal"],
        gst_amount=totals["gst_amount"],
        cgst_amount=totals["cgst_amount"],
        sgst_amount=totals["sgst_amount"],
        igst_amount=totals["igst_amount"],
        cess_amount=totals["cess_amount"],
        round_off_amount=totals["round_off_amount"],
        total_amount=total_amount,
        payment_method=None if is_proforma else invoice_data.payment_method,
        notes=invoice_data.notes,
        payment_status=payment_status,
        paid_amount=paid_amount,
        balance_amount=balance_amount,
        created_at=(
            datetime.fromisoformat(invoice_data.invoice_date).isoformat()
            if invoice_data.invoice_date
            else datetime.now(timezone.utc).isoformat()
        )
    )

    await db.invoices.insert_one(invoice.model_dump())

    if not is_proforma:
        # 🔻 Update stock only for actual invoices
        for item in invoice_data.items:
            await db.products.update_one(
                {"id": item.product_id, "user_id": current_user["id"]},
                {"$inc": {"stock": -item.quantity}}
            )

    # 💰 Auto payment only for actual invoices
    if not is_proforma and invoice_data.payment_method:
        payment = Payment(
            user_id=current_user["id"],
            invoice_id=invoice.id,
            invoice_number=invoice.invoice_number,
            customer_id=invoice.customer_id,
            customer_name=invoice.customer_name,
            amount=total_amount,
            payment_method=invoice_data.payment_method
        )
        await db.payments.insert_one(payment.model_dump())

    return invoice



@api_router.get("/invoices", response_model=List[Invoice])
async def get_invoices(current_user: dict = Depends(get_current_user)):
    invoices = await db.invoices.find({"user_id": current_user["id"]}, {"_id": 0}).sort("created_at", -1).to_list(1000)
    return invoices

@api_router.get("/invoices/{invoice_id}", response_model=Invoice)
async def get_invoice(invoice_id: str, current_user: dict = Depends(get_current_user)):
    invoice = await db.invoices.find_one({"id": invoice_id, "user_id": current_user["id"]}, {"_id": 0})
    if not invoice:
        raise HTTPException(status_code=404, detail="Invoice not found")
    return invoice

@api_router.delete("/invoices/{invoice_id}")
async def delete_invoice(invoice_id: str, current_user: dict = Depends(get_current_user)):
    invoice = await db.invoices.find_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )

    if not invoice:
        raise HTTPException(status_code=404, detail="Invoice not found")

    if not is_proforma_invoice(invoice.get("invoice_type")):
        # ================= RESTORE PRODUCT STOCK =================
        for item in invoice["items"]:
            await db.products.update_one(
                {"id": item["product_id"], "user_id": current_user["id"]},
                {"$inc": {"stock": item["quantity"]}}
            )

    # ================= DELETE PAYMENTS =================
    await db.payments.delete_many(
        {"invoice_id": invoice_id, "user_id": current_user["id"]}
    )

    # ================= DELETE INVOICE =================
    await db.invoices.delete_one(
        {"id": invoice_id, "user_id": current_user["id"]}
    )

    return {"message": "Invoice and related payments deleted successfully"}


@api_router.put("/invoices/{invoice_id}", response_model=Invoice)
async def update_invoice(
    invoice_id: str,
    invoice_data: InvoiceUpdate,
    current_user: dict = Depends(get_current_user)
):
    invoice = await db.invoices.find_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not invoice:
        raise HTTPException(status_code=404, detail="Invoice not found")

    was_proforma = is_proforma_invoice(invoice.get("invoice_type"))
    is_proforma = is_proforma_invoice(invoice_data.invoice_type)

    if is_proforma and invoice.get("paid_amount", 0.0) > 0:
        raise HTTPException(
            status_code=400,
            detail="Paid invoices cannot be converted to proforma invoices"
        )

    if not was_proforma:
        # ================= RESTORE OLD STOCK =================
        for old_item in invoice["items"]:
            await db.products.update_one(
                {"id": old_item["product_id"], "user_id": current_user["id"]},
                {"$inc": {"stock": old_item["quantity"]}}
            )

    # ================= RECALCULATE TOTALS =================
    customer = await db.customers.find_one(
        {"id": invoice["customer_id"], "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not customer:
        raise HTTPException(status_code=404, detail="Customer not found")

    totals = await build_gst_invoice_totals(current_user["id"], customer, invoice_data)
    total_amount = totals["total_amount"]

    if is_proforma:
        paid_amount = 0.0
        balance_amount = total_amount
        payment_status = "proforma"
    else:
        paid_amount = invoice.get("paid_amount", 0.0)
        balance_amount = round_money(total_amount - paid_amount)

        if balance_amount <= 0:
            payment_status = "paid"
            balance_amount = 0.0
        elif paid_amount > 0:
            payment_status = "partial"
        else:
            payment_status = "pending"

    update_data = {
        "items": [i.model_dump() for i in totals["items"]],
        "invoice_type": invoice_data.invoice_type,
        "place_of_supply": totals["place_of_supply"],
        "place_state_code": totals["place_state_code"],
        "reverse_charge": invoice_data.reverse_charge,
        "tax_type": totals["tax_type"],
        "subtotal": totals["subtotal"],
        "gst_amount": totals["gst_amount"],
        "cgst_amount": totals["cgst_amount"],
        "sgst_amount": totals["sgst_amount"],
        "igst_amount": totals["igst_amount"],
        "cess_amount": totals["cess_amount"],
        "round_off_amount": totals["round_off_amount"],
        "total_amount": total_amount,
        "paid_amount": paid_amount,
        "balance_amount": balance_amount,
        "payment_status": payment_status,
        "payment_method": None if is_proforma else invoice.get("payment_method"),
        "notes": invoice_data.notes,
    }

    if invoice_data.invoice_date:
        update_data["created_at"] = (
            datetime.strptime(invoice_data.invoice_date, "%Y-%m-%d")
            .replace(tzinfo=timezone.utc)
            .isoformat()
        )

    await db.invoices.update_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"$set": update_data}
    )

    if not is_proforma:
        # ================= DEDUCT NEW STOCK =================
        for new_item in invoice_data.items:
            await db.products.update_one(
                {"id": new_item.product_id, "user_id": current_user["id"]},
                {"$inc": {"stock": -new_item.quantity}}
            )

    updated_invoice = await db.invoices.find_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )

    return updated_invoice

@api_router.post("/invoices/{invoice_id}/convert", response_model=Invoice)
async def convert_proforma_invoice(
    invoice_id: str,
    conversion_data: ProformaConvertRequest,
    current_user: dict = Depends(get_current_user)
):
    invoice = await db.invoices.find_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not invoice:
        raise HTTPException(status_code=404, detail="Invoice not found")
    if not is_proforma_invoice(invoice.get("invoice_type")):
        raise HTTPException(status_code=400, detail="Only proforma invoices can be converted")

    customer = await db.customers.find_one(
        {"id": invoice["customer_id"], "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not customer:
        raise HTTPException(status_code=404, detail="Customer not found")

    target_invoice_type = (
        conversion_data.invoice_type
        if not is_proforma_invoice(conversion_data.invoice_type)
        else "tax_invoice"
    )
    invoice_date = conversion_data.invoice_date
    if not invoice_date:
        invoice_date = str(invoice.get("created_at", ""))[:10]

    invoice_payload = InvoiceCreate(
        customer_id=invoice["customer_id"],
        invoice_date=invoice_date,
        items=[InvoiceItem(**item) for item in invoice.get("items", [])],
        invoice_type=target_invoice_type,
        place_of_supply=invoice.get("place_of_supply"),
        place_state_code=invoice.get("place_state_code"),
        reverse_charge=invoice.get("reverse_charge", False),
        payment_method=conversion_data.payment_method,
        notes=invoice.get("notes"),
    )
    totals = await build_gst_invoice_totals(current_user["id"], customer, invoice_payload)
    total_amount = totals["total_amount"]
    invoice_number = await get_next_document_number(current_user["id"], target_invoice_type)

    if conversion_data.payment_method:
        paid_amount = total_amount
        balance_amount = 0.0
        payment_status = "paid"
    else:
        paid_amount = 0.0
        balance_amount = total_amount
        payment_status = "pending"

    update_data = {
        "invoice_number": invoice_number,
        "items": [i.model_dump() for i in totals["items"]],
        "invoice_type": target_invoice_type,
        "place_of_supply": totals["place_of_supply"],
        "place_state_code": totals["place_state_code"],
        "reverse_charge": invoice_payload.reverse_charge,
        "tax_type": totals["tax_type"],
        "subtotal": totals["subtotal"],
        "gst_amount": totals["gst_amount"],
        "cgst_amount": totals["cgst_amount"],
        "sgst_amount": totals["sgst_amount"],
        "igst_amount": totals["igst_amount"],
        "cess_amount": totals["cess_amount"],
        "round_off_amount": totals["round_off_amount"],
        "total_amount": total_amount,
        "paid_amount": paid_amount,
        "balance_amount": balance_amount,
        "payment_status": payment_status,
        "payment_method": conversion_data.payment_method,
    }
    if invoice_date:
        update_data["created_at"] = (
            datetime.strptime(invoice_date, "%Y-%m-%d")
            .replace(tzinfo=timezone.utc)
            .isoformat()
        )

    await db.invoices.update_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"$set": update_data}
    )

    for item in totals["items"]:
        await db.products.update_one(
            {"id": item.product_id, "user_id": current_user["id"]},
            {"$inc": {"stock": -item.quantity}}
        )

    if conversion_data.payment_method:
        payment = Payment(
            user_id=current_user["id"],
            invoice_id=invoice_id,
            invoice_number=invoice_number,
            customer_id=invoice["customer_id"],
            customer_name=invoice["customer_name"],
            amount=total_amount,
            payment_method=conversion_data.payment_method,
        )
        await db.payments.insert_one(payment.model_dump())

    updated_invoice = await db.invoices.find_one(
        {"id": invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    return updated_invoice


# ============= PAYMENT ROUTES =============
@api_router.post("/payments", response_model=Payment)
async def create_payment(payment_data: PaymentCreate, current_user: dict = Depends(get_current_user)):
    invoice = await db.invoices.find_one(
        {"id": payment_data.invoice_id, "user_id": current_user["id"]},
        {"_id": 0}
    )
    if not invoice:
        raise HTTPException(status_code=404, detail="Invoice not found")
    if is_proforma_invoice(invoice.get("invoice_type")):
        raise HTTPException(
            status_code=400,
            detail="Payments cannot be added to a proforma invoice"
        )

    payment = Payment(
        user_id=current_user["id"],
        invoice_id=payment_data.invoice_id,
        invoice_number=invoice["invoice_number"],
        customer_id=invoice["customer_id"],
        customer_name=invoice["customer_name"],
        amount=payment_data.amount,
        payment_method=payment_data.payment_method,
        notes=payment_data.notes
    )

    await db.payments.insert_one(payment.model_dump())

    # ✅ FIX-2: recalculate totals
    payments = await db.payments.find(
        {"invoice_id": payment_data.invoice_id, "user_id": current_user["id"]}, {"_id": 0}
    ).to_list(1000)

    total_paid = round_money(sum(p["amount"] for p in payments))
    balance_amount = round_money(invoice["total_amount"] - total_paid)

    if balance_amount <= 0:
        payment_status = "paid"
        balance_amount = 0.0
    elif total_paid > 0:
        payment_status = "partial"
    else:
        payment_status = "pending"

    # ✅ update invoice correctly
    await db.invoices.update_one(
        {"id": payment_data.invoice_id, "user_id": current_user["id"]},
        {
            "$set": {
                "payment_status": payment_status,
                "paid_amount": total_paid,
                "balance_amount": balance_amount,
                "payment_method": payment_data.payment_method
            }
        }
    )

    return payment


@api_router.get("/payments", response_model=List[Payment])
async def get_payments(current_user: dict = Depends(get_current_user)):
    payments = await db.payments.find({"user_id": current_user["id"]}, {"_id": 0}).sort("payment_date", -1).to_list(1000)
    return payments

# ============= SETTINGS ROUTES =============
@api_router.get("/settings/company", response_model=CompanySettings)
async def get_company_settings(current_user: dict = Depends(get_current_user)):
    settings = await db.company_settings.find_one({"user_id": current_user["id"]}, {"_id": 0})
    if not settings:
        settings = build_default_company_settings(current_user)
        await db.company_settings.insert_one(settings)
    return CompanySettings(**ensure_company_settings_defaults(settings))

@api_router.put("/settings/company", response_model=CompanySettings)
async def update_company_settings(settings_data: CompanySettingsUpdate, current_user: dict = Depends(get_current_user)):
    update_data = {k: v for k, v in settings_data.model_dump().items() if v is not None}
    if not update_data:
        raise HTTPException(status_code=400, detail="No data to update")
    
    await db.company_settings.update_one(
        {"user_id": current_user["id"]},
        {"$set": update_data},
        upsert=True
    )
    
    settings = await db.company_settings.find_one({"user_id": current_user["id"]}, {"_id": 0})
    return CompanySettings(**ensure_company_settings_defaults(settings))

# ============= DASHBOARD STATS =============
@api_router.get("/dashboard/stats")
async def get_dashboard_stats(current_user: dict = Depends(get_current_user)):
    invoice_query = {
        "user_id": current_user["id"],
        "invoice_type": {"$ne": PROFORMA_INVOICE_TYPE},
    }

    # Get counts
    total_customers = await db.customers.count_documents({"user_id": current_user["id"]})
    total_products = await db.products.count_documents({"user_id": current_user["id"]})
    all_products = await db.products.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(1000)
    low_stock_count = sum(
        1 for product in all_products
        if int(product.get("stock", 0)) <= int(product.get("low_stock_threshold", 0))
    )
    total_invoices = await db.invoices.count_documents(invoice_query)
    
    # Calculate revenue
    invoices = await db.invoices.find(invoice_query, {"_id": 0}).to_list(1000)
    total_revenue = sum(inv["total_amount"] for inv in invoices)
    paid_revenue = sum(inv.get("paid_amount", 0) for inv in invoices)
    pending_revenue = sum(inv.get("balance_amount", 0) for inv in invoices)

    
    # Recent invoices
    recent_invoices = await db.invoices.find(invoice_query, {"_id": 0}).sort("created_at", -1).limit(5).to_list(5)
    
    # Monthly revenue (last 6 months)
    from datetime import datetime
    monthly_data = []
    for i in range(6):
        month = datetime.now(timezone.utc).replace(day=1) - timedelta(days=30 * i)
        month_start = month.replace(day=1).isoformat()
        if i == 0:
            month_end = datetime.now(timezone.utc).isoformat()
        else:
            next_month = month.replace(day=28) + timedelta(days=4)
            month_end = (next_month - timedelta(days=next_month.day)).isoformat()
        
        month_invoices = await db.invoices.find({
            **invoice_query,
            "created_at": {"$gte": month_start, "$lt": month_end}
        }, {"_id": 0}).to_list(1000)
        
        revenue = sum(inv["total_amount"] for inv in month_invoices)
        monthly_data.append({
            "month": month.strftime("%b %Y"),
            "revenue": revenue
        })
    
    monthly_data.reverse()
    
    return {
        "total_customers": total_customers,
        "total_products": total_products,
        "low_stock_count": low_stock_count,
        "total_invoices": total_invoices,
        "total_revenue": total_revenue,
        "paid_revenue": paid_revenue,
        "pending_revenue": pending_revenue,
        "recent_invoices": recent_invoices,
        "monthly_revenue": monthly_data
    }

@api_router.get("/exports/backup")
async def export_backup(current_user: dict = Depends(get_current_user)):
    customers = await db.customers.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    products = await db.products.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    invoices = await db.invoices.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    payments = await db.payments.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    purchase_bills = await db.purchase_bills.find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    company_settings = await db.company_settings.find_one({"user_id": current_user["id"]}, {"_id": 0})

    return JSONResponse({
        "exported_at": datetime.now(timezone.utc).isoformat(),
        "customers": customers,
        "products": products,
        "invoices": invoices,
        "payments": payments,
        "purchase_bills": purchase_bills,
        "company_settings": ensure_company_settings_defaults(company_settings or build_default_company_settings(current_user)),
    })

@api_router.get("/exports/{resource}")
async def export_resource(resource: str, format: str = "json", current_user: dict = Depends(get_current_user)):
    collection_map = {
        "customers": db.customers,
        "products": db.products,
        "invoices": db.invoices,
        "payments": db.payments,
        "purchase_bills": db.purchase_bills,
    }
    if resource not in collection_map:
        raise HTTPException(status_code=404, detail="Unsupported export resource")

    rows = await collection_map[resource].find({"user_id": current_user["id"]}, {"_id": 0}).to_list(5000)
    if format == "csv":
        csv_content = collection_to_csv(rows, resource)
        return Response(
            content=csv_content,
            media_type="text/csv",
            headers={"Content-Disposition": f"attachment; filename={resource}.csv"},
        )
    return JSONResponse(rows)

# ============= GST REPORTS =============
@api_router.get("/gst/reports")
async def get_gst_reports(
    from_date: Optional[str] = None,
    to_date: Optional[str] = None,
    current_user: dict = Depends(get_current_user)
):
    query: Dict[str, Any] = {
        "user_id": current_user["id"],
        "invoice_type": {"$ne": PROFORMA_INVOICE_TYPE},
    }
    date_query: Dict[str, str] = {}
    start = parse_report_date(from_date)
    end = parse_report_date(to_date, end_of_day=True)

    if start:
        date_query["$gte"] = start
    if end:
        date_query["$lte"] = end
    if date_query:
        query["created_at"] = date_query

    invoices = await db.invoices.find(query, {"_id": 0}).sort("created_at", -1).to_list(5000)

    totals = {
        "invoice_count": len(invoices),
        "taxable_value": 0.0,
        "cgst_amount": 0.0,
        "sgst_amount": 0.0,
        "igst_amount": 0.0,
        "cess_amount": 0.0,
        "gst_amount": 0.0,
        "total_amount": 0.0,
    }
    rate_wise = defaultdict(lambda: {
        "gst_rate": 0.0,
        "taxable_value": 0.0,
        "cgst_amount": 0.0,
        "sgst_amount": 0.0,
        "igst_amount": 0.0,
        "cess_amount": 0.0,
        "total_amount": 0.0,
    })
    hsn_summary = defaultdict(lambda: {
        "hsn_code": "",
        "description": "",
        "quantity": 0.0,
        "taxable_value": 0.0,
        "cgst_amount": 0.0,
        "sgst_amount": 0.0,
        "igst_amount": 0.0,
        "cess_amount": 0.0,
        "total_amount": 0.0,
    })
    b2b = []
    b2c = []

    for invoice in invoices:
        totals["taxable_value"] += invoice.get("subtotal", 0)
        totals["cgst_amount"] += invoice.get("cgst_amount", 0)
        totals["sgst_amount"] += invoice.get("sgst_amount", 0)
        totals["igst_amount"] += invoice.get("igst_amount", 0)
        totals["cess_amount"] += invoice.get("cess_amount", 0)
        totals["gst_amount"] += invoice.get("gst_amount", 0)
        totals["total_amount"] += invoice.get("total_amount", 0)

        customer = await db.customers.find_one(
            {"id": invoice.get("customer_id"), "user_id": current_user["id"]},
            {"_id": 0}
        ) or {}
        customer_type = customer.get("customer_type") or ("B2B" if customer.get("gst_number") else "B2C")
        invoice_row = {
            "invoice_number": invoice.get("invoice_number"),
            "date": invoice.get("created_at"),
            "customer_name": invoice.get("customer_name"),
            "gst_number": customer.get("gst_number"),
            "place_of_supply": invoice.get("place_of_supply"),
            "taxable_value": round_money(invoice.get("subtotal", 0)),
            "cgst_amount": round_money(invoice.get("cgst_amount", 0)),
            "sgst_amount": round_money(invoice.get("sgst_amount", 0)),
            "igst_amount": round_money(invoice.get("igst_amount", 0)),
            "cess_amount": round_money(invoice.get("cess_amount", 0)),
            "total_amount": round_money(invoice.get("total_amount", 0)),
            "payment_status": invoice.get("payment_status"),
        }
        if customer_type.upper() == "B2B" or customer.get("gst_number"):
            b2b.append(invoice_row)
        else:
            b2c.append(invoice_row)

        for item in invoice.get("items", []):
            gst_rate = float(item.get("gst_rate", 0))
            rate_key = f"{gst_rate:.2f}"
            rate_wise[rate_key]["gst_rate"] = gst_rate
            rate_wise[rate_key]["taxable_value"] += item.get("taxable_value", item.get("total", 0))
            rate_wise[rate_key]["cgst_amount"] += item.get("cgst_amount", 0)
            rate_wise[rate_key]["sgst_amount"] += item.get("sgst_amount", 0)
            rate_wise[rate_key]["igst_amount"] += item.get("igst_amount", 0)
            rate_wise[rate_key]["cess_amount"] += item.get("cess_amount", 0)
            rate_wise[rate_key]["total_amount"] += item.get("total_amount", item.get("total", 0))

            hsn_code = item.get("hsn_code") or "NA"
            hsn_summary[hsn_code]["hsn_code"] = hsn_code
            hsn_summary[hsn_code]["description"] = item.get("product_name", "")
            hsn_summary[hsn_code]["quantity"] += item.get("quantity", 0)
            hsn_summary[hsn_code]["taxable_value"] += item.get("taxable_value", item.get("total", 0))
            hsn_summary[hsn_code]["cgst_amount"] += item.get("cgst_amount", 0)
            hsn_summary[hsn_code]["sgst_amount"] += item.get("sgst_amount", 0)
            hsn_summary[hsn_code]["igst_amount"] += item.get("igst_amount", 0)
            hsn_summary[hsn_code]["cess_amount"] += item.get("cess_amount", 0)
            hsn_summary[hsn_code]["total_amount"] += item.get("total_amount", item.get("total", 0))

    for key in totals:
        if key != "invoice_count":
            totals[key] = round_money(totals[key])

    def rounded_rows(rows):
        output = []
        for row in rows:
            rounded = {}
            for key, value in row.items():
                rounded[key] = round_money(value) if isinstance(value, float) else value
            output.append(rounded)
        return output

    return {
        "period": {"from_date": from_date, "to_date": to_date},
        "summary": totals,
        "gstr_3b_outward": {
            "taxable_value": totals["taxable_value"],
            "igst_amount": totals["igst_amount"],
            "cgst_amount": totals["cgst_amount"],
            "sgst_amount": totals["sgst_amount"],
            "cess_amount": totals["cess_amount"],
        },
        "rate_wise": rounded_rows(rate_wise.values()),
        "hsn_summary": rounded_rows(hsn_summary.values()),
        "b2b": b2b,
        "b2c": b2c,
        "sales_register": invoices,
    }

# Include router
app.include_router(api_router)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

@app.on_event("shutdown")
async def shutdown_db_client():
    client.close()

if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run("server:app", host="0.0.0.0", port=port, log_level="info")
