""" Full evaluation script for CompanionGuard-RL. Runs detection and intervention evaluations against multiple baselines: Detection baselines: - L1a: Keyword detector - L1b: Regex detector - L1c: Combined keyword+regex - Ours: CompanionRiskDetector (Module B) Intervention baselines: - Rule-based (l_risk >= 3 → REJECT) - Threshold policy (per-level mapping) - RL policy (Module C, Ours) [requires --agent-ckpt] Usage (detection only, no RL agent yet): python scripts/evaluate.py \\ --detector-ckpt checkpoints/detector/best.pt \\ --test-data data/processed/CompanionRisk-Bench/test.jsonl \\ --config configs/detector_config.yaml Usage (with RL intervention agent): python scripts/evaluate.py \\ --detector-ckpt checkpoints/detector/best.pt \\ --agent-ckpt checkpoints/intervention/final.pt \\ --test-data data/processed/CompanionRisk-Bench/test.jsonl \\ --config configs/detector_config.yaml \\ --intervention-config configs/intervention_config.yaml Usage (human-annotated subset only): python scripts/evaluate.py \\ --detector-ckpt checkpoints/detector/best.pt \\ --source-filter human \\ --test-data data/processed/CompanionRisk-Bench/test.jsonl \\ --config configs/detector_config.yaml """ import argparse import json import yaml import torch import numpy as np from collections import defaultdict from pathlib import Path from typing import List, Dict, Optional from transformers import AutoTokenizer from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize from src.models.detector import CompanionRiskDetector from src.utils.metrics import detection_metrics, intervention_metrics from src.utils.baselines import ( KeywordDetector, RegexDetector, CombinedRuleDetector, rule_based_intervention, threshold_intervention, ) from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector from src.utils.taxonomy import ( ACTION_NAME_TO_ID, NUM_RISK_LEVELS, NUM_PRIMARY, NUM_FINE, PRIMARY_CATEGORY_LIST, FINE_GRAINED_LABELS, INTERVENTION_ACTIONS, ) # ── Fine-grained label split ────────────────────────────────────────────────── # PUBLIC: 10 labels covering R1/R2/R5-R9 — present in human-annotated subsets. # COMPANION: 4 labels specific to companion AI (R3/R4/R10) — not in human subsets; # evaluated separately on generated test split (see paper §4.3). PUBLIC_FINE_LABELS = [ "DirectEncouragement", "MethodFacilitation", "RiskNormalization", "Romanticization", "FalseReassurance", "PseudoTherapy", "CrisisNonResponse", "BoundaryFailure", "AgeInappropriateIntimacy", "PrivacySolicitation", ] COMPANION_FINE_LABELS = [ "DependencyReinforcement", "IsolationReinforcement", "ManipulativeAttachment", "CoRumination", ] # Pre-compute column indices into FINE_GRAINED_LABELS for fast slicing PUBLIC_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in PUBLIC_FINE_LABELS] COMPANION_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in COMPANION_FINE_LABELS] LABEL_FILTER_MAP = { "all": list(range(NUM_FINE)), "public": PUBLIC_FINE_IDX, "companion": COMPANION_FINE_IDX, } # ── 数据源过滤 ──────────────────────────────────────────────────────────────── # 用来区分"人工标注"样本的来源标记(id 前缀或 source 字段) HUMAN_SOURCE_PREFIXES = ("suicide-", "cosafe-", "dices-") HUMAN_SOURCE_NAMES = ("suicide_risk", "cosafe", "dices") def filter_by_source(samples: List[Dict], source_filter: str) -> List[Dict]: """ source_filter: 'all' — 返回全部 'human' — 仅保留来自人工标注数据集的样本 'generated' — 仅保留 LLM 生成样本 判断依据:sample["source"] 字段(若存在),否则用 id 前缀判断。 """ if source_filter == "all": return samples filtered = [] for s in samples: src = s.get("source", "") sid = s.get("id", "") is_human = ( any(src == name for name in HUMAN_SOURCE_NAMES) or any(sid.startswith(pfx) for pfx in HUMAN_SOURCE_PREFIXES) ) if source_filter == "human" and is_human: filtered.append(s) elif source_filter == "generated" and not is_human: filtered.append(s) return filtered def get_obs_dim(detector_hidden: int) -> int: return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 # ── Detection evaluation ────────────────────────────────────────────────────── def run_neural_detection( model: CompanionRiskDetector, tokenizer, samples: List[Dict], cfg: Dict, device: str, label_filter: str = "all", ) -> Dict: """Run the neural detector on test samples, compute binary + level + category metrics.""" model.eval() y_true, y_pred = [], [] l_true, l_pred = [], [] fine_true_list, fine_pred_list = [], [] binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5) fine_threshold = cfg.get("evaluation", {}).get("fine_threshold", 0.4) data_cfg = cfg.get("data", {}) # Per-category tracking cat_results = defaultdict(lambda: {"total": 0, "detected": 0}) for sample in samples: sample = validate_and_normalize(dict(sample)) texts = format_conversation( sample["persona"], sample["history"], sample["user_input"], sample["ai_response"], max_history_turns=data_cfg.get("max_history_turns", 5), ) def enc(text, max_len): return tokenizer( text, max_length=max_len, truncation=True, padding="max_length", return_tensors="pt", ) p_enc = enc(texts["persona_text"], data_cfg.get("max_persona_len", 128)) c_enc = enc(texts["context_text"], data_cfg.get("max_context_len", 512)) r_enc = enc(texts["response_text"], data_cfg.get("max_response_len", 256)) with torch.no_grad(): preds = model.predict( p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device), c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device), r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device), binary_threshold=binary_threshold, fine_threshold=fine_threshold, ) y_t = int(sample["y_risk"]) y_p = preds["y_risk"].item() y_true.append(y_t) y_pred.append(y_p) l_true.append(int(sample["l_risk"])) l_pred.append(preds["l_risk"].item()) # Fine-grained multi-label # Ground truth gt_fine = np.zeros(NUM_FINE) for lbl in sample.get("c_fine", []): from src.utils.taxonomy import label_to_index if lbl in FINE_GRAINED_LABELS: gt_fine[label_to_index(lbl)] = 1.0 fine_true_list.append(gt_fine) # Prediction: predict() already applies sigmoid + fine_threshold → binary tensor # Key is "c_fine", NOT "c_fine_logits" (that key does not exist in predict() output) pred_fine = preds["c_fine"].cpu().numpy().flatten() # ensure 1D [NUM_FINE] fine_pred_list.append(pred_fine) # Per-category recall cat = sample.get("c_primary", "None") if cat in PRIMARY_CATEGORY_LIST: cat_results[cat]["total"] += 1 if y_p == 1: cat_results[cat]["detected"] += 1 # Build per-category recall dict per_cat = {} for cat in PRIMARY_CATEGORY_LIST: total = cat_results[cat]["total"] detected = cat_results[cat]["detected"] recall = detected / total if total > 0 else 0.0 per_cat[cat] = { "total": total, "detected": detected, "recall": round(recall, 4), "miss_rate": round(1 - recall, 4), } fine_true_arr = np.array(fine_true_list) # [N, NUM_FINE] fine_pred_arr = np.array(fine_pred_list) # [N, NUM_FINE] # Apply label filter — restrict fine-label columns to the selected subset col_idx = LABEL_FILTER_MAP.get(label_filter, list(range(NUM_FINE))) if label_filter != "all" and len(col_idx) < NUM_FINE: fine_true_arr = fine_true_arr[:, col_idx] fine_pred_arr = fine_pred_arr[:, col_idx] metrics = detection_metrics( y_true, y_pred, l_true, l_pred, fine_true_arr, fine_pred_arr, ) metrics["per_category_recall"] = per_cat metrics["label_filter"] = label_filter return metrics def run_keyword_detection( samples: List[Dict], detector, compute_cat_recall: bool = True, ) -> Dict: y_true, y_pred = [], [] l_true, l_pred = [], [] cat_results = defaultdict(lambda: {"total": 0, "detected": 0}) for s in samples: result = detector.detect(s.get("ai_response", "")) y_t = int(s["y_risk"]) y_p = result["y_risk"] y_true.append(y_t) y_pred.append(y_p) l_true.append(int(s["l_risk"])) l_pred.append(result["l_risk"]) cat = s.get("c_primary", "None") if cat in PRIMARY_CATEGORY_LIST: cat_results[cat]["total"] += 1 if y_p == 1: cat_results[cat]["detected"] += 1 metrics = detection_metrics(y_true, y_pred, l_true, l_pred) if compute_cat_recall: per_cat = {} for cat in PRIMARY_CATEGORY_LIST: total = cat_results[cat]["total"] detected = cat_results[cat]["detected"] recall = detected / total if total > 0 else 0.0 per_cat[cat] = { "total": total, "detected": detected, "recall": round(recall, 4), "miss_rate": round(1 - recall, 4), } metrics["per_category_recall"] = per_cat return metrics # ── Intervention evaluation ─────────────────────────────────────────────────── def _collect_c_primary_idx(processed_samples: List[Dict]) -> List[int]: """Extract ground-truth c_primary index for category-aware metrics.""" from src.utils.taxonomy import category_to_index, PRIMARY_CATEGORY_LIST result = [] for s in processed_samples: cat = s.get("c_primary", "None") if cat in PRIMARY_CATEGORY_LIST: result.append(category_to_index(cat)) else: result.append(int(s.get("c_primary_idx", 0))) return result def run_rl_intervention(agent, processed_samples: List[Dict], device: str) -> Dict: agent.eval() y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], [] for s in processed_samples: obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device) with torch.no_grad(): action, _, _, _ = agent.get_action(obs, deterministic=True) y_risk_true.append(int(s["y_risk"])) l_risk_true.append(int(s["l_risk"])) a_pred.append(action.item()) a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0)) c_primary_idx = _collect_c_primary_idx(processed_samples) return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend, c_primary_idx) def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict: y_risk_true = [int(s["y_risk"]) for s in processed_samples] l_risk_true = [int(s["l_risk"]) for s in processed_samples] # Policy input uses det_l_risk (detector prediction, as in real deployment) a_pred = [policy_fn(int(s.get("det_l_risk", s["l_risk"]))) for s in processed_samples] c_primary_idx = _collect_c_primary_idx(processed_samples) return intervention_metrics(y_risk_true, l_risk_true, a_pred, c_primary_idx=c_primary_idx) # ── Pretty printing ─────────────────────────────────────────────────────────── def print_metrics(name: str, metrics: Dict): companion_specific = {"R3", "R4", "R10"} print(f"\n{'=' * 55}") print(f" {name}") print(f"{'=' * 55}") level_names = {0: "Safe", 1: "Mild", 2: "Moderate", 3: "High", 4: "Critical"} # Determine which fine labels are active (respect label_filter stored in metrics) lf = metrics.get("label_filter", "all") active_fine_labels = ( PUBLIC_FINE_LABELS if lf == "public" else COMPANION_FINE_LABELS if lf == "companion" else FINE_GRAINED_LABELS ) for k, v in metrics.items(): if k in ("label_filter",): continue # printed separately or in header if k == "per_category_recall": print(f" {'─'*51}") print(f" Per-category Recall (↑好,漏检率=1-recall):") print(f" {'─'*51}") for cat, m in sorted(v.items()): flag = " ◀ companion特有" if cat in companion_specific else "" print(f" {cat:4s}: recall={m['recall']:.3f} miss={m['miss_rate']:.3f}" f" (n={m['total']:3d}){flag}") elif k == "level_per_class_f1": print(f" {'─'*51}") print(f" Per-level F1 (诊断 level_macro_f1 下降原因):") for i, f in enumerate(v): print(f" L{i} {level_names[i]:10s}: {f:.4f}") elif k == "fine_per_label_f1": print(f" {'─'*51}") n_lbl = len(active_fine_labels) print(f" Per fine-label F1 ({n_lbl} labels, filter={lf}):") for lbl, f in zip(active_fine_labels, v): print(f" {lbl:35s}: {f:.4f}") elif k == "per_level_action_dist": print(f" {'─'*51}") print(f" Per-level Action Distribution:") print(f" {'Level':15s} {'n':>5} PASS WARN RWRT REJT CRISIS") for lvl_name, lv in v.items(): dist_str = " ".join(f"{x:.3f}" for x in lv["action_dist"]) print(f" {lvl_name:15s} {lv['n']:>5} {dist_str}") elif k == "per_category_action_dist": print(f" {'─'*51}") print(f" Per-category Action Distribution:") print(f" {'Category':6s} {'n':>5} PASS WARN RWRT REJT CRISIS") for cat_name, cv in v.items(): dist_str = " ".join(f"{x:.3f}" for x in cv["action_dist"]) print(f" {cat_name:6s} {cv['n']:>5} {dist_str}") elif k == "exact_action_accuracy_by_level": print(f" {'─'*51}") print(f" Action Accuracy by Level (vs a_recommend):") for lvl_name, acc in v.items(): print(f" {lvl_name:15s}: {acc:.4f}") elif isinstance(v, float): print(f" {k:35s}: {v:.4f}") elif isinstance(v, list): formatted = [f"{x:.3f}" for x in v] print(f" {k:35s}: {formatted}") else: print(f" {k:35s}: {v}") # ── Main ────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="CompanionGuard-RL 完整评估") parser.add_argument("--detector-ckpt", required=True, help="检测器权重路径,例如 checkpoints/detector/best.pt") parser.add_argument("--agent-ckpt", default=None, help="RL agent 权重路径(BC+PPO)") parser.add_argument("--bc-ckpt", default=None, help="BC-only agent 权重路径(BC 阶段后保存,用于 ablation 对照)") parser.add_argument("--test-data", default="data/processed/CompanionRisk-Bench/test.jsonl", help="测试集路径") parser.add_argument("--config", default="configs/detector_config.yaml", help="检测器配置路径") parser.add_argument("--intervention-config", default="configs/intervention_config.yaml", help="干预配置路径(仅 --agent-ckpt 存在时需要)") parser.add_argument("--output", default="experiments/eval_results.json", help="结果保存路径") parser.add_argument("--source-filter", default="all", choices=["all", "human", "generated"], help=( "样本过滤: " "all=全部, " "human=仅人工标注(suicide/cosafe/dices), " "generated=仅LLM生成" )) parser.add_argument("--label-filter", default="all", choices=["all", "public", "companion"], help=( "细粒度标签集过滤 (fine_macro_f1 计算范围): " "all=全部14标签, " "public=10个通用标签(R1/R2/R5-R9,人工子集可用), " "companion=4个companion专属标签(R3/R4/R10)" )) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # 仅在需要 RL 评估时才读取干预配置,避免文件不存在时崩溃 int_cfg = {} if args.agent_ckpt or args.bc_ckpt: if not Path(args.intervention_config).exists(): print(f"[ERROR] agent-ckpt 指定但找不到干预配置: {args.intervention_config}") return with open(args.intervention_config) as f: int_cfg = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"\n{'═'*55}") print(f" CompanionGuard-RL 完整评估") print(f"{'═'*55}") print(f" Device : {device}") print(f" Test data : {args.test_data}") print(f" Source filter : {args.source_filter}") print(f" Label filter : {args.label_filter}") print(f" Detector ckpt : {args.detector_ckpt}") print(f" RL agent ckpt : {args.agent_ckpt or '(跳过)'}") print(f" BC-only ckpt : {args.bc_ckpt or '(跳过)'}") # ── 加载测试集 ────────────────────────────────────────────────────────── all_samples = load_jsonl(args.test_data) samples = filter_by_source(all_samples, args.source_filter) print(f"\n 样本总数: {len(all_samples)} → 过滤后: {len(samples)} " f"(filter={args.source_filter})") if not samples: print("[ERROR] 过滤后样本为空,检查 --source-filter 或数据集中的 source/id 字段") return risky = sum(int(s["y_risk"]) for s in samples) print(f" 有风险: {risky} 安全: {len(samples)-risky}") # ── 加载模型 ────────────────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"]) detector = CompanionRiskDetector( model_name=cfg["model"]["name"], hidden_size=cfg["model"]["hidden_size"], num_heads=cfg["model"]["num_heads"], dropout=cfg["model"]["dropout"], use_lora=cfg["model"]["use_lora"], ).to(device) if Path(args.detector_ckpt).exists(): detector.load_state_dict( torch.load(args.detector_ckpt, map_location=device, weights_only=True) ) print(f" Detector loaded: {args.detector_ckpt}") else: print(f"[WARN] Checkpoint 未找到: {args.detector_ckpt},使用随机权重") all_results = { "meta": { "test_file": str(args.test_data), "source_filter": args.source_filter, "label_filter": args.label_filter, "n_total": len(all_samples), "n_filtered": len(samples), "n_risky": risky, } } # ── 检测基线评估 ────────────────────────────────────────────────────── print(f"\n{'─'*55}") print(" Detection Task(检测任务)") print(f"{'─'*55}") print("\nRunning: L1a — Keyword Detector") kw_m = run_keyword_detection(samples, KeywordDetector()) print_metrics("L1a: Keyword Detector", kw_m) all_results["L1a_keyword"] = kw_m print("\nRunning: L1b — Regex Detector") re_m = run_keyword_detection(samples, RegexDetector()) print_metrics("L1b: Regex Detector", re_m) all_results["L1b_regex"] = re_m print("\nRunning: L1c — Combined Keyword+Regex") cb_m = run_keyword_detection(samples, CombinedRuleDetector()) print_metrics("L1c: Combined Detector", cb_m) all_results["L1c_combined"] = cb_m print("\nRunning: Ours — CompanionRiskDetector (Module B)") neural_m = run_neural_detection( detector, tokenizer, samples, cfg, device, label_filter=args.label_filter, ) print_metrics("Ours: CompanionRiskDetector", neural_m) all_results["ours_detection"] = neural_m # ── 综合对比表 ──────────────────────────────────────────────────────── print(f"\n{'─'*55}") print(" Detection Summary(对比汇总)") print(f"{'─'*55}") print(f" {'Method':25s} {'BinF1':>7} {'Recall':>7} {'FNR':>7} {'LvlF1':>7}") print(f" {'─'*53}") for tag, m in [ ("L1a_keyword", all_results["L1a_keyword"]), ("L1b_regex", all_results["L1b_regex"]), ("L1c_combined",all_results["L1c_combined"]), ("Ours", all_results["ours_detection"]), ]: bf1 = m.get("binary_f1", float("nan")) rec = m.get("high_risk_recall", float("nan")) fnr = m.get("false_negative_rate", float("nan")) lvlf1 = m.get("level_macro_f1", float("nan")) print(f" {tag:25s} {bf1:7.4f} {rec:7.4f} {fnr:7.4f} {lvlf1:7.4f}") # ── 干预评估(仅当提供 agent_ckpt 或 bc_ckpt 时)───────────────────── if args.agent_ckpt or args.bc_ckpt: print(f"\n{'─'*55}") print(" Intervention Task(干预任务)") print(f"{'─'*55}") print(" Preprocessing samples with detector for RL state...") processed = preprocess_samples_with_detector( samples, detector, tokenizer, device=device, binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5), ) from src.models.intervention_agent import InterventionAgent detector_hidden = cfg["model"]["hidden_size"] def _load_agent(ckpt_path: str) -> "InterventionAgent": ag = InterventionAgent( detector_hidden=detector_hidden, state_hidden=int_cfg["agent"]["state_hidden"], dropout=int_cfg["agent"]["dropout"], ).to(device) if Path(ckpt_path).exists(): ag.load_state_dict( torch.load(ckpt_path, map_location=device, weights_only=True) ) print(f" Agent loaded: {ckpt_path}") else: print(f"[WARN] Checkpoint 未找到: {ckpt_path}") return ag print("\nRunning: Rule-based Intervention (l_risk≥3→REJECT)") rule_m = run_rule_intervention(processed, rule_based_intervention) print_metrics("Rule-based Policy", rule_m) all_results["baseline_rule"] = rule_m print("\nRunning: Threshold Intervention Baseline") thr_m = run_rule_intervention(processed, threshold_intervention) print_metrics("Threshold Policy", thr_m) all_results["baseline_threshold"] = thr_m if args.bc_ckpt: print("\nRunning: BC-only Intervention Policy (ablation)") bc_agent = _load_agent(args.bc_ckpt) bc_m = run_rl_intervention(bc_agent, processed, device) print_metrics("BC-only Policy (ablation)", bc_m) all_results["bc_only_intervention"] = bc_m if args.agent_ckpt: print("\nRunning: Ours — RL Intervention Policy (BC+PPO, Module C)") rl_agent = _load_agent(args.agent_ckpt) rl_m = run_rl_intervention(rl_agent, processed, device) print_metrics("Ours: RL Intervention Policy", rl_m) all_results["ours_intervention"] = rl_m # ── 保存结果 ────────────────────────────────────────────────────────── Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w", encoding="utf-8") as f: json.dump(all_results, f, indent=2, default=str, ensure_ascii=False) print(f"\n{'═'*55}") print(f" 结果已保存: {args.output}") print(f"{'═'*55}\n") if __name__ == "__main__": main()