package com.efsi.kinometric;

import android.util.Log;
import org.json.JSONObject;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * BalanceScorer - Complete balance test scoring engine for Android.
 *
 * Exact port of Python tuning_engine.py. All scoring logic matches the server
 * (tuning_engine.py, analyze8.py) and cross-validation suite (cross_validate.py).
 *
 * Usage:
 *   ScoringConfig config = new ScoringConfig(jsonString);  // or ScoringConfig.createDefault()
 *   BalanceScorer scorer = new BalanceScorer(config);
 *
 *   // From CSV file content:
 *   Map<String, Object> result = scorer.analyzeFromCsv(csvContent);
 *
 *   // From live sensor arrays:
 *   Map<String, Object> result = scorer.analyze(qDiffArray, sensortimeArray);
 *
 *   // Result keys: movement_score, test_duration, sample_count, windows,
 *   //   stability, fatigue, duration_penalty, scoring_mode, seven_second_bonus, risk_level
 */
public class BalanceScorer {
    private static final String TAG = "BalanceScorer";

    private final ScoringConfig config;

    public BalanceScorer(ScoringConfig config) {
        this.config = config;
    }

    // =========================================================================
    // Public API
    // =========================================================================

    /**
     * Analyze a CSV string and return results matching server format.
     *
     * @param csvContent Full CSV file as a string
     * @return Results map, or null if file is unparseable
     */
    public Map<String, Object> analyzeFromCsv(String csvContent) {
        double[][] parsed = parseCsv(csvContent);
        if (parsed == null) return null;
        return analyze(parsed[0], parsed[1]);
    }

    /**
     * Analyze pre-parsed sensor data arrays.
     *
     * @param qDiff      All q_diff values (before warmup skip)
     * @param sensortime All sensortime values in milliseconds (before warmup skip)
     * @return Results map, or null if insufficient data
     */
    public Map<String, Object> analyze(double[] qDiff, double[] sensortime) {
        int startRow = config.startRowSkip;

        if (startRow >= qDiff.length) return null;

        // Apply warmup skip
        double[] qd = Arrays.copyOfRange(qDiff, startRow, qDiff.length);
        double[] st = Arrays.copyOfRange(sensortime, startRow, sensortime.length);

        if (qd.length < config.minSamples) {
            double totalDur = st.length > 1 ? durationSeconds(st, 0, st.length - 1) : 0.0;
            Map<String, Object> result = new HashMap<>();
            result.put("insufficient_data", true);
            result.put("reason", "Recording too short (" + qd.length + " samples, "
                    + round2(totalDur) + "s after warmup skip)");
            result.put("sample_count", qd.length);
            result.put("test_duration", round2(totalDur));
            return result;
        }

        double totalDuration = durationSeconds(st, 0, st.length - 1);
        int sampleCount = qd.length;

        // ----- Analyze each target window -----
        Map<String, Map<String, Object>> windowsResults = new HashMap<>();
        for (int targetDur : config.targetWindowsSeconds) {
            String key = targetDur + "_second";
            Integer windowSize = findWindowSize(st, (double) targetDur, 0);

            if (windowSize == null || windowSize > qd.length) {
                windowsResults.put(key, null);
                continue;
            }

            double bestScore = -1;
            Map<String, Object> bestWindow = null;

            for (int i = 0; i <= qd.length - windowSize; i++) {
                double[] windowData = Arrays.copyOfRange(qd, i, i + windowSize);
                double minVal = arrayMin(windowData);
                double maxVal = percentile(windowData, config.outlierPercentile);
                double rangeVal = maxVal - minVal;

                double score = scoreMovement(rangeVal);

                if (score > bestScore) {
                    bestScore = score;
                    double actualDur = durationSeconds(st, i, i + windowSize);
                    double timeStart = durationSeconds(st, 0, i);
                    bestWindow = new HashMap<>();
                    bestWindow.put("score", round2(score));
                    bestWindow.put("raw_score", round2(score)); // score == raw (rescale=none)
                    bestWindow.put("actual_duration", round2(actualDur));
                    bestWindow.put("time_start", round2(timeStart));
                    bestWindow.put("time_end", round2(timeStart + actualDur));
                    bestWindow.put("q_diff_range", round6(rangeVal));
                    bestWindow.put("sample_count", windowSize);
                }
            }

            windowsResults.put(key, bestWindow);
        }

        // ----- Stability -----
        int stableCount = 0;
        for (double v : qd) {
            if (v < config.stableThreshold) stableCount++;
        }
        double percentStable = qd.length > 0 ? round1(stableCount * 100.0 / qd.length) : 0.0;

        // Longest continuous stable stretch (by duration, not sample count)
        double maxStableDuration = 0.0;
        Integer bestStart = null;
        Integer bestEnd = null;
        Integer currentStart = null;
        for (int i = 0; i < qd.length; i++) {
            if (qd[i] < config.stableThreshold) {
                if (currentStart == null) currentStart = i;
            } else {
                if (currentStart != null) {
                    double dur = durationSeconds(st, currentStart, i);
                    if (dur > maxStableDuration) {
                        maxStableDuration = dur;
                        bestStart = currentStart;
                        bestEnd = i;
                    }
                    currentStart = null;
                }
            }
        }
        if (currentStart != null) {
            double dur = durationSeconds(st, currentStart, qd.length - 1);
            if (dur > maxStableDuration) {
                maxStableDuration = dur;
                bestStart = currentStart;
                bestEnd = qd.length - 1;
            }
        }

        double totalStableTime = stableCount > 0 ? totalDuration * stableCount / qd.length : 0.0;

        // ----- Analyze 7s window for bonus (even if not in target_windows) -----
        if (config.sevenSecondBonusEnabled && !containsInt(config.targetWindowsSeconds, 7)) {
            Integer ws7 = findWindowSize(st, 7.0, 0);
            if (ws7 != null && ws7 <= qd.length) {
                double bestScore = -1;
                Map<String, Object> bestWindow = null;
                for (int i = 0; i <= qd.length - ws7; i++) {
                    double[] windowData = Arrays.copyOfRange(qd, i, i + ws7);
                    double minVal = arrayMin(windowData);
                    double maxVal = percentile(windowData, config.outlierPercentile);
                    double rangeVal = maxVal - minVal;
                    double score = scoreMovement(rangeVal);
                    if (score > bestScore) {
                        bestScore = score;
                        double actualDur = durationSeconds(st, i, i + ws7);
                        double timeStart = durationSeconds(st, 0, i);
                        bestWindow = new HashMap<>();
                        bestWindow.put("score", round2(score));
                        bestWindow.put("raw_score", round2(score));
                        bestWindow.put("actual_duration", round2(actualDur));
                        bestWindow.put("time_start", round2(timeStart));
                        bestWindow.put("time_end", round2(timeStart + actualDur));
                        bestWindow.put("q_diff_range", round6(rangeVal));
                        bestWindow.put("sample_count", ws7);
                    }
                }
                windowsResults.put("7_second", bestWindow);
            } else {
                windowsResults.put("7_second", null);
            }
        }

        // ----- Overall movement score -----
        double movementScore;
        double durationPenalty = 0;
        String durationPenaltyReason = null;
        boolean qualityBonusApplied = false;
        Double stableStretchMeanQdiff = null;

        if ("duration".equals(config.scoringMode)) {
            // Duration mode: score from longest stable stretch
            double rawDur = interpolate(maxStableDuration, config.durationKnotsX, config.durationKnotsY);
            movementScore = clamp010(rawDur);

            // Quality bonus
            if (config.qualityBonusEnabled && bestStart != null) {
                double sum = 0;
                int count = bestEnd - bestStart;
                if (count > 0) {
                    for (int i = bestStart; i < bestEnd; i++) {
                        sum += qd[i];
                    }
                    double meanQd = sum / count;
                    stableStretchMeanQdiff = round6(meanQd);
                    if (meanQd < config.stableThreshold * config.qualityBonusThresholdPct) {
                        movementScore = Math.min(10.0, movementScore + config.qualityBonusValue);
                        qualityBonusApplied = true;
                    }
                }
            }
        } else {
            // Q-diff mode: only use target window scores for base calculation
            List<Double> availableScores = new ArrayList<>();
            for (int w : config.targetWindowsSeconds) {
                String key = w + "_second";
                Map<String, Object> wResult = windowsResults.get(key);
                if (wResult != null && wResult.get("score") != null) {
                    availableScores.add(((Number) wResult.get("score")).doubleValue());
                }
            }

            if (availableScores.isEmpty()) {
                movementScore = 0.0;
            } else if ("min".equals(config.aggregation)) {
                movementScore = Collections.min(availableScores);
            } else if ("max".equals(config.aggregation)) {
                movementScore = Collections.max(availableScores);
            } else if ("avg".equals(config.aggregation) || "mean".equals(config.aggregation)) {
                double sum = 0;
                for (double s : availableScores) sum += s;
                movementScore = sum / availableScores.size();
            } else {
                movementScore = Collections.min(availableScores);
            }

            // Duration penalty (qdiff mode only) — config-driven window check
            int[] targetWins = config.targetWindowsSeconds.clone();
            Arrays.sort(targetWins);
            int maxWin = targetWins.length > 0 ? targetWins[targetWins.length - 1] : 5;
            int minWin = targetWins.length > 0 ? targetWins[0] : 3;
            boolean hasMax = windowsResults.get(maxWin + "_second") != null;
            boolean hasMinOnly = windowsResults.get(minWin + "_second") != null;
            if (hasMinOnly) {
                // Check that ALL other windows are null
                for (int w : targetWins) {
                    if (w != minWin && windowsResults.get(w + "_second") != null) {
                        hasMinOnly = false;
                        break;
                    }
                }
            }

            if (hasMinOnly) {
                durationPenalty = config.durationPenalty3sOnly;
                durationPenaltyReason = "3s_only";
            } else if (!hasMax && !hasMinOnly) {
                durationPenalty = config.durationPenalty5sMax;
                durationPenaltyReason = "5s_max";
            }

            if (durationPenalty > 0) {
                movementScore = Math.max(0, movementScore - durationPenalty);
            }
        }

        // ----- 7-second bonus: cap score unless 7s window qualifies -----
        Double bonusApplied = null;
        if (config.sevenSecondBonusEnabled && !"duration".equals(config.scoringMode)) {
            Map<String, Object> w7 = windowsResults.get("7_second");
            if (w7 != null && w7.get("q_diff_range") != null) {
                double qdiff7s = ((Number) w7.get("q_diff_range")).doubleValue();
                double bonusScore = config.maxWithout7s;
                for (int i = 0; i < config.bonusTierMaxQdiff.length; i++) {
                    if (qdiff7s <= config.bonusTierMaxQdiff[i]) {
                        bonusScore = config.bonusTierScore[i];
                        break;
                    }
                }
                movementScore = Math.min(movementScore, bonusScore);
                bonusApplied = bonusScore;
            } else {
                movementScore = Math.min(movementScore, config.maxWithout7s);
                bonusApplied = config.maxWithout7s;
            }
        }

        double avgSampleRate = totalDuration > 0 ? sampleCount / totalDuration : 0.0;
        Map<String, Object> fatigue = calculateFatigue(windowsResults);

        // Build stability map
        Map<String, Object> stability = new HashMap<>();
        stability.put("percent_stable", percentStable);
        stability.put("total_stable_time", round1(totalStableTime));
        stability.put("continuous_duration", round1(maxStableDuration));
        stability.put("threshold_used", config.stableThreshold);

        // Build result
        Map<String, Object> result = new HashMap<>();
        result.put("movement_score", round2(movementScore));
        result.put("test_duration", round2(totalDuration));
        result.put("sample_count", sampleCount);
        result.put("avg_sample_rate", round1(avgSampleRate));
        result.put("windows", windowsResults);
        result.put("stability", stability);
        result.put("percent_stable", percentStable);
        result.put("fatigue", fatigue);
        result.put("duration_penalty", durationPenalty);
        result.put("duration_penalty_reason", durationPenaltyReason);
        result.put("scoring_mode", config.scoringMode);
        result.put("max_stable_duration", round2(maxStableDuration));
        result.put("seven_second_bonus", bonusApplied);
        result.put("quality_bonus_applied", qualityBonusApplied);
        result.put("stable_stretch_mean_qdiff", stableStretchMeanQdiff);
        result.put("risk_level", riskLevel(movementScore));

        return result;
    }

    // =========================================================================
    // Live recording indicator — lightweight stability check
    // =========================================================================

    /**
     * Returns the longest continuous stable duration (seconds) from accumulated
     * sensor data. Call periodically (~500ms) during recording for the live
     * stability indicator.
     *
     * This skips warmup rows and uses config.stableThreshold (0.007).
     * Much cheaper than full analyze() — no windowing or scoring.
     *
     * @param qDiff      Accumulated q_diff values (including warmup)
     * @param sensortime Accumulated sensortime values in milliseconds
     * @return Longest continuous stable stretch in seconds, or 0.0 if insufficient data
     */
    public double getCurrentStableDuration(double[] qDiff, double[] sensortime) {
        int startRow = config.startRowSkip;
        if (startRow >= qDiff.length) return 0.0;

        double[] qd = Arrays.copyOfRange(qDiff, startRow, qDiff.length);
        double[] st = Arrays.copyOfRange(sensortime, startRow, sensortime.length);

        if (qd.length < 2) return 0.0;

        double maxDuration = 0.0;
        Integer currentStart = null;

        for (int i = 0; i < qd.length; i++) {
            if (qd[i] < config.stableThreshold) {
                if (currentStart == null) currentStart = i;
            } else {
                if (currentStart != null) {
                    double dur = durationSeconds(st, currentStart, i);
                    if (dur > maxDuration) maxDuration = dur;
                    currentStart = null;
                }
            }
        }

        // Check final stretch (still stable at end of data)
        if (currentStart != null) {
            double dur = durationSeconds(st, currentStart, qd.length - 1);
            if (dur > maxDuration) maxDuration = dur;
        }

        return Math.round(maxDuration * 100.0) / 100.0;
    }

    /**
     * Returns the usable test duration in seconds (after warmup skip).
     * Call during recording for the live duration indicator.
     *
     * @param sensortime Accumulated sensortime values in milliseconds
     * @return Duration in seconds after warmup, or 0.0 if insufficient data
     */
    public double getCurrentTestDuration(double[] sensortime) {
        int startRow = config.startRowSkip;
        if (startRow >= sensortime.length) return 0.0;

        return Math.round(
            durationSeconds(sensortime, startRow, sensortime.length - 1) * 100.0
        ) / 100.0;
    }

    // =========================================================================
    // Risk level
    // =========================================================================

    public String riskLevel(double score) {
        if (score < config.criticalBelow) return "Critical";
        if (score < config.highBelow) return "High";
        if (score < config.moderateBelow) return "Moderate";
        return "Low";
    }

    // =========================================================================
    // CSV parsing
    // =========================================================================

    /**
     * Parse CSV content into [qDiff[], sensortime[]] arrays.
     * Returns null if the CSV is too short or unparseable.
     */
    private double[][] parseCsv(String content) {
        String[] lines = content.split("\n");
        int dataStart = config.headerRows;
        if (lines.length <= dataStart) return null;

        List<Double> qDiffList = new ArrayList<>();
        List<Double> stList = new ArrayList<>();

        int maxCol = Math.max(config.qDiffColumn, config.sensortimeColumn);

        for (int i = dataStart; i < lines.length; i++) {
            String line = lines[i].trim();
            if (line.isEmpty()) continue;

            String[] cols = line.split(",");
            if (cols.length <= maxCol) continue;

            try {
                double qd = Double.parseDouble(cols[config.qDiffColumn].trim());
                double st = Double.parseDouble(cols[config.sensortimeColumn].trim());
                qDiffList.add(qd);
                stList.add(st);
            } catch (NumberFormatException e) {
                // skip malformed rows
            }
        }

        if (qDiffList.isEmpty()) return null;

        double[] qDiff = new double[qDiffList.size()];
        double[] sensortime = new double[stList.size()];
        for (int i = 0; i < qDiffList.size(); i++) {
            qDiff[i] = qDiffList.get(i);
            sensortime[i] = stList.get(i);
        }

        return new double[][]{qDiff, sensortime};
    }

    // =========================================================================
    // Scoring functions (exact match to Python tuning_engine.py)
    // =========================================================================

    /**
     * Piecewise linear interpolation through knot points.
     * Uses the config's scoring curve by default, or custom knots if provided.
     */
    private double interpolate(double x, double[] knotsX, double[] knotsY) {
        if (x <= knotsX[0]) return knotsY[0];
        if (x >= knotsX[knotsX.length - 1]) return knotsY[knotsY.length - 1];

        int lo = 0, hi = knotsX.length - 1;
        while (hi - lo > 1) {
            int mid = (lo + hi) / 2;
            if (knotsX[mid] <= x) {
                lo = mid;
            } else {
                hi = mid;
            }
        }

        double t = (x - knotsX[lo]) / (knotsX[hi] - knotsX[lo]);
        return knotsY[lo] + t * (knotsY[hi] - knotsY[lo]);
    }

    /**
     * Score a q-diff range value. Returns the score (0-10).
     * With rescale=none (production), score == raw.
     */
    private double scoreMovement(double qDiffRange) {
        double raw = interpolate(qDiffRange, config.knotsX, config.knotsY);
        raw = clamp010(raw);

        // Cap perfectly stable tests
        if (qDiffRange < config.stableCapThreshold) {
            raw = config.stableCapScore;
        }

        return raw;
    }

    // =========================================================================
    // Window helpers
    // =========================================================================

    private static double durationSeconds(double[] st, int startIdx, int endIdx) {
        if (startIdx >= endIdx || endIdx >= st.length) return 0.0;
        return (st[endIdx] - st[startIdx]) / 1000.0;
    }

    /**
     * Find number of samples spanning targetSeconds from startIdx.
     * Uses binary search (matches np.searchsorted in Python).
     */
    private static Integer findWindowSize(double[] st, double targetSeconds, int startIdx) {
        if (startIdx >= st.length - 1) return null;
        double targetTime = st[startIdx] + targetSeconds * 1000.0;

        // Binary search for first index >= targetTime
        int lo = startIdx, hi = st.length;
        while (lo < hi) {
            int mid = (lo + hi) / 2;
            if (st[mid] < targetTime) {
                lo = mid + 1;
            } else {
                hi = mid;
            }
        }
        if (lo >= st.length) return null;
        return lo - startIdx;
    }

    // =========================================================================
    // Fatigue calculation
    // =========================================================================

    private Map<String, Object> calculateFatigue(Map<String, Map<String, Object>> windows) {
        Map<String, Object> w3 = windows.get("3_second");
        Map<String, Object> w5 = windows.get("5_second");
        Map<String, Object> w7 = windows.get("7_second");

        Double s3 = w3 != null ? ((Number) w3.get("score")).doubleValue() : null;
        Double s5 = w5 != null ? ((Number) w5.get("score")).doubleValue() : null;
        Double s7 = w7 != null ? ((Number) w7.get("score")).doubleValue() : null;

        Map<String, Object> result = new HashMap<>();

        if (s3 == null || s7 == null) {
            result.put("pattern", "insufficient");
            result.put("flagged", false);
            return result;
        }

        double shortTerm = (s5 != null) ? Math.max(s3, s5) : s3;
        double drop = round2(shortTerm - s7);
        Double drop5to7 = (s5 != null) ? round2(s5 - s7) : null;

        String pattern;
        if (shortTerm >= config.fatiguesMinShortTerm && drop > config.fatiguesMinDrop) {
            pattern = "fatigues";
        } else if (shortTerm >= config.consistentMinShortTerm && drop <= config.consistentMaxDrop) {
            pattern = "consistent";
        } else if (shortTerm < config.unstableMaxShortTerm && drop <= config.unstableMaxDrop) {
            pattern = "unstable";
        } else {
            pattern = "declining";
        }

        result.put("short_term_score", round2(shortTerm));
        result.put("drop_3to7", drop);
        result.put("drop_5to7", drop5to7);
        result.put("pattern", pattern);
        result.put("flagged", "fatigues".equals(pattern));

        return result;
    }

    // =========================================================================
    // Math utilities
    // =========================================================================

    private static double clamp010(double v) {
        return Math.max(0.0, Math.min(10.0, v));
    }

    private static double arrayMin(double[] arr) {
        double m = arr[0];
        for (int i = 1; i < arr.length; i++) {
            if (arr[i] < m) m = arr[i];
        }
        return m;
    }

    /**
     * Percentile calculation matching numpy's default (linear interpolation).
     */
    private static double percentile(double[] data, int pct) {
        double[] sorted = data.clone();
        Arrays.sort(sorted);
        if (sorted.length == 1) return sorted[0];

        double rank = (pct / 100.0) * (sorted.length - 1);
        int lower = (int) Math.floor(rank);
        int upper = (int) Math.ceil(rank);

        if (lower == upper) return sorted[lower];

        double fraction = rank - lower;
        return sorted[lower] + fraction * (sorted[upper] - sorted[lower]);
    }

    private static boolean containsInt(int[] arr, int value) {
        for (int v : arr) {
            if (v == value) return true;
        }
        return false;
    }

    private static double round1(double v) {
        return Math.round(v * 10.0) / 10.0;
    }

    private static double round2(double v) {
        return Math.round(v * 100.0) / 100.0;
    }

    private static double round6(double v) {
        return Math.round(v * 1000000.0) / 1000000.0;
    }

    /**
     * Convert results to JSON string for storage/upload.
     */
    public static String resultsToJson(Map<String, Object> results) {
        try {
            JSONObject json = mapToJson(results);
            return json.toString();
        } catch (Exception e) {
            Log.e(TAG, "Error converting results to JSON: " + e.getMessage());
            return "{}";
        }
    }

    @SuppressWarnings("unchecked")
    private static JSONObject mapToJson(Map<String, Object> map) throws Exception {
        JSONObject json = new JSONObject();
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            Object val = entry.getValue();
            if (val == null) {
                json.put(entry.getKey(), JSONObject.NULL);
            } else if (val instanceof Map) {
                json.put(entry.getKey(), mapToJson((Map<String, Object>) val));
            } else {
                json.put(entry.getKey(), val);
            }
        }
        return json;
    }
}
