Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
613 lines
25 KiB
Python
613 lines
25 KiB
Python
"""
|
||
Full evaluation script for CompanionGuard-RL.
|
||
|
||
Runs detection and intervention evaluations against multiple baselines:
|
||
|
||
Detection baselines:
|
||
- L1a: Keyword detector
|
||
- L1b: Regex detector
|
||
- L1c: Combined keyword+regex
|
||
- Ours: CompanionRiskDetector (Module B)
|
||
|
||
Intervention baselines:
|
||
- Rule-based (l_risk >= 3 → REJECT)
|
||
- Threshold policy (per-level mapping)
|
||
- RL policy (Module C, Ours) [requires --agent-ckpt]
|
||
|
||
Usage (detection only, no RL agent yet):
|
||
python scripts/evaluate.py \\
|
||
--detector-ckpt checkpoints/detector/best.pt \\
|
||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||
--config configs/detector_config.yaml
|
||
|
||
Usage (with RL intervention agent):
|
||
python scripts/evaluate.py \\
|
||
--detector-ckpt checkpoints/detector/best.pt \\
|
||
--agent-ckpt checkpoints/intervention/final.pt \\
|
||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||
--config configs/detector_config.yaml \\
|
||
--intervention-config configs/intervention_config.yaml
|
||
|
||
Usage (human-annotated subset only):
|
||
python scripts/evaluate.py \\
|
||
--detector-ckpt checkpoints/detector/best.pt \\
|
||
--source-filter human \\
|
||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||
--config configs/detector_config.yaml
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import yaml
|
||
import torch
|
||
import numpy as np
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
from typing import List, Dict, Optional
|
||
from transformers import AutoTokenizer
|
||
|
||
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
|
||
from src.models.detector import CompanionRiskDetector
|
||
from src.utils.metrics import detection_metrics, intervention_metrics
|
||
from src.utils.baselines import (
|
||
KeywordDetector,
|
||
RegexDetector,
|
||
CombinedRuleDetector,
|
||
rule_based_intervention,
|
||
threshold_intervention,
|
||
)
|
||
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
|
||
from src.utils.taxonomy import (
|
||
ACTION_NAME_TO_ID,
|
||
NUM_RISK_LEVELS,
|
||
NUM_PRIMARY,
|
||
NUM_FINE,
|
||
PRIMARY_CATEGORY_LIST,
|
||
FINE_GRAINED_LABELS,
|
||
INTERVENTION_ACTIONS,
|
||
)
|
||
|
||
# ── Fine-grained label split ──────────────────────────────────────────────────
|
||
# PUBLIC: 10 labels covering R1/R2/R5-R9 — present in human-annotated subsets.
|
||
# COMPANION: 4 labels specific to companion AI (R3/R4/R10) — not in human subsets;
|
||
# evaluated separately on generated test split (see paper §4.3).
|
||
PUBLIC_FINE_LABELS = [
|
||
"DirectEncouragement",
|
||
"MethodFacilitation",
|
||
"RiskNormalization",
|
||
"Romanticization",
|
||
"FalseReassurance",
|
||
"PseudoTherapy",
|
||
"CrisisNonResponse",
|
||
"BoundaryFailure",
|
||
"AgeInappropriateIntimacy",
|
||
"PrivacySolicitation",
|
||
]
|
||
COMPANION_FINE_LABELS = [
|
||
"DependencyReinforcement",
|
||
"IsolationReinforcement",
|
||
"ManipulativeAttachment",
|
||
"CoRumination",
|
||
]
|
||
|
||
# Pre-compute column indices into FINE_GRAINED_LABELS for fast slicing
|
||
PUBLIC_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in PUBLIC_FINE_LABELS]
|
||
COMPANION_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in COMPANION_FINE_LABELS]
|
||
|
||
LABEL_FILTER_MAP = {
|
||
"all": list(range(NUM_FINE)),
|
||
"public": PUBLIC_FINE_IDX,
|
||
"companion": COMPANION_FINE_IDX,
|
||
}
|
||
|
||
|
||
# ── 数据源过滤 ────────────────────────────────────────────────────────────────
|
||
|
||
# 用来区分"人工标注"样本的来源标记(id 前缀或 source 字段)
|
||
HUMAN_SOURCE_PREFIXES = ("suicide-", "cosafe-", "dices-")
|
||
HUMAN_SOURCE_NAMES = ("suicide_risk", "cosafe", "dices")
|
||
|
||
def filter_by_source(samples: List[Dict], source_filter: str) -> List[Dict]:
|
||
"""
|
||
source_filter:
|
||
'all' — 返回全部
|
||
'human' — 仅保留来自人工标注数据集的样本
|
||
'generated' — 仅保留 LLM 生成样本
|
||
判断依据:sample["source"] 字段(若存在),否则用 id 前缀判断。
|
||
"""
|
||
if source_filter == "all":
|
||
return samples
|
||
|
||
filtered = []
|
||
for s in samples:
|
||
src = s.get("source", "")
|
||
sid = s.get("id", "")
|
||
is_human = (
|
||
any(src == name for name in HUMAN_SOURCE_NAMES) or
|
||
any(sid.startswith(pfx) for pfx in HUMAN_SOURCE_PREFIXES)
|
||
)
|
||
if source_filter == "human" and is_human:
|
||
filtered.append(s)
|
||
elif source_filter == "generated" and not is_human:
|
||
filtered.append(s)
|
||
return filtered
|
||
|
||
|
||
def get_obs_dim(detector_hidden: int) -> int:
|
||
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||
|
||
|
||
# ── Detection evaluation ──────────────────────────────────────────────────────
|
||
|
||
def run_neural_detection(
|
||
model: CompanionRiskDetector,
|
||
tokenizer,
|
||
samples: List[Dict],
|
||
cfg: Dict,
|
||
device: str,
|
||
label_filter: str = "all",
|
||
) -> Dict:
|
||
"""Run the neural detector on test samples, compute binary + level + category metrics."""
|
||
model.eval()
|
||
y_true, y_pred = [], []
|
||
l_true, l_pred = [], []
|
||
fine_true_list, fine_pred_list = [], []
|
||
|
||
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
|
||
fine_threshold = cfg.get("evaluation", {}).get("fine_threshold", 0.4)
|
||
data_cfg = cfg.get("data", {})
|
||
|
||
# Per-category tracking
|
||
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
|
||
|
||
for sample in samples:
|
||
sample = validate_and_normalize(dict(sample))
|
||
texts = format_conversation(
|
||
sample["persona"], sample["history"],
|
||
sample["user_input"], sample["ai_response"],
|
||
max_history_turns=data_cfg.get("max_history_turns", 5),
|
||
)
|
||
|
||
def enc(text, max_len):
|
||
return tokenizer(
|
||
text, max_length=max_len, truncation=True,
|
||
padding="max_length", return_tensors="pt",
|
||
)
|
||
|
||
p_enc = enc(texts["persona_text"], data_cfg.get("max_persona_len", 128))
|
||
c_enc = enc(texts["context_text"], data_cfg.get("max_context_len", 512))
|
||
r_enc = enc(texts["response_text"], data_cfg.get("max_response_len", 256))
|
||
|
||
with torch.no_grad():
|
||
preds = model.predict(
|
||
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
||
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
||
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
||
binary_threshold=binary_threshold,
|
||
fine_threshold=fine_threshold,
|
||
)
|
||
|
||
y_t = int(sample["y_risk"])
|
||
y_p = preds["y_risk"].item()
|
||
y_true.append(y_t)
|
||
y_pred.append(y_p)
|
||
l_true.append(int(sample["l_risk"]))
|
||
l_pred.append(preds["l_risk"].item())
|
||
|
||
# Fine-grained multi-label
|
||
# Ground truth
|
||
gt_fine = np.zeros(NUM_FINE)
|
||
for lbl in sample.get("c_fine", []):
|
||
from src.utils.taxonomy import label_to_index
|
||
if lbl in FINE_GRAINED_LABELS:
|
||
gt_fine[label_to_index(lbl)] = 1.0
|
||
fine_true_list.append(gt_fine)
|
||
|
||
# Prediction: predict() already applies sigmoid + fine_threshold → binary tensor
|
||
# Key is "c_fine", NOT "c_fine_logits" (that key does not exist in predict() output)
|
||
pred_fine = preds["c_fine"].cpu().numpy().flatten() # ensure 1D [NUM_FINE]
|
||
fine_pred_list.append(pred_fine)
|
||
|
||
# Per-category recall
|
||
cat = sample.get("c_primary", "None")
|
||
if cat in PRIMARY_CATEGORY_LIST:
|
||
cat_results[cat]["total"] += 1
|
||
if y_p == 1:
|
||
cat_results[cat]["detected"] += 1
|
||
|
||
# Build per-category recall dict
|
||
per_cat = {}
|
||
for cat in PRIMARY_CATEGORY_LIST:
|
||
total = cat_results[cat]["total"]
|
||
detected = cat_results[cat]["detected"]
|
||
recall = detected / total if total > 0 else 0.0
|
||
per_cat[cat] = {
|
||
"total": total,
|
||
"detected": detected,
|
||
"recall": round(recall, 4),
|
||
"miss_rate": round(1 - recall, 4),
|
||
}
|
||
|
||
fine_true_arr = np.array(fine_true_list) # [N, NUM_FINE]
|
||
fine_pred_arr = np.array(fine_pred_list) # [N, NUM_FINE]
|
||
|
||
# Apply label filter — restrict fine-label columns to the selected subset
|
||
col_idx = LABEL_FILTER_MAP.get(label_filter, list(range(NUM_FINE)))
|
||
if label_filter != "all" and len(col_idx) < NUM_FINE:
|
||
fine_true_arr = fine_true_arr[:, col_idx]
|
||
fine_pred_arr = fine_pred_arr[:, col_idx]
|
||
|
||
metrics = detection_metrics(
|
||
y_true, y_pred,
|
||
l_true, l_pred,
|
||
fine_true_arr, fine_pred_arr,
|
||
)
|
||
metrics["per_category_recall"] = per_cat
|
||
metrics["label_filter"] = label_filter
|
||
return metrics
|
||
|
||
|
||
def run_keyword_detection(
|
||
samples: List[Dict],
|
||
detector,
|
||
compute_cat_recall: bool = True,
|
||
) -> Dict:
|
||
y_true, y_pred = [], []
|
||
l_true, l_pred = [], []
|
||
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
|
||
|
||
for s in samples:
|
||
result = detector.detect(s.get("ai_response", ""))
|
||
y_t = int(s["y_risk"])
|
||
y_p = result["y_risk"]
|
||
y_true.append(y_t)
|
||
y_pred.append(y_p)
|
||
l_true.append(int(s["l_risk"]))
|
||
l_pred.append(result["l_risk"])
|
||
|
||
cat = s.get("c_primary", "None")
|
||
if cat in PRIMARY_CATEGORY_LIST:
|
||
cat_results[cat]["total"] += 1
|
||
if y_p == 1:
|
||
cat_results[cat]["detected"] += 1
|
||
|
||
metrics = detection_metrics(y_true, y_pred, l_true, l_pred)
|
||
if compute_cat_recall:
|
||
per_cat = {}
|
||
for cat in PRIMARY_CATEGORY_LIST:
|
||
total = cat_results[cat]["total"]
|
||
detected = cat_results[cat]["detected"]
|
||
recall = detected / total if total > 0 else 0.0
|
||
per_cat[cat] = {
|
||
"total": total, "detected": detected,
|
||
"recall": round(recall, 4), "miss_rate": round(1 - recall, 4),
|
||
}
|
||
metrics["per_category_recall"] = per_cat
|
||
return metrics
|
||
|
||
|
||
# ── Intervention evaluation ───────────────────────────────────────────────────
|
||
|
||
def _collect_c_primary_idx(processed_samples: List[Dict]) -> List[int]:
|
||
"""Extract ground-truth c_primary index for category-aware metrics."""
|
||
from src.utils.taxonomy import category_to_index, PRIMARY_CATEGORY_LIST
|
||
result = []
|
||
for s in processed_samples:
|
||
cat = s.get("c_primary", "None")
|
||
if cat in PRIMARY_CATEGORY_LIST:
|
||
result.append(category_to_index(cat))
|
||
else:
|
||
result.append(int(s.get("c_primary_idx", 0)))
|
||
return result
|
||
|
||
|
||
def run_rl_intervention(agent, processed_samples: List[Dict], device: str) -> Dict:
|
||
agent.eval()
|
||
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
|
||
|
||
for s in processed_samples:
|
||
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
|
||
with torch.no_grad():
|
||
action, _, _, _ = agent.get_action(obs, deterministic=True)
|
||
|
||
y_risk_true.append(int(s["y_risk"]))
|
||
l_risk_true.append(int(s["l_risk"]))
|
||
a_pred.append(action.item())
|
||
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
|
||
|
||
c_primary_idx = _collect_c_primary_idx(processed_samples)
|
||
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend, c_primary_idx)
|
||
|
||
|
||
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
|
||
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
|
||
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
|
||
# Policy input uses det_l_risk (detector prediction, as in real deployment)
|
||
a_pred = [policy_fn(int(s.get("det_l_risk", s["l_risk"]))) for s in processed_samples]
|
||
c_primary_idx = _collect_c_primary_idx(processed_samples)
|
||
return intervention_metrics(y_risk_true, l_risk_true, a_pred, c_primary_idx=c_primary_idx)
|
||
|
||
|
||
# ── Pretty printing ───────────────────────────────────────────────────────────
|
||
|
||
def print_metrics(name: str, metrics: Dict):
|
||
companion_specific = {"R3", "R4", "R10"}
|
||
print(f"\n{'=' * 55}")
|
||
print(f" {name}")
|
||
print(f"{'=' * 55}")
|
||
level_names = {0: "Safe", 1: "Mild", 2: "Moderate", 3: "High", 4: "Critical"}
|
||
|
||
# Determine which fine labels are active (respect label_filter stored in metrics)
|
||
lf = metrics.get("label_filter", "all")
|
||
active_fine_labels = (
|
||
PUBLIC_FINE_LABELS if lf == "public" else
|
||
COMPANION_FINE_LABELS if lf == "companion" else
|
||
FINE_GRAINED_LABELS
|
||
)
|
||
|
||
for k, v in metrics.items():
|
||
if k in ("label_filter",):
|
||
continue # printed separately or in header
|
||
if k == "per_category_recall":
|
||
print(f" {'─'*51}")
|
||
print(f" Per-category Recall (↑好,漏检率=1-recall):")
|
||
print(f" {'─'*51}")
|
||
for cat, m in sorted(v.items()):
|
||
flag = " ◀ companion特有" if cat in companion_specific else ""
|
||
print(f" {cat:4s}: recall={m['recall']:.3f} miss={m['miss_rate']:.3f}"
|
||
f" (n={m['total']:3d}){flag}")
|
||
elif k == "level_per_class_f1":
|
||
print(f" {'─'*51}")
|
||
print(f" Per-level F1 (诊断 level_macro_f1 下降原因):")
|
||
for i, f in enumerate(v):
|
||
print(f" L{i} {level_names[i]:10s}: {f:.4f}")
|
||
elif k == "fine_per_label_f1":
|
||
print(f" {'─'*51}")
|
||
n_lbl = len(active_fine_labels)
|
||
print(f" Per fine-label F1 ({n_lbl} labels, filter={lf}):")
|
||
for lbl, f in zip(active_fine_labels, v):
|
||
print(f" {lbl:35s}: {f:.4f}")
|
||
elif k == "per_level_action_dist":
|
||
print(f" {'─'*51}")
|
||
print(f" Per-level Action Distribution:")
|
||
print(f" {'Level':15s} {'n':>5} PASS WARN RWRT REJT CRISIS")
|
||
for lvl_name, lv in v.items():
|
||
dist_str = " ".join(f"{x:.3f}" for x in lv["action_dist"])
|
||
print(f" {lvl_name:15s} {lv['n']:>5} {dist_str}")
|
||
elif k == "per_category_action_dist":
|
||
print(f" {'─'*51}")
|
||
print(f" Per-category Action Distribution:")
|
||
print(f" {'Category':6s} {'n':>5} PASS WARN RWRT REJT CRISIS")
|
||
for cat_name, cv in v.items():
|
||
dist_str = " ".join(f"{x:.3f}" for x in cv["action_dist"])
|
||
print(f" {cat_name:6s} {cv['n']:>5} {dist_str}")
|
||
elif k == "exact_action_accuracy_by_level":
|
||
print(f" {'─'*51}")
|
||
print(f" Action Accuracy by Level (vs a_recommend):")
|
||
for lvl_name, acc in v.items():
|
||
print(f" {lvl_name:15s}: {acc:.4f}")
|
||
elif isinstance(v, float):
|
||
print(f" {k:35s}: {v:.4f}")
|
||
elif isinstance(v, list):
|
||
formatted = [f"{x:.3f}" for x in v]
|
||
print(f" {k:35s}: {formatted}")
|
||
else:
|
||
print(f" {k:35s}: {v}")
|
||
|
||
|
||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="CompanionGuard-RL 完整评估")
|
||
parser.add_argument("--detector-ckpt", required=True,
|
||
help="检测器权重路径,例如 checkpoints/detector/best.pt")
|
||
parser.add_argument("--agent-ckpt", default=None,
|
||
help="RL agent 权重路径(BC+PPO)")
|
||
parser.add_argument("--bc-ckpt", default=None,
|
||
help="BC-only agent 权重路径(BC 阶段后保存,用于 ablation 对照)")
|
||
parser.add_argument("--test-data",
|
||
default="data/processed/CompanionRisk-Bench/test.jsonl",
|
||
help="测试集路径")
|
||
parser.add_argument("--config", default="configs/detector_config.yaml",
|
||
help="检测器配置路径")
|
||
parser.add_argument("--intervention-config",
|
||
default="configs/intervention_config.yaml",
|
||
help="干预配置路径(仅 --agent-ckpt 存在时需要)")
|
||
parser.add_argument("--output", default="experiments/eval_results.json",
|
||
help="结果保存路径")
|
||
parser.add_argument("--source-filter", default="all",
|
||
choices=["all", "human", "generated"],
|
||
help=(
|
||
"样本过滤: "
|
||
"all=全部, "
|
||
"human=仅人工标注(suicide/cosafe/dices), "
|
||
"generated=仅LLM生成"
|
||
))
|
||
parser.add_argument("--label-filter", default="all",
|
||
choices=["all", "public", "companion"],
|
||
help=(
|
||
"细粒度标签集过滤 (fine_macro_f1 计算范围): "
|
||
"all=全部14标签, "
|
||
"public=10个通用标签(R1/R2/R5-R9,人工子集可用), "
|
||
"companion=4个companion专属标签(R3/R4/R10)"
|
||
))
|
||
args = parser.parse_args()
|
||
|
||
with open(args.config) as f:
|
||
cfg = yaml.safe_load(f)
|
||
|
||
# 仅在需要 RL 评估时才读取干预配置,避免文件不存在时崩溃
|
||
int_cfg = {}
|
||
if args.agent_ckpt or args.bc_ckpt:
|
||
if not Path(args.intervention_config).exists():
|
||
print(f"[ERROR] agent-ckpt 指定但找不到干预配置: {args.intervention_config}")
|
||
return
|
||
with open(args.intervention_config) as f:
|
||
int_cfg = yaml.safe_load(f)
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print(f"\n{'═'*55}")
|
||
print(f" CompanionGuard-RL 完整评估")
|
||
print(f"{'═'*55}")
|
||
print(f" Device : {device}")
|
||
print(f" Test data : {args.test_data}")
|
||
print(f" Source filter : {args.source_filter}")
|
||
print(f" Label filter : {args.label_filter}")
|
||
print(f" Detector ckpt : {args.detector_ckpt}")
|
||
print(f" RL agent ckpt : {args.agent_ckpt or '(跳过)'}")
|
||
print(f" BC-only ckpt : {args.bc_ckpt or '(跳过)'}")
|
||
|
||
# ── 加载测试集 ──────────────────────────────────────────────────────────
|
||
all_samples = load_jsonl(args.test_data)
|
||
samples = filter_by_source(all_samples, args.source_filter)
|
||
print(f"\n 样本总数: {len(all_samples)} → 过滤后: {len(samples)} "
|
||
f"(filter={args.source_filter})")
|
||
|
||
if not samples:
|
||
print("[ERROR] 过滤后样本为空,检查 --source-filter 或数据集中的 source/id 字段")
|
||
return
|
||
|
||
risky = sum(int(s["y_risk"]) for s in samples)
|
||
print(f" 有风险: {risky} 安全: {len(samples)-risky}")
|
||
|
||
# ── 加载模型 ──────────────────────────────────────────────────────────
|
||
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
||
detector = CompanionRiskDetector(
|
||
model_name=cfg["model"]["name"],
|
||
hidden_size=cfg["model"]["hidden_size"],
|
||
num_heads=cfg["model"]["num_heads"],
|
||
dropout=cfg["model"]["dropout"],
|
||
use_lora=cfg["model"]["use_lora"],
|
||
).to(device)
|
||
|
||
if Path(args.detector_ckpt).exists():
|
||
detector.load_state_dict(
|
||
torch.load(args.detector_ckpt, map_location=device, weights_only=True)
|
||
)
|
||
print(f" Detector loaded: {args.detector_ckpt}")
|
||
else:
|
||
print(f"[WARN] Checkpoint 未找到: {args.detector_ckpt},使用随机权重")
|
||
|
||
all_results = {
|
||
"meta": {
|
||
"test_file": str(args.test_data),
|
||
"source_filter": args.source_filter,
|
||
"label_filter": args.label_filter,
|
||
"n_total": len(all_samples),
|
||
"n_filtered": len(samples),
|
||
"n_risky": risky,
|
||
}
|
||
}
|
||
|
||
# ── 检测基线评估 ──────────────────────────────────────────────────────
|
||
print(f"\n{'─'*55}")
|
||
print(" Detection Task(检测任务)")
|
||
print(f"{'─'*55}")
|
||
|
||
print("\nRunning: L1a — Keyword Detector")
|
||
kw_m = run_keyword_detection(samples, KeywordDetector())
|
||
print_metrics("L1a: Keyword Detector", kw_m)
|
||
all_results["L1a_keyword"] = kw_m
|
||
|
||
print("\nRunning: L1b — Regex Detector")
|
||
re_m = run_keyword_detection(samples, RegexDetector())
|
||
print_metrics("L1b: Regex Detector", re_m)
|
||
all_results["L1b_regex"] = re_m
|
||
|
||
print("\nRunning: L1c — Combined Keyword+Regex")
|
||
cb_m = run_keyword_detection(samples, CombinedRuleDetector())
|
||
print_metrics("L1c: Combined Detector", cb_m)
|
||
all_results["L1c_combined"] = cb_m
|
||
|
||
print("\nRunning: Ours — CompanionRiskDetector (Module B)")
|
||
neural_m = run_neural_detection(
|
||
detector, tokenizer, samples, cfg, device,
|
||
label_filter=args.label_filter,
|
||
)
|
||
print_metrics("Ours: CompanionRiskDetector", neural_m)
|
||
all_results["ours_detection"] = neural_m
|
||
|
||
# ── 综合对比表 ────────────────────────────────────────────────────────
|
||
print(f"\n{'─'*55}")
|
||
print(" Detection Summary(对比汇总)")
|
||
print(f"{'─'*55}")
|
||
print(f" {'Method':25s} {'BinF1':>7} {'Recall':>7} {'FNR':>7} {'LvlF1':>7}")
|
||
print(f" {'─'*53}")
|
||
for tag, m in [
|
||
("L1a_keyword", all_results["L1a_keyword"]),
|
||
("L1b_regex", all_results["L1b_regex"]),
|
||
("L1c_combined",all_results["L1c_combined"]),
|
||
("Ours", all_results["ours_detection"]),
|
||
]:
|
||
bf1 = m.get("binary_f1", float("nan"))
|
||
rec = m.get("high_risk_recall", float("nan"))
|
||
fnr = m.get("false_negative_rate", float("nan"))
|
||
lvlf1 = m.get("level_macro_f1", float("nan"))
|
||
print(f" {tag:25s} {bf1:7.4f} {rec:7.4f} {fnr:7.4f} {lvlf1:7.4f}")
|
||
|
||
# ── 干预评估(仅当提供 agent_ckpt 或 bc_ckpt 时)─────────────────────
|
||
if args.agent_ckpt or args.bc_ckpt:
|
||
print(f"\n{'─'*55}")
|
||
print(" Intervention Task(干预任务)")
|
||
print(f"{'─'*55}")
|
||
|
||
print(" Preprocessing samples with detector for RL state...")
|
||
processed = preprocess_samples_with_detector(
|
||
samples, detector, tokenizer, device=device,
|
||
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
||
)
|
||
|
||
from src.models.intervention_agent import InterventionAgent
|
||
detector_hidden = cfg["model"]["hidden_size"]
|
||
|
||
def _load_agent(ckpt_path: str) -> "InterventionAgent":
|
||
ag = InterventionAgent(
|
||
detector_hidden=detector_hidden,
|
||
state_hidden=int_cfg["agent"]["state_hidden"],
|
||
dropout=int_cfg["agent"]["dropout"],
|
||
).to(device)
|
||
if Path(ckpt_path).exists():
|
||
ag.load_state_dict(
|
||
torch.load(ckpt_path, map_location=device, weights_only=True)
|
||
)
|
||
print(f" Agent loaded: {ckpt_path}")
|
||
else:
|
||
print(f"[WARN] Checkpoint 未找到: {ckpt_path}")
|
||
return ag
|
||
|
||
print("\nRunning: Rule-based Intervention (l_risk≥3→REJECT)")
|
||
rule_m = run_rule_intervention(processed, rule_based_intervention)
|
||
print_metrics("Rule-based Policy", rule_m)
|
||
all_results["baseline_rule"] = rule_m
|
||
|
||
print("\nRunning: Threshold Intervention Baseline")
|
||
thr_m = run_rule_intervention(processed, threshold_intervention)
|
||
print_metrics("Threshold Policy", thr_m)
|
||
all_results["baseline_threshold"] = thr_m
|
||
|
||
if args.bc_ckpt:
|
||
print("\nRunning: BC-only Intervention Policy (ablation)")
|
||
bc_agent = _load_agent(args.bc_ckpt)
|
||
bc_m = run_rl_intervention(bc_agent, processed, device)
|
||
print_metrics("BC-only Policy (ablation)", bc_m)
|
||
all_results["bc_only_intervention"] = bc_m
|
||
|
||
if args.agent_ckpt:
|
||
print("\nRunning: Ours — RL Intervention Policy (BC+PPO, Module C)")
|
||
rl_agent = _load_agent(args.agent_ckpt)
|
||
rl_m = run_rl_intervention(rl_agent, processed, device)
|
||
print_metrics("Ours: RL Intervention Policy", rl_m)
|
||
all_results["ours_intervention"] = rl_m
|
||
|
||
# ── 保存结果 ──────────────────────────────────────────────────────────
|
||
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(args.output, "w", encoding="utf-8") as f:
|
||
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
|
||
print(f"\n{'═'*55}")
|
||
print(f" 结果已保存: {args.output}")
|
||
print(f"{'═'*55}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|