Files
CompanionGuard-RL/scripts/evaluate.py
wangyu 4a0e71fb23 refactor: complete full implementation replacing all placeholder/mock content
Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 17:50:17 +08:00

273 lines
10 KiB
Python

"""
Full evaluation script for CompanionGuard-RL.
Runs detection and intervention evaluations against multiple baselines:
Detection baselines:
- L1a: Keyword detector
- L1b: Regex detector
- L1c: Combined keyword+regex
- Ours: CompanionRiskDetector (Module B)
Intervention baselines:
- Rule-based (l_risk >= 3 → REJECT)
- Threshold policy (per-level mapping)
- RL policy (Module C, Ours)
Usage:
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--agent-ckpt checkpoints/intervention/final.pt \\
--test-data data/processed/test.jsonl \\
--config configs/detector_config.yaml \\
--intervention-config configs/intervention_config.yaml
"""
import argparse
import json
import yaml
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.utils.metrics import detection_metrics, intervention_metrics
from src.utils.baselines import (
KeywordDetector,
RegexDetector,
CombinedRuleDetector,
rule_based_intervention,
threshold_intervention,
)
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
PRIMARY_CATEGORY_LIST,
FINE_GRAINED_LABELS,
)
def get_obs_dim(detector_hidden: int) -> int:
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
# ── Detection evaluation ──────────────────────────────────────────────────
def run_neural_detection(
model: CompanionRiskDetector,
tokenizer,
samples: List[Dict],
cfg: Dict,
device: str,
) -> Dict:
"""Run the neural detector on test samples and compute metrics."""
model.eval()
y_true, y_pred = [], []
l_true, l_pred = [], []
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
for sample in samples:
sample = validate_and_normalize(dict(sample))
texts = format_conversation(
sample["persona"], sample["history"],
sample["user_input"], sample["ai_response"],
max_history_turns=cfg.get("data", {}).get("max_history_turns", 5),
)
def enc(text, max_len):
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], cfg.get("data", {}).get("max_persona_len", 128))
c_enc = enc(texts["context_text"], cfg.get("data", {}).get("max_context_len", 512))
r_enc = enc(texts["response_text"], cfg.get("data", {}).get("max_response_len", 256))
with torch.no_grad():
preds = model.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
binary_threshold=binary_threshold,
)
y_true.append(int(sample["y_risk"]))
y_pred.append(preds["y_risk"].item())
l_true.append(int(sample["l_risk"]))
l_pred.append(preds["l_risk"].item())
return detection_metrics(y_true, y_pred, l_true, l_pred)
def run_keyword_detection(samples: List[Dict], detector) -> Dict:
y_true, y_pred = [], []
l_true, l_pred = [], []
for s in samples:
result = detector.detect(s.get("ai_response", ""))
y_true.append(int(s["y_risk"]))
y_pred.append(result["y_risk"])
l_true.append(int(s["l_risk"]))
l_pred.append(result["l_risk"])
return detection_metrics(y_true, y_pred, l_true, l_pred)
# ── Intervention evaluation ───────────────────────────────────────────────
def run_rl_intervention(
agent: InterventionAgent,
processed_samples: List[Dict],
device: str,
) -> Dict:
agent.eval()
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
for s in processed_samples:
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
with torch.no_grad():
action, _, _, _ = agent.get_action(obs, deterministic=True)
y_risk_true.append(int(s["y_risk"]))
l_risk_true.append(int(s["l_risk"]))
a_pred.append(action.item())
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
a_pred = [policy_fn(int(s["l_risk"])) for s in processed_samples]
return intervention_metrics(y_risk_true, l_risk_true, a_pred)
def print_metrics(name: str, metrics: Dict):
print(f"\n{'=' * 50}")
print(f" {name}")
print(f"{'=' * 50}")
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:35s}: {v:.4f}")
elif isinstance(v, list):
formatted = [f"{x:.3f}" for x in v]
print(f" {k:35s}: {formatted}")
else:
print(f" {k:35s}: {v}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--detector-ckpt", required=True)
parser.add_argument("--agent-ckpt", default=None)
parser.add_argument("--test-data", default="data/processed/test.jsonl")
parser.add_argument("--config", default="configs/detector_config.yaml")
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
parser.add_argument("--output", default="experiments/eval_results.json")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
with open(args.intervention_config) as f:
int_cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
samples = load_jsonl(args.test_data)
print(f"Loaded {len(samples)} test samples.")
# ── Neural detector ──────────────────────────────────────────────────
detector = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
num_heads=cfg["model"]["num_heads"],
dropout=cfg["model"]["dropout"],
use_lora=cfg["model"]["use_lora"],
).to(device)
if Path(args.detector_ckpt).exists():
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print(f"Detector loaded from {args.detector_ckpt}")
else:
print(f"[WARN] Checkpoint not found: {args.detector_ckpt}. Using random weights.")
# ── Detection evaluation ─────────────────────────────────────────────
all_results = {}
print("\nRunning: L1a — Keyword Detector")
kw_metrics = run_keyword_detection(samples, KeywordDetector())
print_metrics("L1a: Keyword Detector", kw_metrics)
all_results["L1a_keyword"] = kw_metrics
print("\nRunning: L1b — Regex Detector")
re_metrics = run_keyword_detection(samples, RegexDetector())
print_metrics("L1b: Regex Detector", re_metrics)
all_results["L1b_regex"] = re_metrics
print("\nRunning: L1c — Combined Keyword+Regex")
cb_metrics = run_keyword_detection(samples, CombinedRuleDetector())
print_metrics("L1c: Combined Detector", cb_metrics)
all_results["L1c_combined"] = cb_metrics
print("\nRunning: Ours — Neural Detector (Module B)")
neural_metrics = run_neural_detection(detector, tokenizer, samples, cfg, device)
print_metrics("Ours: CompanionRiskDetector", neural_metrics)
all_results["ours_detection"] = neural_metrics
# ── Intervention evaluation ──────────────────────────────────────────
if args.agent_ckpt:
print("\n\nPreprocessing test samples with detector for RL evaluation...")
processed = preprocess_samples_with_detector(
samples, detector, tokenizer, device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
# Build RL agent
detector_hidden = cfg["model"]["hidden_size"]
obs_dim = get_obs_dim(detector_hidden)
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=int_cfg["agent"]["state_hidden"],
dropout=int_cfg["agent"]["dropout"],
).to(device)
if Path(args.agent_ckpt).exists():
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print(f"Agent loaded from {args.agent_ckpt}")
else:
print(f"[WARN] Agent checkpoint not found: {args.agent_ckpt}. Using random weights.")
print("\nRunning: Rule-based Intervention Baseline (l_risk >= 3 → REJECT)")
rule_m = run_rule_intervention(processed, rule_based_intervention)
print_metrics("Rule-based Policy", rule_m)
all_results["baseline_rule"] = rule_m
print("\nRunning: Threshold Intervention Baseline")
thr_m = run_rule_intervention(processed, threshold_intervention)
print_metrics("Threshold Policy", thr_m)
all_results["baseline_threshold"] = thr_m
print("\nRunning: Ours — RL Intervention Policy (Module C)")
rl_m = run_rl_intervention(agent, processed, device)
print_metrics("Ours: RL Intervention Policy", rl_m)
all_results["ours_intervention"] = rl_m
# ── Save results ─────────────────────────────────────────────────────
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
print(f"\nAll results saved to {args.output}")
if __name__ == "__main__":
main()