refactor: complete full implementation replacing all placeholder/mock content

Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-09 17:50:17 +08:00
parent 7d4345c29d
commit 4a0e71fb23
13 changed files with 1829 additions and 564 deletions

View File

@@ -6,98 +6,32 @@ Two-stage training:
Stage 2: PPO fine-tuning with multi-objective reward
Usage:
python scripts/train_intervention.py --config configs/intervention_config.yaml
python scripts/train_intervention.py --config configs/intervention_config.yaml \
--train-data data/processed/train.jsonl
"""
import argparse
import yaml
import torch
import numpy as np
import wandb
from pathlib import Path
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.rl.companion_env import CompanionEnv
from src.rl.ppo_trainer import PPOTrainer
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
category_to_index,
from src.utils.preprocessing import (
preprocess_samples_with_detector,
build_bc_tensors,
)
from transformers import AutoTokenizer
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
import wandb
def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device):
"""Run detector on all samples to extract state vectors for RL env."""
from src.data.dataset import format_conversation
processed = []
detector.eval()
for sample in samples:
texts = format_conversation(
sample["persona"],
sample["history"],
sample["user_input"],
sample["ai_response"],
)
def enc(text, max_len):
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
with torch.no_grad():
preds = detector.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
)
# Build persona/history pool embeddings (reuse e_fused as approximation)
e_fused = preds["e_fused"].squeeze(0).cpu().numpy()
processed.append({
**sample,
"d_score": preds["d_score"].item(),
"l_risk": preds["l_risk"].item(),
"c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(),
"c_primary_idx": preds["c_primary"].item(),
"e_H_pool": e_fused.tolist(),
"e_P_pool": e_fused.tolist(),
"a_recommend": sample.get("a_recommend", "PASS"),
})
return processed
def build_bc_tensors(processed_samples, obs_dim, device):
"""Build observation and expert action tensors for behavior cloning."""
obs_list, action_list = [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
obs_list.append(obs)
action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
action_tensor = torch.LongTensor(action_list).to(device)
return obs_tensor, action_tensor
def get_obs_dim(detector_hidden: int) -> int:
"""Compute observation vector dimension."""
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
def main():
@@ -110,7 +44,7 @@ def main():
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Device: {device}")
if cfg["logging"]["use_wandb"]:
wandb.init(
@@ -120,29 +54,46 @@ def main():
)
# Load detector
tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"])
detector_cfg = cfg["detector"]
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
detector = CompanionRiskDetector(
model_name=cfg["detector"]["model_name"],
hidden_size=cfg["detector"]["hidden_size"],
model_name=detector_cfg["model_name"],
hidden_size=detector_cfg["hidden_size"],
).to(device)
detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device))
detector.eval()
print("Detector loaded.")
# Load and preprocess training data
ckpt_path = detector_cfg["checkpoint"]
if Path(ckpt_path).exists():
detector.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Detector loaded from {ckpt_path}")
else:
print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.")
detector.eval()
# Pre-process training data through the detector
print(f"Loading training data: {args.train_data}")
raw_samples = load_jsonl(args.train_data)
print(f"Preprocessing {len(raw_samples)} samples with detector...")
processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device)
detector_hidden = cfg["detector"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
processed = preprocess_samples_with_detector(
raw_samples,
detector,
tokenizer,
device=device,
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
)
# Build RL agent
detector_hidden = detector_cfg["hidden_size"]
obs_dim = get_obs_dim(detector_hidden)
print(f"Observation dimension: {obs_dim}")
# Build the RL agent
agent_cfg = cfg["agent"]
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=cfg["agent"]["state_hidden"],
dropout=cfg["agent"]["dropout"],
)
state_hidden=agent_cfg["state_hidden"],
dropout=agent_cfg["dropout"],
).to(device)
trainer = PPOTrainer(
agent=agent,
@@ -162,34 +113,38 @@ def main():
)
# Stage 1: Behavior cloning warm-up
if cfg["behavior_cloning"]["enabled"]:
print("Stage 1: Behavior cloning warm-up...")
obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device)
bc_cfg = cfg.get("behavior_cloning", {})
if bc_cfg.get("enabled", True):
print("\n=== Stage 1: Behavior Cloning Warm-up ===")
obs_tensor, action_tensor = build_bc_tensors(processed, device=device)
trainer.behavior_cloning_warmup(
obs_tensor, action_tensor,
n_epochs=cfg["behavior_cloning"]["epochs"],
lr=cfg["behavior_cloning"]["lr"],
obs_tensor,
action_tensor,
n_epochs=bc_cfg.get("epochs", 5),
lr=bc_cfg.get("lr", 1e-3),
)
# Stage 2: PPO fine-tuning
print("Stage 2: PPO fine-tuning...")
print("\n=== Stage 2: PPO Fine-tuning ===")
env_cfg = cfg.get("environment", {})
env = CompanionEnv(
samples=processed,
detector_hidden=detector_hidden,
reward_weights=cfg["reward"],
max_turns=cfg["environment"]["max_turns"],
reward_weights=cfg.get("reward"),
max_turns=env_cfg.get("max_turns", 20),
)
Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
output_cfg = cfg["output"]
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
trainer.train(
env=env,
total_timesteps=cfg["ppo"]["total_timesteps"],
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
checkpoint_dir=cfg["output"]["checkpoint_dir"],
save_interval=cfg["output"]["save_interval"],
checkpoint_dir=output_cfg["checkpoint_dir"],
save_interval=output_cfg.get("save_interval", 10_000),
)
torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt")
print("Training complete.")