refactor: complete full implementation replacing all placeholder/mock content
Detection module (Module B): - detector.py: expose separate e_P_pool and e_H_pool for RL state; fix compute_loss to skip primary head when c_primary="None" - dataset.py: handle c_primary="None" safely; add validate_and_normalize Data pipeline: - data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe); systematic category→fine-label mapping; safe sample generation (25%); per-category risk level distribution; max_retries logic - llm_judge.py: incremental file writing; rate limiting; retry logic; annotate_from_file convenience method; consistency validation - annotate_data.py: stratified split by y_risk; dataset statistics report RL module (Module C): - ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple); fix action type passed to env.step; proper buffer reset and size tracking - companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with auto-reset; correct Gymnasium interface Shared utilities (new files): - src/utils/preprocessing.py: preprocess_samples_with_detector using separate e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up - src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b), CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention, LLMJudgePolicy for full baseline comparison Scripts: - train_intervention.py: use preprocessing module; separate e_H/e_P pools - evaluate.py: proper module imports (no circular scripts import); full multi-baseline comparison; save all results to JSON - generate_data.py: API key check; safe_ratio + max_retries CLI args Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,80 +1,171 @@
|
||||
"""
|
||||
Step 2: LLM judge pre-annotation.
|
||||
Step 2: LLM judge pre-annotation and dataset split.
|
||||
|
||||
Reads raw generated JSONL, runs LLM annotation on each sample,
|
||||
then splits into train/val/test sets with stratified sampling by y_risk.
|
||||
|
||||
Usage:
|
||||
python scripts/annotate_data.py --input data/raw/generated.jsonl \
|
||||
--output data/processed/annotated.jsonl \
|
||||
--config configs/data_generation.yaml
|
||||
# Annotate + split
|
||||
python scripts/annotate_data.py \\
|
||||
--input data/raw/generated.jsonl \\
|
||||
--output data/processed/annotated.jsonl \\
|
||||
--config configs/data_generation.yaml
|
||||
|
||||
# Skip annotation (already annotated), only split
|
||||
python scripts/annotate_data.py \\
|
||||
--input data/raw/generated.jsonl \\
|
||||
--output data/processed/annotated.jsonl \\
|
||||
--skip-annotation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
import random
|
||||
import yaml
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from src.data.llm_judge import LLMJudge
|
||||
from src.data.dataset import load_jsonl
|
||||
from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize
|
||||
|
||||
|
||||
def split_dataset(samples, train_ratio=0.8, val_ratio=0.1, seed=42):
|
||||
def stratified_split(
|
||||
samples: List[Dict],
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||||
"""
|
||||
Stratified split by y_risk to ensure each split has balanced risk/safe samples.
|
||||
"""
|
||||
random.seed(seed)
|
||||
random.shuffle(samples)
|
||||
n = len(samples)
|
||||
n_train = int(n * train_ratio)
|
||||
n_val = int(n * val_ratio)
|
||||
return (
|
||||
samples[:n_train],
|
||||
samples[n_train: n_train + n_val],
|
||||
samples[n_train + n_val:],
|
||||
|
||||
risky = [s for s in samples if s.get("y_risk", 0) == 1]
|
||||
safe = [s for s in samples if s.get("y_risk", 0) == 0]
|
||||
|
||||
random.shuffle(risky)
|
||||
random.shuffle(safe)
|
||||
|
||||
def split_list(lst):
|
||||
n = len(lst)
|
||||
n_train = int(n * train_ratio)
|
||||
n_val = int(n * val_ratio)
|
||||
return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:]
|
||||
|
||||
train_r, val_r, test_r = split_list(risky)
|
||||
train_s, val_s, test_s = split_list(safe)
|
||||
|
||||
train = train_r + train_s
|
||||
val = val_r + val_s
|
||||
test = test_r + test_s
|
||||
|
||||
random.shuffle(train)
|
||||
random.shuffle(val)
|
||||
random.shuffle(test)
|
||||
|
||||
return train, val, test
|
||||
|
||||
|
||||
def print_dataset_stats(name: str, samples: List[Dict]):
|
||||
"""Print per-split statistics."""
|
||||
total = len(samples)
|
||||
n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1)
|
||||
n_safe = total - n_risky
|
||||
|
||||
level_counts = Counter(s.get("l_risk", 0) for s in samples)
|
||||
action_counts = Counter(s.get("a_recommend", "PASS") for s in samples)
|
||||
category_counts = Counter(
|
||||
s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None"
|
||||
)
|
||||
|
||||
|
||||
def save_jsonl(samples, path):
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for s in samples:
|
||||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||||
print(f"Saved {len(samples)} samples → {path}")
|
||||
print(f"\n [{name}] {total} samples "
|
||||
f"(risky={n_risky} {n_risky/total*100:.0f}%, "
|
||||
f"safe={n_safe} {n_safe/total*100:.0f}%)")
|
||||
print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }")
|
||||
print(f" Actions: { dict(action_counts) }")
|
||||
if category_counts:
|
||||
top5 = category_counts.most_common(5)
|
||||
print(f" Top categories: { {k: v for k, v in top5} }")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True)
|
||||
parser.add_argument("--output", default="data/processed/annotated.jsonl")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Annotate and split CompanionGuard-RL dataset"
|
||||
)
|
||||
parser.add_argument("--input", required=True, help="Input JSONL file path")
|
||||
parser.add_argument(
|
||||
"--output", default="data/processed/annotated.jsonl",
|
||||
help="Output annotated JSONL path"
|
||||
)
|
||||
parser.add_argument("--config", default="configs/data_generation.yaml")
|
||||
parser.add_argument("--skip-annotation", action="store_true",
|
||||
help="Skip LLM annotation (use existing labels)")
|
||||
parser.add_argument(
|
||||
"--skip-annotation", action="store_true",
|
||||
help="Skip LLM annotation (use labels already in input file)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-only", action="store_true",
|
||||
help="Only perform dataset split (no annotation), reads from --output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
samples = load_jsonl(args.input)
|
||||
print(f"Loaded {len(samples)} samples from {args.input}")
|
||||
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
|
||||
ann_cfg = cfg.get("annotation", {})
|
||||
output_dir = Path(args.output).parent
|
||||
|
||||
if not args.skip_annotation:
|
||||
# ── Step 1: Annotation ───────────────────────────────────────────────
|
||||
if args.split_only:
|
||||
samples = load_jsonl(args.output)
|
||||
print(f"Loaded {len(samples)} pre-annotated samples from {args.output}")
|
||||
elif args.skip_annotation:
|
||||
samples = load_jsonl(args.input)
|
||||
samples = [validate_and_normalize(s) for s in samples]
|
||||
save_jsonl(samples, args.output)
|
||||
print(f"Loaded and normalized {len(samples)} samples (annotation skipped)")
|
||||
else:
|
||||
judge = LLMJudge(
|
||||
api_type=cfg["api"]["type"],
|
||||
model=cfg["annotation"]["judge_model"],
|
||||
model=ann_cfg.get("judge_model", cfg["api"]["model"]),
|
||||
)
|
||||
samples = judge.annotate_from_file(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
delay=ann_cfg.get("delay", 0.3),
|
||||
max_retries=ann_cfg.get("max_retries", 2),
|
||||
)
|
||||
samples = judge.annotate_batch(samples, output_path=args.output)
|
||||
else:
|
||||
save_jsonl(samples, args.output)
|
||||
|
||||
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
|
||||
train, val, test = split_dataset(
|
||||
if not samples:
|
||||
print("[ERROR] No samples to process after annotation. Check logs above.")
|
||||
raise SystemExit(1)
|
||||
|
||||
# ── Step 2: Dataset statistics ───────────────────────────────────────
|
||||
print(f"\n{'='*50}")
|
||||
print("Dataset statistics:")
|
||||
print_dataset_stats("ALL", samples)
|
||||
|
||||
# ── Step 3: Stratified split ─────────────────────────────────────────
|
||||
train, val, test = stratified_split(
|
||||
samples,
|
||||
train_ratio=split_cfg["train"],
|
||||
val_ratio=split_cfg["val"],
|
||||
seed=split_cfg.get("seed", 42),
|
||||
)
|
||||
|
||||
base = Path(args.output).parent
|
||||
save_jsonl(train, base / "train.jsonl")
|
||||
save_jsonl(val, base / "val.jsonl")
|
||||
save_jsonl(test, base / "test.jsonl")
|
||||
save_jsonl(train, output_dir / "train.jsonl")
|
||||
save_jsonl(val, output_dir / "val.jsonl")
|
||||
save_jsonl(test, output_dir / "test.jsonl")
|
||||
|
||||
print(f"Split: train={len(train)}, val={len(val)}, test={len(test)}")
|
||||
print_dataset_stats("train", train)
|
||||
print_dataset_stats("val", val)
|
||||
print_dataset_stats("test", test)
|
||||
|
||||
print(f"\nSplit complete:")
|
||||
print(f" train : {len(train):4d} → {output_dir}/train.jsonl")
|
||||
print(f" val : {len(val):4d} → {output_dir}/val.jsonl")
|
||||
print(f" test : {len(test):4d} → {output_dir}/test.jsonl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,109 +1,168 @@
|
||||
"""
|
||||
Evaluation script: run detection + intervention baselines and ours.
|
||||
Full evaluation script for CompanionGuard-RL.
|
||||
|
||||
Runs detection and intervention evaluations against multiple baselines:
|
||||
|
||||
Detection baselines:
|
||||
- L1a: Keyword detector
|
||||
- L1b: Regex detector
|
||||
- L1c: Combined keyword+regex
|
||||
- Ours: CompanionRiskDetector (Module B)
|
||||
|
||||
Intervention baselines:
|
||||
- Rule-based (l_risk >= 3 → REJECT)
|
||||
- Threshold policy (per-level mapping)
|
||||
- RL policy (Module C, Ours)
|
||||
|
||||
Usage:
|
||||
python scripts/evaluate.py --detector-ckpt checkpoints/detector/best.pt \
|
||||
--agent-ckpt checkpoints/intervention/final.pt \
|
||||
--test-data data/processed/test.jsonl \
|
||||
--config configs/detector_config.yaml
|
||||
python scripts/evaluate.py \\
|
||||
--detector-ckpt checkpoints/detector/best.pt \\
|
||||
--agent-ckpt checkpoints/intervention/final.pt \\
|
||||
--test-data data/processed/test.jsonl \\
|
||||
--config configs/detector_config.yaml \\
|
||||
--intervention-config configs/intervention_config.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from src.data.dataset import load_jsonl, format_conversation
|
||||
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
|
||||
from src.models.detector import CompanionRiskDetector
|
||||
from src.models.intervention_agent import InterventionAgent
|
||||
from src.utils.metrics import detection_metrics, intervention_metrics
|
||||
from src.utils.baselines import (
|
||||
KeywordDetector,
|
||||
RegexDetector,
|
||||
CombinedRuleDetector,
|
||||
rule_based_intervention,
|
||||
threshold_intervention,
|
||||
)
|
||||
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
|
||||
from src.utils.taxonomy import (
|
||||
ACTION_NAME_TO_ID,
|
||||
INTERVENTION_ACTIONS,
|
||||
NUM_RISK_LEVELS,
|
||||
NUM_PRIMARY,
|
||||
DEFAULT_ACTION_BY_LEVEL,
|
||||
PRIMARY_CATEGORY_LIST,
|
||||
FINE_GRAINED_LABELS,
|
||||
)
|
||||
|
||||
|
||||
# ── Baselines ────────────────────────────────────────────────────────────────
|
||||
|
||||
def rule_based_policy(l_risk: int) -> int:
|
||||
"""Baseline: fixed rule — l_risk >= 3 → REJECT, else PASS."""
|
||||
return 3 if l_risk >= 3 else 0
|
||||
def get_obs_dim(detector_hidden: int) -> int:
|
||||
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
|
||||
|
||||
def threshold_policy(l_risk: int) -> int:
|
||||
"""Baseline: threshold mapping per risk level."""
|
||||
return DEFAULT_ACTION_BY_LEVEL[l_risk]
|
||||
# ── Detection evaluation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── Main evaluation ──────────────────────────────────────────────────────────
|
||||
|
||||
def run_detection_eval(model, tokenizer, samples, cfg, device):
|
||||
def run_neural_detection(
|
||||
model: CompanionRiskDetector,
|
||||
tokenizer,
|
||||
samples: List[Dict],
|
||||
cfg: Dict,
|
||||
device: str,
|
||||
) -> Dict:
|
||||
"""Run the neural detector on test samples and compute metrics."""
|
||||
model.eval()
|
||||
y_true, y_pred = [], []
|
||||
l_true, l_pred = [], []
|
||||
|
||||
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
|
||||
|
||||
for sample in samples:
|
||||
sample = validate_and_normalize(dict(sample))
|
||||
texts = format_conversation(
|
||||
sample["persona"], sample["history"],
|
||||
sample["user_input"], sample["ai_response"],
|
||||
max_history_turns=cfg.get("data", {}).get("max_history_turns", 5),
|
||||
)
|
||||
|
||||
def enc(text, max_len):
|
||||
return tokenizer(text, max_length=max_len, truncation=True,
|
||||
padding="max_length", return_tensors="pt")
|
||||
return tokenizer(
|
||||
text, max_length=max_len, truncation=True,
|
||||
padding="max_length", return_tensors="pt",
|
||||
)
|
||||
|
||||
p_enc = enc(texts["persona_text"], 128)
|
||||
c_enc = enc(texts["context_text"], 512)
|
||||
r_enc = enc(texts["response_text"], 256)
|
||||
p_enc = enc(texts["persona_text"], cfg.get("data", {}).get("max_persona_len", 128))
|
||||
c_enc = enc(texts["context_text"], cfg.get("data", {}).get("max_context_len", 512))
|
||||
r_enc = enc(texts["response_text"], cfg.get("data", {}).get("max_response_len", 256))
|
||||
|
||||
with torch.no_grad():
|
||||
preds = model.predict(
|
||||
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
||||
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
||||
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
||||
binary_threshold=binary_threshold,
|
||||
)
|
||||
|
||||
y_true.append(sample["y_risk"])
|
||||
y_true.append(int(sample["y_risk"]))
|
||||
y_pred.append(preds["y_risk"].item())
|
||||
l_true.append(sample["l_risk"])
|
||||
l_true.append(int(sample["l_risk"]))
|
||||
l_pred.append(preds["l_risk"].item())
|
||||
|
||||
return detection_metrics(y_true, y_pred, l_true, l_pred)
|
||||
|
||||
|
||||
def run_intervention_eval(agent, processed_samples, obs_dim, device):
|
||||
def run_keyword_detection(samples: List[Dict], detector) -> Dict:
|
||||
y_true, y_pred = [], []
|
||||
l_true, l_pred = [], []
|
||||
for s in samples:
|
||||
result = detector.detect(s.get("ai_response", ""))
|
||||
y_true.append(int(s["y_risk"]))
|
||||
y_pred.append(result["y_risk"])
|
||||
l_true.append(int(s["l_risk"]))
|
||||
l_pred.append(result["l_risk"])
|
||||
return detection_metrics(y_true, y_pred, l_true, l_pred)
|
||||
|
||||
|
||||
# ── Intervention evaluation ───────────────────────────────────────────────
|
||||
|
||||
def run_rl_intervention(
|
||||
agent: InterventionAgent,
|
||||
processed_samples: List[Dict],
|
||||
device: str,
|
||||
) -> Dict:
|
||||
agent.eval()
|
||||
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
|
||||
|
||||
for s in processed_samples:
|
||||
d_score = np.array([s["d_score"]], dtype=np.float32)
|
||||
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
|
||||
l_risk_oh[int(s["l_risk"])] = 1.0
|
||||
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
|
||||
e_H = np.array(s["e_H_pool"], dtype=np.float32)
|
||||
e_P = np.array(s["e_P_pool"], dtype=np.float32)
|
||||
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
|
||||
obs = torch.FloatTensor(
|
||||
np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
|
||||
).unsqueeze(0).to(device)
|
||||
|
||||
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
action, _, _, _ = agent.get_action(obs, deterministic=True)
|
||||
|
||||
y_risk_true.append(s["y_risk"])
|
||||
y_risk_true.append(int(s["y_risk"]))
|
||||
l_risk_true.append(int(s["l_risk"]))
|
||||
a_pred.append(action.item())
|
||||
a_recommend.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
|
||||
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
|
||||
|
||||
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
|
||||
|
||||
|
||||
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
|
||||
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
|
||||
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
|
||||
a_pred = [policy_fn(int(s["l_risk"])) for s in processed_samples]
|
||||
return intervention_metrics(y_risk_true, l_risk_true, a_pred)
|
||||
|
||||
|
||||
def print_metrics(name: str, metrics: Dict):
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f" {name}")
|
||||
print(f"{'=' * 50}")
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k:35s}: {v:.4f}")
|
||||
elif isinstance(v, list):
|
||||
formatted = [f"{x:.3f}" for x in v]
|
||||
print(f" {k:35s}: {formatted}")
|
||||
else:
|
||||
print(f" {k:35s}: {v}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--detector-ckpt", required=True)
|
||||
@@ -111,6 +170,7 @@ def main():
|
||||
parser.add_argument("--test-data", default="data/processed/test.jsonl")
|
||||
parser.add_argument("--config", default="configs/detector_config.yaml")
|
||||
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
|
||||
parser.add_argument("--output", default="experiments/eval_results.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
@@ -119,74 +179,93 @@ def main():
|
||||
int_cfg = yaml.safe_load(f)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
||||
print(f"Device: {device}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
||||
samples = load_jsonl(args.test_data)
|
||||
print(f"Loaded {len(samples)} test samples.")
|
||||
|
||||
# Detection evaluation
|
||||
# ── Neural detector ──────────────────────────────────────────────────
|
||||
detector = CompanionRiskDetector(
|
||||
model_name=cfg["model"]["name"],
|
||||
hidden_size=cfg["model"]["hidden_size"],
|
||||
num_heads=cfg["model"]["num_heads"],
|
||||
dropout=cfg["model"]["dropout"],
|
||||
use_lora=cfg["model"]["use_lora"],
|
||||
).to(device)
|
||||
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
|
||||
|
||||
print("\n=== Detection Evaluation ===")
|
||||
det_metrics = run_detection_eval(detector, tokenizer, samples, cfg, device)
|
||||
for k, v in det_metrics.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k}: {v:.4f}")
|
||||
if Path(args.detector_ckpt).exists():
|
||||
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
|
||||
print(f"Detector loaded from {args.detector_ckpt}")
|
||||
else:
|
||||
print(f"[WARN] Checkpoint not found: {args.detector_ckpt}. Using random weights.")
|
||||
|
||||
# Intervention evaluation
|
||||
# ── Detection evaluation ─────────────────────────────────────────────
|
||||
all_results = {}
|
||||
|
||||
print("\nRunning: L1a — Keyword Detector")
|
||||
kw_metrics = run_keyword_detection(samples, KeywordDetector())
|
||||
print_metrics("L1a: Keyword Detector", kw_metrics)
|
||||
all_results["L1a_keyword"] = kw_metrics
|
||||
|
||||
print("\nRunning: L1b — Regex Detector")
|
||||
re_metrics = run_keyword_detection(samples, RegexDetector())
|
||||
print_metrics("L1b: Regex Detector", re_metrics)
|
||||
all_results["L1b_regex"] = re_metrics
|
||||
|
||||
print("\nRunning: L1c — Combined Keyword+Regex")
|
||||
cb_metrics = run_keyword_detection(samples, CombinedRuleDetector())
|
||||
print_metrics("L1c: Combined Detector", cb_metrics)
|
||||
all_results["L1c_combined"] = cb_metrics
|
||||
|
||||
print("\nRunning: Ours — Neural Detector (Module B)")
|
||||
neural_metrics = run_neural_detection(detector, tokenizer, samples, cfg, device)
|
||||
print_metrics("Ours: CompanionRiskDetector", neural_metrics)
|
||||
all_results["ours_detection"] = neural_metrics
|
||||
|
||||
# ── Intervention evaluation ──────────────────────────────────────────
|
||||
if args.agent_ckpt:
|
||||
from scripts.train_intervention import preprocess_samples_with_detector
|
||||
print("\n\nPreprocessing test samples with detector for RL evaluation...")
|
||||
processed = preprocess_samples_with_detector(
|
||||
samples, detector, tokenizer, device=device,
|
||||
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
||||
)
|
||||
|
||||
# Build RL agent
|
||||
detector_hidden = cfg["model"]["hidden_size"]
|
||||
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
|
||||
processed = preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device)
|
||||
|
||||
obs_dim = get_obs_dim(detector_hidden)
|
||||
agent = InterventionAgent(
|
||||
detector_hidden=detector_hidden,
|
||||
state_hidden=int_cfg["agent"]["state_hidden"],
|
||||
dropout=int_cfg["agent"]["dropout"],
|
||||
).to(device)
|
||||
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
|
||||
|
||||
print("\n=== Intervention Evaluation: RL Policy (Ours) ===")
|
||||
int_metrics = run_intervention_eval(agent, processed, obs_dim, device)
|
||||
for k, v in int_metrics.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k}: {v:.4f}")
|
||||
elif isinstance(v, list):
|
||||
print(f" {k}: {[f'{x:.3f}' for x in v]}")
|
||||
if Path(args.agent_ckpt).exists():
|
||||
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
|
||||
print(f"Agent loaded from {args.agent_ckpt}")
|
||||
else:
|
||||
print(f"[WARN] Agent checkpoint not found: {args.agent_ckpt}. Using random weights.")
|
||||
|
||||
print("\n=== Intervention Evaluation: Rule-based Baseline ===")
|
||||
rule_preds = [rule_based_policy(s["l_risk"]) for s in processed]
|
||||
rule_metrics = intervention_metrics(
|
||||
[s["y_risk"] for s in processed],
|
||||
[s["l_risk"] for s in processed],
|
||||
rule_preds,
|
||||
)
|
||||
for k, v in rule_metrics.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k}: {v:.4f}")
|
||||
print("\nRunning: Rule-based Intervention Baseline (l_risk >= 3 → REJECT)")
|
||||
rule_m = run_rule_intervention(processed, rule_based_intervention)
|
||||
print_metrics("Rule-based Policy", rule_m)
|
||||
all_results["baseline_rule"] = rule_m
|
||||
|
||||
print("\n=== Intervention Evaluation: Threshold Baseline ===")
|
||||
thr_preds = [threshold_policy(s["l_risk"]) for s in processed]
|
||||
thr_metrics = intervention_metrics(
|
||||
[s["y_risk"] for s in processed],
|
||||
[s["l_risk"] for s in processed],
|
||||
thr_preds,
|
||||
)
|
||||
for k, v in thr_metrics.items():
|
||||
if isinstance(v, float):
|
||||
print(f" {k}: {v:.4f}")
|
||||
print("\nRunning: Threshold Intervention Baseline")
|
||||
thr_m = run_rule_intervention(processed, threshold_intervention)
|
||||
print_metrics("Threshold Policy", thr_m)
|
||||
all_results["baseline_threshold"] = thr_m
|
||||
|
||||
# Save results
|
||||
results = {"detection": det_metrics}
|
||||
Path("experiments").mkdir(exist_ok=True)
|
||||
with open("experiments/eval_results.json", "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
print("\nResults saved to experiments/eval_results.json")
|
||||
print("\nRunning: Ours — RL Intervention Policy (Module C)")
|
||||
rl_m = run_rl_intervention(agent, processed, device)
|
||||
print_metrics("Ours: RL Intervention Policy", rl_m)
|
||||
all_results["ours_intervention"] = rl_m
|
||||
|
||||
# ── Save results ─────────────────────────────────────────────────────
|
||||
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
|
||||
print(f"\nAll results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,39 +1,83 @@
|
||||
"""
|
||||
Step 1: Generate companion conversation dataset using LLM.
|
||||
|
||||
Generates multi-turn conversations covering all 10 risk categories plus
|
||||
safe (benign) negative samples.
|
||||
|
||||
Usage:
|
||||
python scripts/generate_data.py --config configs/data_generation.yaml
|
||||
|
||||
Environment variables required (set before running):
|
||||
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
|
||||
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from src.data.data_generator import ConversationGenerator
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate CompanionGuard-RL dataset via LLM API"
|
||||
)
|
||||
parser.add_argument("--config", default="configs/data_generation.yaml")
|
||||
parser.add_argument(
|
||||
"--total", type=int, default=None,
|
||||
help="Override total_samples from config"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe-ratio", type=float, default=None,
|
||||
help="Override safe sample ratio (default 0.25)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default=None,
|
||||
help="Override output file path"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# Apply CLI overrides
|
||||
total_samples = args.total or cfg["generation"]["total_samples"]
|
||||
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
|
||||
output_file = args.output or cfg["output"]["output_file"]
|
||||
|
||||
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check API key
|
||||
api_type = cfg["api"]["type"]
|
||||
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
|
||||
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
|
||||
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
|
||||
raise SystemExit(1)
|
||||
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
|
||||
print("[ERROR] OPENAI_API_KEY environment variable not set.")
|
||||
print("Set it with: export OPENAI_API_KEY=your_api_key")
|
||||
raise SystemExit(1)
|
||||
|
||||
generator = ConversationGenerator(
|
||||
api_type=cfg["api"]["type"],
|
||||
api_type=api_type,
|
||||
model=cfg["api"]["model"],
|
||||
)
|
||||
|
||||
print(f"Generating {total_samples} samples "
|
||||
f"({int(total_samples * (1 - safe_ratio))} risky + "
|
||||
f"{int(total_samples * safe_ratio)} safe)")
|
||||
print(f"Output: {output_file}")
|
||||
|
||||
count = generator.generate_dataset(
|
||||
output_path=cfg["output"]["output_file"],
|
||||
total_samples=cfg["generation"]["total_samples"],
|
||||
samples_per_category=cfg["generation"]["samples_per_category"],
|
||||
output_path=output_file,
|
||||
total_samples=total_samples,
|
||||
safe_ratio=safe_ratio,
|
||||
delay=cfg["generation"]["delay"],
|
||||
max_retries=cfg["generation"].get("max_retries", 3),
|
||||
)
|
||||
|
||||
print(f"Generated {count} samples → {cfg['output']['output_file']}")
|
||||
print(f"\nGeneration complete: {count} samples → {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,98 +6,32 @@ Two-stage training:
|
||||
Stage 2: PPO fine-tuning with multi-objective reward
|
||||
|
||||
Usage:
|
||||
python scripts/train_intervention.py --config configs/intervention_config.yaml
|
||||
python scripts/train_intervention.py --config configs/intervention_config.yaml \
|
||||
--train-data data/processed/train.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
import wandb
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from src.data.dataset import load_jsonl
|
||||
from src.models.detector import CompanionRiskDetector
|
||||
from src.models.intervention_agent import InterventionAgent
|
||||
from src.rl.companion_env import CompanionEnv
|
||||
from src.rl.ppo_trainer import PPOTrainer
|
||||
from src.utils.taxonomy import (
|
||||
ACTION_NAME_TO_ID,
|
||||
NUM_RISK_LEVELS,
|
||||
NUM_PRIMARY,
|
||||
category_to_index,
|
||||
from src.utils.preprocessing import (
|
||||
preprocess_samples_with_detector,
|
||||
build_bc_tensors,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
|
||||
import wandb
|
||||
|
||||
|
||||
def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device):
|
||||
"""Run detector on all samples to extract state vectors for RL env."""
|
||||
from src.data.dataset import format_conversation
|
||||
|
||||
processed = []
|
||||
detector.eval()
|
||||
|
||||
for sample in samples:
|
||||
texts = format_conversation(
|
||||
sample["persona"],
|
||||
sample["history"],
|
||||
sample["user_input"],
|
||||
sample["ai_response"],
|
||||
)
|
||||
|
||||
def enc(text, max_len):
|
||||
return tokenizer(
|
||||
text, max_length=max_len, truncation=True,
|
||||
padding="max_length", return_tensors="pt",
|
||||
)
|
||||
|
||||
p_enc = enc(texts["persona_text"], 128)
|
||||
c_enc = enc(texts["context_text"], 512)
|
||||
r_enc = enc(texts["response_text"], 256)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = detector.predict(
|
||||
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
||||
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
||||
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
||||
)
|
||||
|
||||
# Build persona/history pool embeddings (reuse e_fused as approximation)
|
||||
e_fused = preds["e_fused"].squeeze(0).cpu().numpy()
|
||||
|
||||
processed.append({
|
||||
**sample,
|
||||
"d_score": preds["d_score"].item(),
|
||||
"l_risk": preds["l_risk"].item(),
|
||||
"c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(),
|
||||
"c_primary_idx": preds["c_primary"].item(),
|
||||
"e_H_pool": e_fused.tolist(),
|
||||
"e_P_pool": e_fused.tolist(),
|
||||
"a_recommend": sample.get("a_recommend", "PASS"),
|
||||
})
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def build_bc_tensors(processed_samples, obs_dim, device):
|
||||
"""Build observation and expert action tensors for behavior cloning."""
|
||||
obs_list, action_list = [], []
|
||||
|
||||
for s in processed_samples:
|
||||
d_score = np.array([s["d_score"]], dtype=np.float32)
|
||||
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
|
||||
l_risk_oh[int(s["l_risk"])] = 1.0
|
||||
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
|
||||
e_H = np.array(s["e_H_pool"], dtype=np.float32)
|
||||
e_P = np.array(s["e_P_pool"], dtype=np.float32)
|
||||
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
|
||||
obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
|
||||
obs_list.append(obs)
|
||||
action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
|
||||
|
||||
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
|
||||
action_tensor = torch.LongTensor(action_list).to(device)
|
||||
return obs_tensor, action_tensor
|
||||
def get_obs_dim(detector_hidden: int) -> int:
|
||||
"""Compute observation vector dimension."""
|
||||
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
|
||||
|
||||
def main():
|
||||
@@ -110,7 +44,7 @@ def main():
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
print(f"Device: {device}")
|
||||
|
||||
if cfg["logging"]["use_wandb"]:
|
||||
wandb.init(
|
||||
@@ -120,29 +54,46 @@ def main():
|
||||
)
|
||||
|
||||
# Load detector
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"])
|
||||
detector_cfg = cfg["detector"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
|
||||
detector = CompanionRiskDetector(
|
||||
model_name=cfg["detector"]["model_name"],
|
||||
hidden_size=cfg["detector"]["hidden_size"],
|
||||
model_name=detector_cfg["model_name"],
|
||||
hidden_size=detector_cfg["hidden_size"],
|
||||
).to(device)
|
||||
detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device))
|
||||
detector.eval()
|
||||
print("Detector loaded.")
|
||||
|
||||
# Load and preprocess training data
|
||||
ckpt_path = detector_cfg["checkpoint"]
|
||||
if Path(ckpt_path).exists():
|
||||
detector.load_state_dict(torch.load(ckpt_path, map_location=device))
|
||||
print(f"Detector loaded from {ckpt_path}")
|
||||
else:
|
||||
print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.")
|
||||
|
||||
detector.eval()
|
||||
|
||||
# Pre-process training data through the detector
|
||||
print(f"Loading training data: {args.train_data}")
|
||||
raw_samples = load_jsonl(args.train_data)
|
||||
print(f"Preprocessing {len(raw_samples)} samples with detector...")
|
||||
processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device)
|
||||
|
||||
detector_hidden = cfg["detector"]["hidden_size"]
|
||||
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
processed = preprocess_samples_with_detector(
|
||||
raw_samples,
|
||||
detector,
|
||||
tokenizer,
|
||||
device=device,
|
||||
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
||||
)
|
||||
|
||||
# Build RL agent
|
||||
detector_hidden = detector_cfg["hidden_size"]
|
||||
obs_dim = get_obs_dim(detector_hidden)
|
||||
print(f"Observation dimension: {obs_dim}")
|
||||
|
||||
# Build the RL agent
|
||||
agent_cfg = cfg["agent"]
|
||||
agent = InterventionAgent(
|
||||
detector_hidden=detector_hidden,
|
||||
state_hidden=cfg["agent"]["state_hidden"],
|
||||
dropout=cfg["agent"]["dropout"],
|
||||
)
|
||||
state_hidden=agent_cfg["state_hidden"],
|
||||
dropout=agent_cfg["dropout"],
|
||||
).to(device)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
agent=agent,
|
||||
@@ -162,34 +113,38 @@ def main():
|
||||
)
|
||||
|
||||
# Stage 1: Behavior cloning warm-up
|
||||
if cfg["behavior_cloning"]["enabled"]:
|
||||
print("Stage 1: Behavior cloning warm-up...")
|
||||
obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device)
|
||||
bc_cfg = cfg.get("behavior_cloning", {})
|
||||
if bc_cfg.get("enabled", True):
|
||||
print("\n=== Stage 1: Behavior Cloning Warm-up ===")
|
||||
obs_tensor, action_tensor = build_bc_tensors(processed, device=device)
|
||||
trainer.behavior_cloning_warmup(
|
||||
obs_tensor, action_tensor,
|
||||
n_epochs=cfg["behavior_cloning"]["epochs"],
|
||||
lr=cfg["behavior_cloning"]["lr"],
|
||||
obs_tensor,
|
||||
action_tensor,
|
||||
n_epochs=bc_cfg.get("epochs", 5),
|
||||
lr=bc_cfg.get("lr", 1e-3),
|
||||
)
|
||||
|
||||
# Stage 2: PPO fine-tuning
|
||||
print("Stage 2: PPO fine-tuning...")
|
||||
print("\n=== Stage 2: PPO Fine-tuning ===")
|
||||
env_cfg = cfg.get("environment", {})
|
||||
env = CompanionEnv(
|
||||
samples=processed,
|
||||
detector_hidden=detector_hidden,
|
||||
reward_weights=cfg["reward"],
|
||||
max_turns=cfg["environment"]["max_turns"],
|
||||
reward_weights=cfg.get("reward"),
|
||||
max_turns=env_cfg.get("max_turns", 20),
|
||||
)
|
||||
|
||||
Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
|
||||
output_cfg = cfg["output"]
|
||||
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.train(
|
||||
env=env,
|
||||
total_timesteps=cfg["ppo"]["total_timesteps"],
|
||||
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
|
||||
checkpoint_dir=cfg["output"]["checkpoint_dir"],
|
||||
save_interval=cfg["output"]["save_interval"],
|
||||
checkpoint_dir=output_cfg["checkpoint_dir"],
|
||||
save_interval=output_cfg.get("save_interval", 10_000),
|
||||
)
|
||||
|
||||
torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt")
|
||||
print("Training complete.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user