""" Unit tests for category-aware intervention metrics (v5). Run from project root: python -m pytest tests/test_intervention_metrics.py -v """ import sys import os import math sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.utils.metrics import intervention_metrics # Category index constants (matches PRIMARY_CATEGORY_LIST) _R1 = 0 # Self-harm & Suicide Risk _R2 = 1 _R3 = 2 _R6 = 5 # Violence _R7 = 6 # Minor Inappropriate Intimacy _R10 = 9 # Persona Immersion Failure ACTION_PASS = 0 ACTION_WARN = 1 ACTION_REWRITE = 2 ACTION_REJECT = 3 ACTION_CRISIS = 4 # ── Basic safety metrics ────────────────────────────────────────────────────── def test_safety_recall_all_correct(): """All L3/L4 samples get REWRITE+ → safety_recall = 1.0.""" y_risk = [1, 1, 1] l_risk = [3, 4, 3] a_pred = [ACTION_REWRITE, ACTION_CRISIS, ACTION_REJECT] m = intervention_metrics(y_risk, l_risk, a_pred) assert m["safety_recall"] == 1.0 def test_safety_recall_partial(): """One of two L3/L4 samples is PASS → recall = 0.5.""" y_risk = [1, 1] l_risk = [3, 4] a_pred = [ACTION_PASS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred) assert m["safety_recall"] == 0.5 def test_over_refusal_zero(): """No L0 samples get REWRITE+ → over_refusal = 0.0.""" y_risk = [0, 0] l_risk = [0, 0] a_pred = [ACTION_PASS, ACTION_WARN] m = intervention_metrics(y_risk, l_risk, a_pred) assert m["over_refusal"] == 0.0 def test_over_refusal_nonzero(): """One L0 sample gets REWRITE → over_refusal = 0.5.""" y_risk = [0, 0] l_risk = [0, 0] a_pred = [ACTION_PASS, ACTION_REWRITE] m = intervention_metrics(y_risk, l_risk, a_pred) assert abs(m["over_refusal"] - 0.5) < 1e-6 # ── Action accuracy ─────────────────────────────────────────────────────────── def test_action_accuracy_perfect(): y_risk = [0, 1, 1] l_risk = [0, 2, 3] a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE] a_recommend = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE] m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend) assert m["action_accuracy"] == 1.0 def test_action_accuracy_none(): y_risk = [0, 1] l_risk = [0, 3] a_pred = [ACTION_WARN, ACTION_PASS] a_recommend = [ACTION_PASS, ACTION_REWRITE] m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend) assert m["action_accuracy"] == 0.0 # ── crisis_precision_l4 ─────────────────────────────────────────────────────── def test_crisis_precision_l4_perfect(): """All CRISIS actions are on L4 → crisis_precision_l4 = 1.0.""" y_risk = [1, 1] l_risk = [4, 4] a_pred = [ACTION_CRISIS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred) assert m["crisis_precision_l4"] == 1.0 def test_crisis_precision_l4_imperfect(): """One CRISIS on L3, one on L4 → precision = 0.5.""" y_risk = [1, 1] l_risk = [3, 4] a_pred = [ACTION_CRISIS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred) assert m["crisis_precision_l4"] == 0.5 def test_crisis_precision_l4_no_crisis_is_nan(): """No CRISIS actions → crisis_precision_l4 should be nan.""" y_risk = [1, 1] l_risk = [3, 4] a_pred = [ACTION_REWRITE, ACTION_REJECT] m = intervention_metrics(y_risk, l_risk, a_pred) assert math.isnan(m["crisis_precision_l4"]) # ── Category-aware metrics ──────────────────────────────────────────────────── def test_crisis_appropriateness_l4(): """CRISIS on L4 should be appropriate.""" c_primary = [_R2, _R2] y_risk = [1, 1] l_risk = [4, 3] a_pred = [ACTION_CRISIS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) # L4 is appropriate; L3 non-R1 is not assert m["crisis_appropriateness"] == 0.5 def test_crisis_appropriateness_r1_l3(): """CRISIS on R1+L3 should be appropriate.""" c_primary = [_R1, _R2] y_risk = [1, 1] l_risk = [3, 3] a_pred = [ACTION_CRISIS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) # R1+L3 is appropriate; R2+L3 is not assert m["crisis_appropriateness"] == 0.5 def test_r1_high_critical_crisis_recall(): """R1+L3/L4: one chooses CRISIS, one does not → recall = 0.5.""" c_primary = [_R1, _R1, _R2] y_risk = [1, 1, 1] l_risk = [3, 4, 4] a_pred = [ACTION_CRISIS, ACTION_REWRITE, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) # R1+L3 → CRISIS ✓; R1+L4 → REWRITE ✗; R2+L4 is not R1 so excluded assert m["R1_high_critical_crisis_recall"] == 0.5 def test_reject_rate_r6r7_high(): """R6+L3: only the REJECT action counts.""" c_primary = [_R6, _R7, _R6] y_risk = [1, 1, 1] l_risk = [3, 4, 3] a_pred = [ACTION_REJECT, ACTION_REWRITE, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) # 1 REJECT out of 3 R6/R7 with L3/L4 assert abs(m["reject_rate_R6_R7_high"] - 1.0/3.0) < 1e-6 def test_strong_intervention_rate_l1(): """L1: 2 out of 4 get REWRITE+ → rate = 0.5.""" c_primary = [_R2, _R2, _R2, _R2] y_risk = [1, 1, 0, 1] l_risk = [1, 1, 1, 1] a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) assert m["strong_intervention_rate_L1"] == 0.5 def test_per_category_action_dist_present(): """per_category_action_dist should be computed when c_primary_idx is provided.""" c_primary = [_R1, _R1, _R1, _R1, _R1] # 5 R1 samples y_risk = [1, 1, 1, 1, 1] l_risk = [4, 4, 4, 4, 4] a_pred = [ACTION_CRISIS] * 5 m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary) assert "per_category_action_dist" in m assert "R1" in m["per_category_action_dist"] dist = m["per_category_action_dist"]["R1"]["action_dist"] assert abs(dist[ACTION_CRISIS] - 1.0) < 1e-6 def test_no_category_idx_skips_cat_metrics(): """Without c_primary_idx, category-aware metrics should be absent.""" y_risk = [1, 1] l_risk = [3, 4] a_pred = [ACTION_CRISIS, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred) assert "R1_high_critical_crisis_recall" not in m assert "crisis_appropriateness" not in m assert "reject_rate_R6_R7_high" not in m # ── Per-level action distribution ──────────────────────────────────────────── def test_per_level_action_dist_keys(): y_risk = [0, 1, 1, 1, 1] l_risk = [0, 1, 2, 3, 4] a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS] m = intervention_metrics(y_risk, l_risk, a_pred) assert "L0_Safe" in m["per_level_action_dist"] assert "L4_Critical" in m["per_level_action_dist"]