Two-module pipeline for AI companion safety: - Module B: context-aware risk detector with CrossAttention fusion - Module C: PPO-based adaptive intervention policy Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels), dataset generation/annotation pipeline, training scripts, and eval suite. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
194 lines
7.2 KiB
Python
194 lines
7.2 KiB
Python
"""
|
|
Evaluation script: run detection + intervention baselines and ours.
|
|
|
|
Usage:
|
|
python scripts/evaluate.py --detector-ckpt checkpoints/detector/best.pt \
|
|
--agent-ckpt checkpoints/intervention/final.pt \
|
|
--test-data data/processed/test.jsonl \
|
|
--config configs/detector_config.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from transformers import AutoTokenizer
|
|
|
|
from src.data.dataset import load_jsonl, format_conversation
|
|
from src.models.detector import CompanionRiskDetector
|
|
from src.models.intervention_agent import InterventionAgent
|
|
from src.utils.metrics import detection_metrics, intervention_metrics
|
|
from src.utils.taxonomy import (
|
|
ACTION_NAME_TO_ID,
|
|
INTERVENTION_ACTIONS,
|
|
NUM_RISK_LEVELS,
|
|
NUM_PRIMARY,
|
|
DEFAULT_ACTION_BY_LEVEL,
|
|
)
|
|
|
|
|
|
# ── Baselines ────────────────────────────────────────────────────────────────
|
|
|
|
def rule_based_policy(l_risk: int) -> int:
|
|
"""Baseline: fixed rule — l_risk >= 3 → REJECT, else PASS."""
|
|
return 3 if l_risk >= 3 else 0
|
|
|
|
|
|
def threshold_policy(l_risk: int) -> int:
|
|
"""Baseline: threshold mapping per risk level."""
|
|
return DEFAULT_ACTION_BY_LEVEL[l_risk]
|
|
|
|
|
|
# ── Main evaluation ──────────────────────────────────────────────────────────
|
|
|
|
def run_detection_eval(model, tokenizer, samples, cfg, device):
|
|
model.eval()
|
|
y_true, y_pred = [], []
|
|
l_true, l_pred = [], []
|
|
|
|
for sample in samples:
|
|
texts = format_conversation(
|
|
sample["persona"], sample["history"],
|
|
sample["user_input"], sample["ai_response"],
|
|
)
|
|
|
|
def enc(text, max_len):
|
|
return tokenizer(text, max_length=max_len, truncation=True,
|
|
padding="max_length", return_tensors="pt")
|
|
|
|
p_enc = enc(texts["persona_text"], 128)
|
|
c_enc = enc(texts["context_text"], 512)
|
|
r_enc = enc(texts["response_text"], 256)
|
|
|
|
with torch.no_grad():
|
|
preds = model.predict(
|
|
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
|
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
|
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
|
)
|
|
|
|
y_true.append(sample["y_risk"])
|
|
y_pred.append(preds["y_risk"].item())
|
|
l_true.append(sample["l_risk"])
|
|
l_pred.append(preds["l_risk"].item())
|
|
|
|
return detection_metrics(y_true, y_pred, l_true, l_pred)
|
|
|
|
|
|
def run_intervention_eval(agent, processed_samples, obs_dim, device):
|
|
agent.eval()
|
|
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
|
|
|
|
for s in processed_samples:
|
|
d_score = np.array([s["d_score"]], dtype=np.float32)
|
|
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
|
|
l_risk_oh[int(s["l_risk"])] = 1.0
|
|
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
|
|
e_H = np.array(s["e_H_pool"], dtype=np.float32)
|
|
e_P = np.array(s["e_P_pool"], dtype=np.float32)
|
|
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
|
|
obs = torch.FloatTensor(
|
|
np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
|
|
).unsqueeze(0).to(device)
|
|
|
|
with torch.no_grad():
|
|
action, _, _, _ = agent.get_action(obs, deterministic=True)
|
|
|
|
y_risk_true.append(s["y_risk"])
|
|
l_risk_true.append(int(s["l_risk"]))
|
|
a_pred.append(action.item())
|
|
a_recommend.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
|
|
|
|
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--detector-ckpt", required=True)
|
|
parser.add_argument("--agent-ckpt", default=None)
|
|
parser.add_argument("--test-data", default="data/processed/test.jsonl")
|
|
parser.add_argument("--config", default="configs/detector_config.yaml")
|
|
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
with open(args.intervention_config) as f:
|
|
int_cfg = yaml.safe_load(f)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
|
|
|
samples = load_jsonl(args.test_data)
|
|
print(f"Loaded {len(samples)} test samples.")
|
|
|
|
# Detection evaluation
|
|
detector = CompanionRiskDetector(
|
|
model_name=cfg["model"]["name"],
|
|
hidden_size=cfg["model"]["hidden_size"],
|
|
).to(device)
|
|
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
|
|
|
|
print("\n=== Detection Evaluation ===")
|
|
det_metrics = run_detection_eval(detector, tokenizer, samples, cfg, device)
|
|
for k, v in det_metrics.items():
|
|
if isinstance(v, float):
|
|
print(f" {k}: {v:.4f}")
|
|
|
|
# Intervention evaluation
|
|
if args.agent_ckpt:
|
|
from scripts.train_intervention import preprocess_samples_with_detector
|
|
detector_hidden = cfg["model"]["hidden_size"]
|
|
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
|
|
|
processed = preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device)
|
|
|
|
agent = InterventionAgent(
|
|
detector_hidden=detector_hidden,
|
|
state_hidden=int_cfg["agent"]["state_hidden"],
|
|
).to(device)
|
|
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
|
|
|
|
print("\n=== Intervention Evaluation: RL Policy (Ours) ===")
|
|
int_metrics = run_intervention_eval(agent, processed, obs_dim, device)
|
|
for k, v in int_metrics.items():
|
|
if isinstance(v, float):
|
|
print(f" {k}: {v:.4f}")
|
|
elif isinstance(v, list):
|
|
print(f" {k}: {[f'{x:.3f}' for x in v]}")
|
|
|
|
print("\n=== Intervention Evaluation: Rule-based Baseline ===")
|
|
rule_preds = [rule_based_policy(s["l_risk"]) for s in processed]
|
|
rule_metrics = intervention_metrics(
|
|
[s["y_risk"] for s in processed],
|
|
[s["l_risk"] for s in processed],
|
|
rule_preds,
|
|
)
|
|
for k, v in rule_metrics.items():
|
|
if isinstance(v, float):
|
|
print(f" {k}: {v:.4f}")
|
|
|
|
print("\n=== Intervention Evaluation: Threshold Baseline ===")
|
|
thr_preds = [threshold_policy(s["l_risk"]) for s in processed]
|
|
thr_metrics = intervention_metrics(
|
|
[s["y_risk"] for s in processed],
|
|
[s["l_risk"] for s in processed],
|
|
thr_preds,
|
|
)
|
|
for k, v in thr_metrics.items():
|
|
if isinstance(v, float):
|
|
print(f" {k}: {v:.4f}")
|
|
|
|
# Save results
|
|
results = {"detection": det_metrics}
|
|
Path("experiments").mkdir(exist_ok=True)
|
|
with open("experiments/eval_results.json", "w") as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
print("\nResults saved to experiments/eval_results.json")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|