2026-05-09 17:21:11 +08:00
|
|
|
"""
|
2026-05-09 17:50:17 +08:00
|
|
|
Full evaluation script for CompanionGuard-RL.
|
|
|
|
|
|
|
|
|
|
Runs detection and intervention evaluations against multiple baselines:
|
|
|
|
|
|
|
|
|
|
Detection baselines:
|
|
|
|
|
- L1a: Keyword detector
|
|
|
|
|
- L1b: Regex detector
|
|
|
|
|
- L1c: Combined keyword+regex
|
|
|
|
|
- Ours: CompanionRiskDetector (Module B)
|
|
|
|
|
|
|
|
|
|
Intervention baselines:
|
|
|
|
|
- Rule-based (l_risk >= 3 → REJECT)
|
|
|
|
|
- Threshold policy (per-level mapping)
|
|
|
|
|
- RL policy (Module C, Ours)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
Usage:
|
2026-05-09 17:50:17 +08:00
|
|
|
python scripts/evaluate.py \\
|
|
|
|
|
--detector-ckpt checkpoints/detector/best.pt \\
|
|
|
|
|
--agent-ckpt checkpoints/intervention/final.pt \\
|
|
|
|
|
--test-data data/processed/test.jsonl \\
|
|
|
|
|
--config configs/detector_config.yaml \\
|
|
|
|
|
--intervention-config configs/intervention_config.yaml
|
2026-05-09 17:21:11 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import json
|
2026-05-09 17:50:17 +08:00
|
|
|
import yaml
|
2026-05-09 17:21:11 +08:00
|
|
|
import torch
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pathlib import Path
|
2026-05-09 17:50:17 +08:00
|
|
|
from typing import List, Dict
|
2026-05-09 17:21:11 +08:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
|
2026-05-09 17:21:11 +08:00
|
|
|
from src.models.detector import CompanionRiskDetector
|
|
|
|
|
from src.models.intervention_agent import InterventionAgent
|
|
|
|
|
from src.utils.metrics import detection_metrics, intervention_metrics
|
2026-05-09 17:50:17 +08:00
|
|
|
from src.utils.baselines import (
|
|
|
|
|
KeywordDetector,
|
|
|
|
|
RegexDetector,
|
|
|
|
|
CombinedRuleDetector,
|
|
|
|
|
rule_based_intervention,
|
|
|
|
|
threshold_intervention,
|
|
|
|
|
)
|
|
|
|
|
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
|
2026-05-09 17:21:11 +08:00
|
|
|
from src.utils.taxonomy import (
|
|
|
|
|
ACTION_NAME_TO_ID,
|
|
|
|
|
NUM_RISK_LEVELS,
|
|
|
|
|
NUM_PRIMARY,
|
2026-05-09 17:50:17 +08:00
|
|
|
PRIMARY_CATEGORY_LIST,
|
|
|
|
|
FINE_GRAINED_LABELS,
|
2026-05-09 17:21:11 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
def get_obs_dim(detector_hidden: int) -> int:
|
|
|
|
|
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
# ── Detection evaluation ──────────────────────────────────────────────────
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
def run_neural_detection(
|
|
|
|
|
model: CompanionRiskDetector,
|
|
|
|
|
tokenizer,
|
|
|
|
|
samples: List[Dict],
|
|
|
|
|
cfg: Dict,
|
|
|
|
|
device: str,
|
|
|
|
|
) -> Dict:
|
|
|
|
|
"""Run the neural detector on test samples and compute metrics."""
|
2026-05-09 17:21:11 +08:00
|
|
|
model.eval()
|
|
|
|
|
y_true, y_pred = [], []
|
|
|
|
|
l_true, l_pred = [], []
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
|
|
|
|
|
|
2026-05-09 17:21:11 +08:00
|
|
|
for sample in samples:
|
2026-05-09 17:50:17 +08:00
|
|
|
sample = validate_and_normalize(dict(sample))
|
2026-05-09 17:21:11 +08:00
|
|
|
texts = format_conversation(
|
|
|
|
|
sample["persona"], sample["history"],
|
|
|
|
|
sample["user_input"], sample["ai_response"],
|
2026-05-09 17:50:17 +08:00
|
|
|
max_history_turns=cfg.get("data", {}).get("max_history_turns", 5),
|
2026-05-09 17:21:11 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def enc(text, max_len):
|
2026-05-09 17:50:17 +08:00
|
|
|
return tokenizer(
|
|
|
|
|
text, max_length=max_len, truncation=True,
|
|
|
|
|
padding="max_length", return_tensors="pt",
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
p_enc = enc(texts["persona_text"], cfg.get("data", {}).get("max_persona_len", 128))
|
|
|
|
|
c_enc = enc(texts["context_text"], cfg.get("data", {}).get("max_context_len", 512))
|
|
|
|
|
r_enc = enc(texts["response_text"], cfg.get("data", {}).get("max_response_len", 256))
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
preds = model.predict(
|
|
|
|
|
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
|
|
|
|
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
|
|
|
|
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
2026-05-09 17:50:17 +08:00
|
|
|
binary_threshold=binary_threshold,
|
2026-05-09 17:21:11 +08:00
|
|
|
)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
y_true.append(int(sample["y_risk"]))
|
2026-05-09 17:21:11 +08:00
|
|
|
y_pred.append(preds["y_risk"].item())
|
2026-05-09 17:50:17 +08:00
|
|
|
l_true.append(int(sample["l_risk"]))
|
2026-05-09 17:21:11 +08:00
|
|
|
l_pred.append(preds["l_risk"].item())
|
|
|
|
|
|
|
|
|
|
return detection_metrics(y_true, y_pred, l_true, l_pred)
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
def run_keyword_detection(samples: List[Dict], detector) -> Dict:
|
|
|
|
|
y_true, y_pred = [], []
|
|
|
|
|
l_true, l_pred = [], []
|
|
|
|
|
for s in samples:
|
|
|
|
|
result = detector.detect(s.get("ai_response", ""))
|
|
|
|
|
y_true.append(int(s["y_risk"]))
|
|
|
|
|
y_pred.append(result["y_risk"])
|
|
|
|
|
l_true.append(int(s["l_risk"]))
|
|
|
|
|
l_pred.append(result["l_risk"])
|
|
|
|
|
return detection_metrics(y_true, y_pred, l_true, l_pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Intervention evaluation ───────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
def run_rl_intervention(
|
|
|
|
|
agent: InterventionAgent,
|
|
|
|
|
processed_samples: List[Dict],
|
|
|
|
|
device: str,
|
|
|
|
|
) -> Dict:
|
2026-05-09 17:21:11 +08:00
|
|
|
agent.eval()
|
|
|
|
|
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
|
|
|
|
|
|
|
|
|
|
for s in processed_samples:
|
2026-05-09 17:50:17 +08:00
|
|
|
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
|
2026-05-09 17:21:11 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
|
action, _, _, _ = agent.get_action(obs, deterministic=True)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
y_risk_true.append(int(s["y_risk"]))
|
2026-05-09 17:21:11 +08:00
|
|
|
l_risk_true.append(int(s["l_risk"]))
|
|
|
|
|
a_pred.append(action.item())
|
2026-05-09 17:50:17 +08:00
|
|
|
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
|
|
|
|
|
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
|
|
|
|
|
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
|
|
|
|
|
a_pred = [policy_fn(int(s["l_risk"])) for s in processed_samples]
|
|
|
|
|
return intervention_metrics(y_risk_true, l_risk_true, a_pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_metrics(name: str, metrics: Dict):
|
|
|
|
|
print(f"\n{'=' * 50}")
|
|
|
|
|
print(f" {name}")
|
|
|
|
|
print(f"{'=' * 50}")
|
|
|
|
|
for k, v in metrics.items():
|
|
|
|
|
if isinstance(v, float):
|
|
|
|
|
print(f" {k:35s}: {v:.4f}")
|
|
|
|
|
elif isinstance(v, list):
|
|
|
|
|
formatted = [f"{x:.3f}" for x in v]
|
|
|
|
|
print(f" {k:35s}: {formatted}")
|
|
|
|
|
else:
|
|
|
|
|
print(f" {k:35s}: {v}")
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:21:11 +08:00
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument("--detector-ckpt", required=True)
|
|
|
|
|
parser.add_argument("--agent-ckpt", default=None)
|
|
|
|
|
parser.add_argument("--test-data", default="data/processed/test.jsonl")
|
|
|
|
|
parser.add_argument("--config", default="configs/detector_config.yaml")
|
|
|
|
|
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
|
2026-05-09 17:50:17 +08:00
|
|
|
parser.add_argument("--output", default="experiments/eval_results.json")
|
2026-05-09 17:21:11 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
with open(args.config) as f:
|
|
|
|
|
cfg = yaml.safe_load(f)
|
|
|
|
|
with open(args.intervention_config) as f:
|
|
|
|
|
int_cfg = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
2026-05-09 17:50:17 +08:00
|
|
|
print(f"Device: {device}")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
2026-05-09 17:21:11 +08:00
|
|
|
samples = load_jsonl(args.test_data)
|
|
|
|
|
print(f"Loaded {len(samples)} test samples.")
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
# ── Neural detector ──────────────────────────────────────────────────
|
2026-05-09 17:21:11 +08:00
|
|
|
detector = CompanionRiskDetector(
|
|
|
|
|
model_name=cfg["model"]["name"],
|
|
|
|
|
hidden_size=cfg["model"]["hidden_size"],
|
2026-05-09 17:50:17 +08:00
|
|
|
num_heads=cfg["model"]["num_heads"],
|
|
|
|
|
dropout=cfg["model"]["dropout"],
|
|
|
|
|
use_lora=cfg["model"]["use_lora"],
|
2026-05-09 17:21:11 +08:00
|
|
|
).to(device)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
if Path(args.detector_ckpt).exists():
|
|
|
|
|
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
|
|
|
|
|
print(f"Detector loaded from {args.detector_ckpt}")
|
|
|
|
|
else:
|
|
|
|
|
print(f"[WARN] Checkpoint not found: {args.detector_ckpt}. Using random weights.")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
# ── Detection evaluation ─────────────────────────────────────────────
|
|
|
|
|
all_results = {}
|
|
|
|
|
|
|
|
|
|
print("\nRunning: L1a — Keyword Detector")
|
|
|
|
|
kw_metrics = run_keyword_detection(samples, KeywordDetector())
|
|
|
|
|
print_metrics("L1a: Keyword Detector", kw_metrics)
|
|
|
|
|
all_results["L1a_keyword"] = kw_metrics
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
print("\nRunning: L1b — Regex Detector")
|
|
|
|
|
re_metrics = run_keyword_detection(samples, RegexDetector())
|
|
|
|
|
print_metrics("L1b: Regex Detector", re_metrics)
|
|
|
|
|
all_results["L1b_regex"] = re_metrics
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
print("\nRunning: L1c — Combined Keyword+Regex")
|
|
|
|
|
cb_metrics = run_keyword_detection(samples, CombinedRuleDetector())
|
|
|
|
|
print_metrics("L1c: Combined Detector", cb_metrics)
|
|
|
|
|
all_results["L1c_combined"] = cb_metrics
|
|
|
|
|
|
|
|
|
|
print("\nRunning: Ours — Neural Detector (Module B)")
|
|
|
|
|
neural_metrics = run_neural_detection(detector, tokenizer, samples, cfg, device)
|
|
|
|
|
print_metrics("Ours: CompanionRiskDetector", neural_metrics)
|
|
|
|
|
all_results["ours_detection"] = neural_metrics
|
|
|
|
|
|
|
|
|
|
# ── Intervention evaluation ──────────────────────────────────────────
|
|
|
|
|
if args.agent_ckpt:
|
|
|
|
|
print("\n\nPreprocessing test samples with detector for RL evaluation...")
|
|
|
|
|
processed = preprocess_samples_with_detector(
|
|
|
|
|
samples, detector, tokenizer, device=device,
|
|
|
|
|
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Build RL agent
|
|
|
|
|
detector_hidden = cfg["model"]["hidden_size"]
|
|
|
|
|
obs_dim = get_obs_dim(detector_hidden)
|
2026-05-09 17:21:11 +08:00
|
|
|
agent = InterventionAgent(
|
|
|
|
|
detector_hidden=detector_hidden,
|
|
|
|
|
state_hidden=int_cfg["agent"]["state_hidden"],
|
2026-05-09 17:50:17 +08:00
|
|
|
dropout=int_cfg["agent"]["dropout"],
|
2026-05-09 17:21:11 +08:00
|
|
|
).to(device)
|
2026-05-09 17:50:17 +08:00
|
|
|
|
|
|
|
|
if Path(args.agent_ckpt).exists():
|
|
|
|
|
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
|
|
|
|
|
print(f"Agent loaded from {args.agent_ckpt}")
|
|
|
|
|
else:
|
|
|
|
|
print(f"[WARN] Agent checkpoint not found: {args.agent_ckpt}. Using random weights.")
|
|
|
|
|
|
|
|
|
|
print("\nRunning: Rule-based Intervention Baseline (l_risk >= 3 → REJECT)")
|
|
|
|
|
rule_m = run_rule_intervention(processed, rule_based_intervention)
|
|
|
|
|
print_metrics("Rule-based Policy", rule_m)
|
|
|
|
|
all_results["baseline_rule"] = rule_m
|
|
|
|
|
|
|
|
|
|
print("\nRunning: Threshold Intervention Baseline")
|
|
|
|
|
thr_m = run_rule_intervention(processed, threshold_intervention)
|
|
|
|
|
print_metrics("Threshold Policy", thr_m)
|
|
|
|
|
all_results["baseline_threshold"] = thr_m
|
|
|
|
|
|
|
|
|
|
print("\nRunning: Ours — RL Intervention Policy (Module C)")
|
|
|
|
|
rl_m = run_rl_intervention(agent, processed, device)
|
|
|
|
|
print_metrics("Ours: RL Intervention Policy", rl_m)
|
|
|
|
|
all_results["ours_intervention"] = rl_m
|
|
|
|
|
|
|
|
|
|
# ── Save results ─────────────────────────────────────────────────────
|
|
|
|
|
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
with open(args.output, "w", encoding="utf-8") as f:
|
|
|
|
|
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
|
|
|
|
|
print(f"\nAll results saved to {args.output}")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|