Detection module (Module B): - detector.py: expose separate e_P_pool and e_H_pool for RL state; fix compute_loss to skip primary head when c_primary="None" - dataset.py: handle c_primary="None" safely; add validate_and_normalize Data pipeline: - data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe); systematic category→fine-label mapping; safe sample generation (25%); per-category risk level distribution; max_retries logic - llm_judge.py: incremental file writing; rate limiting; retry logic; annotate_from_file convenience method; consistency validation - annotate_data.py: stratified split by y_risk; dataset statistics report RL module (Module C): - ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple); fix action type passed to env.step; proper buffer reset and size tracking - companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with auto-reset; correct Gymnasium interface Shared utilities (new files): - src/utils/preprocessing.py: preprocess_samples_with_detector using separate e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up - src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b), CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention, LLMJudgePolicy for full baseline comparison Scripts: - train_intervention.py: use preprocessing module; separate e_H/e_P pools - evaluate.py: proper module imports (no circular scripts import); full multi-baseline comparison; save all results to JSON - generate_data.py: API key check; safe_ratio + max_retries CLI args Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
153 lines
4.7 KiB
Python
153 lines
4.7 KiB
Python
"""
|
|
Step 4: Train Module C — RL Intervention Policy (PPO).
|
|
|
|
Two-stage training:
|
|
Stage 1: Behavior cloning warm-up from a_recommend labels
|
|
Stage 2: PPO fine-tuning with multi-objective reward
|
|
|
|
Usage:
|
|
python scripts/train_intervention.py --config configs/intervention_config.yaml \
|
|
--train-data data/processed/train.jsonl
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import torch
|
|
from pathlib import Path
|
|
from transformers import AutoTokenizer
|
|
|
|
from src.data.dataset import load_jsonl
|
|
from src.models.detector import CompanionRiskDetector
|
|
from src.models.intervention_agent import InterventionAgent
|
|
from src.rl.companion_env import CompanionEnv
|
|
from src.rl.ppo_trainer import PPOTrainer
|
|
from src.utils.preprocessing import (
|
|
preprocess_samples_with_detector,
|
|
build_bc_tensors,
|
|
)
|
|
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
|
|
import wandb
|
|
|
|
|
|
def get_obs_dim(detector_hidden: int) -> int:
|
|
"""Compute observation vector dimension."""
|
|
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="configs/intervention_config.yaml")
|
|
parser.add_argument("--train-data", default="data/processed/train.jsonl")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Device: {device}")
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
|
wandb.init(
|
|
project=cfg["logging"]["project"],
|
|
name=cfg["logging"]["run_name"],
|
|
config=cfg,
|
|
)
|
|
|
|
# Load detector
|
|
detector_cfg = cfg["detector"]
|
|
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
|
|
detector = CompanionRiskDetector(
|
|
model_name=detector_cfg["model_name"],
|
|
hidden_size=detector_cfg["hidden_size"],
|
|
).to(device)
|
|
|
|
ckpt_path = detector_cfg["checkpoint"]
|
|
if Path(ckpt_path).exists():
|
|
detector.load_state_dict(torch.load(ckpt_path, map_location=device))
|
|
print(f"Detector loaded from {ckpt_path}")
|
|
else:
|
|
print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.")
|
|
|
|
detector.eval()
|
|
|
|
# Pre-process training data through the detector
|
|
print(f"Loading training data: {args.train_data}")
|
|
raw_samples = load_jsonl(args.train_data)
|
|
print(f"Preprocessing {len(raw_samples)} samples with detector...")
|
|
|
|
processed = preprocess_samples_with_detector(
|
|
raw_samples,
|
|
detector,
|
|
tokenizer,
|
|
device=device,
|
|
binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5),
|
|
)
|
|
|
|
detector_hidden = detector_cfg["hidden_size"]
|
|
obs_dim = get_obs_dim(detector_hidden)
|
|
print(f"Observation dimension: {obs_dim}")
|
|
|
|
# Build the RL agent
|
|
agent_cfg = cfg["agent"]
|
|
agent = InterventionAgent(
|
|
detector_hidden=detector_hidden,
|
|
state_hidden=agent_cfg["state_hidden"],
|
|
dropout=agent_cfg["dropout"],
|
|
).to(device)
|
|
|
|
trainer = PPOTrainer(
|
|
agent=agent,
|
|
obs_dim=obs_dim,
|
|
lr=cfg["ppo"]["lr"],
|
|
clip_eps=cfg["ppo"]["clip_eps"],
|
|
entropy_coef=cfg["ppo"]["entropy_coef"],
|
|
value_coef=cfg["ppo"]["value_coef"],
|
|
max_grad_norm=cfg["ppo"]["max_grad_norm"],
|
|
gamma=cfg["ppo"]["gamma"],
|
|
gae_lambda=cfg["ppo"]["gae_lambda"],
|
|
n_epochs=cfg["ppo"]["n_epochs"],
|
|
batch_size=cfg["ppo"]["batch_size"],
|
|
buffer_size=cfg["ppo"]["n_rollout_steps"],
|
|
device=device,
|
|
use_wandb=cfg["logging"]["use_wandb"],
|
|
)
|
|
|
|
# Stage 1: Behavior cloning warm-up
|
|
bc_cfg = cfg.get("behavior_cloning", {})
|
|
if bc_cfg.get("enabled", True):
|
|
print("\n=== Stage 1: Behavior Cloning Warm-up ===")
|
|
obs_tensor, action_tensor = build_bc_tensors(processed, device=device)
|
|
trainer.behavior_cloning_warmup(
|
|
obs_tensor,
|
|
action_tensor,
|
|
n_epochs=bc_cfg.get("epochs", 5),
|
|
lr=bc_cfg.get("lr", 1e-3),
|
|
)
|
|
|
|
# Stage 2: PPO fine-tuning
|
|
print("\n=== Stage 2: PPO Fine-tuning ===")
|
|
env_cfg = cfg.get("environment", {})
|
|
env = CompanionEnv(
|
|
samples=processed,
|
|
detector_hidden=detector_hidden,
|
|
reward_weights=cfg.get("reward"),
|
|
max_turns=env_cfg.get("max_turns", 20),
|
|
)
|
|
|
|
output_cfg = cfg["output"]
|
|
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
|
|
|
|
trainer.train(
|
|
env=env,
|
|
total_timesteps=cfg["ppo"]["total_timesteps"],
|
|
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
|
|
checkpoint_dir=output_cfg["checkpoint_dir"],
|
|
save_interval=output_cfg.get("save_interval", 10_000),
|
|
)
|
|
|
|
print("Training complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|