""" Step 4: Train Module C — RL Intervention Policy (PPO). Two-stage training: Stage 1: Behavior cloning warm-up from a_recommend labels Stage 2: PPO fine-tuning with multi-objective reward Usage: python scripts/train_intervention.py --config configs/intervention_config.yaml \ --train-data data/processed/train.jsonl """ import argparse import yaml import torch from pathlib import Path from transformers import AutoTokenizer from src.data.dataset import load_jsonl from src.models.detector import CompanionRiskDetector from src.models.intervention_agent import InterventionAgent from src.rl.companion_env import CompanionEnv from src.rl.ppo_trainer import PPOTrainer from src.utils.preprocessing import ( preprocess_samples_with_detector, build_bc_tensors, ) from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY import wandb def get_obs_dim(detector_hidden: int) -> int: """Compute observation vector dimension.""" return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/intervention_config.yaml") parser.add_argument("--train-data", default="data/processed/train.jsonl") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if cfg["logging"]["use_wandb"]: wandb.init( project=cfg["logging"]["project"], name=cfg["logging"]["run_name"], config=cfg, ) # Load detector detector_cfg = cfg["detector"] tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"]) detector = CompanionRiskDetector( model_name=detector_cfg["model_name"], hidden_size=detector_cfg["hidden_size"], ).to(device) ckpt_path = detector_cfg["checkpoint"] if Path(ckpt_path).exists(): detector.load_state_dict(torch.load(ckpt_path, map_location=device)) print(f"Detector loaded from {ckpt_path}") else: print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.") detector.eval() # Pre-process training data through the detector print(f"Loading training data: {args.train_data}") raw_samples = load_jsonl(args.train_data) print(f"Preprocessing {len(raw_samples)} samples with detector...") processed = preprocess_samples_with_detector( raw_samples, detector, tokenizer, device=device, binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5), ) detector_hidden = detector_cfg["hidden_size"] obs_dim = get_obs_dim(detector_hidden) print(f"Observation dimension: {obs_dim}") # Build the RL agent agent_cfg = cfg["agent"] agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=agent_cfg["state_hidden"], dropout=agent_cfg["dropout"], ).to(device) trainer = PPOTrainer( agent=agent, obs_dim=obs_dim, lr=cfg["ppo"]["lr"], clip_eps=cfg["ppo"]["clip_eps"], entropy_coef=cfg["ppo"]["entropy_coef"], value_coef=cfg["ppo"]["value_coef"], max_grad_norm=cfg["ppo"]["max_grad_norm"], gamma=cfg["ppo"]["gamma"], gae_lambda=cfg["ppo"]["gae_lambda"], n_epochs=cfg["ppo"]["n_epochs"], batch_size=cfg["ppo"]["batch_size"], buffer_size=cfg["ppo"]["n_rollout_steps"], device=device, use_wandb=cfg["logging"]["use_wandb"], ) # Stage 1: Behavior cloning warm-up bc_cfg = cfg.get("behavior_cloning", {}) if bc_cfg.get("enabled", True): print("\n=== Stage 1: Behavior Cloning Warm-up ===") obs_tensor, action_tensor = build_bc_tensors(processed, device=device) trainer.behavior_cloning_warmup( obs_tensor, action_tensor, n_epochs=bc_cfg.get("epochs", 5), lr=bc_cfg.get("lr", 1e-3), ) # Stage 2: PPO fine-tuning print("\n=== Stage 2: PPO Fine-tuning ===") env_cfg = cfg.get("environment", {}) env = CompanionEnv( samples=processed, detector_hidden=detector_hidden, reward_weights=cfg.get("reward"), max_turns=env_cfg.get("max_turns", 20), ) output_cfg = cfg["output"] Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) trainer.train( env=env, total_timesteps=cfg["ppo"]["total_timesteps"], n_rollout_steps=cfg["ppo"]["n_rollout_steps"], checkpoint_dir=output_cfg["checkpoint_dir"], save_interval=output_cfg.get("save_interval", 10_000), ) print("Training complete.") if __name__ == "__main__": main()