""" Unit tests for v5 reward function (label-aligned constrained reward). Run from project root: python -m pytest tests/test_reward_v5.py -v """ import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.rl.reward import ( compute_reward, EXACT_ALIGN_BONUS, ADJACENT_ALIGN_BONUS, PASS_HIGH_PENALTY, PASS_MID_PENALTY, OVER_REFUSAL_PENALTY, CRISIS_R1_BONUS, REJECT_R6R7_BONUS, REWRITE_COMPANION_BONUS, CRISIS_MISUSE_PENALTY, ACTION_PASS, ACTION_WARN, ACTION_REWRITE, ACTION_REJECT, ACTION_CRISIS, _R1_IDX, _R3_IDX, _R4_IDX, _R6_IDX, _R7_IDX, _R10_IDX, ) # ── Primary alignment signal ────────────────────────────────────────────────── def test_exact_alignment_gives_bonus(): r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS) assert r == EXACT_ALIGN_BONUS def test_adjacent_alignment_gives_partial_bonus(): # a_recommend=REJECT(3), action=REWRITE(2) → adjacent r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT) # Contains adjacent bonus but also safety checks (REWRITE on L3 is not PASS, so no PASS penalty) assert ADJACENT_ALIGN_BONUS in [r - x for x in [0.0, REWRITE_COMPANION_BONUS]] # More precisely: reward = adjacent_bonus (no safety penalty for REWRITE on L3) r2 = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REJECT) assert r2 >= ADJACENT_ALIGN_BONUS def test_wrong_action_gets_no_alignment_bonus(): # action=WARN, a_recommend=CRISIS → |4-1|=3, not adjacent r = compute_reward(ACTION_WARN, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_CRISIS) assert r < ADJACENT_ALIGN_BONUS # ── Safety hard constraints ─────────────────────────────────────────────────── def test_pass_on_l3_triggers_penalty(): r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None) assert r == PASS_HIGH_PENALTY def test_pass_on_l4_triggers_penalty(): r = compute_reward(ACTION_PASS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None) assert r == PASS_HIGH_PENALTY def test_pass_on_l2_triggers_mid_penalty(): r = compute_reward(ACTION_PASS, y_risk=1, l_risk=2, c_primary_idx=1, a_recommend=None) assert r == PASS_MID_PENALTY def test_pass_on_l1_no_penalty(): r = compute_reward(ACTION_PASS, y_risk=0, l_risk=1, c_primary_idx=1, a_recommend=None) assert r == 0.0 def test_rewrite_on_l0_triggers_over_refusal_penalty(): r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None) assert r == OVER_REFUSAL_PENALTY def test_reject_on_l0_triggers_over_refusal_penalty(): r = compute_reward(ACTION_REJECT, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None) assert r == OVER_REFUSAL_PENALTY def test_crisis_on_l0_triggers_over_refusal_and_misuse(): r = compute_reward(ACTION_CRISIS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=None) # L0 + CRISIS: over_refusal + crisis_misuse (not R1, not L4) assert r == OVER_REFUSAL_PENALTY + CRISIS_MISUSE_PENALTY # ── Category-specific bonuses ───────────────────────────────────────────────── def test_crisis_on_r1_l3_gives_bonus(): r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=_R1_IDX, a_recommend=None) assert r == CRISIS_R1_BONUS # no misuse penalty because it IS R1 def test_crisis_on_r1_l4_gives_bonus(): r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=None) # R1 + L4: R1 bonus; l_risk==4 so no misuse penalty assert r == CRISIS_R1_BONUS def test_crisis_on_non_r1_l3_triggers_misuse_penalty(): # c_primary_idx=1 (R2), l_risk=3 (not L4) → CRISIS misuse r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=None) assert r == CRISIS_MISUSE_PENALTY def test_crisis_on_l4_non_r1_no_misuse_penalty(): # l_risk==4: misuse condition (not R1 AND l_risk<4) is false → no penalty r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=1, a_recommend=None) assert r == 0.0 def test_reject_on_r6_l3_gives_bonus(): r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=3, c_primary_idx=_R6_IDX, a_recommend=None) assert r == REJECT_R6R7_BONUS def test_reject_on_r7_l4_gives_bonus(): r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=4, c_primary_idx=_R7_IDX, a_recommend=None) assert r == REJECT_R6R7_BONUS def test_reject_on_r6_l2_no_bonus(): r = compute_reward(ACTION_REJECT, y_risk=1, l_risk=2, c_primary_idx=_R6_IDX, a_recommend=None) assert r == 0.0 # l_risk < 3, no bonus def test_rewrite_on_r3_l2_gives_companion_bonus(): r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R3_IDX, a_recommend=None) assert r == REWRITE_COMPANION_BONUS def test_rewrite_on_r4_l3_gives_companion_bonus(): r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=_R4_IDX, a_recommend=None) assert r == REWRITE_COMPANION_BONUS def test_rewrite_on_r10_l2_gives_companion_bonus(): r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=2, c_primary_idx=_R10_IDX, a_recommend=None) assert r == REWRITE_COMPANION_BONUS def test_rewrite_on_r3_l1_no_companion_bonus(): r = compute_reward(ACTION_REWRITE, y_risk=0, l_risk=1, c_primary_idx=_R3_IDX, a_recommend=None) assert r == 0.0 # l_risk < 2, no bonus # ── Combined scenarios ──────────────────────────────────────────────────────── def test_best_action_l0_safe(): """L0 safe: PASS + exact alignment should give max reward.""" r = compute_reward(ACTION_PASS, y_risk=0, l_risk=0, c_primary_idx=1, a_recommend=ACTION_PASS) assert r == EXACT_ALIGN_BONUS def test_best_action_l4_r1_crisis(): """L4 + R1: CRISIS should give alignment + R1 bonus.""" r = compute_reward(ACTION_CRISIS, y_risk=1, l_risk=4, c_primary_idx=_R1_IDX, a_recommend=ACTION_CRISIS) assert r == EXACT_ALIGN_BONUS + CRISIS_R1_BONUS def test_worst_action_l3_pass(): """L3 high: PASS is the worst action — heavy penalty.""" r = compute_reward(ACTION_PASS, y_risk=1, l_risk=3, c_primary_idx=1, a_recommend=ACTION_REWRITE) # PASS on L3: heavy penalty; also wrong action vs a_recommend=REWRITE assert r == PASS_HIGH_PENALTY def test_no_a_recommend_still_works(): """compute_reward should not crash without a_recommend.""" r = compute_reward(ACTION_REWRITE, y_risk=1, l_risk=3, c_primary_idx=1) assert isinstance(r, float)