import os
import math
import csv
import json
from typing import List, Optional, Tuple, Dict
import psycopg2
import pandas as pd


# Movement scoring functions
def clamp_0_10(v):
    """Clamp value to range [0, 10]."""
    return max(0.0, min(10.0, v))


def interpolate_piecewise_linear(x, knots_x, knots_y):
    """Interpolate using piecewise linear interpolation."""
    if x <= knots_x[0]:
        return knots_y[0]
    if x >= knots_x[-1]:
        return knots_y[-1]
    lo, hi = 0, len(knots_x) - 1
    while lo + 1 < hi:
        mid = (lo + hi) // 2
        if knots_x[mid] <= x:
            lo = mid
        else:
            hi = mid
    x0, y0 = knots_x[lo], knots_y[lo]
    x1, y1 = knots_x[hi], knots_y[hi]
    t = (x - x0) / (x1 - x0)
    return y0 + t * (y1 - y0)


def load_movement_model(path="trained_mapper.json"):
    """Load the trained movement scoring model."""
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Warning: Movement model file '{path}' not found. Movement scores will be set to 0.")
        return None
    except Exception as e:
        print(f"Warning: Error loading movement model: {e}. Movement scores will be set to 0.")
        return None


def score_movement(x, model):
    """Score a movement value using the trained model."""
    if model is None:
        return 0.0
    y = interpolate_piecewise_linear(x, model["knots_x"], model["knots_y"])
    return clamp_0_10(y)


# Database connection parameters
DB_CONFIG = {
    "dbname": "kinosave",
    "user": "efsi",
    "password": "efsi1",
    "host": "kinometric.com",
    "port": "5432"
}

# Patient IDs to exclude from processing
EXCLUDED_PATIENT_IDS = {55, 56, 57, 1, 2, 30}

COL_O = 13  # accMag is the last column (index 13, which corresponds to column N in Excel, but let's use it)
COL_F = 5  # Q-diff is at index 5 (corresponds to column F)
WINDOW = 50
START_ROW = 50  # start at row 50 (1-based)


def get_patient_info(patient_id: int) -> Optional[Dict]:
    """Get patient name from database."""
    try:
        conn = psycopg2.connect(**DB_CONFIG)
        cur = conn.cursor()

        # Just get name from database
        cur.execute("""
            SELECT first_name, last_name
            FROM patients
            WHERE patient_id = %s;
        """, (patient_id,))

        row = cur.fetchone()
        cur.close()
        conn.close()

        if row:
            full_name = f"{row[0]} {row[1]}"
            return {"name": full_name}
        else:
            return None

    except Exception as e:
        print(f"Database error for patient_id {patient_id}: {e}")
        return None


def parse_float(val) -> Optional[float]:
    """Convert to float if possible; else None."""
    try:
        if val is None:
            return None
        if isinstance(val, str):
            val = val.replace(",", "").strip()
        num = float(val)
        if math.isnan(num):
            return None
        return num
    except Exception:
        return None


def read_column_values_csv(path: str, col_index_0based: int, start_row: int) -> List[float]:
    """Read only valid numeric values from given column starting at start_row."""
    values: List[float] = []
    try:
        # Read the entire file and find the line that starts with "W"
        with open(path, 'r') as file:
            lines = file.readlines()

        # Find the header line that starts with "W"
        header_line_idx = None
        for i, line in enumerate(lines):
            if line.strip().startswith('W,') or line.strip().startswith('W '):
                header_line_idx = i
                break

        if header_line_idx is None:
            print(f"  Warning: Could not find header line starting with 'W' in {path}")
            return values

        # Read CSV starting from the header line
        df = pd.read_csv(path, skiprows=header_line_idx)

        # Clean column names (remove leading/trailing spaces)
        df.columns = df.columns.str.strip()

        if col_index_0based >= len(df.columns):
            print(f"  Warning: Column index {col_index_0based} exceeds available columns ({len(df.columns)})")
            print(f"  Available columns: {list(df.columns)}")
            return values

        # Get the column data starting from start_row (convert to 0-based)
        start_idx = start_row - 1 if start_row > 1 else 0
        if start_idx >= len(df):
            print(f"  Warning: Start row {start_row} exceeds data length ({len(df)})")
            return values

        column_data = df.iloc[start_idx:, col_index_0based]

        for val in column_data:
            parsed_val = parse_float(val)
            if parsed_val is not None:
                values.append(parsed_val)

    except Exception as e:
        print(f"  Error reading CSV {path}: {e}")

    return values


def best_50_window_min_range(values: List[float]) -> Optional[Tuple[float, int, int]]:
    """Return (min_range, start_idx, end_idx) for least-movement 50-sample window."""
    n = len(values)
    if n < WINDOW:
        return None
    best_range = None
    best_start = 0
    for end in range(WINDOW - 1, n):
        start = end - (WINDOW - 1)
        window = values[start:end + 1]
        w_min = min(window)
        w_max = max(window)
        w_range = w_max - w_min
        if best_range is None or w_range < best_range:
            best_range = w_range
            best_start = start
    return (best_range, best_start, best_start + WINDOW - 1)


def extract_prefix_number(filename: str) -> Optional[int]:
    """Extract the leading number from the start of the filename."""
    base = os.path.basename(filename)
    num_str = ""
    for ch in base:
        if ch.isdigit():
            num_str += ch
        else:
            break
    if len(num_str) > 0:  # Any number of digits is fine
        try:
            return int(num_str)
        except Exception:
            return None
    return None


def extract_test_condition(filename: str) -> str:
    """Extract test condition from filename."""
    filename_lower = filename.lower()
    if "lb-eo" in filename_lower or "lb_eo" in filename_lower:
        return "left-open"
    elif "rb-eo" in filename_lower or "rb_eo" in filename_lower:
        return "right-open"
    elif "lb-ec" in filename_lower or "lb_ec" in filename_lower:
        return "left-closed"
    elif "rb-ec" in filename_lower or "rb_ec" in filename_lower:
        return "right-closed"
    else:
        return "unknown"
    """Extract the leading number from the start of the filename."""
    base = os.path.basename(filename)
    num_str = ""
    for ch in base:
        if ch.isdigit():
            num_str += ch
        else:
            break
    if len(num_str) > 0:  # Any number of digits is fine
        try:
            return int(num_str)
        except Exception:
            return None
    return None


def main():
    # Load movement scoring model
    movement_model = load_movement_model("trained_mapper.json")
    if movement_model:
        print(f"Loaded movement model with {len(movement_model['knots_x'])} knots", flush=True)
    else:
        print("Movement scoring disabled - model not available", flush=True)

    # Only get files that start with a number
    all_files = [f for f in os.listdir(".") if os.path.isfile(f) and f.lower().endswith(".csv")]
    files = [f for f in all_files if extract_prefix_number(f) is not None]

    print(f"Found {len(all_files)} total .csv file(s) in current directory.", flush=True)
    print(f"Found {len(files)} .csv file(s) that start with a number.", flush=True)
    print(f"Excluded patient IDs: {sorted(EXCLUDED_PATIENT_IDS)}", flush=True)

    if not files:
        print("No CSV files starting with numbers to process.", flush=True)
        if all_files:
            print(f"Skipped files (don't start with numbers): {', '.join(all_files)}", flush=True)
        return

    processed = 0
    skipped_unknown = 0
    skipped_not_enough = 0
    errored = 0

    results = []
    for csv_file in sorted(files):
        print(f"Processing: {csv_file}", flush=True)

        patient_id = extract_prefix_number(csv_file)
        if not patient_id:
            skipped_unknown += 1
            print(f"  → Skipped (no valid patient ID prefix): {csv_file}", flush=True)
            continue

        # Check if patient ID is in exclusion list
        if patient_id in EXCLUDED_PATIENT_IDS:
            skipped_unknown += 1
            print(f"  → Skipped (excluded patient ID {patient_id}): {csv_file}", flush=True)
            continue

        # Database lookup instead of hardcoded LOOKUP
        patient_info = get_patient_info(patient_id)
        if not patient_info:
            skipped_unknown += 1
            print(f"  → Skipped (patient ID {patient_id} not found in database): {csv_file}", flush=True)
            continue

        name = patient_info["name"]

        # Extract test condition from filename
        test_condition = extract_test_condition(csv_file)

        try:
            values_o = read_column_values_csv(csv_file, COL_O, START_ROW)
            values_f = read_column_values_csv(csv_file, COL_F, START_ROW)
        except Exception as e:
            errored += 1
            print(f"  → Error reading CSV: {e}", flush=True)
            continue

        best_o = best_50_window_min_range(values_o)
        best_f = best_50_window_min_range(values_f)

        if best_o is None:
            skipped_not_enough += 1
            print(f"  → Skipped (not enough numeric data in column O; need ≥ {WINDOW})", flush=True)
            continue
        if best_f is None:
            skipped_not_enough += 1
            print(f"  → Skipped (not enough numeric data in column F; need ≥ {WINDOW})", flush=True)
            continue

        min_range_o, s_o, e_o = best_o
        movement_f, s_f, e_f = best_f

        # Calculate movement score using the trained model
        movement_score = score_movement(movement_f, movement_model)

        results.append((name, min_range_o, movement_f, movement_score, test_condition, csv_file, s_o, e_o, s_f, e_f))
        processed += 1
        print(f"  → Success: {name} ({test_condition}) Movement: {movement_f:.6f}, Score: {movement_score:.2f}",
              flush=True)

    if results:
        # Create two versions of the results

        # Version 1: Sort by last name (extract last name from full name)
        results_by_name = sorted(results, key=lambda x: x[0].split()[-1])  # Sort by last word in name

        # Version 2: Sort by MovementScore (descending - highest scores first)
        results_by_score = sorted(results, key=lambda x: -x[3])

        # Display Version 1: Sorted by Last Name
        print("\n=== RESULTS VERSION 1: Sorted by Last Name ===", flush=True)
        print(f"{'Name':<25} {'Best50Δ_O':>12} {'Movement':>12} {'MoveScore':>9} {'TestCondition':<12}   File",
              flush=True)
        print("-" * 110, flush=True)
        for name, best50_o, movement_f, movement_score, test_condition, cfile, s_o, e_o, s_f, e_f in results_by_name:
            print(
                f"{name:<25} {best50_o:12.6f} {movement_f:12.6f} {movement_score:9.2f} {test_condition:<12}   {cfile}",
                flush=True)

        # Display Version 2: Sorted by MovementScore
        print("\n=== RESULTS VERSION 2: Sorted by MovementScore (Highest First) ===", flush=True)
        print(f"{'Name':<25} {'Best50Δ_O':>12} {'Movement':>12} {'MoveScore':>9} {'TestCondition':<12}   File",
              flush=True)
        print("-" * 110, flush=True)
        for name, best50_o, movement_f, movement_score, test_condition, cfile, s_o, e_o, s_f, e_f in results_by_score:
            print(
                f"{name:<25} {best50_o:12.6f} {movement_f:12.6f} {movement_score:9.2f} {test_condition:<12}   {cfile}",
                flush=True)

        # Write CSV Version 1: Sorted by Last Name
        with open("results_by_name.csv", "w", newline="", encoding="utf-8") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "Name",
                "Best50Δ_O", "O_StartIndex", "O_EndIndex",
                "Movement", "F_StartIndex", "F_EndIndex",
                "MovementScore",
                "TestCondition",
                "File"
            ])
            for name, best50_o, movement_f, movement_score, test_condition, cfile, s_o, e_o, s_f, e_f in results_by_name:
                writer.writerow([
                    name,
                    f"{best50_o:.6f}", s_o, e_o,
                    f"{movement_f:.6f}", s_f, e_f,
                    f"{movement_score:.2f}",
                    test_condition,
                    cfile
                ])
        print("\n✅ Results by name written to results_by_name.csv", flush=True)

        # Write CSV Version 2: Sorted by MovementScore
        with open("results_by_score.csv", "w", newline="", encoding="utf-8") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "Name",
                "Best50Δ_O", "O_StartIndex", "O_EndIndex",
                "Movement", "F_StartIndex", "F_EndIndex",
                "MovementScore",
                "TestCondition",
                "File"
            ])
            for name, best50_o, movement_f, movement_score, test_condition, cfile, s_o, e_o, s_f, e_f in results_by_score:
                writer.writerow([
                    name,
                    f"{best50_o:.6f}", s_o, e_o,
                    f"{movement_f:.6f}", s_f, e_f,
                    f"{movement_score:.2f}",
                    test_condition,
                    cfile
                ])
        print("✅ Results by score written to results_by_score.csv", flush=True)
    else:
        print("\nNo results to display (nothing met the criteria).", flush=True)

    # Summary
    print(f"\nSummary: processed={processed}, skipped_unknown={skipped_unknown}, "
          f"skipped_not_enough={skipped_not_enough}, errors={errored}", flush=True)


def test_database_connection():
    """Test function to verify database connectivity and patient lookups."""
    print("Testing database connection...", flush=True)
    test_ids = [30, 150, 55]  # Use your test IDs

    for pid in test_ids:
        result = get_patient_info(pid)
        if result:
            print(f"Patient {pid}: {result['name']}")
        else:
            print(f"Patient {pid}: Not found")


if __name__ == "__main__":
    # Uncomment the line below to test database connection first
    # test_database_connection()

    main()