Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
206 lines
7.3 KiB
Python
206 lines
7.3 KiB
Python
"""
|
|
Unit tests for category-aware intervention metrics (v5).
|
|
|
|
Run from project root:
|
|
python -m pytest tests/test_intervention_metrics.py -v
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import math
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from src.utils.metrics import intervention_metrics
|
|
|
|
# Category index constants (matches PRIMARY_CATEGORY_LIST)
|
|
_R1 = 0 # Self-harm & Suicide Risk
|
|
_R2 = 1
|
|
_R3 = 2
|
|
_R6 = 5 # Violence
|
|
_R7 = 6 # Minor Inappropriate Intimacy
|
|
_R10 = 9 # Persona Immersion Failure
|
|
|
|
ACTION_PASS = 0
|
|
ACTION_WARN = 1
|
|
ACTION_REWRITE = 2
|
|
ACTION_REJECT = 3
|
|
ACTION_CRISIS = 4
|
|
|
|
|
|
# ── Basic safety metrics ──────────────────────────────────────────────────────
|
|
|
|
def test_safety_recall_all_correct():
|
|
"""All L3/L4 samples get REWRITE+ → safety_recall = 1.0."""
|
|
y_risk = [1, 1, 1]
|
|
l_risk = [3, 4, 3]
|
|
a_pred = [ACTION_REWRITE, ACTION_CRISIS, ACTION_REJECT]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert m["safety_recall"] == 1.0
|
|
|
|
|
|
def test_safety_recall_partial():
|
|
"""One of two L3/L4 samples is PASS → recall = 0.5."""
|
|
y_risk = [1, 1]
|
|
l_risk = [3, 4]
|
|
a_pred = [ACTION_PASS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert m["safety_recall"] == 0.5
|
|
|
|
|
|
def test_over_refusal_zero():
|
|
"""No L0 samples get REWRITE+ → over_refusal = 0.0."""
|
|
y_risk = [0, 0]
|
|
l_risk = [0, 0]
|
|
a_pred = [ACTION_PASS, ACTION_WARN]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert m["over_refusal"] == 0.0
|
|
|
|
|
|
def test_over_refusal_nonzero():
|
|
"""One L0 sample gets REWRITE → over_refusal = 0.5."""
|
|
y_risk = [0, 0]
|
|
l_risk = [0, 0]
|
|
a_pred = [ACTION_PASS, ACTION_REWRITE]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert abs(m["over_refusal"] - 0.5) < 1e-6
|
|
|
|
|
|
# ── Action accuracy ───────────────────────────────────────────────────────────
|
|
|
|
def test_action_accuracy_perfect():
|
|
y_risk = [0, 1, 1]
|
|
l_risk = [0, 2, 3]
|
|
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE]
|
|
a_recommend = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend)
|
|
assert m["action_accuracy"] == 1.0
|
|
|
|
|
|
def test_action_accuracy_none():
|
|
y_risk = [0, 1]
|
|
l_risk = [0, 3]
|
|
a_pred = [ACTION_WARN, ACTION_PASS]
|
|
a_recommend = [ACTION_PASS, ACTION_REWRITE]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend)
|
|
assert m["action_accuracy"] == 0.0
|
|
|
|
|
|
# ── crisis_precision_l4 ───────────────────────────────────────────────────────
|
|
|
|
def test_crisis_precision_l4_perfect():
|
|
"""All CRISIS actions are on L4 → crisis_precision_l4 = 1.0."""
|
|
y_risk = [1, 1]
|
|
l_risk = [4, 4]
|
|
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert m["crisis_precision_l4"] == 1.0
|
|
|
|
|
|
def test_crisis_precision_l4_imperfect():
|
|
"""One CRISIS on L3, one on L4 → precision = 0.5."""
|
|
y_risk = [1, 1]
|
|
l_risk = [3, 4]
|
|
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert m["crisis_precision_l4"] == 0.5
|
|
|
|
|
|
def test_crisis_precision_l4_no_crisis_is_nan():
|
|
"""No CRISIS actions → crisis_precision_l4 should be nan."""
|
|
y_risk = [1, 1]
|
|
l_risk = [3, 4]
|
|
a_pred = [ACTION_REWRITE, ACTION_REJECT]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert math.isnan(m["crisis_precision_l4"])
|
|
|
|
|
|
# ── Category-aware metrics ────────────────────────────────────────────────────
|
|
|
|
def test_crisis_appropriateness_l4():
|
|
"""CRISIS on L4 should be appropriate."""
|
|
c_primary = [_R2, _R2]
|
|
y_risk = [1, 1]
|
|
l_risk = [4, 3]
|
|
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
# L4 is appropriate; L3 non-R1 is not
|
|
assert m["crisis_appropriateness"] == 0.5
|
|
|
|
|
|
def test_crisis_appropriateness_r1_l3():
|
|
"""CRISIS on R1+L3 should be appropriate."""
|
|
c_primary = [_R1, _R2]
|
|
y_risk = [1, 1]
|
|
l_risk = [3, 3]
|
|
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
# R1+L3 is appropriate; R2+L3 is not
|
|
assert m["crisis_appropriateness"] == 0.5
|
|
|
|
|
|
def test_r1_high_critical_crisis_recall():
|
|
"""R1+L3/L4: one chooses CRISIS, one does not → recall = 0.5."""
|
|
c_primary = [_R1, _R1, _R2]
|
|
y_risk = [1, 1, 1]
|
|
l_risk = [3, 4, 4]
|
|
a_pred = [ACTION_CRISIS, ACTION_REWRITE, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
# R1+L3 → CRISIS ✓; R1+L4 → REWRITE ✗; R2+L4 is not R1 so excluded
|
|
assert m["R1_high_critical_crisis_recall"] == 0.5
|
|
|
|
|
|
def test_reject_rate_r6r7_high():
|
|
"""R6+L3: only the REJECT action counts."""
|
|
c_primary = [_R6, _R7, _R6]
|
|
y_risk = [1, 1, 1]
|
|
l_risk = [3, 4, 3]
|
|
a_pred = [ACTION_REJECT, ACTION_REWRITE, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
# 1 REJECT out of 3 R6/R7 with L3/L4
|
|
assert abs(m["reject_rate_R6_R7_high"] - 1.0/3.0) < 1e-6
|
|
|
|
|
|
def test_strong_intervention_rate_l1():
|
|
"""L1: 2 out of 4 get REWRITE+ → rate = 0.5."""
|
|
c_primary = [_R2, _R2, _R2, _R2]
|
|
y_risk = [1, 1, 0, 1]
|
|
l_risk = [1, 1, 1, 1]
|
|
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
assert m["strong_intervention_rate_L1"] == 0.5
|
|
|
|
|
|
def test_per_category_action_dist_present():
|
|
"""per_category_action_dist should be computed when c_primary_idx is provided."""
|
|
c_primary = [_R1, _R1, _R1, _R1, _R1] # 5 R1 samples
|
|
y_risk = [1, 1, 1, 1, 1]
|
|
l_risk = [4, 4, 4, 4, 4]
|
|
a_pred = [ACTION_CRISIS] * 5
|
|
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
|
|
assert "per_category_action_dist" in m
|
|
assert "R1" in m["per_category_action_dist"]
|
|
dist = m["per_category_action_dist"]["R1"]["action_dist"]
|
|
assert abs(dist[ACTION_CRISIS] - 1.0) < 1e-6
|
|
|
|
|
|
def test_no_category_idx_skips_cat_metrics():
|
|
"""Without c_primary_idx, category-aware metrics should be absent."""
|
|
y_risk = [1, 1]
|
|
l_risk = [3, 4]
|
|
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert "R1_high_critical_crisis_recall" not in m
|
|
assert "crisis_appropriateness" not in m
|
|
assert "reject_rate_R6_R7_high" not in m
|
|
|
|
|
|
# ── Per-level action distribution ────────────────────────────────────────────
|
|
|
|
def test_per_level_action_dist_keys():
|
|
y_risk = [0, 1, 1, 1, 1]
|
|
l_risk = [0, 1, 2, 3, 4]
|
|
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS]
|
|
m = intervention_metrics(y_risk, l_risk, a_pred)
|
|
assert "L0_Safe" in m["per_level_action_dist"]
|
|
assert "L4_Critical" in m["per_level_action_dist"]
|