from itertools import groupby
from operator import itemgetter
import shutil
from pathlib import Path
import sys
import json
import os

import pandas as pd
import logging

import numpy as np


# Global variable to control output style
use_logging = False
single_file_passed = False
log_score_total = None
log_score_first = None
log_score_second = None

def log_msg(*args):
    msg = ' '.join(str(a) for a in args)
    if single_file_passed:
        logging.info(msg)
    else:
        print(msg)


# ----------------------------------------
# Read and Clean CSV File
# ----------------------------------------
def read_csv(file_path):
    """Reads a CSV file while preserving large integer precision for time-related columns."""
    if not os.path.exists(file_path):
        return {"Success": False, "Error": f"The file '{file_path}' does not exist."}

    try:

        with open(file_path, 'r', encoding='utf-8') as f:
            first_line = f.readline()

        first_cell = first_line.strip().split(',')[0]

        if "fall".lower() in first_cell.lower():
            skiprows = 4
        else:
            skiprows = 2

        # Load CSV while skipping the first three rows
        df = pd.read_csv(file_path, skiprows=skiprows, dtype=str, on_bad_lines="skip")
        # Clean column headers by stripping whitespace
        df.columns = [col.strip() for col in df.columns]

        log_msg("📋 Headers from CSV:", df.columns.tolist())

        # Remove empty columns and trim spaces from headers
        df.columns = df.columns.str.strip()

        # Convert all possible numeric columns
        df = df.apply(pd.to_numeric, errors='coerce')

        # Drop any fully empty rows (trailing comments)
        df.dropna(how='all', inplace=True)

        return df
    
    except Exception as e:
        return {"Success": False, "Error": str(e)}


# ----------------------------------------
# 2. Find the Most Stable Period
# ----------------------------------------
def find_stable_period(data, column_index=5, start_range=40, end_range=None, window_size=10):
    """
    Finds the most stable period (least change in values) within the specified range.

    Parameters:
        data (DataFrame): Cleaned CSV data.
        column_index (int): Index of the stability-related column (default: 4, column 'E').
        start_range (int): Start index for searching the stable period.
        end_range (int or None): End index for searching the stable period. If None, use full dataset.
        window_size (int): Number of consecutive points used to determine stability.

    Returns:
        tuple: (stable_start, stable_end) indices of the most stable period.
    """
    stability_values = data.iloc[:, column_index]

    # If no end_range is provided, use the entire dataset
    if end_range is None:
        end_range = len(stability_values) - 1

    # Compute a rolling standard deviation to measure stability
    rolling_std = stability_values.rolling(window=window_size).std()

    # Define stability threshold (40th percentile of rolling standard deviation)
    flat_threshold = rolling_std.quantile(0.4)  # The lowest 40% are considered stable

    # Identify periods with low standard deviation
    flat_periods = rolling_std[rolling_std <= flat_threshold]

    # Extract indices of the flattest regions
    flat_indices = flat_periods.index

    # Find the longest contiguous stable period within the range
    stable_in_range = flat_indices[(flat_indices >= start_range) & (flat_indices <= end_range)]

    # Determine start and end of the most stable segment
    if not stable_in_range.empty:
        stable_start = stable_in_range.min()
        stable_end = stable_in_range.max()
    else:
        stable_start, stable_end = None, None

    return stable_start, stable_end



def find_longest_stable_after_termination(data, column_index=13, start_after=0, window_size=5, std_threshold=0.01, min_length=10):
    """
    Finds the longest stable segment (low std dev) after a given row index.

    Parameters:
        data (DataFrame): Full dataset.
        column_index (int): Column to analyze.
        start_after (int): Row index to start scanning from.
        window_size (int): Rolling std window.
        std_threshold (float): Max std to count as stable.
        min_length (int): Minimum length of a stable group.

    Returns:
        (start_index, end_index): The longest stable group after `start_after`.
    """
    from itertools import groupby
    from operator import itemgetter

    series = data.iloc[start_after:, column_index]
    rolling_std = series.rolling(window=window_size).std()

    stable_flags = rolling_std < std_threshold
    stable_indices = stable_flags[stable_flags].index.tolist()

    # Group contiguous stable indices
    groups = []
    for _, g in groupby(enumerate(stable_indices), lambda i: i[1] - i[0]):
        group = list(map(itemgetter(1), g))
        if len(group) >= min_length:
            groups.append(group)

    if not groups:
        return None, None

    longest = max(groups, key=len)
    return longest[0], longest[-1]


#
# Calcualate the stability
#

def comparative_stability(cv):
    """
    Maps Coefficient of Variation (CV) to a stability score using a fixed logistic curve.
    Ensures consistent comparative scoring across different samples:
      - CV = 0.78 maps to ~85% (very stable)
      - CV = 0.91 maps to ~15% (unstable)
      - Lower CVs give scores closer to 100%
      - Higher CVs give scores closer to 0%
    """
    midpoint = 0.845  # Midpoint between 0.78 and 0.91
    steepness = 30    # Controls slope of the logistic curve

    score = 100 / (1 + np.exp(steepness * (cv - midpoint)))
    return round(score, 1)

def print_stability_breakdown(data, column_index, startRow, endRow):
    global log_score_second, log_score_first, log_score_total
    """
    Prints a detailed breakdown of stability for a range and its halves.
    Includes standard deviation, mean, coefficient of variation (CV),
    and human-friendly stability percent scores.
    """
    values = data.iloc[startRow:endRow + 1, column_index].dropna()
    total_len = len(values)

    if total_len == 0:
        print(" No valid data in selected range.")
        return

    mid = startRow + total_len // 2
    first_half = data.iloc[startRow:mid, column_index].dropna()
    second_half = data.iloc[mid:endRow + 1, column_index].dropna()

    # Standard deviations
    std_total = np.std(values)
    std_first = np.std(first_half)
    std_second = np.std(second_half)

    # Means
    mean_total = np.mean(values)
    mean_first = np.mean(first_half)
    mean_second = np.mean(second_half)

    # Coefficient of variation
    cv_total = std_total / mean_total if mean_total != 0 else float('inf')
    cv_first = std_first / mean_first if mean_first != 0 else float('inf')
    cv_second = std_second / mean_second if mean_second != 0 else float('inf')

    # Weighted average of the halves
    weighted_std = (std_first * len(first_half) + std_second * len(second_half)) / total_len

    # Comparative stability: more realistic for accMag and other natural ranges
    def comparative_score(std_val, max_std=0.05):
        score = (1 - min(std_val / max_std, 1.0)) * 100
        return round(score, 1)

    score_total = comparative_score(std_total)
    score_first = comparative_score(std_first)
    score_second = comparative_score(std_second)


    #  Comparative Stability Score (Normalized between 0–100, where 100 is most stable)
    def comparative_score(std_val, max_std=0.04):
        score = (1 - min(std_val / max_std, 1.0)) * 100
        return round(score, 1)

    comp_score_first = comparative_score(std_first)
    comp_score_second = comparative_score(std_second)

    log_score_first = comparative_stability(cv_first)
    log_score_second = comparative_stability(cv_second)
    log_score_total = (log_score_first+ log_score_second) / 2

    log_msg(f"\n Stability Breakdown for Rows {startRow} to {endRow}")
    log_msg(f"   Total samples: {total_len}")
    log_msg(f"   Overall Std Dev:     {std_total:.6f}")
    log_msg(f"   Overall Mean:        {mean_total:.6f}")
    log_msg(f"   Overall Coef Var:    {cv_total:.6f}")
    log_msg(f"   Stability Score:     {score_total:.1f}%")

    log_msg(f"   First Half Std Dev:  {std_first:.6f}")
    log_msg(f"   First Half Mean:     {mean_first:.6f}")
    log_msg(f"   First Half Coef Var: {cv_first:.6f}")
    log_msg(f"   First Half Stability Score: {score_first:.1f}%")
    log_msg(f"   First Half Comparative Stability: {comp_score_first:.1f}%")

    log_msg(f"   Second Half Std Dev:  {std_second:.6f}")
    log_msg(f"   Second Half Mean:     {mean_second:.6f}")
    log_msg(f"   Second Half Coef Var: {cv_second:.6f}")
    log_msg(f"   Second Half Stability Score: {score_second:.1f}%")
    log_msg(f"   Second Half Comparative Stability: {comp_score_second:.1f}%")

    log_msg(f"   Weighted Std Dev:    {weighted_std:.6f}")

    log_msg(f"   Logistic Stability Score (Overall):     {log_score_total:.1f}%")
    log_msg(f"   Logistic Stability Score (First Half):  {log_score_first:.1f}%")
    log_msg(f"   Logistic Stability Score (Second Half): {log_score_second:.1f}%")
    log_msg("────────────────────────────────────────────")

def calculate_stability(data, column_index=4, startRow=0, endRow=None):
    """
    Determines the stability over a given period in the dataset.

    Parameters:
        data (DataFrame): The dataset containing numerical values.
        column_index (int): Index of the column to analyze for stability (default: 4, column 'E').
        startRow (int): The starting index of the period.
        endRow (int or None): The ending index of the period. If None, it uses the last row.

    Returns:
        float: The standard deviation over the given period (lower value = higher stability).
    """

    # Ensure endRow is within the valid range
    if endRow is None or endRow > len(data) - 1:
        endRow = len(data) - 1

    if startRow < 0 or startRow >= endRow:
        raise ValueError(f"Invalid range: startRow {startRow} must be less than endRow {endRow}")

    # Extract the column values for the given range
    values = data.iloc[startRow:endRow + 1, column_index]

    # Remove NaN values (if any exist)
    values = values.dropna()

    if values.empty:
        return None  # Return None if no valid data is present

    # Calculate standard deviation (measure of stability)
    stability_score = np.std(values)

    return stability_score


def calculate_time_difference(data, column_index, firstValOffset, lastValOffset):
    """
    Calculates the time difference between two values in a specified column.

    Parameters:
        data (DataFrame): The dataset containing numerical values.
        column_index (int): Index of the column containing time or timestamp values.
        firstValOffset (int): Row offset for the first value.
        lastValOffset (int): Row offset for the second value.

    Returns:
        float: The difference between the two time values (last - first), or None if invalid.
    """

    # Ensure valid offsets
    if firstValOffset < 0 or lastValOffset < 0 or firstValOffset >= len(data) or lastValOffset >= len(data):
        raise ValueError("Offsets are out of bounds of the dataset")

    # Get the two values
    val1 = data.iloc[firstValOffset, column_index]
    val2 = data.iloc[lastValOffset, column_index]

    # Ensure values are numeric and not NaN
    if pd.isna(val1) or pd.isna(val2):
        return None

    # Return the time difference
    return val2 - val1


# ----------------------------------------
# Analyze File and Return JSON
# ----------------------------------------
def analyze_file(file_path, showPlot=True):
    """Processes the file and returns JSON with test results."""
    try:
        global log_score_total, log_score_first, log_score_second
        data = read_csv(file_path)

        if not isinstance(data, pd.DataFrame):
            return json.dumps({"Success": False, "Error": "Failed to load data."})

        # Detect the most stable period dynamically


        startCnt = 0
        endCnt = 0

        startCnt1, endCnt1 = find_longest_stable_region(data, column_index=13)

        if startCnt1 is None or endCnt1 is None:
            log_msg(f"⚠️ Skipping file '{file_path}' due to missing startCnt1 or endCnt1.")
            return json.dumps({"Success": False, "Error": "No stable region found."})

        # Offset the window
        adj_start = startCnt1 + 5
        adj_end = endCnt1 - 5

        # Check for valid range
        if adj_start >= adj_end:
            log_msg(f"⚠️ Skipping file '{file_path}' due to invalid stability range ({adj_start} to {adj_end})")
            return json.dumps({"Success": False, "Error": "No stable region found."})
        
        print_stability_breakdown(data, column_index=13, startRow=adj_start, endRow=adj_end)

        # Calculate full stability score
        stabilityScore = calculate_stability(data, startRow=adj_start, endRow=adj_end)
        timeInMsec = calculate_time_difference(data, firstValOffset=adj_start, lastValOffset=adj_end, column_index=8)
        log_msg("🧪 Full range time (ms): " + str(timeInMsec))

        # Midpoint for halves
        mid = (adj_start + adj_end) // 2

        # First half
        stabilityScore1 = calculate_stability(data, startRow=adj_start, endRow=mid)
        log_msg(f"📏 Stability (first half): {stabilityScore1}")

        # Second half
        stabilityScore2 = calculate_stability(data, startRow=mid + 1, endRow=adj_end)
        log_msg(f"📏 Stability (second half): {stabilityScore2}")

        response = {
            "Success": "Success",
            "testname": file_path.split("/")[-1],
            "startCnt": startCnt,
            "stopCnt": endCnt,
            "startCnt1": startCnt1,
            "stopCnt1": endCnt1,
            "timeInMsec": timeInMsec,
            "stabilityRaw": stabilityScore,
            "stabilityRaw1": stabilityScore1,
            "stabilityRaw2": stabilityScore2,
            #"stabilityScore": score_from_stability(stabilityScore),
            #"stabilityScore1": score_from_stability(stabilityScore1),
            #"stabilityScore2": score_from_stability(stabilityScore2),
            "stabilityScore": round(min(log_score_total, 999.99), 2),
            "stabilityScore1": round(min(log_score_first, 999.99), 2),
            "stabilityScore2": round(min(log_score_second, 999.99), 2),
            "fall_risk": 50,

        }

        from pathlib import Path

        resultFileName = file_path
        if file_path.lower().endswith(".csv"):
            if single_file_passed:
                # Use full path with new output directory for single file case
                output_dir = "/home/kinometric/learn/processed"
                Path(output_dir).mkdir(parents=True, exist_ok=True)  # Make sure directory exists
                base_name = Path(file_path).stem  # "27-rb-ec-10-4-2025"
                resultFileName = f"{output_dir}/{base_name}.xlsx"
            else:
                # Default: save next to original CSV
                resultFileName = file_path[:-4] + ".xlsx"

        process_csv_and_add_charts(file_path, resultFileName, startCnt1, endCnt1, stabilityScore, stabilityScore1, stabilityScore2, timeInMsec)

        log_msg("✅ JSON being returned:")
        log_msg(json.dumps(response, indent=4))

        # Convert all NumPy int64 values to Python int
        response = {key: int(value) if isinstance(value, (np.integer, np.int64)) else value for key, value in
                    response.items()}

        return json.dumps(response, indent=4)
    except Exception as e:
        print(f"❌ Error analyzing file '{file_path}': {e}")
        return json.dumps({"Success": False, "Error": str(e)})


def score_from_stability(value):
    """
    Converts a stability value (between 0 and 0.2) to a score between 0 and 100.
    Linearly maps:
      0.2 → 0
      0.0 → 100
    Values above 0.2 return negative scores.
    """
    score = (0.0012 - value) / 0.0012 * 100
    return round(score, 2)


def find_longest_stable_region(data, column_index=13,
                                motion_threshold=0.3,
                                settle_threshold=0.1,
                                exit_threshold=0.2,
                                min_length=10):
    """
    Scans the entire dataset and finds the longest stable period after motion bursts.
    Prints 3 rows before and after each motion/stable start/stable end.
    """
    series = data.iloc[:, column_index].reset_index(drop=True)

    watching = False
    in_stable = False
    start_offset = None
    stable_regions = []

    def print_context(label, idx, val):
        print(f"\n{label} at index {idx}, val={val:.4f}")
        for j in range(max(0, idx - 3), min(len(series), idx + 4)):
            prefix = "➜" if j == idx else "  "
            print(f"{prefix} [{j}] = {series[j]:.4f}")

    for i, val in enumerate(series):
        if not watching and val > motion_threshold:
            watching = True
            print_context("🚀 Motion", i, val)

        elif watching and not in_stable and val < settle_threshold:
            start_offset = i
            in_stable = True
            print_context("🔽 Enter Stable", i, val)

        elif in_stable and val > exit_threshold:
            end_offset = i
            print_context("🔼 Exit Stable", i, val)
            watching = False
            in_stable = False

            if end_offset - start_offset >= min_length:
                stable_regions.append((start_offset, end_offset))

    if in_stable:
        end_offset = len(series) - 1
        if end_offset - start_offset >= min_length:
            stable_regions.append((start_offset, end_offset))
        print_context("✅ Stable to EOF", end_offset, series[end_offset])

    if not stable_regions:
        print("\n❌ No stable regions found.")
        return None, None

    longest = max(stable_regions, key=lambda x: x[1] - x[0])
    print(f"\n🏁 Longest stable region: {longest[0]} to {longest[1]} (length={longest[1] - longest[0]})")
    return longest




# ----------------------------------------
# Save as excel with charts
# ----------------------------------------

def process_csv(csv_file, excel_file, skip_rows=4):
    import pandas as pd
    from openpyxl import load_workbook

    if not os.path.exists(csv_file):
        raise FileNotFoundError(f"File not found: {csv_file}")
    
    try:

        with open(csv_file, 'r', encoding='utf-8') as f:
            first_line = f.readline()

        first_cell = first_line.strip().split(',')[0]

        if "fall".lower() not in first_cell.lower():
            skip_rows = 2

        # ✅ Read header row correctly
        df = pd.read_csv(csv_file, skiprows=skip_rows)

        # ✅ Clean column names
        df.columns = [col.strip() for col in df.columns]

        # ✅ Insert row index for charting
        df.insert(0, "Index", range(1, len(df) + 1))

        # ✅ Convert only numeric columns
        for col in df.columns:
            try:
                df[col] = pd.to_numeric(df[col])
            except Exception:
                pass  # Leave non-numeric columns untouched

        # ✅ Drop all-NaN rows (after conversion)
        df.dropna(how='all', inplace=True)

        # ✅ Save cleaned DataFrame to Excel
        df.to_excel(excel_file, index=False, engine='openpyxl')

        wb = load_workbook(excel_file)
        ws = wb.active
        return df, wb, ws
    except Exception as e:
        print(f"❌ Error processing CSV: {e}")
        return None, None, None
       # return {"Success": False, "Error": str(e)}


def create_chart(ws, df, header_name, chart_title, position, color="000000", line_col_offset=30):
    from openpyxl.chart import LineChart, Reference, Series
    from openpyxl.utils import get_column_letter

    max_row = ws.max_row

    if header_name not in df.columns:
        log_msg(f"❌ Column '{header_name}' not found in DataFrame.")
        log_msg("🧪 Available headers:", df.columns.tolist())
        return

    col_index_df = df.columns.get_loc(header_name)
    excel_col_index = col_index_df + 1
    excel_col_letter = get_column_letter(excel_col_index)

    log_msg(f"📊 Chart for '{header_name}' -> Excel column: {excel_col_letter} ({excel_col_index})")

    # 🧪 log_msg a few sample values
    for r in range(2, min(max_row, 10)):
        val = ws.cell(row=r, column=excel_col_index).value
        log_msg(f"Row {r} [{excel_col_letter}]: {val}")

    # Create references
    data = Reference(ws, min_col=excel_col_index, min_row=1, max_row=max_row)  # Include header
    labels = Reference(ws, min_col=1, min_row=2, max_row=max_row)  # Index column (A)

    # Build the chart
    chart = LineChart()
    chart.title = chart_title
    chart.width = 32.5
    chart.height = 15.25
    chart.y_axis.title = header_name
    chart.x_axis.title = "Index"
    chart.add_data(data, titles_from_data=True)  # ✅ Pull title from row 1
    chart.set_categories(labels)

    # Apply style
    for s in chart.series:
        s.graphicalProperties.line.solidFill = color

    chart.legend = None
    chart.x_axis.delete = False
    chart.y_axis.delete = False

    ws.add_chart(chart, position)


def highlight_stable_region(ws, startCnt, endCnt):
    from openpyxl.styles import PatternFill
    stable_fill = PatternFill(start_color="FFFF00", end_color="FFFF00", fill_type="solid")
    for row in ws.iter_rows(min_row=startCnt + 2, max_row=endCnt + 2):
        for cell in row:
            cell.fill = stable_fill


def process_csv_and_add_charts(csv_file, excel_file, startCnt, endCnt, stabilityScore, stabilityScore1, stabilityScore2, timeInMsec):
    try:

        df, wb, ws = process_csv(csv_file, excel_file)

        # Write summary info
        ws["R2"].value = startCnt
        ws["S2"].value = endCnt
        ws["T2"].value = timeInMsec
        ws["U2"].value = stabilityScore
        ws["V2"].value = stabilityScore1
        ws["W2"].value = stabilityScore2
        ws["U3"].value = score_from_stability(stabilityScore)
        ws["V3"].value = score_from_stability(stabilityScore1)
        ws["W3"].value = score_from_stability(stabilityScore2)

        # Highlight stable region (optional)
        from openpyxl.styles import PatternFill
        fill = PatternFill(start_color="FFFF00", end_color="FFFF00", fill_type="solid")
        for row in ws.iter_rows(min_row=startCnt + 2, max_row=endCnt + 2):
            for cell in row:
                cell.fill = fill

        # Charts using actual headers, not Excel letters

        create_chart(ws, df, "Q-diff", "Q-diff Stability", "R5")
        create_chart(ws, df, "accMag", "Acceleration Magnitude", "R35")

        wb.save(excel_file)
        log_msg(f"✅ Excel file saved with charts: {excel_file}")
    except Exception as e:
        log_msg(f"❌ Error in process_csv_and_add_charts: {e}")
        import traceback
        log_msg(traceback.format_exc())


# ----------------------------------------
# Main Execution
# ----------------------------------------
def main():
    global single_file_passed

    if (use_logging):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            filename='output.log',
            filemode='w'  # 'w' = overwrite each run; use 'a' to append
        )

    show_plot = True  # Default: show plot
    target_files = []  # List of CSV files to process

    # -----------------------------
    # ARGUMENT CONVENTIONS:
    #   - No args: process all CSV files in current dir with plot enabled
    #   - 1 arg:
    #       - If "true/false/yes/no/1/0": use as plot flag, process all files
    #       - Otherwise: treat as filename, use default show_plot=True
    #   - 2 args:
    #       - First arg = plot flag
    #       - Second arg = specific filename to process
    # -----------------------------

    if len(sys.argv) >= 2:
        arg1 = sys.argv[1].lower()

        if arg1 not in ['true', 'false', '1', '0', 'yes', 'no']:
            # Case: One argument, assumed to be a filename
            target_files = [Path(sys.argv[1])]
            single_file_passed = True
        else:
            # Case: First argument is a valid boolean-like flag
            show_plot = arg1 in ['true', '1', 'yes']

    if len(sys.argv) >= 3:
        # Case: Two arguments (plot flag and filename)
        target_files = [Path(sys.argv[2])]
        single_file_passed = True

    # Get current and processed directory paths
    current_dir = Path.cwd()
    processed_dir = current_dir / 'processed'
    processed_dir.mkdir(exist_ok=True)  # Create 'processed' dir if needed

    # If no specific file was passed, default to all CSVs in current dir
    if not target_files:
        target_files = list(current_dir.glob('*.csv'))

    # Process each target CSV
    for csv_file in target_files:
        if not csv_file.exists():
            log_msg(f"File not found: {csv_file}")
            continue

        # Analyze file
        result = analyze_file(str(csv_file), show_plot)
        log_msg(result)
        print(result)

        # Move the processed CSV to the 'processed' directory
      #  if not single_file_passed:
      #      dest_csv_path = processed_dir / csv_file.name
      #      shutil.move(str(csv_file), str(dest_csv_path))
      #      log_msg(f"Moved {csv_file.name} to {processed_dir}")

        # Also move matching .xls or .xlsx file (if present)
        for ext in ('.xls', '.xlsx'):
            corresponding_file = csv_file.with_suffix(ext)
            if corresponding_file.exists():
                dest_file_path = processed_dir / corresponding_file.name
                shutil.move(str(corresponding_file), str(dest_file_path))
                log_msg(f"Moved {corresponding_file.name} to {processed_dir}")


if __name__ == "__main__":
    # Determine early if a single file is passed
    if len(sys.argv) >= 2:
        arg1 = sys.argv[1].lower()
        if arg1 not in ['true', 'false', '1', '0', 'yes', 'no']:
            single_file_passed = True
        if len(sys.argv) >= 3:
            single_file_passed = True

    main()
