Files
CompanionGuard-RL/code/scripts/evaluate.py
zhangsiyuan bd1f51c496 chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 11:28:42 +08:00

613 lines
25 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Full evaluation script for CompanionGuard-RL.
Runs detection and intervention evaluations against multiple baselines:
Detection baselines:
- L1a: Keyword detector
- L1b: Regex detector
- L1c: Combined keyword+regex
- Ours: CompanionRiskDetector (Module B)
Intervention baselines:
- Rule-based (l_risk >= 3 → REJECT)
- Threshold policy (per-level mapping)
- RL policy (Module C, Ours) [requires --agent-ckpt]
Usage (detection only, no RL agent yet):
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
--config configs/detector_config.yaml
Usage (with RL intervention agent):
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--agent-ckpt checkpoints/intervention/final.pt \\
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
--config configs/detector_config.yaml \\
--intervention-config configs/intervention_config.yaml
Usage (human-annotated subset only):
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--source-filter human \\
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
--config configs/detector_config.yaml
"""
import argparse
import json
import yaml
import torch
import numpy as np
from collections import defaultdict
from pathlib import Path
from typing import List, Dict, Optional
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
from src.models.detector import CompanionRiskDetector
from src.utils.metrics import detection_metrics, intervention_metrics
from src.utils.baselines import (
KeywordDetector,
RegexDetector,
CombinedRuleDetector,
rule_based_intervention,
threshold_intervention,
)
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
NUM_FINE,
PRIMARY_CATEGORY_LIST,
FINE_GRAINED_LABELS,
INTERVENTION_ACTIONS,
)
# ── Fine-grained label split ──────────────────────────────────────────────────
# PUBLIC: 10 labels covering R1/R2/R5-R9 — present in human-annotated subsets.
# COMPANION: 4 labels specific to companion AI (R3/R4/R10) — not in human subsets;
# evaluated separately on generated test split (see paper §4.3).
PUBLIC_FINE_LABELS = [
"DirectEncouragement",
"MethodFacilitation",
"RiskNormalization",
"Romanticization",
"FalseReassurance",
"PseudoTherapy",
"CrisisNonResponse",
"BoundaryFailure",
"AgeInappropriateIntimacy",
"PrivacySolicitation",
]
COMPANION_FINE_LABELS = [
"DependencyReinforcement",
"IsolationReinforcement",
"ManipulativeAttachment",
"CoRumination",
]
# Pre-compute column indices into FINE_GRAINED_LABELS for fast slicing
PUBLIC_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in PUBLIC_FINE_LABELS]
COMPANION_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in COMPANION_FINE_LABELS]
LABEL_FILTER_MAP = {
"all": list(range(NUM_FINE)),
"public": PUBLIC_FINE_IDX,
"companion": COMPANION_FINE_IDX,
}
# ── 数据源过滤 ────────────────────────────────────────────────────────────────
# 用来区分"人工标注"样本的来源标记id 前缀或 source 字段)
HUMAN_SOURCE_PREFIXES = ("suicide-", "cosafe-", "dices-")
HUMAN_SOURCE_NAMES = ("suicide_risk", "cosafe", "dices")
def filter_by_source(samples: List[Dict], source_filter: str) -> List[Dict]:
"""
source_filter:
'all' — 返回全部
'human' — 仅保留来自人工标注数据集的样本
'generated' — 仅保留 LLM 生成样本
判断依据sample["source"] 字段(若存在),否则用 id 前缀判断。
"""
if source_filter == "all":
return samples
filtered = []
for s in samples:
src = s.get("source", "")
sid = s.get("id", "")
is_human = (
any(src == name for name in HUMAN_SOURCE_NAMES) or
any(sid.startswith(pfx) for pfx in HUMAN_SOURCE_PREFIXES)
)
if source_filter == "human" and is_human:
filtered.append(s)
elif source_filter == "generated" and not is_human:
filtered.append(s)
return filtered
def get_obs_dim(detector_hidden: int) -> int:
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
# ── Detection evaluation ──────────────────────────────────────────────────────
def run_neural_detection(
model: CompanionRiskDetector,
tokenizer,
samples: List[Dict],
cfg: Dict,
device: str,
label_filter: str = "all",
) -> Dict:
"""Run the neural detector on test samples, compute binary + level + category metrics."""
model.eval()
y_true, y_pred = [], []
l_true, l_pred = [], []
fine_true_list, fine_pred_list = [], []
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
fine_threshold = cfg.get("evaluation", {}).get("fine_threshold", 0.4)
data_cfg = cfg.get("data", {})
# Per-category tracking
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
for sample in samples:
sample = validate_and_normalize(dict(sample))
texts = format_conversation(
sample["persona"], sample["history"],
sample["user_input"], sample["ai_response"],
max_history_turns=data_cfg.get("max_history_turns", 5),
)
def enc(text, max_len):
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], data_cfg.get("max_persona_len", 128))
c_enc = enc(texts["context_text"], data_cfg.get("max_context_len", 512))
r_enc = enc(texts["response_text"], data_cfg.get("max_response_len", 256))
with torch.no_grad():
preds = model.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
binary_threshold=binary_threshold,
fine_threshold=fine_threshold,
)
y_t = int(sample["y_risk"])
y_p = preds["y_risk"].item()
y_true.append(y_t)
y_pred.append(y_p)
l_true.append(int(sample["l_risk"]))
l_pred.append(preds["l_risk"].item())
# Fine-grained multi-label
# Ground truth
gt_fine = np.zeros(NUM_FINE)
for lbl in sample.get("c_fine", []):
from src.utils.taxonomy import label_to_index
if lbl in FINE_GRAINED_LABELS:
gt_fine[label_to_index(lbl)] = 1.0
fine_true_list.append(gt_fine)
# Prediction: predict() already applies sigmoid + fine_threshold → binary tensor
# Key is "c_fine", NOT "c_fine_logits" (that key does not exist in predict() output)
pred_fine = preds["c_fine"].cpu().numpy().flatten() # ensure 1D [NUM_FINE]
fine_pred_list.append(pred_fine)
# Per-category recall
cat = sample.get("c_primary", "None")
if cat in PRIMARY_CATEGORY_LIST:
cat_results[cat]["total"] += 1
if y_p == 1:
cat_results[cat]["detected"] += 1
# Build per-category recall dict
per_cat = {}
for cat in PRIMARY_CATEGORY_LIST:
total = cat_results[cat]["total"]
detected = cat_results[cat]["detected"]
recall = detected / total if total > 0 else 0.0
per_cat[cat] = {
"total": total,
"detected": detected,
"recall": round(recall, 4),
"miss_rate": round(1 - recall, 4),
}
fine_true_arr = np.array(fine_true_list) # [N, NUM_FINE]
fine_pred_arr = np.array(fine_pred_list) # [N, NUM_FINE]
# Apply label filter — restrict fine-label columns to the selected subset
col_idx = LABEL_FILTER_MAP.get(label_filter, list(range(NUM_FINE)))
if label_filter != "all" and len(col_idx) < NUM_FINE:
fine_true_arr = fine_true_arr[:, col_idx]
fine_pred_arr = fine_pred_arr[:, col_idx]
metrics = detection_metrics(
y_true, y_pred,
l_true, l_pred,
fine_true_arr, fine_pred_arr,
)
metrics["per_category_recall"] = per_cat
metrics["label_filter"] = label_filter
return metrics
def run_keyword_detection(
samples: List[Dict],
detector,
compute_cat_recall: bool = True,
) -> Dict:
y_true, y_pred = [], []
l_true, l_pred = [], []
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
for s in samples:
result = detector.detect(s.get("ai_response", ""))
y_t = int(s["y_risk"])
y_p = result["y_risk"]
y_true.append(y_t)
y_pred.append(y_p)
l_true.append(int(s["l_risk"]))
l_pred.append(result["l_risk"])
cat = s.get("c_primary", "None")
if cat in PRIMARY_CATEGORY_LIST:
cat_results[cat]["total"] += 1
if y_p == 1:
cat_results[cat]["detected"] += 1
metrics = detection_metrics(y_true, y_pred, l_true, l_pred)
if compute_cat_recall:
per_cat = {}
for cat in PRIMARY_CATEGORY_LIST:
total = cat_results[cat]["total"]
detected = cat_results[cat]["detected"]
recall = detected / total if total > 0 else 0.0
per_cat[cat] = {
"total": total, "detected": detected,
"recall": round(recall, 4), "miss_rate": round(1 - recall, 4),
}
metrics["per_category_recall"] = per_cat
return metrics
# ── Intervention evaluation ───────────────────────────────────────────────────
def _collect_c_primary_idx(processed_samples: List[Dict]) -> List[int]:
"""Extract ground-truth c_primary index for category-aware metrics."""
from src.utils.taxonomy import category_to_index, PRIMARY_CATEGORY_LIST
result = []
for s in processed_samples:
cat = s.get("c_primary", "None")
if cat in PRIMARY_CATEGORY_LIST:
result.append(category_to_index(cat))
else:
result.append(int(s.get("c_primary_idx", 0)))
return result
def run_rl_intervention(agent, processed_samples: List[Dict], device: str) -> Dict:
agent.eval()
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
for s in processed_samples:
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
with torch.no_grad():
action, _, _, _ = agent.get_action(obs, deterministic=True)
y_risk_true.append(int(s["y_risk"]))
l_risk_true.append(int(s["l_risk"]))
a_pred.append(action.item())
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
c_primary_idx = _collect_c_primary_idx(processed_samples)
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend, c_primary_idx)
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
# Policy input uses det_l_risk (detector prediction, as in real deployment)
a_pred = [policy_fn(int(s.get("det_l_risk", s["l_risk"]))) for s in processed_samples]
c_primary_idx = _collect_c_primary_idx(processed_samples)
return intervention_metrics(y_risk_true, l_risk_true, a_pred, c_primary_idx=c_primary_idx)
# ── Pretty printing ───────────────────────────────────────────────────────────
def print_metrics(name: str, metrics: Dict):
companion_specific = {"R3", "R4", "R10"}
print(f"\n{'=' * 55}")
print(f" {name}")
print(f"{'=' * 55}")
level_names = {0: "Safe", 1: "Mild", 2: "Moderate", 3: "High", 4: "Critical"}
# Determine which fine labels are active (respect label_filter stored in metrics)
lf = metrics.get("label_filter", "all")
active_fine_labels = (
PUBLIC_FINE_LABELS if lf == "public" else
COMPANION_FINE_LABELS if lf == "companion" else
FINE_GRAINED_LABELS
)
for k, v in metrics.items():
if k in ("label_filter",):
continue # printed separately or in header
if k == "per_category_recall":
print(f" {''*51}")
print(f" Per-category Recall (↑好,漏检率=1-recall):")
print(f" {''*51}")
for cat, m in sorted(v.items()):
flag = " ◀ companion特有" if cat in companion_specific else ""
print(f" {cat:4s}: recall={m['recall']:.3f} miss={m['miss_rate']:.3f}"
f" (n={m['total']:3d}){flag}")
elif k == "level_per_class_f1":
print(f" {''*51}")
print(f" Per-level F1 (诊断 level_macro_f1 下降原因):")
for i, f in enumerate(v):
print(f" L{i} {level_names[i]:10s}: {f:.4f}")
elif k == "fine_per_label_f1":
print(f" {''*51}")
n_lbl = len(active_fine_labels)
print(f" Per fine-label F1 ({n_lbl} labels, filter={lf}):")
for lbl, f in zip(active_fine_labels, v):
print(f" {lbl:35s}: {f:.4f}")
elif k == "per_level_action_dist":
print(f" {''*51}")
print(f" Per-level Action Distribution:")
print(f" {'Level':15s} {'n':>5} PASS WARN RWRT REJT CRISIS")
for lvl_name, lv in v.items():
dist_str = " ".join(f"{x:.3f}" for x in lv["action_dist"])
print(f" {lvl_name:15s} {lv['n']:>5} {dist_str}")
elif k == "per_category_action_dist":
print(f" {''*51}")
print(f" Per-category Action Distribution:")
print(f" {'Category':6s} {'n':>5} PASS WARN RWRT REJT CRISIS")
for cat_name, cv in v.items():
dist_str = " ".join(f"{x:.3f}" for x in cv["action_dist"])
print(f" {cat_name:6s} {cv['n']:>5} {dist_str}")
elif k == "exact_action_accuracy_by_level":
print(f" {''*51}")
print(f" Action Accuracy by Level (vs a_recommend):")
for lvl_name, acc in v.items():
print(f" {lvl_name:15s}: {acc:.4f}")
elif isinstance(v, float):
print(f" {k:35s}: {v:.4f}")
elif isinstance(v, list):
formatted = [f"{x:.3f}" for x in v]
print(f" {k:35s}: {formatted}")
else:
print(f" {k:35s}: {v}")
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="CompanionGuard-RL 完整评估")
parser.add_argument("--detector-ckpt", required=True,
help="检测器权重路径,例如 checkpoints/detector/best.pt")
parser.add_argument("--agent-ckpt", default=None,
help="RL agent 权重路径BC+PPO")
parser.add_argument("--bc-ckpt", default=None,
help="BC-only agent 权重路径BC 阶段后保存,用于 ablation 对照)")
parser.add_argument("--test-data",
default="data/processed/CompanionRisk-Bench/test.jsonl",
help="测试集路径")
parser.add_argument("--config", default="configs/detector_config.yaml",
help="检测器配置路径")
parser.add_argument("--intervention-config",
default="configs/intervention_config.yaml",
help="干预配置路径(仅 --agent-ckpt 存在时需要)")
parser.add_argument("--output", default="experiments/eval_results.json",
help="结果保存路径")
parser.add_argument("--source-filter", default="all",
choices=["all", "human", "generated"],
help=(
"样本过滤: "
"all=全部, "
"human=仅人工标注(suicide/cosafe/dices), "
"generated=仅LLM生成"
))
parser.add_argument("--label-filter", default="all",
choices=["all", "public", "companion"],
help=(
"细粒度标签集过滤 (fine_macro_f1 计算范围): "
"all=全部14标签, "
"public=10个通用标签(R1/R2/R5-R9人工子集可用), "
"companion=4个companion专属标签(R3/R4/R10)"
))
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
# 仅在需要 RL 评估时才读取干预配置,避免文件不存在时崩溃
int_cfg = {}
if args.agent_ckpt or args.bc_ckpt:
if not Path(args.intervention_config).exists():
print(f"[ERROR] agent-ckpt 指定但找不到干预配置: {args.intervention_config}")
return
with open(args.intervention_config) as f:
int_cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n{''*55}")
print(f" CompanionGuard-RL 完整评估")
print(f"{''*55}")
print(f" Device : {device}")
print(f" Test data : {args.test_data}")
print(f" Source filter : {args.source_filter}")
print(f" Label filter : {args.label_filter}")
print(f" Detector ckpt : {args.detector_ckpt}")
print(f" RL agent ckpt : {args.agent_ckpt or '(跳过)'}")
print(f" BC-only ckpt : {args.bc_ckpt or '(跳过)'}")
# ── 加载测试集 ──────────────────────────────────────────────────────────
all_samples = load_jsonl(args.test_data)
samples = filter_by_source(all_samples, args.source_filter)
print(f"\n 样本总数: {len(all_samples)} → 过滤后: {len(samples)} "
f"filter={args.source_filter}")
if not samples:
print("[ERROR] 过滤后样本为空,检查 --source-filter 或数据集中的 source/id 字段")
return
risky = sum(int(s["y_risk"]) for s in samples)
print(f" 有风险: {risky} 安全: {len(samples)-risky}")
# ── 加载模型 ──────────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
detector = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
num_heads=cfg["model"]["num_heads"],
dropout=cfg["model"]["dropout"],
use_lora=cfg["model"]["use_lora"],
).to(device)
if Path(args.detector_ckpt).exists():
detector.load_state_dict(
torch.load(args.detector_ckpt, map_location=device, weights_only=True)
)
print(f" Detector loaded: {args.detector_ckpt}")
else:
print(f"[WARN] Checkpoint 未找到: {args.detector_ckpt},使用随机权重")
all_results = {
"meta": {
"test_file": str(args.test_data),
"source_filter": args.source_filter,
"label_filter": args.label_filter,
"n_total": len(all_samples),
"n_filtered": len(samples),
"n_risky": risky,
}
}
# ── 检测基线评估 ──────────────────────────────────────────────────────
print(f"\n{''*55}")
print(" Detection Task检测任务")
print(f"{''*55}")
print("\nRunning: L1a — Keyword Detector")
kw_m = run_keyword_detection(samples, KeywordDetector())
print_metrics("L1a: Keyword Detector", kw_m)
all_results["L1a_keyword"] = kw_m
print("\nRunning: L1b — Regex Detector")
re_m = run_keyword_detection(samples, RegexDetector())
print_metrics("L1b: Regex Detector", re_m)
all_results["L1b_regex"] = re_m
print("\nRunning: L1c — Combined Keyword+Regex")
cb_m = run_keyword_detection(samples, CombinedRuleDetector())
print_metrics("L1c: Combined Detector", cb_m)
all_results["L1c_combined"] = cb_m
print("\nRunning: Ours — CompanionRiskDetector (Module B)")
neural_m = run_neural_detection(
detector, tokenizer, samples, cfg, device,
label_filter=args.label_filter,
)
print_metrics("Ours: CompanionRiskDetector", neural_m)
all_results["ours_detection"] = neural_m
# ── 综合对比表 ────────────────────────────────────────────────────────
print(f"\n{''*55}")
print(" Detection Summary对比汇总")
print(f"{''*55}")
print(f" {'Method':25s} {'BinF1':>7} {'Recall':>7} {'FNR':>7} {'LvlF1':>7}")
print(f" {''*53}")
for tag, m in [
("L1a_keyword", all_results["L1a_keyword"]),
("L1b_regex", all_results["L1b_regex"]),
("L1c_combined",all_results["L1c_combined"]),
("Ours", all_results["ours_detection"]),
]:
bf1 = m.get("binary_f1", float("nan"))
rec = m.get("high_risk_recall", float("nan"))
fnr = m.get("false_negative_rate", float("nan"))
lvlf1 = m.get("level_macro_f1", float("nan"))
print(f" {tag:25s} {bf1:7.4f} {rec:7.4f} {fnr:7.4f} {lvlf1:7.4f}")
# ── 干预评估(仅当提供 agent_ckpt 或 bc_ckpt 时)─────────────────────
if args.agent_ckpt or args.bc_ckpt:
print(f"\n{''*55}")
print(" Intervention Task干预任务")
print(f"{''*55}")
print(" Preprocessing samples with detector for RL state...")
processed = preprocess_samples_with_detector(
samples, detector, tokenizer, device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
from src.models.intervention_agent import InterventionAgent
detector_hidden = cfg["model"]["hidden_size"]
def _load_agent(ckpt_path: str) -> "InterventionAgent":
ag = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=int_cfg["agent"]["state_hidden"],
dropout=int_cfg["agent"]["dropout"],
).to(device)
if Path(ckpt_path).exists():
ag.load_state_dict(
torch.load(ckpt_path, map_location=device, weights_only=True)
)
print(f" Agent loaded: {ckpt_path}")
else:
print(f"[WARN] Checkpoint 未找到: {ckpt_path}")
return ag
print("\nRunning: Rule-based Intervention (l_risk≥3→REJECT)")
rule_m = run_rule_intervention(processed, rule_based_intervention)
print_metrics("Rule-based Policy", rule_m)
all_results["baseline_rule"] = rule_m
print("\nRunning: Threshold Intervention Baseline")
thr_m = run_rule_intervention(processed, threshold_intervention)
print_metrics("Threshold Policy", thr_m)
all_results["baseline_threshold"] = thr_m
if args.bc_ckpt:
print("\nRunning: BC-only Intervention Policy (ablation)")
bc_agent = _load_agent(args.bc_ckpt)
bc_m = run_rl_intervention(bc_agent, processed, device)
print_metrics("BC-only Policy (ablation)", bc_m)
all_results["bc_only_intervention"] = bc_m
if args.agent_ckpt:
print("\nRunning: Ours — RL Intervention Policy (BC+PPO, Module C)")
rl_agent = _load_agent(args.agent_ckpt)
rl_m = run_rl_intervention(rl_agent, processed, device)
print_metrics("Ours: RL Intervention Policy", rl_m)
all_results["ours_intervention"] = rl_m
# ── 保存结果 ──────────────────────────────────────────────────────────
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
print(f"\n{''*55}")
print(f" 结果已保存: {args.output}")
print(f"{''*55}\n")
if __name__ == "__main__":
main()