chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
613
code/scripts/evaluate.py
Normal file
613
code/scripts/evaluate.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
Full evaluation script for CompanionGuard-RL.
|
||||
|
||||
Runs detection and intervention evaluations against multiple baselines:
|
||||
|
||||
Detection baselines:
|
||||
- L1a: Keyword detector
|
||||
- L1b: Regex detector
|
||||
- L1c: Combined keyword+regex
|
||||
- Ours: CompanionRiskDetector (Module B)
|
||||
|
||||
Intervention baselines:
|
||||
- Rule-based (l_risk >= 3 → REJECT)
|
||||
- Threshold policy (per-level mapping)
|
||||
- RL policy (Module C, Ours) [requires --agent-ckpt]
|
||||
|
||||
Usage (detection only, no RL agent yet):
|
||||
python scripts/evaluate.py \\
|
||||
--detector-ckpt checkpoints/detector/best.pt \\
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||||
--config configs/detector_config.yaml
|
||||
|
||||
Usage (with RL intervention agent):
|
||||
python scripts/evaluate.py \\
|
||||
--detector-ckpt checkpoints/detector/best.pt \\
|
||||
--agent-ckpt checkpoints/intervention/final.pt \\
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||||
--config configs/detector_config.yaml \\
|
||||
--intervention-config configs/intervention_config.yaml
|
||||
|
||||
Usage (human-annotated subset only):
|
||||
python scripts/evaluate.py \\
|
||||
--detector-ckpt checkpoints/detector/best.pt \\
|
||||
--source-filter human \\
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \\
|
||||
--config configs/detector_config.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
|
||||
from src.models.detector import CompanionRiskDetector
|
||||
from src.utils.metrics import detection_metrics, intervention_metrics
|
||||
from src.utils.baselines import (
|
||||
KeywordDetector,
|
||||
RegexDetector,
|
||||
CombinedRuleDetector,
|
||||
rule_based_intervention,
|
||||
threshold_intervention,
|
||||
)
|
||||
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
|
||||
from src.utils.taxonomy import (
|
||||
ACTION_NAME_TO_ID,
|
||||
NUM_RISK_LEVELS,
|
||||
NUM_PRIMARY,
|
||||
NUM_FINE,
|
||||
PRIMARY_CATEGORY_LIST,
|
||||
FINE_GRAINED_LABELS,
|
||||
INTERVENTION_ACTIONS,
|
||||
)
|
||||
|
||||
# ── Fine-grained label split ──────────────────────────────────────────────────
|
||||
# PUBLIC: 10 labels covering R1/R2/R5-R9 — present in human-annotated subsets.
|
||||
# COMPANION: 4 labels specific to companion AI (R3/R4/R10) — not in human subsets;
|
||||
# evaluated separately on generated test split (see paper §4.3).
|
||||
PUBLIC_FINE_LABELS = [
|
||||
"DirectEncouragement",
|
||||
"MethodFacilitation",
|
||||
"RiskNormalization",
|
||||
"Romanticization",
|
||||
"FalseReassurance",
|
||||
"PseudoTherapy",
|
||||
"CrisisNonResponse",
|
||||
"BoundaryFailure",
|
||||
"AgeInappropriateIntimacy",
|
||||
"PrivacySolicitation",
|
||||
]
|
||||
COMPANION_FINE_LABELS = [
|
||||
"DependencyReinforcement",
|
||||
"IsolationReinforcement",
|
||||
"ManipulativeAttachment",
|
||||
"CoRumination",
|
||||
]
|
||||
|
||||
# Pre-compute column indices into FINE_GRAINED_LABELS for fast slicing
|
||||
PUBLIC_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in PUBLIC_FINE_LABELS]
|
||||
COMPANION_FINE_IDX = [FINE_GRAINED_LABELS.index(l) for l in COMPANION_FINE_LABELS]
|
||||
|
||||
LABEL_FILTER_MAP = {
|
||||
"all": list(range(NUM_FINE)),
|
||||
"public": PUBLIC_FINE_IDX,
|
||||
"companion": COMPANION_FINE_IDX,
|
||||
}
|
||||
|
||||
|
||||
# ── 数据源过滤 ────────────────────────────────────────────────────────────────
|
||||
|
||||
# 用来区分"人工标注"样本的来源标记(id 前缀或 source 字段)
|
||||
HUMAN_SOURCE_PREFIXES = ("suicide-", "cosafe-", "dices-")
|
||||
HUMAN_SOURCE_NAMES = ("suicide_risk", "cosafe", "dices")
|
||||
|
||||
def filter_by_source(samples: List[Dict], source_filter: str) -> List[Dict]:
|
||||
"""
|
||||
source_filter:
|
||||
'all' — 返回全部
|
||||
'human' — 仅保留来自人工标注数据集的样本
|
||||
'generated' — 仅保留 LLM 生成样本
|
||||
判断依据:sample["source"] 字段(若存在),否则用 id 前缀判断。
|
||||
"""
|
||||
if source_filter == "all":
|
||||
return samples
|
||||
|
||||
filtered = []
|
||||
for s in samples:
|
||||
src = s.get("source", "")
|
||||
sid = s.get("id", "")
|
||||
is_human = (
|
||||
any(src == name for name in HUMAN_SOURCE_NAMES) or
|
||||
any(sid.startswith(pfx) for pfx in HUMAN_SOURCE_PREFIXES)
|
||||
)
|
||||
if source_filter == "human" and is_human:
|
||||
filtered.append(s)
|
||||
elif source_filter == "generated" and not is_human:
|
||||
filtered.append(s)
|
||||
return filtered
|
||||
|
||||
|
||||
def get_obs_dim(detector_hidden: int) -> int:
|
||||
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
|
||||
|
||||
# ── Detection evaluation ──────────────────────────────────────────────────────
|
||||
|
||||
def run_neural_detection(
|
||||
model: CompanionRiskDetector,
|
||||
tokenizer,
|
||||
samples: List[Dict],
|
||||
cfg: Dict,
|
||||
device: str,
|
||||
label_filter: str = "all",
|
||||
) -> Dict:
|
||||
"""Run the neural detector on test samples, compute binary + level + category metrics."""
|
||||
model.eval()
|
||||
y_true, y_pred = [], []
|
||||
l_true, l_pred = [], []
|
||||
fine_true_list, fine_pred_list = [], []
|
||||
|
||||
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
|
||||
fine_threshold = cfg.get("evaluation", {}).get("fine_threshold", 0.4)
|
||||
data_cfg = cfg.get("data", {})
|
||||
|
||||
# Per-category tracking
|
||||
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
|
||||
|
||||
for sample in samples:
|
||||
sample = validate_and_normalize(dict(sample))
|
||||
texts = format_conversation(
|
||||
sample["persona"], sample["history"],
|
||||
sample["user_input"], sample["ai_response"],
|
||||
max_history_turns=data_cfg.get("max_history_turns", 5),
|
||||
)
|
||||
|
||||
def enc(text, max_len):
|
||||
return tokenizer(
|
||||
text, max_length=max_len, truncation=True,
|
||||
padding="max_length", return_tensors="pt",
|
||||
)
|
||||
|
||||
p_enc = enc(texts["persona_text"], data_cfg.get("max_persona_len", 128))
|
||||
c_enc = enc(texts["context_text"], data_cfg.get("max_context_len", 512))
|
||||
r_enc = enc(texts["response_text"], data_cfg.get("max_response_len", 256))
|
||||
|
||||
with torch.no_grad():
|
||||
preds = model.predict(
|
||||
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
||||
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
||||
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
||||
binary_threshold=binary_threshold,
|
||||
fine_threshold=fine_threshold,
|
||||
)
|
||||
|
||||
y_t = int(sample["y_risk"])
|
||||
y_p = preds["y_risk"].item()
|
||||
y_true.append(y_t)
|
||||
y_pred.append(y_p)
|
||||
l_true.append(int(sample["l_risk"]))
|
||||
l_pred.append(preds["l_risk"].item())
|
||||
|
||||
# Fine-grained multi-label
|
||||
# Ground truth
|
||||
gt_fine = np.zeros(NUM_FINE)
|
||||
for lbl in sample.get("c_fine", []):
|
||||
from src.utils.taxonomy import label_to_index
|
||||
if lbl in FINE_GRAINED_LABELS:
|
||||
gt_fine[label_to_index(lbl)] = 1.0
|
||||
fine_true_list.append(gt_fine)
|
||||
|
||||
# Prediction: predict() already applies sigmoid + fine_threshold → binary tensor
|
||||
# Key is "c_fine", NOT "c_fine_logits" (that key does not exist in predict() output)
|
||||
pred_fine = preds["c_fine"].cpu().numpy().flatten() # ensure 1D [NUM_FINE]
|
||||
fine_pred_list.append(pred_fine)
|
||||
|
||||
# Per-category recall
|
||||
cat = sample.get("c_primary", "None")
|
||||
if cat in PRIMARY_CATEGORY_LIST:
|
||||
cat_results[cat]["total"] += 1
|
||||
if y_p == 1:
|
||||
cat_results[cat]["detected"] += 1
|
||||
|
||||
# Build per-category recall dict
|
||||
per_cat = {}
|
||||
for cat in PRIMARY_CATEGORY_LIST:
|
||||
total = cat_results[cat]["total"]
|
||||
detected = cat_results[cat]["detected"]
|
||||
recall = detected / total if total > 0 else 0.0
|
||||
per_cat[cat] = {
|
||||
"total": total,
|
||||
"detected": detected,
|
||||
"recall": round(recall, 4),
|
||||
"miss_rate": round(1 - recall, 4),
|
||||
}
|
||||
|
||||
fine_true_arr = np.array(fine_true_list) # [N, NUM_FINE]
|
||||
fine_pred_arr = np.array(fine_pred_list) # [N, NUM_FINE]
|
||||
|
||||
# Apply label filter — restrict fine-label columns to the selected subset
|
||||
col_idx = LABEL_FILTER_MAP.get(label_filter, list(range(NUM_FINE)))
|
||||
if label_filter != "all" and len(col_idx) < NUM_FINE:
|
||||
fine_true_arr = fine_true_arr[:, col_idx]
|
||||
fine_pred_arr = fine_pred_arr[:, col_idx]
|
||||
|
||||
metrics = detection_metrics(
|
||||
y_true, y_pred,
|
||||
l_true, l_pred,
|
||||
fine_true_arr, fine_pred_arr,
|
||||
)
|
||||
metrics["per_category_recall"] = per_cat
|
||||
metrics["label_filter"] = label_filter
|
||||
return metrics
|
||||
|
||||
|
||||
def run_keyword_detection(
|
||||
samples: List[Dict],
|
||||
detector,
|
||||
compute_cat_recall: bool = True,
|
||||
) -> Dict:
|
||||
y_true, y_pred = [], []
|
||||
l_true, l_pred = [], []
|
||||
cat_results = defaultdict(lambda: {"total": 0, "detected": 0})
|
||||
|
||||
for s in samples:
|
||||
result = detector.detect(s.get("ai_response", ""))
|
||||
y_t = int(s["y_risk"])
|
||||
y_p = result["y_risk"]
|
||||
y_true.append(y_t)
|
||||
y_pred.append(y_p)
|
||||
l_true.append(int(s["l_risk"]))
|
||||
l_pred.append(result["l_risk"])
|
||||
|
||||
cat = s.get("c_primary", "None")
|
||||
if cat in PRIMARY_CATEGORY_LIST:
|
||||
cat_results[cat]["total"] += 1
|
||||
if y_p == 1:
|
||||
cat_results[cat]["detected"] += 1
|
||||
|
||||
metrics = detection_metrics(y_true, y_pred, l_true, l_pred)
|
||||
if compute_cat_recall:
|
||||
per_cat = {}
|
||||
for cat in PRIMARY_CATEGORY_LIST:
|
||||
total = cat_results[cat]["total"]
|
||||
detected = cat_results[cat]["detected"]
|
||||
recall = detected / total if total > 0 else 0.0
|
||||
per_cat[cat] = {
|
||||
"total": total, "detected": detected,
|
||||
"recall": round(recall, 4), "miss_rate": round(1 - recall, 4),
|
||||
}
|
||||
metrics["per_category_recall"] = per_cat
|
||||
return metrics
|
||||
|
||||
|
||||
# ── Intervention evaluation ───────────────────────────────────────────────────
|
||||
|
||||
def _collect_c_primary_idx(processed_samples: List[Dict]) -> List[int]:
|
||||
"""Extract ground-truth c_primary index for category-aware metrics."""
|
||||
from src.utils.taxonomy import category_to_index, PRIMARY_CATEGORY_LIST
|
||||
result = []
|
||||
for s in processed_samples:
|
||||
cat = s.get("c_primary", "None")
|
||||
if cat in PRIMARY_CATEGORY_LIST:
|
||||
result.append(category_to_index(cat))
|
||||
else:
|
||||
result.append(int(s.get("c_primary_idx", 0)))
|
||||
return result
|
||||
|
||||
|
||||
def run_rl_intervention(agent, processed_samples: List[Dict], device: str) -> Dict:
|
||||
agent.eval()
|
||||
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
|
||||
|
||||
for s in processed_samples:
|
||||
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
action, _, _, _ = agent.get_action(obs, deterministic=True)
|
||||
|
||||
y_risk_true.append(int(s["y_risk"]))
|
||||
l_risk_true.append(int(s["l_risk"]))
|
||||
a_pred.append(action.item())
|
||||
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
|
||||
|
||||
c_primary_idx = _collect_c_primary_idx(processed_samples)
|
||||
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend, c_primary_idx)
|
||||
|
||||
|
||||
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
|
||||
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
|
||||
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
|
||||
# Policy input uses det_l_risk (detector prediction, as in real deployment)
|
||||
a_pred = [policy_fn(int(s.get("det_l_risk", s["l_risk"]))) for s in processed_samples]
|
||||
c_primary_idx = _collect_c_primary_idx(processed_samples)
|
||||
return intervention_metrics(y_risk_true, l_risk_true, a_pred, c_primary_idx=c_primary_idx)
|
||||
|
||||
|
||||
# ── Pretty printing ───────────────────────────────────────────────────────────
|
||||
|
||||
def print_metrics(name: str, metrics: Dict):
|
||||
companion_specific = {"R3", "R4", "R10"}
|
||||
print(f"\n{'=' * 55}")
|
||||
print(f" {name}")
|
||||
print(f"{'=' * 55}")
|
||||
level_names = {0: "Safe", 1: "Mild", 2: "Moderate", 3: "High", 4: "Critical"}
|
||||
|
||||
# Determine which fine labels are active (respect label_filter stored in metrics)
|
||||
lf = metrics.get("label_filter", "all")
|
||||
active_fine_labels = (
|
||||
PUBLIC_FINE_LABELS if lf == "public" else
|
||||
COMPANION_FINE_LABELS if lf == "companion" else
|
||||
FINE_GRAINED_LABELS
|
||||
)
|
||||
|
||||
for k, v in metrics.items():
|
||||
if k in ("label_filter",):
|
||||
continue # printed separately or in header
|
||||
if k == "per_category_recall":
|
||||
print(f" {'─'*51}")
|
||||
print(f" Per-category Recall (↑好,漏检率=1-recall):")
|
||||
print(f" {'─'*51}")
|
||||
for cat, m in sorted(v.items()):
|
||||
flag = " ◀ companion特有" if cat in companion_specific else ""
|
||||
print(f" {cat:4s}: recall={m['recall']:.3f} miss={m['miss_rate']:.3f}"
|
||||
f" (n={m['total']:3d}){flag}")
|
||||
elif k == "level_per_class_f1":
|
||||
print(f" {'─'*51}")
|
||||
print(f" Per-level F1 (诊断 level_macro_f1 下降原因):")
|
||||
for i, f in enumerate(v):
|
||||
print(f" L{i} {level_names[i]:10s}: {f:.4f}")
|
||||
elif k == "fine_per_label_f1":
|
||||
print(f" {'─'*51}")
|
||||
n_lbl = len(active_fine_labels)
|
||||
print(f" Per fine-label F1 ({n_lbl} labels, filter={lf}):")
|
||||
for lbl, f in zip(active_fine_labels, v):
|
||||
print(f" {lbl:35s}: {f:.4f}")
|
||||
elif k == "per_level_action_dist":
|
||||
print(f" {'─'*51}")
|
||||
print(f" Per-level Action Distribution:")
|
||||
print(f" {'Level':15s} {'n':>5} PASS WARN RWRT REJT CRISIS")
|
||||
for lvl_name, lv in v.items():
|
||||
dist_str = " ".join(f"{x:.3f}" for x in lv["action_dist"])
|
||||
print(f" {lvl_name:15s} {lv['n']:>5} {dist_str}")
|
||||
elif k == "per_category_action_dist":
|
||||
print(f" {'─'*51}")
|
||||
print(f" Per-category Action Distribution:")
|
||||
print(f" {'Category':6s} {'n':>5} PASS WARN RWRT REJT CRISIS")
|
||||
for cat_name, cv in v.items():
|
||||
dist_str = " ".join(f"{x:.3f}" for x in cv["action_dist"])
|
||||
print(f" {cat_name:6s} {cv['n']:>5} {dist_str}")
|
||||
elif k == "exact_action_accuracy_by_level":
|
||||
print(f" {'─'*51}")
|
||||
print(f" Action Accuracy by Level (vs a_recommend):")
|
||||
for lvl_name, acc in v.items():
|
||||
print(f" {lvl_name:15s}: {acc:.4f}")
|
||||
elif isinstance(v, float):
|
||||
print(f" {k:35s}: {v:.4f}")
|
||||
elif isinstance(v, list):
|
||||
formatted = [f"{x:.3f}" for x in v]
|
||||
print(f" {k:35s}: {formatted}")
|
||||
else:
|
||||
print(f" {k:35s}: {v}")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CompanionGuard-RL 完整评估")
|
||||
parser.add_argument("--detector-ckpt", required=True,
|
||||
help="检测器权重路径,例如 checkpoints/detector/best.pt")
|
||||
parser.add_argument("--agent-ckpt", default=None,
|
||||
help="RL agent 权重路径(BC+PPO)")
|
||||
parser.add_argument("--bc-ckpt", default=None,
|
||||
help="BC-only agent 权重路径(BC 阶段后保存,用于 ablation 对照)")
|
||||
parser.add_argument("--test-data",
|
||||
default="data/processed/CompanionRisk-Bench/test.jsonl",
|
||||
help="测试集路径")
|
||||
parser.add_argument("--config", default="configs/detector_config.yaml",
|
||||
help="检测器配置路径")
|
||||
parser.add_argument("--intervention-config",
|
||||
default="configs/intervention_config.yaml",
|
||||
help="干预配置路径(仅 --agent-ckpt 存在时需要)")
|
||||
parser.add_argument("--output", default="experiments/eval_results.json",
|
||||
help="结果保存路径")
|
||||
parser.add_argument("--source-filter", default="all",
|
||||
choices=["all", "human", "generated"],
|
||||
help=(
|
||||
"样本过滤: "
|
||||
"all=全部, "
|
||||
"human=仅人工标注(suicide/cosafe/dices), "
|
||||
"generated=仅LLM生成"
|
||||
))
|
||||
parser.add_argument("--label-filter", default="all",
|
||||
choices=["all", "public", "companion"],
|
||||
help=(
|
||||
"细粒度标签集过滤 (fine_macro_f1 计算范围): "
|
||||
"all=全部14标签, "
|
||||
"public=10个通用标签(R1/R2/R5-R9,人工子集可用), "
|
||||
"companion=4个companion专属标签(R3/R4/R10)"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# 仅在需要 RL 评估时才读取干预配置,避免文件不存在时崩溃
|
||||
int_cfg = {}
|
||||
if args.agent_ckpt or args.bc_ckpt:
|
||||
if not Path(args.intervention_config).exists():
|
||||
print(f"[ERROR] agent-ckpt 指定但找不到干预配置: {args.intervention_config}")
|
||||
return
|
||||
with open(args.intervention_config) as f:
|
||||
int_cfg = yaml.safe_load(f)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"\n{'═'*55}")
|
||||
print(f" CompanionGuard-RL 完整评估")
|
||||
print(f"{'═'*55}")
|
||||
print(f" Device : {device}")
|
||||
print(f" Test data : {args.test_data}")
|
||||
print(f" Source filter : {args.source_filter}")
|
||||
print(f" Label filter : {args.label_filter}")
|
||||
print(f" Detector ckpt : {args.detector_ckpt}")
|
||||
print(f" RL agent ckpt : {args.agent_ckpt or '(跳过)'}")
|
||||
print(f" BC-only ckpt : {args.bc_ckpt or '(跳过)'}")
|
||||
|
||||
# ── 加载测试集 ──────────────────────────────────────────────────────────
|
||||
all_samples = load_jsonl(args.test_data)
|
||||
samples = filter_by_source(all_samples, args.source_filter)
|
||||
print(f"\n 样本总数: {len(all_samples)} → 过滤后: {len(samples)} "
|
||||
f"(filter={args.source_filter})")
|
||||
|
||||
if not samples:
|
||||
print("[ERROR] 过滤后样本为空,检查 --source-filter 或数据集中的 source/id 字段")
|
||||
return
|
||||
|
||||
risky = sum(int(s["y_risk"]) for s in samples)
|
||||
print(f" 有风险: {risky} 安全: {len(samples)-risky}")
|
||||
|
||||
# ── 加载模型 ──────────────────────────────────────────────────────────
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
||||
detector = CompanionRiskDetector(
|
||||
model_name=cfg["model"]["name"],
|
||||
hidden_size=cfg["model"]["hidden_size"],
|
||||
num_heads=cfg["model"]["num_heads"],
|
||||
dropout=cfg["model"]["dropout"],
|
||||
use_lora=cfg["model"]["use_lora"],
|
||||
).to(device)
|
||||
|
||||
if Path(args.detector_ckpt).exists():
|
||||
detector.load_state_dict(
|
||||
torch.load(args.detector_ckpt, map_location=device, weights_only=True)
|
||||
)
|
||||
print(f" Detector loaded: {args.detector_ckpt}")
|
||||
else:
|
||||
print(f"[WARN] Checkpoint 未找到: {args.detector_ckpt},使用随机权重")
|
||||
|
||||
all_results = {
|
||||
"meta": {
|
||||
"test_file": str(args.test_data),
|
||||
"source_filter": args.source_filter,
|
||||
"label_filter": args.label_filter,
|
||||
"n_total": len(all_samples),
|
||||
"n_filtered": len(samples),
|
||||
"n_risky": risky,
|
||||
}
|
||||
}
|
||||
|
||||
# ── 检测基线评估 ──────────────────────────────────────────────────────
|
||||
print(f"\n{'─'*55}")
|
||||
print(" Detection Task(检测任务)")
|
||||
print(f"{'─'*55}")
|
||||
|
||||
print("\nRunning: L1a — Keyword Detector")
|
||||
kw_m = run_keyword_detection(samples, KeywordDetector())
|
||||
print_metrics("L1a: Keyword Detector", kw_m)
|
||||
all_results["L1a_keyword"] = kw_m
|
||||
|
||||
print("\nRunning: L1b — Regex Detector")
|
||||
re_m = run_keyword_detection(samples, RegexDetector())
|
||||
print_metrics("L1b: Regex Detector", re_m)
|
||||
all_results["L1b_regex"] = re_m
|
||||
|
||||
print("\nRunning: L1c — Combined Keyword+Regex")
|
||||
cb_m = run_keyword_detection(samples, CombinedRuleDetector())
|
||||
print_metrics("L1c: Combined Detector", cb_m)
|
||||
all_results["L1c_combined"] = cb_m
|
||||
|
||||
print("\nRunning: Ours — CompanionRiskDetector (Module B)")
|
||||
neural_m = run_neural_detection(
|
||||
detector, tokenizer, samples, cfg, device,
|
||||
label_filter=args.label_filter,
|
||||
)
|
||||
print_metrics("Ours: CompanionRiskDetector", neural_m)
|
||||
all_results["ours_detection"] = neural_m
|
||||
|
||||
# ── 综合对比表 ────────────────────────────────────────────────────────
|
||||
print(f"\n{'─'*55}")
|
||||
print(" Detection Summary(对比汇总)")
|
||||
print(f"{'─'*55}")
|
||||
print(f" {'Method':25s} {'BinF1':>7} {'Recall':>7} {'FNR':>7} {'LvlF1':>7}")
|
||||
print(f" {'─'*53}")
|
||||
for tag, m in [
|
||||
("L1a_keyword", all_results["L1a_keyword"]),
|
||||
("L1b_regex", all_results["L1b_regex"]),
|
||||
("L1c_combined",all_results["L1c_combined"]),
|
||||
("Ours", all_results["ours_detection"]),
|
||||
]:
|
||||
bf1 = m.get("binary_f1", float("nan"))
|
||||
rec = m.get("high_risk_recall", float("nan"))
|
||||
fnr = m.get("false_negative_rate", float("nan"))
|
||||
lvlf1 = m.get("level_macro_f1", float("nan"))
|
||||
print(f" {tag:25s} {bf1:7.4f} {rec:7.4f} {fnr:7.4f} {lvlf1:7.4f}")
|
||||
|
||||
# ── 干预评估(仅当提供 agent_ckpt 或 bc_ckpt 时)─────────────────────
|
||||
if args.agent_ckpt or args.bc_ckpt:
|
||||
print(f"\n{'─'*55}")
|
||||
print(" Intervention Task(干预任务)")
|
||||
print(f"{'─'*55}")
|
||||
|
||||
print(" Preprocessing samples with detector for RL state...")
|
||||
processed = preprocess_samples_with_detector(
|
||||
samples, detector, tokenizer, device=device,
|
||||
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
||||
)
|
||||
|
||||
from src.models.intervention_agent import InterventionAgent
|
||||
detector_hidden = cfg["model"]["hidden_size"]
|
||||
|
||||
def _load_agent(ckpt_path: str) -> "InterventionAgent":
|
||||
ag = InterventionAgent(
|
||||
detector_hidden=detector_hidden,
|
||||
state_hidden=int_cfg["agent"]["state_hidden"],
|
||||
dropout=int_cfg["agent"]["dropout"],
|
||||
).to(device)
|
||||
if Path(ckpt_path).exists():
|
||||
ag.load_state_dict(
|
||||
torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
)
|
||||
print(f" Agent loaded: {ckpt_path}")
|
||||
else:
|
||||
print(f"[WARN] Checkpoint 未找到: {ckpt_path}")
|
||||
return ag
|
||||
|
||||
print("\nRunning: Rule-based Intervention (l_risk≥3→REJECT)")
|
||||
rule_m = run_rule_intervention(processed, rule_based_intervention)
|
||||
print_metrics("Rule-based Policy", rule_m)
|
||||
all_results["baseline_rule"] = rule_m
|
||||
|
||||
print("\nRunning: Threshold Intervention Baseline")
|
||||
thr_m = run_rule_intervention(processed, threshold_intervention)
|
||||
print_metrics("Threshold Policy", thr_m)
|
||||
all_results["baseline_threshold"] = thr_m
|
||||
|
||||
if args.bc_ckpt:
|
||||
print("\nRunning: BC-only Intervention Policy (ablation)")
|
||||
bc_agent = _load_agent(args.bc_ckpt)
|
||||
bc_m = run_rl_intervention(bc_agent, processed, device)
|
||||
print_metrics("BC-only Policy (ablation)", bc_m)
|
||||
all_results["bc_only_intervention"] = bc_m
|
||||
|
||||
if args.agent_ckpt:
|
||||
print("\nRunning: Ours — RL Intervention Policy (BC+PPO, Module C)")
|
||||
rl_agent = _load_agent(args.agent_ckpt)
|
||||
rl_m = run_rl_intervention(rl_agent, processed, device)
|
||||
print_metrics("Ours: RL Intervention Policy", rl_m)
|
||||
all_results["ours_intervention"] = rl_m
|
||||
|
||||
# ── 保存结果 ──────────────────────────────────────────────────────────
|
||||
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
|
||||
print(f"\n{'═'*55}")
|
||||
print(f" 结果已保存: {args.output}")
|
||||
print(f"{'═'*55}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
| ||||