refactor: complete full implementation replacing all placeholder/mock content

Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-09 17:50:17 +08:00
parent 7d4345c29d
commit 4a0e71fb23
13 changed files with 1829 additions and 564 deletions

View File

@@ -1,80 +1,171 @@
"""
Step 2: LLM judge pre-annotation.
Step 2: LLM judge pre-annotation and dataset split.
Reads raw generated JSONL, runs LLM annotation on each sample,
then splits into train/val/test sets with stratified sampling by y_risk.
Usage:
python scripts/annotate_data.py --input data/raw/generated.jsonl \
--output data/processed/annotated.jsonl \
--config configs/data_generation.yaml
# Annotate + split
python scripts/annotate_data.py \\
--input data/raw/generated.jsonl \\
--output data/processed/annotated.jsonl \\
--config configs/data_generation.yaml
# Skip annotation (already annotated), only split
python scripts/annotate_data.py \\
--input data/raw/generated.jsonl \\
--output data/processed/annotated.jsonl \\
--skip-annotation
"""
import argparse
import json
import yaml
import random
import yaml
from collections import Counter
from pathlib import Path
from typing import List, Dict, Tuple
from src.data.llm_judge import LLMJudge
from src.data.dataset import load_jsonl
from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize
def split_dataset(samples, train_ratio=0.8, val_ratio=0.1, seed=42):
def stratified_split(
samples: List[Dict],
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
Stratified split by y_risk to ensure each split has balanced risk/safe samples.
"""
random.seed(seed)
random.shuffle(samples)
n = len(samples)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
return (
samples[:n_train],
samples[n_train: n_train + n_val],
samples[n_train + n_val:],
risky = [s for s in samples if s.get("y_risk", 0) == 1]
safe = [s for s in samples if s.get("y_risk", 0) == 0]
random.shuffle(risky)
random.shuffle(safe)
def split_list(lst):
n = len(lst)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:]
train_r, val_r, test_r = split_list(risky)
train_s, val_s, test_s = split_list(safe)
train = train_r + train_s
val = val_r + val_s
test = test_r + test_s
random.shuffle(train)
random.shuffle(val)
random.shuffle(test)
return train, val, test
def print_dataset_stats(name: str, samples: List[Dict]):
"""Print per-split statistics."""
total = len(samples)
n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1)
n_safe = total - n_risky
level_counts = Counter(s.get("l_risk", 0) for s in samples)
action_counts = Counter(s.get("a_recommend", "PASS") for s in samples)
category_counts = Counter(
s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None"
)
def save_jsonl(samples, path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples → {path}")
print(f"\n [{name}] {total} samples "
f"(risky={n_risky} {n_risky/total*100:.0f}%, "
f"safe={n_safe} {n_safe/total*100:.0f}%)")
print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }")
print(f" Actions: { dict(action_counts) }")
if category_counts:
top5 = category_counts.most_common(5)
print(f" Top categories: { {k: v for k, v in top5} }")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--output", default="data/processed/annotated.jsonl")
parser = argparse.ArgumentParser(
description="Annotate and split CompanionGuard-RL dataset"
)
parser.add_argument("--input", required=True, help="Input JSONL file path")
parser.add_argument(
"--output", default="data/processed/annotated.jsonl",
help="Output annotated JSONL path"
)
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument("--skip-annotation", action="store_true",
help="Skip LLM annotation (use existing labels)")
parser.add_argument(
"--skip-annotation", action="store_true",
help="Skip LLM annotation (use labels already in input file)"
)
parser.add_argument(
"--split-only", action="store_true",
help="Only perform dataset split (no annotation), reads from --output"
)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
samples = load_jsonl(args.input)
print(f"Loaded {len(samples)} samples from {args.input}")
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
ann_cfg = cfg.get("annotation", {})
output_dir = Path(args.output).parent
if not args.skip_annotation:
# ── Step 1: Annotation ───────────────────────────────────────────────
if args.split_only:
samples = load_jsonl(args.output)
print(f"Loaded {len(samples)} pre-annotated samples from {args.output}")
elif args.skip_annotation:
samples = load_jsonl(args.input)
samples = [validate_and_normalize(s) for s in samples]
save_jsonl(samples, args.output)
print(f"Loaded and normalized {len(samples)} samples (annotation skipped)")
else:
judge = LLMJudge(
api_type=cfg["api"]["type"],
model=cfg["annotation"]["judge_model"],
model=ann_cfg.get("judge_model", cfg["api"]["model"]),
)
samples = judge.annotate_from_file(
input_path=args.input,
output_path=args.output,
delay=ann_cfg.get("delay", 0.3),
max_retries=ann_cfg.get("max_retries", 2),
)
samples = judge.annotate_batch(samples, output_path=args.output)
else:
save_jsonl(samples, args.output)
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
train, val, test = split_dataset(
if not samples:
print("[ERROR] No samples to process after annotation. Check logs above.")
raise SystemExit(1)
# ── Step 2: Dataset statistics ───────────────────────────────────────
print(f"\n{'='*50}")
print("Dataset statistics:")
print_dataset_stats("ALL", samples)
# ── Step 3: Stratified split ─────────────────────────────────────────
train, val, test = stratified_split(
samples,
train_ratio=split_cfg["train"],
val_ratio=split_cfg["val"],
seed=split_cfg.get("seed", 42),
)
base = Path(args.output).parent
save_jsonl(train, base / "train.jsonl")
save_jsonl(val, base / "val.jsonl")
save_jsonl(test, base / "test.jsonl")
save_jsonl(train, output_dir / "train.jsonl")
save_jsonl(val, output_dir / "val.jsonl")
save_jsonl(test, output_dir / "test.jsonl")
print(f"Split: train={len(train)}, val={len(val)}, test={len(test)}")
print_dataset_stats("train", train)
print_dataset_stats("val", val)
print_dataset_stats("test", test)
print(f"\nSplit complete:")
print(f" train : {len(train):4d}{output_dir}/train.jsonl")
print(f" val : {len(val):4d}{output_dir}/val.jsonl")
print(f" test : {len(test):4d}{output_dir}/test.jsonl")
if __name__ == "__main__":

View File

@@ -1,109 +1,168 @@
"""
Evaluation script: run detection + intervention baselines and ours.
Full evaluation script for CompanionGuard-RL.
Runs detection and intervention evaluations against multiple baselines:
Detection baselines:
- L1a: Keyword detector
- L1b: Regex detector
- L1c: Combined keyword+regex
- Ours: CompanionRiskDetector (Module B)
Intervention baselines:
- Rule-based (l_risk >= 3 → REJECT)
- Threshold policy (per-level mapping)
- RL policy (Module C, Ours)
Usage:
python scripts/evaluate.py --detector-ckpt checkpoints/detector/best.pt \
--agent-ckpt checkpoints/intervention/final.pt \
--test-data data/processed/test.jsonl \
--config configs/detector_config.yaml
python scripts/evaluate.py \\
--detector-ckpt checkpoints/detector/best.pt \\
--agent-ckpt checkpoints/intervention/final.pt \\
--test-data data/processed/test.jsonl \\
--config configs/detector_config.yaml \\
--intervention-config configs/intervention_config.yaml
"""
import argparse
import yaml
import json
import yaml
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl, format_conversation
from src.data.dataset import load_jsonl, format_conversation, validate_and_normalize
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.utils.metrics import detection_metrics, intervention_metrics
from src.utils.baselines import (
KeywordDetector,
RegexDetector,
CombinedRuleDetector,
rule_based_intervention,
threshold_intervention,
)
from src.utils.preprocessing import preprocess_samples_with_detector, build_obs_vector
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
INTERVENTION_ACTIONS,
NUM_RISK_LEVELS,
NUM_PRIMARY,
DEFAULT_ACTION_BY_LEVEL,
PRIMARY_CATEGORY_LIST,
FINE_GRAINED_LABELS,
)
# ── Baselines ────────────────────────────────────────────────────────────────
def rule_based_policy(l_risk: int) -> int:
"""Baseline: fixed rule — l_risk >= 3 → REJECT, else PASS."""
return 3 if l_risk >= 3 else 0
def get_obs_dim(detector_hidden: int) -> int:
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
def threshold_policy(l_risk: int) -> int:
"""Baseline: threshold mapping per risk level."""
return DEFAULT_ACTION_BY_LEVEL[l_risk]
# ── Detection evaluation ──────────────────────────────────────────────────
# ── Main evaluation ──────────────────────────────────────────────────────────
def run_detection_eval(model, tokenizer, samples, cfg, device):
def run_neural_detection(
model: CompanionRiskDetector,
tokenizer,
samples: List[Dict],
cfg: Dict,
device: str,
) -> Dict:
"""Run the neural detector on test samples and compute metrics."""
model.eval()
y_true, y_pred = [], []
l_true, l_pred = [], []
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
for sample in samples:
sample = validate_and_normalize(dict(sample))
texts = format_conversation(
sample["persona"], sample["history"],
sample["user_input"], sample["ai_response"],
max_history_turns=cfg.get("data", {}).get("max_history_turns", 5),
)
def enc(text, max_len):
return tokenizer(text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt")
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
p_enc = enc(texts["persona_text"], cfg.get("data", {}).get("max_persona_len", 128))
c_enc = enc(texts["context_text"], cfg.get("data", {}).get("max_context_len", 512))
r_enc = enc(texts["response_text"], cfg.get("data", {}).get("max_response_len", 256))
with torch.no_grad():
preds = model.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
binary_threshold=binary_threshold,
)
y_true.append(sample["y_risk"])
y_true.append(int(sample["y_risk"]))
y_pred.append(preds["y_risk"].item())
l_true.append(sample["l_risk"])
l_true.append(int(sample["l_risk"]))
l_pred.append(preds["l_risk"].item())
return detection_metrics(y_true, y_pred, l_true, l_pred)
def run_intervention_eval(agent, processed_samples, obs_dim, device):
def run_keyword_detection(samples: List[Dict], detector) -> Dict:
y_true, y_pred = [], []
l_true, l_pred = [], []
for s in samples:
result = detector.detect(s.get("ai_response", ""))
y_true.append(int(s["y_risk"]))
y_pred.append(result["y_risk"])
l_true.append(int(s["l_risk"]))
l_pred.append(result["l_risk"])
return detection_metrics(y_true, y_pred, l_true, l_pred)
# ── Intervention evaluation ───────────────────────────────────────────────
def run_rl_intervention(
agent: InterventionAgent,
processed_samples: List[Dict],
device: str,
) -> Dict:
agent.eval()
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = torch.FloatTensor(
np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
).unsqueeze(0).to(device)
obs = torch.FloatTensor(build_obs_vector(s)).unsqueeze(0).to(device)
with torch.no_grad():
action, _, _, _ = agent.get_action(obs, deterministic=True)
y_risk_true.append(s["y_risk"])
y_risk_true.append(int(s["y_risk"]))
l_risk_true.append(int(s["l_risk"]))
a_pred.append(action.item())
a_recommend.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
a_recommend.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
def run_rule_intervention(processed_samples: List[Dict], policy_fn) -> Dict:
y_risk_true = [int(s["y_risk"]) for s in processed_samples]
l_risk_true = [int(s["l_risk"]) for s in processed_samples]
a_pred = [policy_fn(int(s["l_risk"])) for s in processed_samples]
return intervention_metrics(y_risk_true, l_risk_true, a_pred)
def print_metrics(name: str, metrics: Dict):
print(f"\n{'=' * 50}")
print(f" {name}")
print(f"{'=' * 50}")
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:35s}: {v:.4f}")
elif isinstance(v, list):
formatted = [f"{x:.3f}" for x in v]
print(f" {k:35s}: {formatted}")
else:
print(f" {k:35s}: {v}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--detector-ckpt", required=True)
@@ -111,6 +170,7 @@ def main():
parser.add_argument("--test-data", default="data/processed/test.jsonl")
parser.add_argument("--config", default="configs/detector_config.yaml")
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
parser.add_argument("--output", default="experiments/eval_results.json")
args = parser.parse_args()
with open(args.config) as f:
@@ -119,74 +179,93 @@ def main():
int_cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
print(f"Device: {device}")
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
samples = load_jsonl(args.test_data)
print(f"Loaded {len(samples)} test samples.")
# Detection evaluation
# ── Neural detector ──────────────────────────────────────────────────
detector = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
num_heads=cfg["model"]["num_heads"],
dropout=cfg["model"]["dropout"],
use_lora=cfg["model"]["use_lora"],
).to(device)
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print("\n=== Detection Evaluation ===")
det_metrics = run_detection_eval(detector, tokenizer, samples, cfg, device)
for k, v in det_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
if Path(args.detector_ckpt).exists():
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print(f"Detector loaded from {args.detector_ckpt}")
else:
print(f"[WARN] Checkpoint not found: {args.detector_ckpt}. Using random weights.")
# Intervention evaluation
# ── Detection evaluation ─────────────────────────────────────────────
all_results = {}
print("\nRunning: L1a — Keyword Detector")
kw_metrics = run_keyword_detection(samples, KeywordDetector())
print_metrics("L1a: Keyword Detector", kw_metrics)
all_results["L1a_keyword"] = kw_metrics
print("\nRunning: L1b — Regex Detector")
re_metrics = run_keyword_detection(samples, RegexDetector())
print_metrics("L1b: Regex Detector", re_metrics)
all_results["L1b_regex"] = re_metrics
print("\nRunning: L1c — Combined Keyword+Regex")
cb_metrics = run_keyword_detection(samples, CombinedRuleDetector())
print_metrics("L1c: Combined Detector", cb_metrics)
all_results["L1c_combined"] = cb_metrics
print("\nRunning: Ours — Neural Detector (Module B)")
neural_metrics = run_neural_detection(detector, tokenizer, samples, cfg, device)
print_metrics("Ours: CompanionRiskDetector", neural_metrics)
all_results["ours_detection"] = neural_metrics
# ── Intervention evaluation ──────────────────────────────────────────
if args.agent_ckpt:
from scripts.train_intervention import preprocess_samples_with_detector
print("\n\nPreprocessing test samples with detector for RL evaluation...")
processed = preprocess_samples_with_detector(
samples, detector, tokenizer, device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
# Build RL agent
detector_hidden = cfg["model"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
processed = preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device)
obs_dim = get_obs_dim(detector_hidden)
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=int_cfg["agent"]["state_hidden"],
dropout=int_cfg["agent"]["dropout"],
).to(device)
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print("\n=== Intervention Evaluation: RL Policy (Ours) ===")
int_metrics = run_intervention_eval(agent, processed, obs_dim, device)
for k, v in int_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
elif isinstance(v, list):
print(f" {k}: {[f'{x:.3f}' for x in v]}")
if Path(args.agent_ckpt).exists():
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print(f"Agent loaded from {args.agent_ckpt}")
else:
print(f"[WARN] Agent checkpoint not found: {args.agent_ckpt}. Using random weights.")
print("\n=== Intervention Evaluation: Rule-based Baseline ===")
rule_preds = [rule_based_policy(s["l_risk"]) for s in processed]
rule_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
rule_preds,
)
for k, v in rule_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
print("\nRunning: Rule-based Intervention Baseline (l_risk >= 3 → REJECT)")
rule_m = run_rule_intervention(processed, rule_based_intervention)
print_metrics("Rule-based Policy", rule_m)
all_results["baseline_rule"] = rule_m
print("\n=== Intervention Evaluation: Threshold Baseline ===")
thr_preds = [threshold_policy(s["l_risk"]) for s in processed]
thr_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
thr_preds,
)
for k, v in thr_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
print("\nRunning: Threshold Intervention Baseline")
thr_m = run_rule_intervention(processed, threshold_intervention)
print_metrics("Threshold Policy", thr_m)
all_results["baseline_threshold"] = thr_m
# Save results
results = {"detection": det_metrics}
Path("experiments").mkdir(exist_ok=True)
with open("experiments/eval_results.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print("\nResults saved to experiments/eval_results.json")
print("\nRunning: Ours — RL Intervention Policy (Module C)")
rl_m = run_rl_intervention(agent, processed, device)
print_metrics("Ours: RL Intervention Policy", rl_m)
all_results["ours_intervention"] = rl_m
# ── Save results ─────────────────────────────────────────────────────
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2, default=str, ensure_ascii=False)
print(f"\nAll results saved to {args.output}")
if __name__ == "__main__":

View File

@@ -1,39 +1,83 @@
"""
Step 1: Generate companion conversation dataset using LLM.
Generates multi-turn conversations covering all 10 risk categories plus
safe (benign) negative samples.
Usage:
python scripts/generate_data.py --config configs/data_generation.yaml
Environment variables required (set before running):
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
"""
import argparse
import os
import yaml
from pathlib import Path
from src.data.data_generator import ConversationGenerator
def main():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Generate CompanionGuard-RL dataset via LLM API"
)
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument(
"--total", type=int, default=None,
help="Override total_samples from config"
)
parser.add_argument(
"--safe-ratio", type=float, default=None,
help="Override safe sample ratio (default 0.25)"
)
parser.add_argument(
"--output", type=str, default=None,
help="Override output file path"
)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
# Apply CLI overrides
total_samples = args.total or cfg["generation"]["total_samples"]
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
output_file = args.output or cfg["output"]["output_file"]
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
# Check API key
api_type = cfg["api"]["type"]
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
raise SystemExit(1)
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
print("[ERROR] OPENAI_API_KEY environment variable not set.")
print("Set it with: export OPENAI_API_KEY=your_api_key")
raise SystemExit(1)
generator = ConversationGenerator(
api_type=cfg["api"]["type"],
api_type=api_type,
model=cfg["api"]["model"],
)
print(f"Generating {total_samples} samples "
f"({int(total_samples * (1 - safe_ratio))} risky + "
f"{int(total_samples * safe_ratio)} safe)")
print(f"Output: {output_file}")
count = generator.generate_dataset(
output_path=cfg["output"]["output_file"],
total_samples=cfg["generation"]["total_samples"],
samples_per_category=cfg["generation"]["samples_per_category"],
output_path=output_file,
total_samples=total_samples,
safe_ratio=safe_ratio,
delay=cfg["generation"]["delay"],
max_retries=cfg["generation"].get("max_retries", 3),
)
print(f"Generated {count} samples → {cfg['output']['output_file']}")
print(f"\nGeneration complete: {count} samples → {output_file}")
if __name__ == "__main__":

View File

@@ -6,98 +6,32 @@ Two-stage training:
Stage 2: PPO fine-tuning with multi-objective reward
Usage:
python scripts/train_intervention.py --config configs/intervention_config.yaml
python scripts/train_intervention.py --config configs/intervention_config.yaml \
--train-data data/processed/train.jsonl
"""
import argparse
import yaml
import torch
import numpy as np
import wandb
from pathlib import Path
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.rl.companion_env import CompanionEnv
from src.rl.ppo_trainer import PPOTrainer
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
category_to_index,
from src.utils.preprocessing import (
preprocess_samples_with_detector,
build_bc_tensors,
)
from transformers import AutoTokenizer
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
import wandb
def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device):
"""Run detector on all samples to extract state vectors for RL env."""
from src.data.dataset import format_conversation
processed = []
detector.eval()
for sample in samples:
texts = format_conversation(
sample["persona"],
sample["history"],
sample["user_input"],
sample["ai_response"],
)
def enc(text, max_len):
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
with torch.no_grad():
preds = detector.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
)
# Build persona/history pool embeddings (reuse e_fused as approximation)
e_fused = preds["e_fused"].squeeze(0).cpu().numpy()
processed.append({
**sample,
"d_score": preds["d_score"].item(),
"l_risk": preds["l_risk"].item(),
"c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(),
"c_primary_idx": preds["c_primary"].item(),
"e_H_pool": e_fused.tolist(),
"e_P_pool": e_fused.tolist(),
"a_recommend": sample.get("a_recommend", "PASS"),
})
return processed
def build_bc_tensors(processed_samples, obs_dim, device):
"""Build observation and expert action tensors for behavior cloning."""
obs_list, action_list = [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
obs_list.append(obs)
action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
action_tensor = torch.LongTensor(action_list).to(device)
return obs_tensor, action_tensor
def get_obs_dim(detector_hidden: int) -> int:
"""Compute observation vector dimension."""
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
def main():
@@ -110,7 +44,7 @@ def main():
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Device: {device}")
if cfg["logging"]["use_wandb"]:
wandb.init(
@@ -120,29 +54,46 @@ def main():
)
# Load detector
tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"])
detector_cfg = cfg["detector"]
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
detector = CompanionRiskDetector(
model_name=cfg["detector"]["model_name"],
hidden_size=cfg["detector"]["hidden_size"],
model_name=detector_cfg["model_name"],
hidden_size=detector_cfg["hidden_size"],
).to(device)
detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device))
detector.eval()
print("Detector loaded.")
# Load and preprocess training data
ckpt_path = detector_cfg["checkpoint"]
if Path(ckpt_path).exists():
detector.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Detector loaded from {ckpt_path}")
else:
print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.")
detector.eval()
# Pre-process training data through the detector
print(f"Loading training data: {args.train_data}")
raw_samples = load_jsonl(args.train_data)
print(f"Preprocessing {len(raw_samples)} samples with detector...")
processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device)
detector_hidden = cfg["detector"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
processed = preprocess_samples_with_detector(
raw_samples,
detector,
tokenizer,
device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
# Build RL agent
detector_hidden = detector_cfg["hidden_size"]
obs_dim = get_obs_dim(detector_hidden)
print(f"Observation dimension: {obs_dim}")
# Build the RL agent
agent_cfg = cfg["agent"]
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=cfg["agent"]["state_hidden"],
dropout=cfg["agent"]["dropout"],
)
state_hidden=agent_cfg["state_hidden"],
dropout=agent_cfg["dropout"],
).to(device)
trainer = PPOTrainer(
agent=agent,
@@ -162,34 +113,38 @@ def main():
)
# Stage 1: Behavior cloning warm-up
if cfg["behavior_cloning"]["enabled"]:
print("Stage 1: Behavior cloning warm-up...")
obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device)
bc_cfg = cfg.get("behavior_cloning", {})
if bc_cfg.get("enabled", True):
print("\n=== Stage 1: Behavior Cloning Warm-up ===")
obs_tensor, action_tensor = build_bc_tensors(processed, device=device)
trainer.behavior_cloning_warmup(
obs_tensor, action_tensor,
n_epochs=cfg["behavior_cloning"]["epochs"],
lr=cfg["behavior_cloning"]["lr"],
obs_tensor,
action_tensor,
n_epochs=bc_cfg.get("epochs", 5),
lr=bc_cfg.get("lr", 1e-3),
)
# Stage 2: PPO fine-tuning
print("Stage 2: PPO fine-tuning...")
print("\n=== Stage 2: PPO Fine-tuning ===")
env_cfg = cfg.get("environment", {})
env = CompanionEnv(
samples=processed,
detector_hidden=detector_hidden,
reward_weights=cfg["reward"],
max_turns=cfg["environment"]["max_turns"],
reward_weights=cfg.get("reward"),
max_turns=env_cfg.get("max_turns", 20),
)
Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
output_cfg = cfg["output"]
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
trainer.train(
env=env,
total_timesteps=cfg["ppo"]["total_timesteps"],
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
checkpoint_dir=cfg["output"]["checkpoint_dir"],
save_interval=cfg["output"]["save_interval"],
checkpoint_dir=output_cfg["checkpoint_dir"],
save_interval=output_cfg.get("save_interval", 10_000),
)
torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt")
print("Training complete.")