refactor: complete full implementation replacing all placeholder/mock content

Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-09 17:50:17 +08:00
parent 7d4345c29d
commit 4a0e71fb23
13 changed files with 1829 additions and 564 deletions

View File

@@ -1,109 +1,168 @@
"""
Evaluation script: run detection + intervention baselines and ours.
Full evaluation script for CompanionGuard-RL.
Runs detection and intervention evaluations against multiple baselines:
Detection baselines:
- L1a: Keyword detector
- L1b: Regex detector
- L1c: Combined keyword+regex
- Ours: CompanionRiskDetector (Module B)
Intervention baselines:
- Rule-based (l_risk >= 3 → REJECT)
- Threshold policy (per-level mapping)
- RL policy (Module C, Ours)
Usage:
python scripts/evaluate.py --detector-ckpt checkpoints/detector/best.pt \
--agent-ckpt checkpoints/intervention/final.pt \
--test-data data/processed/test.jsonl \
--config configs/detector_config.yaml
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--agent-ckpt checkpoints/intervention/final.pt \\
--test-data data/processed/test.jsonl \\
--config configs/detector_config.yaml \\
--intervention-config configs/intervention_config.yaml
"""
import argparse
import yaml
import json
import yaml
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl, format_conversation
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.utils.metrics import detection_metrics, intervention_metrics
from src.utils.baselines import (
KeywordDetector,
RegexDetector,
CombinedRuleDetector,
rule_based_intervention,
threshold_intervention,
)
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
INTERVENTION_ACTIONS,
NUM_RISK_LEVELS,
NUM_PRIMARY,
DEFAULT_ACTION_BY_LEVEL,
PRIMARY_CATEGORY_LIST,
FINE_GRAINED_LABELS,
)
# ── Baselines ────────────────────────────────────────────────────────────────
def rule_based_policy(l_risk: int) -> int:
"""Baseline: fixed rule — l_risk >= 3 → REJECT, else PASS."""
return 3 if l_risk >= 3 else 0
def get_obs_dim(detector_hidden: int) -> int:
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
def threshold_policy(l_risk: int) -> int:
"""Baseline: threshold mapping per risk level."""
return DEFAULT_ACTION_BY_LEVEL[l_risk]
# ── Detection evaluation ──────────────────────────────────────────────────
# ── Main evaluation ──────────────────────────────────────────────────────────
def run_detection_eval(model, tokenizer, samples, cfg, device):
def run_neural_detection(
model: CompanionRiskDetector,
tokenizer,
samples: List[Dict],
cfg: Dict,
device: str,
) -> Dict:
"""Run the neural detector on test samples and compute metrics."""
model.eval()
y_true, y_pred = [], []
l_true, l_pred = [], []
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
for sample in samples:
sample = validate_and_normalize(dict(sample))
texts = format_conversation(
sample["persona"], sample["history"],
sample["user_input"], sample["ai_response"],
max_history_turns=cfg.get("data", {}).get("max_history_turns", 5),
)
def enc(text, max_len):
return tokenizer(text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt")
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
p_enc = enc(texts["persona_text"], cfg.get("data", {}).get("max_persona_len", 128))
c_enc = enc(texts["context_text"], cfg.get("data", {}).get("max_context_len", 512))
r_enc = enc(texts["response_text"], cfg.get("data", {}).get("max_response_len", 256))
with torch.no_grad():
preds = model.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
binary_threshold=binary_threshold,
)
y_true.append(sample["y_risk"])
y_true.append(int(sample["y_risk"]))
y_pred.append(preds["y_risk"].item())
l_true.append(sample["l_risk"])
l_true.append(int(sample["l_risk"]))
l_pred.append(preds["l_risk"].item())
return detection_metrics(y_true, y_pred, l_true, l_pred)
def run_intervention_eval(agent, processed_samples, obs_dim, device):
def run_keyword_detection(samples: List[Dict], detector) -> Dict:
y_true, y_pred = [], []
l_true, l_pred = [], []
for s in samples:
result = detector.detect(s.get("ai_response", ""))
y_true.append(int(s["y_risk"]))
y_pred.append(result["y_risk"])
l_true.append(int(s["l_risk"]))
l_pred.append(result["l_risk"])
return detection_metrics(y_true, y_pred, l_true, l_pred)
# ── Intervention evaluation ───────────────────────────────────────────────
def run_rl_intervention(
agent: InterventionAgent,
processed_samples: List[Dict],
device: str,
) -> Dict:
agent.eval()
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = torch.FloatTensor(
np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
).unsqueeze(0).to(device)
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
with torch.no_grad():
action, _, _, _ = agent.get_action(obs, deterministic=True)
y_risk_true.append(s["y_risk"])
y_risk_true.append(int(s["y_risk"]))
l_risk_true.append(int(s["l_risk"]))
a_pred.append(action.item())
a_recommend.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
a_pred = [policy_fn(int(s["l_risk"])) for s in processed_samples]
return intervention_metrics(y_risk_true, l_risk_true, a_pred)
def print_metrics(name: str, metrics: Dict):
print(f"\n{'=' * 50}")
print(f" {name}")
print(f"{'=' * 50}")
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:35s}: {v:.4f}")
elif isinstance(v, list):
formatted = [f"{x:.3f}" for x in v]
print(f" {k:35s}: {formatted}")
else:
print(f" {k:35s}: {v}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--detector-ckpt", required=True)
@@ -111,6 +170,7 @@ def main():
parser.add_argument("--test-data", default="data/processed/test.jsonl")
parser.add_argument("--config", default="configs/detector_config.yaml")
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
parser.add_argument("--output", default="experiments/eval_results.json")
args = parser.parse_args()
with open(args.config) as f:
@@ -119,74 +179,93 @@ def main():
int_cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
print(f"Device: {device}")
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
samples = load_jsonl(args.test_data)
print(f"Loaded {len(samples)} test samples.")
# Detection evaluation
# ── Neural detector ──────────────────────────────────────────────────
detector = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
num_heads=cfg["model"]["num_heads"],
dropout=cfg["model"]["dropout"],
use_lora=cfg["model"]["use_lora"],
).to(device)
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print("\n=== Detection Evaluation ===")
det_metrics = run_detection_eval(detector, tokenizer, samples, cfg, device)
for k, v in det_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
if Path(args.detector_ckpt).exists():
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print(f"Detector loaded from {args.detector_ckpt}")
else:
print(f"[WARN] Checkpoint not found: {args.detector_ckpt}. Using random weights.")
# Intervention evaluation
# ── Detection evaluation ─────────────────────────────────────────────
all_results = {}
print("\nRunning: L1a — Keyword Detector")
kw_metrics = run_keyword_detection(samples, KeywordDetector())
print_metrics("L1a: Keyword Detector", kw_metrics)
all_results["L1a_keyword"] = kw_metrics
print("\nRunning: L1b — Regex Detector")
re_metrics = run_keyword_detection(samples, RegexDetector())
print_metrics("L1b: Regex Detector", re_metrics)
all_results["L1b_regex"] = re_metrics
print("\nRunning: L1c — Combined Keyword+Regex")
cb_metrics = run_keyword_detection(samples, CombinedRuleDetector())
print_metrics("L1c: Combined Detector", cb_metrics)
all_results["L1c_combined"] = cb_metrics
print("\nRunning: Ours — Neural Detector (Module B)")
neural_metrics = run_neural_detection(detector, tokenizer, samples, cfg, device)
print_metrics("Ours: CompanionRiskDetector", neural_metrics)
all_results["ours_detection"] = neural_metrics
# ── Intervention evaluation ──────────────────────────────────────────
if args.agent_ckpt:
from scripts.train_intervention import preprocess_samples_with_detector
print("\n\nPreprocessing test samples with detector for RL evaluation...")
processed = preprocess_samples_with_detector(
samples, detector, tokenizer, device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
# Build RL agent
detector_hidden = cfg["model"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
processed = preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device)
obs_dim = get_obs_dim(detector_hidden)
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=int_cfg["agent"]["state_hidden"],
dropout=int_cfg["agent"]["dropout"],
).to(device)
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print("\n=== Intervention Evaluation: RL Policy (Ours) ===")
int_metrics = run_intervention_eval(agent, processed, obs_dim, device)
for k, v in int_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
elif isinstance(v, list):
print(f" {k}: {[f'{x:.3f}' for x in v]}")
if Path(args.agent_ckpt).exists():
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print(f"Agent loaded from {args.agent_ckpt}")
else:
print(f"[WARN] Agent checkpoint not found: {args.agent_ckpt}. Using random weights.")
print("\n=== Intervention Evaluation: Rule-based Baseline ===")
rule_preds = [rule_based_policy(s["l_risk"]) for s in processed]
rule_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
rule_preds,
)
for k, v in rule_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
print("\nRunning: Rule-based Intervention Baseline (l_risk >= 3 → REJECT)")
rule_m = run_rule_intervention(processed, rule_based_intervention)
print_metrics("Rule-based Policy", rule_m)
all_results["baseline_rule"] = rule_m
print("\n=== Intervention Evaluation: Threshold Baseline ===")
thr_preds = [threshold_policy(s["l_risk"]) for s in processed]
thr_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
thr_preds,
)
for k, v in thr_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
print("\nRunning: Threshold Intervention Baseline")
thr_m = run_rule_intervention(processed, threshold_intervention)
print_metrics("Threshold Policy", thr_m)
all_results["baseline_threshold"] = thr_m
# Save results
results = {"detection": det_metrics}
Path("experiments").mkdir(exist_ok=True)
with open("experiments/eval_results.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print("\nResults saved to experiments/eval_results.json")
print("\nRunning: Ours — RL Intervention Policy (Module C)")
rl_m = run_rl_intervention(agent, processed, device)
print_metrics("Ours: RL Intervention Policy", rl_m)
all_results["ours_intervention"] = rl_m
# ── Save results ─────────────────────────────────────────────────────
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
print(f"\nAll results saved to {args.output}")
if __name__ == "__main__":