import csv
import sys
import json
import os


def read_csv(file_path):import csv


from decimal import InvalidOperation
from decimal import Decimal
import numpy as np

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Ensure compatibility with PyCharm

from scipy.stats import gaussian_kde  # Used to calculate point density



def read_csv(file_path):

    """Processes the file and returns JSON with test results."""
    if not os.path.exists(file_path):
        return json.dumps({"Success": False, "Error": f"The file '{file_path}' does not exist."})

    """Reads a CSV file while preserving large integer precision for time-related columns."""
    with open(file_path, mode='r', newline='') as file:
        reader = csv.reader(file)

        headers = None
        data_rows = []
        extra_text = []

        # Identify the header row
        for row in reader:
            clean_row = [col.strip() for col in row]  # Remove leading/trailing spaces
            if "W" in clean_row and "sensortime" in clean_row:  # Detect actual header row
                headers = clean_row
                #print("Detected Headers:", headers)  # Debugging
                break  # Stop once headers are found

        if headers is None:
            raise ValueError("Valid data headers not found in CSV file.")

        # Read data after the header
        for row in reader:
            if not any(row):  # Skip empty rows
                continue
            try:
                numeric_row = []
                for i, value in enumerate(row):
                    value = value.strip()  # Ensure spaces are removed

                    if headers[i] in ["sensortime", "phonetime", "timediff"]:
                        # Convert using Decimal for precise large number handling
                        if value:
                            try:
                                # Convert to Decimal and then to int to preserve full precision
                                converted_value = int(Decimal(value))
                            except (InvalidOperation, ValueError):
                                #print(f"Conversion error: Invalid number {value}")
                                converted_value = None
                        else:
                            converted_value = None

                        numeric_row.append(converted_value)
                    else:
                        # For other columns, convert to float
                        try:
                            converted_value = float(value) if value else None
                        except ValueError:
                            #print(f"Conversion error: Invalid number {value}")
                            converted_value = None
                        numeric_row.append(converted_value)

                if numeric_row:
                    data_rows.append(numeric_row)
            except ValueError as e:
                #print(f"Skipping row due to conversion error: {row}, Error: {e}")
                extra_text.append(",".join(row))

    # Convert list to NumPy array
    data = np.array(data_rows, dtype=object)  # Use dtype=object to handle mixed types

    # Create a dictionary mapping stripped column names to arrays
    data_dict = {headers[i]: data[:, i] for i in range(len(headers))}

    # Include footer text separately
    data_dict["extra_text"] = extra_text

    return data_dict


def find_stable_periods(data, accel_column='accMag', time_column='sensortime', stable_threshold=0.5,
                        instability_tolerance=5):
    """
    Identifies periods where acceleration is within a stable range, ensuring instability needs to be seen
    for 5 consecutive samples before ending a stable period.
    """
    if accel_column not in data or time_column not in data:
        raise ValueError(f"Columns {accel_column} and {time_column} not found in data.")

    acceleration = np.array(data[accel_column])
    time_counts = np.array(data[time_column])

    #print(f"Acceleration contains: {acceleration}")

    # Remove None values before computing the mean
    acceleration = np.array([a for a in acceleration if a is not None], dtype=np.float64)
    time_counts = np.array([t for t in time_counts if t is not None], dtype=np.float64)

    if acceleration.size == 0:
        return []

    mean_accel = np.mean(acceleration)

    # Define stable range (within mean ± threshold)
    stable_range_min = mean_accel - stable_threshold
    stable_range_max = mean_accel + stable_threshold

    stable_periods = []
    in_stable_range = False
    stable_start_idx = None
    instability_count = 0

    for i in range(len(acceleration)):
        is_stable = stable_range_min <= acceleration[i] <= stable_range_max

        if is_stable:
            if not in_stable_range:
                stable_start_idx = i
                in_stable_range = True
            instability_count = 0  # Reset instability count
        else:
            if in_stable_range:
                instability_count += 1
                if instability_count >= instability_tolerance:
                    stable_end_idx = i - instability_tolerance
                    time_diff = time_counts[stable_end_idx] - time_counts[stable_start_idx]

                    if time_diff > 0:  # Ensure meaningful periods
                        stable_periods.append({
                            "stable_start_index": stable_start_idx,
                            "stable_end_index": stable_end_idx,
                            "time_difference": int(time_diff),
                            "average_between_spikes": round(np.mean(acceleration[stable_start_idx:stable_end_idx]), 2)
                        })

                    in_stable_range = False  # Reset state
                    instability_count = 0  # Reset instability count

    # Ensure a single dictionary is returned if only one result exists
    return stable_periods[0] if len(stable_periods) == 1 else stable_periods


def analyze_file(file_path, showPlot = True):


    """Processes the file and returns JSON with test results."""
    if not os.path.exists(file_path):
        return json.dumps({"Success": False, "Error": f"The file '{file_path}' does not exist."})

    data = read_csv(file_path)

    #print("First few values in phonetime:", data["phonetime"][:5])
    #print("First few values in sensortime:", data["sensortime"][:5])
    #print("First few values in timediff:", data["timediff"][:5])

    sampleThreshold = .2
    spike_results = find_stable_periods(data, stable_threshold=sampleThreshold)

    #print ("spike_results", spike_results)

    # Find the record with the maximum time_difference
    max_record = max(spike_results[1:], key=lambda x: x['time_difference']) if len(spike_results) > 1 else None

    # Extract relevant details
    time_between_spikes = max_record['time_difference']
    before_index = max_record['stable_start_index']
    after_index = max_record['stable_end_index']
    time = max_record['time_difference']



    range_param= (before_index, after_index)
   #range_param = (670, 1370)
    angular_range, stability = process_quaternions(data, range_param=range_param, showPlot=showPlot)
    #print(f"Angular Range: {angular_range}")
    #print(f"Stability: {stability}")


    # Calculate fall risk as an inverse of stability (example: normalize to 0-10 scale)

    if isinstance(spike_results, list) and spike_results:  # Ensure it's a non-empty list
        fall_risk = round(10 - (spike_results[0].get("average_between_spikes", 0) * 2), 1)
    else:
        fall_risk = 10  # Default value if no spikes exist

    # Create final JSON response


    # Construct response
    response = {
        "Success": "Success",
        "testname": file_path.split("/")[-1],  # Extract filename from path
        "sampleThreshold": sampleThreshold,
        "startSample": before_index,
        "endSample": after_index,
        "angularRange": angular_range,
        "stability": stability,
        "sampleTime": time_between_spikes,
        "fall_risk": fall_risk
    }

    return json.dumps(response, indent=4)


def process_quaternions(data, range_param=None, showPlot = True):
    """
    Processes quaternions and performs the following tasks:
    1) Plots the data in 3D with color representing the density of samples at a location.
    2) Returns a value representing the range of movement (angular) over a range passed in the parameter.
    3) Returns a value representing stability over a range passed in the parameter.

    Parameters:
    data (dict): A dictionary containing w, x, y, and z values.
    range_param (tuple): A tuple containing the range for plotting and calculations. Default is None.

    Returns:
    tuple: A tuple containing the range of movement and stability values.
    """

    # Ensure range_param is correctly set as an integer range
    if range_param is None or isinstance(range_param, bool):  # Fix: Handle incorrect boolean assignment
        range_param = (0, len(data['W']))  # Default full range

    start, end = range_param  # Unpack the tuple into start and end values
    start, end = int(start), int(end)  # Ensure integers

    # Validate range to prevent errors
    if start < 0 or end > len(data['W']) or start >= end:
        raise ValueError(f"Invalid range_param: {range_param}. Must be within (0, {len(data['W'])})")



    # Extract quaternion components from data
    w = np.array(data['W'], dtype=object)
    x = np.array(data['X'], dtype=object)
    y = np.array(data['Y'], dtype=object)
    z = np.array(data['Z'], dtype=object)

    # Replace None with 0 and ensure float conversion
    w = np.where(w == None, 0, w).astype(np.float64)
    x = np.where(x == None, 0, x).astype(np.float64)
    y = np.where(y == None, 0, y).astype(np.float64)
    z = np.where(z == None, 0, z).astype(np.float64)

    # Calculate the magnitude of quaternions
    magnitude = np.sqrt(w ** 2 + x ** 2 + y ** 2 + z ** 2)

    # Ensure range_param is correctly set as an integer range
    if range_param is None:
        range_param = (0, len(w))  # Default full range

    start, end = range_param  # Unpack the tuple into start and end values
    start, end = int(start), int(end)  # Ensure integers

    # Validate range to prevent errors
    if start < 0 or end > len(w) or start >= end:
        raise ValueError(f"Invalid range_param: {range_param}. Must be within (0, {len(w)})")

    # Select only the relevant range
    x_range = x[start:end]
    y_range = y[start:end]
    z_range = z[start:end]

    # Compute density of points
    xyz = np.vstack([x_range, y_range, z_range])
    density = gaussian_kde(xyz)(xyz)  # Compute density values for each point

    if showPlot:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        scatter = ax.scatter(x_range, y_range, z_range, c=density, cmap='viridis', marker='o')
        fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=5).set_label('Sample Density')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

        plt.draw()
        plt.pause(0.1)
       ## input("Press Enter to continue...")
        plt.show()

    # Calculate the range of movement (angular)
    angular_range = np.max(magnitude[start:end]) - np.min(magnitude[start:end])

    # Calculate stability as the standard deviation of the magnitude over the range
    stability = np.std(magnitude[start:end])

    return angular_range, stability





if __name__ == "__main__":
    # Default values
    fileName = "2-lb-eo-15-2-2025-charts.csv"
    showPlot = True  # Default to showing the plot

    # Check command-line arguments
    if len(sys.argv) >= 2:
        fileName = sys.argv[1]
    if len(sys.argv) >= 3:
        showPlot = sys.argv[2].lower() in ['true', '1', 'yes']  # Convert argument to boolean

    print(analyze_file(fileName, showPlot))

