chore: initial commit — unified project repo

Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-14 11:28:42 +08:00
commit bd1f51c496
85 changed files with 20568 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
"""
Unit tests for category-aware intervention metrics (v5).
Run from project root:
python -m pytest tests/test_intervention_metrics.py -v
"""
import sys
import os
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.utils.metrics import intervention_metrics
# Category index constants (matches PRIMARY_CATEGORY_LIST)
_R1 = 0 # Self-harm & Suicide Risk
_R2 = 1
_R3 = 2
_R6 = 5 # Violence
_R7 = 6 # Minor Inappropriate Intimacy
_R10 = 9 # Persona Immersion Failure
ACTION_PASS = 0
ACTION_WARN = 1
ACTION_REWRITE = 2
ACTION_REJECT = 3
ACTION_CRISIS = 4
# ── Basic safety metrics ──────────────────────────────────────────────────────
def test_safety_recall_all_correct():
"""All L3/L4 samples get REWRITE+ → safety_recall = 1.0."""
y_risk = [1, 1, 1]
l_risk = [3, 4, 3]
a_pred = [ACTION_REWRITE, ACTION_CRISIS, ACTION_REJECT]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert m["safety_recall"] == 1.0
def test_safety_recall_partial():
"""One of two L3/L4 samples is PASS → recall = 0.5."""
y_risk = [1, 1]
l_risk = [3, 4]
a_pred = [ACTION_PASS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert m["safety_recall"] == 0.5
def test_over_refusal_zero():
"""No L0 samples get REWRITE+ → over_refusal = 0.0."""
y_risk = [0, 0]
l_risk = [0, 0]
a_pred = [ACTION_PASS, ACTION_WARN]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert m["over_refusal"] == 0.0
def test_over_refusal_nonzero():
"""One L0 sample gets REWRITE → over_refusal = 0.5."""
y_risk = [0, 0]
l_risk = [0, 0]
a_pred = [ACTION_PASS, ACTION_REWRITE]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert abs(m["over_refusal"] - 0.5) < 1e-6
# ── Action accuracy ───────────────────────────────────────────────────────────
def test_action_accuracy_perfect():
y_risk = [0, 1, 1]
l_risk = [0, 2, 3]
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE]
a_recommend = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE]
m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend)
assert m["action_accuracy"] == 1.0
def test_action_accuracy_none():
y_risk = [0, 1]
l_risk = [0, 3]
a_pred = [ACTION_WARN, ACTION_PASS]
a_recommend = [ACTION_PASS, ACTION_REWRITE]
m = intervention_metrics(y_risk, l_risk, a_pred, a_recommend)
assert m["action_accuracy"] == 0.0
# ── crisis_precision_l4 ───────────────────────────────────────────────────────
def test_crisis_precision_l4_perfect():
"""All CRISIS actions are on L4 → crisis_precision_l4 = 1.0."""
y_risk = [1, 1]
l_risk = [4, 4]
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert m["crisis_precision_l4"] == 1.0
def test_crisis_precision_l4_imperfect():
"""One CRISIS on L3, one on L4 → precision = 0.5."""
y_risk = [1, 1]
l_risk = [3, 4]
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert m["crisis_precision_l4"] == 0.5
def test_crisis_precision_l4_no_crisis_is_nan():
"""No CRISIS actions → crisis_precision_l4 should be nan."""
y_risk = [1, 1]
l_risk = [3, 4]
a_pred = [ACTION_REWRITE, ACTION_REJECT]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert math.isnan(m["crisis_precision_l4"])
# ── Category-aware metrics ────────────────────────────────────────────────────
def test_crisis_appropriateness_l4():
"""CRISIS on L4 should be appropriate."""
c_primary = [_R2, _R2]
y_risk = [1, 1]
l_risk = [4, 3]
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
# L4 is appropriate; L3 non-R1 is not
assert m["crisis_appropriateness"] == 0.5
def test_crisis_appropriateness_r1_l3():
"""CRISIS on R1+L3 should be appropriate."""
c_primary = [_R1, _R2]
y_risk = [1, 1]
l_risk = [3, 3]
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
# R1+L3 is appropriate; R2+L3 is not
assert m["crisis_appropriateness"] == 0.5
def test_r1_high_critical_crisis_recall():
"""R1+L3/L4: one chooses CRISIS, one does not → recall = 0.5."""
c_primary = [_R1, _R1, _R2]
y_risk = [1, 1, 1]
l_risk = [3, 4, 4]
a_pred = [ACTION_CRISIS, ACTION_REWRITE, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
# R1+L3 → CRISIS ✓; R1+L4 → REWRITE ✗; R2+L4 is not R1 so excluded
assert m["R1_high_critical_crisis_recall"] == 0.5
def test_reject_rate_r6r7_high():
"""R6+L3: only the REJECT action counts."""
c_primary = [_R6, _R7, _R6]
y_risk = [1, 1, 1]
l_risk = [3, 4, 3]
a_pred = [ACTION_REJECT, ACTION_REWRITE, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
# 1 REJECT out of 3 R6/R7 with L3/L4
assert abs(m["reject_rate_R6_R7_high"] - 1.0/3.0) < 1e-6
def test_strong_intervention_rate_l1():
"""L1: 2 out of 4 get REWRITE+ → rate = 0.5."""
c_primary = [_R2, _R2, _R2, _R2]
y_risk = [1, 1, 0, 1]
l_risk = [1, 1, 1, 1]
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
assert m["strong_intervention_rate_L1"] == 0.5
def test_per_category_action_dist_present():
"""per_category_action_dist should be computed when c_primary_idx is provided."""
c_primary = [_R1, _R1, _R1, _R1, _R1] # 5 R1 samples
y_risk = [1, 1, 1, 1, 1]
l_risk = [4, 4, 4, 4, 4]
a_pred = [ACTION_CRISIS] * 5
m = intervention_metrics(y_risk, l_risk, a_pred, c_primary_idx=c_primary)
assert "per_category_action_dist" in m
assert "R1" in m["per_category_action_dist"]
dist = m["per_category_action_dist"]["R1"]["action_dist"]
assert abs(dist[ACTION_CRISIS] - 1.0) < 1e-6
def test_no_category_idx_skips_cat_metrics():
"""Without c_primary_idx, category-aware metrics should be absent."""
y_risk = [1, 1]
l_risk = [3, 4]
a_pred = [ACTION_CRISIS, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert "R1_high_critical_crisis_recall" not in m
assert "crisis_appropriateness" not in m
assert "reject_rate_R6_R7_high" not in m
# ── Per-level action distribution ────────────────────────────────────────────
def test_per_level_action_dist_keys():
y_risk = [0, 1, 1, 1, 1]
l_risk = [0, 1, 2, 3, 4]
a_pred = [ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS]
m = intervention_metrics(y_risk, l_risk, a_pred)
assert "L0_Safe" in m["per_level_action_dist"]
assert "L4_Critical" in m["per_level_action_dist"]

View File

@@ -0,0 +1,168 @@
"""
Unit tests for v5 reward function (label-aligned constrained reward).
Run from project root:
python -m pytest tests/test_reward_v5.py -v
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.rl.reward import (
compute_reward,
EXACT_ALIGN_BONUS, ADJACENT_ALIGN_BONUS,
PASS_HIGH_PENALTY, PASS_MID_PENALTY, OVER_REFUSAL_PENALTY,
CRISIS_R1_BONUS, REJECT_R6R7_BONUS, REWRITE_COMPANION_BONUS,
CRISIS_MISUSE_PENALTY,
ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS,
_R1_IDX, _R3_IDX, _R4_IDX, _R6_IDX, _R7_IDX, _R10_IDX,
)
# ── Primary alignment signal ──────────────────────────────────────────────────
def test_exact_alignment_gives_bonus():
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS)
assert r == EXACT_ALIGN_BONUS
def test_adjacent_alignment_gives_partial_bonus():
# a_recommend=REJECT(3), action=REWRITE(2) → adjacent
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT)
# Contains adjacent bonus but also safety checks (REWRITE on L3 is not PASS, so no PASS penalty)
assert ADJACENT_ALIGN_BONUS in [r - x for x in [0.0, REWRITE_COMPANION_BONUS]]
# More precisely: reward = adjacent_bonus (no safety penalty for REWRITE on L3)
r2 = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT)
assert r2 >= ADJACENT_ALIGN_BONUS
def test_wrong_action_gets_no_alignment_bonus():
# action=WARN, a_recommend=CRISIS → |4-1|=3, not adjacent
r = compute_reward(ACTION_WARN, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_CRISIS)
assert r < ADJACENT_ALIGN_BONUS
# ── Safety hard constraints ───────────────────────────────────────────────────
def test_pass_on_l3_triggers_penalty():
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None)
assert r == PASS_HIGH_PENALTY
def test_pass_on_l4_triggers_penalty():
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None)
assert r == PASS_HIGH_PENALTY
def test_pass_on_l2_triggers_mid_penalty():
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=2, c_primary_idx=1, a_recommend=None)
assert r == PASS_MID_PENALTY
def test_pass_on_l1_no_penalty():
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=1, c_primary_idx=1, a_recommend=None)
assert r == 0.0
def test_rewrite_on_l0_triggers_over_refusal_penalty():
r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
assert r == OVER_REFUSAL_PENALTY
def test_reject_on_l0_triggers_over_refusal_penalty():
r = compute_reward(ACTION_REJECT, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
assert r == OVER_REFUSAL_PENALTY
def test_crisis_on_l0_triggers_over_refusal_and_misuse():
r = compute_reward(ACTION_CRISIS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
# L0 + CRISIS: over_refusal + crisis_misuse (not R1, not L4)
assert r == OVER_REFUSAL_PENALTY + CRISIS_MISUSE_PENALTY
# ── Category-specific bonuses ─────────────────────────────────────────────────
def test_crisis_on_r1_l3_gives_bonus():
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=_R1_IDX, a_recommend=None)
assert r == CRISIS_R1_BONUS # no misuse penalty because it IS R1
def test_crisis_on_r1_l4_gives_bonus():
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=None)
# R1 + L4: R1 bonus; l_risk==4 so no misuse penalty
assert r == CRISIS_R1_BONUS
def test_crisis_on_non_r1_l3_triggers_misuse_penalty():
# c_primary_idx=1 (R2), l_risk=3 (not L4) → CRISIS misuse
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None)
assert r == CRISIS_MISUSE_PENALTY
def test_crisis_on_l4_non_r1_no_misuse_penalty():
# l_risk==4: misuse condition (not R1 AND l_risk<4) is false → no penalty
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None)
assert r == 0.0
def test_reject_on_r6_l3_gives_bonus():
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=3, c_primary_idx=_R6_IDX, a_recommend=None)
assert r == REJECT_R6R7_BONUS
def test_reject_on_r7_l4_gives_bonus():
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=4, c_primary_idx=_R7_IDX, a_recommend=None)
assert r == REJECT_R6R7_BONUS
def test_reject_on_r6_l2_no_bonus():
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=2, c_primary_idx=_R6_IDX, a_recommend=None)
assert r == 0.0 # l_risk < 3, no bonus
def test_rewrite_on_r3_l2_gives_companion_bonus():
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R3_IDX, a_recommend=None)
assert r == REWRITE_COMPANION_BONUS
def test_rewrite_on_r4_l3_gives_companion_bonus():
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=_R4_IDX, a_recommend=None)
assert r == REWRITE_COMPANION_BONUS
def test_rewrite_on_r10_l2_gives_companion_bonus():
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R10_IDX, a_recommend=None)
assert r == REWRITE_COMPANION_BONUS
def test_rewrite_on_r3_l1_no_companion_bonus():
r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=1, c_primary_idx=_R3_IDX, a_recommend=None)
assert r == 0.0 # l_risk < 2, no bonus
# ── Combined scenarios ────────────────────────────────────────────────────────
def test_best_action_l0_safe():
"""L0 safe: PASS + exact alignment should give max reward."""
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS)
assert r == EXACT_ALIGN_BONUS
def test_best_action_l4_r1_crisis():
"""L4 + R1: CRISIS should give alignment + R1 bonus."""
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=ACTION_CRISIS)
assert r == EXACT_ALIGN_BONUS + CRISIS_R1_BONUS
def test_worst_action_l3_pass():
"""L3 high: PASS is the worst action — heavy penalty."""
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REWRITE)
# PASS on L3: heavy penalty; also wrong action vs a_recommend=REWRITE
assert r == PASS_HIGH_PENALTY
def test_no_a_recommend_still_works():
"""compute_reward should not crash without a_recommend."""
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1)
assert isinstance(r, float)