#!/usr/bin/env python3
"""
py-analysis.py - Deterministic Balance Test Clinical Analysis

Generates comprehensive clinical analysis from test data WITHOUT using AI API.
Produces the same JSON structure as analyze_patient_json_ai_single.py:
  summary, risk_level, trend, icd10_codes, interventions

Called by back_py-analysis.php with the same input format.
"""

import json
import sys
import argparse
import math
from datetime import datetime

# Test type labels
TYPE_LABELS = {
    'lbeo': 'Left Leg Eyes Open',
    'rbeo': 'Right Leg Eyes Open',
    'lbec': 'Left Leg Eyes Closed',
    'rbec': 'Right Leg Eyes Closed',
}
TYPE_SHORT = {
    'lbeo': 'LBEO',
    'rbeo': 'RBEO',
    'lbec': 'LBEC',
    'rbec': 'RBEC',
}


def risk_level_from_score(score):
    """Score-to-risk mapping matching scoring_config.json thresholds."""
    if score is None:
        return 'unknown'
    if score < 4:
        return 'high'
    elif score < 6:
        return 'high'
    elif score < 8:
        return 'moderate'
    else:
        return 'low'


def score_descriptor(score):
    """Human-readable descriptor for a score."""
    if score is None:
        return 'not tested'
    if score >= 9:
        return 'excellent'
    elif score >= 7:
        return 'good'
    elif score >= 5:
        return 'fair'
    elif score >= 3:
        return 'poor'
    else:
        return 'very poor'


def format_score(score):
    """Format score as integer if whole, else one decimal."""
    if score is None:
        return 'N/A'
    if score == int(score):
        return str(int(score))
    return f"{score:.1f}"


def compute_trend(current_score, previous_tests):
    """Determine trend from previous test scores."""
    if not previous_tests or current_score is None:
        return 'insufficient_data'

    # Get previous scores that are numeric
    prev_scores = []
    for t in previous_tests:
        s = t.get('score')
        if s is not None and isinstance(s, (int, float)):
            prev_scores.append(float(s))

    if not prev_scores:
        return 'insufficient_data'

    avg_prev = sum(prev_scores) / len(prev_scores)
    diff = current_score - avg_prev

    if diff > 1.5:
        return 'improving'
    elif diff < -1.5:
        return 'declining'
    else:
        return 'stable'


def analyze_type(data):
    """Extract analysis details from a single test type's results_json entry."""
    if data is None:
        return None

    info = {
        'score': data.get('movement_score'),
        'test_duration': data.get('test_duration'),
        'max_stable_duration': data.get('max_stable_duration'),
        'scoring_mode': data.get('scoring_mode', 'qdiff'),
        'duration_penalty': data.get('duration_penalty', 0),
        'duration_penalty_reason': data.get('duration_penalty_reason'),
        'seven_second_bonus': data.get('seven_second_bonus', 0),
    }

    # Windows
    windows = data.get('windows', {})
    info['windows'] = {}
    for wkey in ['3_second', '5_second', '7_second']:
        w = windows.get(wkey)
        if w is not None:
            info['windows'][wkey] = {
                'score': w.get('score'),
                'q_diff_range': w.get('q_diff_range'),
                'actual_duration': w.get('actual_duration'),
            }

    # Stability
    stability = data.get('stability', {})
    info['stability'] = {
        'percent_stable': stability.get('percent_stable'),
        'total_stable_time': stability.get('total_stable_time'),
        'continuous_duration': stability.get('continuous_duration'),
    }

    # Fatigue
    fatigue = data.get('fatigue', {})
    info['fatigue'] = {
        'flagged': fatigue.get('flagged', False),
        'pattern': fatigue.get('pattern', 'unknown'),
        'drop_3to7': fatigue.get('drop_3to7'),
        'short_term_score': fatigue.get('short_term_score'),
    }

    return info


def build_analysis(test_data):
    """Build the complete deterministic analysis."""
    results_str = test_data.get('results_json', '{}')
    if isinstance(results_str, str):
        results = json.loads(results_str)
    else:
        results = results_str

    patient_name = test_data.get('patient_name', 'Patient')
    age = test_data.get('age', 'unknown')
    test_date = test_data.get('test_date', 'unknown')
    complaints_str = test_data.get('complaints', 'None reported')
    previous_tests = test_data.get('previous_tests', [])

    # Analyze each type
    types_data = {}
    for ttype in ['lbeo', 'rbeo', 'lbec', 'rbec']:
        types_data[ttype] = analyze_type(results.get(ttype))

    # Determine which types were tested
    tested_types = [t for t in ['lbeo', 'rbeo', 'lbec', 'rbec'] if types_data[t] is not None]
    untested_types = [t for t in ['lbeo', 'rbeo', 'lbec', 'rbec'] if types_data[t] is None]

    if not tested_types:
        return {
            'summary': f'{age}-year-old patient with no scorable balance test data available.',
            'risk_level': 'unknown',
            'trend': 'insufficient_data',
            'icd10_codes': [{'code': 'R26.89', 'description': 'Other abnormalities of gait and mobility'}],
            'interventions': ['Repeat balance assessment with adequate test duration'],
        }

    # Collect scores
    scores = {}
    for t in tested_types:
        scores[t] = types_data[t]['score']

    valid_scores = [s for s in scores.values() if s is not None]
    worst_score = min(valid_scores) if valid_scores else None
    best_score = max(valid_scores) if valid_scores else None
    worst_type = [t for t in tested_types if scores.get(t) == worst_score][0] if worst_score is not None else None
    overall_score = worst_score  # Overall = worst of all types

    # --- Build summary paragraphs ---
    summary_parts = []

    # Opening: patient demographics + overall result
    gender_hint = ''  # We don't have gender in the data, keep neutral
    score_text = format_score(overall_score)
    descriptor = score_descriptor(overall_score)
    conditions_tested = ', '.join(TYPE_SHORT[t] for t in tested_types)
    conditions_missing = ', '.join(TYPE_SHORT[t] for t in untested_types) if untested_types else None

    opening = f"{age}-year-old patient tested on {conditions_tested}"
    if len(tested_types) == 1:
        t = tested_types[0]
        opening += f" with {descriptor} balance ({format_score(scores[t])}/10)."
    else:
        opening += f" with overall {descriptor} balance (worst score {score_text}/10 on {TYPE_SHORT[worst_type]})."

    if conditions_missing:
        opening += f" {conditions_missing} not tested."

    summary_parts.append(opening)

    # Per-type details
    type_details = []
    for t in tested_types:
        td = types_data[t]
        s = format_score(td['score'])
        detail = f"{TYPE_SHORT[t]}: {s}/10"

        # Window degradation
        w = td['windows']
        window_scores = []
        for wk in ['3_second', '5_second', '7_second']:
            if wk in w:
                window_scores.append((wk.replace('_second', 's'), w[wk]['score']))

        if len(window_scores) >= 2:
            first = window_scores[0]
            last = window_scores[-1]
            if first[1] is not None and last[1] is not None and first[1] - last[1] > 2:
                detail += f" (degrades from {format_score(first[1])} at {first[0]} to {format_score(last[1])} at {last[0]})"

        # Stability
        stab = td['stability']
        pct = stab.get('percent_stable')
        if pct is not None:
            detail += f", {pct:.0f}% stable"

        # Max duration
        max_dur = td.get('max_stable_duration')
        if max_dur is not None:
            detail += f", max stable {max_dur:.1f}s"

        # Duration penalty
        dp = td.get('duration_penalty', 0)
        if dp and dp > 0:
            detail += f" ({dp}-point short-test penalty applied)"

        # 7s bonus
        bonus = td.get('seven_second_bonus', 0)
        if bonus and bonus > 0:
            detail += f" (7s bonus: score capped at {bonus})"

        type_details.append(detail)

    if type_details:
        summary_parts.append(' '.join([d + '.' for d in type_details]))

    # Fatigue analysis
    fatigue_flags = []
    for t in tested_types:
        fat = types_data[t]['fatigue']
        if fat['flagged']:
            pattern = fat['pattern']
            drop = fat.get('drop_3to7')
            drop_str = f" ({format_score(drop)}-point drop)" if drop is not None else ""
            fatigue_flags.append(f"{TYPE_SHORT[t]} shows {pattern} pattern{drop_str}")

    if fatigue_flags:
        summary_parts.append('Fatigue: ' + '; '.join(fatigue_flags) + '.')

    # Asymmetry analysis (left vs right)
    lr_pairs = [('lbeo', 'rbeo'), ('lbec', 'rbec')]
    for left, right in lr_pairs:
        if types_data[left] and types_data[right]:
            ls = scores[left]
            rs = scores[right]
            if ls is not None and rs is not None:
                diff = abs(ls - rs)
                if diff >= 2:
                    better = TYPE_SHORT[left] if ls > rs else TYPE_SHORT[right]
                    worse = TYPE_SHORT[right] if ls > rs else TYPE_SHORT[left]
                    condition = 'Eyes Open' if left == 'lbeo' else 'Eyes Closed'
                    summary_parts.append(
                        f"Significant left/right asymmetry ({condition}): "
                        f"{better} ({format_score(max(ls, rs))}) vs {worse} ({format_score(min(ls, rs))}), "
                        f"difference of {format_score(diff)} points."
                    )

                # Stability asymmetry
                ls_pct = types_data[left]['stability'].get('percent_stable')
                rs_pct = types_data[right]['stability'].get('percent_stable')
                if ls_pct is not None and rs_pct is not None and abs(ls_pct - rs_pct) >= 20:
                    condition = 'Eyes Open' if left == 'lbeo' else 'Eyes Closed'
                    summary_parts.append(
                        f"Stability asymmetry ({condition}): "
                        f"{TYPE_SHORT[left]} {ls_pct:.0f}% vs {TYPE_SHORT[right]} {rs_pct:.0f}% stable."
                    )

    # Eyes open vs closed comparison
    eo_ec_pairs = [('lbeo', 'lbec', 'Left'), ('rbeo', 'rbec', 'Right')]
    for eo, ec, side in eo_ec_pairs:
        if types_data[eo] and types_data[ec]:
            eo_s = scores[eo]
            ec_s = scores[ec]
            if eo_s is not None and ec_s is not None:
                diff = eo_s - ec_s
                if diff >= 2:
                    summary_parts.append(
                        f"{side} leg: eyes closed ({format_score(ec_s)}) significantly worse than eyes open "
                        f"({format_score(eo_s)}), suggesting proprioceptive deficit."
                    )

    # Complaints
    if complaints_str and complaints_str.lower() != 'none reported':
        complaints_list = [c.strip() for c in complaints_str.split(',')]
        symptom_concern = []
        if 'dizziness' in complaints_str.lower():
            symptom_concern.append('dizziness')
        if 'numbness' in complaints_str.lower():
            symptom_concern.append('numbness')
        if 'requires assistance' in complaints_str.lower():
            symptom_concern.append('requires assistance for mobility')
        if symptom_concern:
            summary_parts.append(
                f"Patient reports {", ".join(symptom_concern)}, warranting clinical attention."
            )

    # Trend
    trend = compute_trend(overall_score, previous_tests)
    if trend == 'improving' and previous_tests:
        prev_avg = sum(t['score'] for t in previous_tests if t.get('score') is not None) / max(1, len([t for t in previous_tests if t.get('score') is not None]))
        summary_parts.append(f"Trend: improving from previous average of {format_score(prev_avg)}.")
    elif trend == 'declining' and previous_tests:
        prev_avg = sum(t['score'] for t in previous_tests if t.get('score') is not None) / max(1, len([t for t in previous_tests if t.get('score') is not None]))
        summary_parts.append(f"Trend: declining from previous average of {format_score(prev_avg)}.")
    elif trend == 'stable' and previous_tests:
        summary_parts.append("Trend: stable compared to previous assessments.")

    # --- Risk Level ---
    risk = risk_level_from_score(overall_score)

    # Elevate risk if complaints + score mismatch
    if complaints_str and complaints_str.lower() != 'none reported':
        if 'dizziness' in complaints_str.lower() or 'numbness' in complaints_str.lower():
            if risk == 'low':
                risk = 'moderate'
        if 'requires assistance' in complaints_str.lower():
            if risk in ('low', 'moderate'):
                risk = 'high'

    # Elevate risk if fatigue pattern detected
    if any(types_data[t]['fatigue']['flagged'] for t in tested_types):
        if risk == 'low':
            risk = 'moderate'

    # --- ICD-10 Codes ---
    icd10 = []

    if overall_score is not None and overall_score < 7:
        icd10.append({'code': 'R26.81', 'description': 'Unsteadiness on feet'})

    if any(types_data[t]['fatigue']['flagged'] for t in tested_types):
        icd10.append({'code': 'R26.89', 'description': 'Other abnormalities of gait and mobility'})

    if risk in ('high',):
        icd10.append({'code': 'R29.6', 'description': 'Repeated falls'})

    # Asymmetry
    has_asymmetry = False
    for left, right in lr_pairs:
        if types_data[left] and types_data[right]:
            ls, rs = scores.get(left), scores.get(right)
            if ls is not None and rs is not None and abs(ls - rs) >= 2:
                has_asymmetry = True
    if has_asymmetry:
        icd10.append({'code': 'R26.2', 'description': 'Difficulty in walking, not elsewhere classified'})

    # Eyes closed deficit (proprioceptive)
    has_proprioceptive = False
    for eo, ec, side in eo_ec_pairs:
        if types_data[eo] and types_data[ec]:
            if scores.get(eo) is not None and scores.get(ec) is not None:
                if scores[eo] - scores[ec] >= 2:
                    has_proprioceptive = True
    if has_proprioceptive:
        icd10.append({'code': 'R27.8', 'description': 'Other lack of coordination'})

    # Dizziness complaint
    if complaints_str and 'dizziness' in complaints_str.lower():
        icd10.append({'code': 'R42', 'description': 'Dizziness and giddiness'})

    # Numbness
    if complaints_str and 'numbness' in complaints_str.lower():
        icd10.append({'code': 'R20.0', 'description': 'Anesthesia of skin (numbness)'})

    # Screening code if everything is fine
    if not icd10:
        icd10.append({'code': 'Z13.89', 'description': 'Encounter for screening for other disorder'})

    # --- Interventions ---
    interventions = []

    if overall_score is not None and overall_score < 4:
        interventions.append('Urgent Physical Therapy referral for balance rehabilitation')
        interventions.append('Fall risk precautions and home safety assessment')
        interventions.append('Medication review for fall-risk contributing factors')
    elif overall_score is not None and overall_score < 7:
        interventions.append('Physical Therapy referral for balance training')
        interventions.append('Core strengthening and proprioceptive exercise program')

    if any(types_data[t]['fatigue']['flagged'] for t in tested_types):
        interventions.append('Endurance training to address fatigue-related balance decline')

    if has_asymmetry:
        weaker = None
        for left, right in lr_pairs:
            if types_data[left] and types_data[right]:
                ls, rs = scores.get(left), scores.get(right)
                if ls is not None and rs is not None:
                    if ls < rs:
                        weaker = 'left'
                    elif rs < ls:
                        weaker = 'right'
        if weaker:
            interventions.append(f'Targeted {weaker}-side strengthening to address asymmetry')

    if has_proprioceptive:
        interventions.append('Proprioceptive training (eyes-closed balance exercises)')

    if complaints_str and 'dizziness' in complaints_str.lower():
        interventions.append('Vestibular assessment to evaluate dizziness')

    if complaints_str and 'numbness' in complaints_str.lower():
        interventions.append('Neurological evaluation for peripheral neuropathy screening')

    if trend == 'declining':
        interventions.append('Increased monitoring frequency due to declining trend')

    # Follow-up timing (patients are tested monthly)
    if overall_score is not None:
        if overall_score < 4:
            interventions.append('Follow-up balance assessment in 2-4 weeks')
        elif overall_score < 7:
            interventions.append('Continue monthly balance assessments to track progress')
        else:
            interventions.append('Continue monthly balance assessments')

    if not interventions:
        interventions.append('Continue current activity level')
        interventions.append('Continue monthly balance assessments')

    # --- Final assembly ---
    summary = ' '.join(summary_parts)

    return {
        'summary': summary,
        'risk_level': risk,
        'trend': trend,
        'icd10_codes': icd10,
        'interventions': interventions,
    }


def main():
    parser = argparse.ArgumentParser(description='Deterministic balance test clinical analysis')
    parser.add_argument('--input', required=True, help='Path to input JSON file')
    args = parser.parse_args()

    try:
        with open(args.input, 'r') as f:
            test_data = json.load(f)

        analysis = build_analysis(test_data)

        result = {
            'success': True,
            'analysis': analysis,
            'model': 'py-analysis',
        }
        print(json.dumps(result))

    except Exception as e:
        error_result = {
            'success': False,
            'error': f'Analysis failed: {str(e)}'
        }
        print(json.dumps(error_result))
        sys.exit(1)


if __name__ == '__main__':
    main()
