chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
168
code/tests/test_reward_v5.py
Normal file
168
code/tests/test_reward_v5.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Unit tests for v5 reward function (label-aligned constrained reward).
|
||||
|
||||
Run from project root:
|
||||
python -m pytest tests/test_reward_v5.py -v
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.rl.reward import (
|
||||
compute_reward,
|
||||
EXACT_ALIGN_BONUS, ADJACENT_ALIGN_BONUS,
|
||||
PASS_HIGH_PENALTY, PASS_MID_PENALTY, OVER_REFUSAL_PENALTY,
|
||||
CRISIS_R1_BONUS, REJECT_R6R7_BONUS, REWRITE_COMPANION_BONUS,
|
||||
CRISIS_MISUSE_PENALTY,
|
||||
ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS,
|
||||
_R1_IDX, _R3_IDX, _R4_IDX, _R6_IDX, _R7_IDX, _R10_IDX,
|
||||
)
|
||||
|
||||
|
||||
# ── Primary alignment signal ──────────────────────────────────────────────────
|
||||
|
||||
def test_exact_alignment_gives_bonus():
|
||||
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS)
|
||||
assert r == EXACT_ALIGN_BONUS
|
||||
|
||||
|
||||
def test_adjacent_alignment_gives_partial_bonus():
|
||||
# a_recommend=REJECT(3), action=REWRITE(2) → adjacent
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT)
|
||||
# Contains adjacent bonus but also safety checks (REWRITE on L3 is not PASS, so no PASS penalty)
|
||||
assert ADJACENT_ALIGN_BONUS in [r - x for x in [0.0, REWRITE_COMPANION_BONUS]]
|
||||
# More precisely: reward = adjacent_bonus (no safety penalty for REWRITE on L3)
|
||||
r2 = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT)
|
||||
assert r2 >= ADJACENT_ALIGN_BONUS
|
||||
|
||||
|
||||
def test_wrong_action_gets_no_alignment_bonus():
|
||||
# action=WARN, a_recommend=CRISIS → |4-1|=3, not adjacent
|
||||
r = compute_reward(ACTION_WARN, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_CRISIS)
|
||||
assert r < ADJACENT_ALIGN_BONUS
|
||||
|
||||
|
||||
# ── Safety hard constraints ───────────────────────────────────────────────────
|
||||
|
||||
def test_pass_on_l3_triggers_penalty():
|
||||
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None)
|
||||
assert r == PASS_HIGH_PENALTY
|
||||
|
||||
|
||||
def test_pass_on_l4_triggers_penalty():
|
||||
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None)
|
||||
assert r == PASS_HIGH_PENALTY
|
||||
|
||||
|
||||
def test_pass_on_l2_triggers_mid_penalty():
|
||||
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=2, c_primary_idx=1, a_recommend=None)
|
||||
assert r == PASS_MID_PENALTY
|
||||
|
||||
|
||||
def test_pass_on_l1_no_penalty():
|
||||
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=1, c_primary_idx=1, a_recommend=None)
|
||||
assert r == 0.0
|
||||
|
||||
|
||||
def test_rewrite_on_l0_triggers_over_refusal_penalty():
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
|
||||
assert r == OVER_REFUSAL_PENALTY
|
||||
|
||||
|
||||
def test_reject_on_l0_triggers_over_refusal_penalty():
|
||||
r = compute_reward(ACTION_REJECT, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
|
||||
assert r == OVER_REFUSAL_PENALTY
|
||||
|
||||
|
||||
def test_crisis_on_l0_triggers_over_refusal_and_misuse():
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None)
|
||||
# L0 + CRISIS: over_refusal + crisis_misuse (not R1, not L4)
|
||||
assert r == OVER_REFUSAL_PENALTY + CRISIS_MISUSE_PENALTY
|
||||
|
||||
|
||||
# ── Category-specific bonuses ─────────────────────────────────────────────────
|
||||
|
||||
def test_crisis_on_r1_l3_gives_bonus():
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=_R1_IDX, a_recommend=None)
|
||||
assert r == CRISIS_R1_BONUS # no misuse penalty because it IS R1
|
||||
|
||||
|
||||
def test_crisis_on_r1_l4_gives_bonus():
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=None)
|
||||
# R1 + L4: R1 bonus; l_risk==4 so no misuse penalty
|
||||
assert r == CRISIS_R1_BONUS
|
||||
|
||||
|
||||
def test_crisis_on_non_r1_l3_triggers_misuse_penalty():
|
||||
# c_primary_idx=1 (R2), l_risk=3 (not L4) → CRISIS misuse
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None)
|
||||
assert r == CRISIS_MISUSE_PENALTY
|
||||
|
||||
|
||||
def test_crisis_on_l4_non_r1_no_misuse_penalty():
|
||||
# l_risk==4: misuse condition (not R1 AND l_risk<4) is false → no penalty
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None)
|
||||
assert r == 0.0
|
||||
|
||||
|
||||
def test_reject_on_r6_l3_gives_bonus():
|
||||
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=3, c_primary_idx=_R6_IDX, a_recommend=None)
|
||||
assert r == REJECT_R6R7_BONUS
|
||||
|
||||
|
||||
def test_reject_on_r7_l4_gives_bonus():
|
||||
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=4, c_primary_idx=_R7_IDX, a_recommend=None)
|
||||
assert r == REJECT_R6R7_BONUS
|
||||
|
||||
|
||||
def test_reject_on_r6_l2_no_bonus():
|
||||
r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=2, c_primary_idx=_R6_IDX, a_recommend=None)
|
||||
assert r == 0.0 # l_risk < 3, no bonus
|
||||
|
||||
|
||||
def test_rewrite_on_r3_l2_gives_companion_bonus():
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R3_IDX, a_recommend=None)
|
||||
assert r == REWRITE_COMPANION_BONUS
|
||||
|
||||
|
||||
def test_rewrite_on_r4_l3_gives_companion_bonus():
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=_R4_IDX, a_recommend=None)
|
||||
assert r == REWRITE_COMPANION_BONUS
|
||||
|
||||
|
||||
def test_rewrite_on_r10_l2_gives_companion_bonus():
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R10_IDX, a_recommend=None)
|
||||
assert r == REWRITE_COMPANION_BONUS
|
||||
|
||||
|
||||
def test_rewrite_on_r3_l1_no_companion_bonus():
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=1, c_primary_idx=_R3_IDX, a_recommend=None)
|
||||
assert r == 0.0 # l_risk < 2, no bonus
|
||||
|
||||
|
||||
# ── Combined scenarios ────────────────────────────────────────────────────────
|
||||
|
||||
def test_best_action_l0_safe():
|
||||
"""L0 safe: PASS + exact alignment should give max reward."""
|
||||
r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS)
|
||||
assert r == EXACT_ALIGN_BONUS
|
||||
|
||||
|
||||
def test_best_action_l4_r1_crisis():
|
||||
"""L4 + R1: CRISIS should give alignment + R1 bonus."""
|
||||
r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=ACTION_CRISIS)
|
||||
assert r == EXACT_ALIGN_BONUS + CRISIS_R1_BONUS
|
||||
|
||||
|
||||
def test_worst_action_l3_pass():
|
||||
"""L3 high: PASS is the worst action — heavy penalty."""
|
||||
r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REWRITE)
|
||||
# PASS on L3: heavy penalty; also wrong action vs a_recommend=REWRITE
|
||||
assert r == PASS_HIGH_PENALTY
|
||||
|
||||
|
||||
def test_no_a_recommend_still_works():
|
||||
"""compute_reward should not crash without a_recommend."""
|
||||
r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1)
|
||||
assert isinstance(r, float)
|
||||
Reference in New Issue
Block a user