""" Step 4: Train Module C — RL Intervention Policy (PPO). Two-stage training: Stage 1 (BC warm-up): behavior cloning on all 4 GPUs via Accelerate DDP Stage 2 (PPO fine-tuning): single-GPU (GPU-0) offline RL — inherently sequential Preprocessing (detector inference) is distributed across all 4 GPUs. Usage (4 GPUs): accelerate launch --num_processes=4 --mixed_precision=bf16 \\ scripts/train_intervention.py --config configs/intervention_config.yaml \\ --train-data data/processed/train.jsonl Usage (single GPU): accelerate launch --num_processes=1 \\ scripts/train_intervention.py --config configs/intervention_config.yaml """ import argparse import os import yaml import torch import torch.nn as nn import torch.optim as optim import numpy as np from pathlib import Path from torch.utils.data import DataLoader, TensorDataset, DistributedSampler from transformers import AutoTokenizer from accelerate import Accelerator from accelerate.utils import set_seed from src.data.dataset import load_jsonl from src.models.detector import CompanionRiskDetector from src.models.intervention_agent import InterventionAgent from src.rl.companion_env import CompanionEnv from src.rl.ppo_trainer import PPOTrainer from src.utils.preprocessing import ( preprocess_samples_with_detector, build_bc_tensors, ) from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY import wandb def get_obs_dim(detector_hidden: int) -> int: return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 def distributed_preprocess( raw_samples, detector, tokenizer, accelerator, binary_threshold: float = 0.5, ): """ Distribute detector inference across all GPUs. Each process handles its shard of the dataset; results are gathered on the main process. """ n = len(raw_samples) rank = accelerator.process_index world = accelerator.num_processes # Each process takes its contiguous shard start = (n * rank) // world end = (n * (rank + 1)) // world local_samples = raw_samples[start:end] accelerator.print( f"Preprocessing: rank {rank} handles samples {start}–{end} " f"({len(local_samples)} samples)" ) local_processed = preprocess_samples_with_detector( local_samples, detector, tokenizer, device=str(accelerator.device), binary_threshold=binary_threshold, ) # Gather on main process via object lists all_shards = [None] * world torch.distributed.all_gather_object(all_shards, local_processed) if accelerator.is_main_process: processed = [] for shard in all_shards: processed.extend(shard) return processed return [] def run_bc_warmup( agent: InterventionAgent, obs_tensor: torch.Tensor, action_tensor: torch.Tensor, cfg: dict, accelerator: Accelerator, ): """ Stage 1: Behavior cloning on all GPUs. Returns the updated agent weights (synced automatically via DDP). """ bc_cfg = cfg.get("behavior_cloning", {}) per_gpu_bs = bc_cfg.get("per_gpu_batch_size", 256) n_epochs = bc_cfg.get("epochs", 5) lr = bc_cfg.get("lr", 1e-3) dataset = TensorDataset(obs_tensor, action_tensor) sampler = None if accelerator.num_processes > 1: sampler = DistributedSampler( dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True, ) loader = DataLoader( dataset, batch_size=per_gpu_bs, sampler=sampler, shuffle=(sampler is None), pin_memory=True, drop_last=False, ) optimizer = optim.Adam(agent.parameters(), lr=lr) agent, optimizer, loader = accelerator.prepare(agent, optimizer, loader) losses = [] for epoch in range(n_epochs): if accelerator.num_processes > 1: loader.sampler.set_epoch(epoch) epoch_loss = 0.0 agent.train() for obs_batch, act_batch in loader: loss = accelerator.unwrap_model(agent).behavior_clone_loss( obs_batch, act_batch ) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() avg_loss = epoch_loss / max(len(loader), 1) losses.append(avg_loss) accelerator.print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}") if cfg["logging"]["use_wandb"] and accelerator.is_main_process: accelerator.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1}) # Return the unwrapped agent (weights are consistent across all processes) return accelerator.unwrap_model(agent), losses def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/intervention_config.yaml") parser.add_argument("--train-data", default="data/processed/train.jsonl") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) set_seed(42) # ── Accelerator for BC stage ───────────────────────────────────────── bc_cfg = cfg.get("behavior_cloning", {}) accelerator = Accelerator( mixed_precision=bc_cfg.get("mixed_precision", "bf16"), gradient_accumulation_steps=1, log_with="wandb" if cfg["logging"]["use_wandb"] else None, ) accelerator.print( f"Running on {accelerator.num_processes} GPU(s), " f"mixed_precision={accelerator.mixed_precision}" ) if cfg["logging"]["use_wandb"]: accelerator.init_trackers( project_name=cfg["logging"]["project"], config=cfg, init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}}, ) # ── Load detector (shared weights, each process loads its own copy) ── detector_cfg = cfg["detector"] tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"]) detector = CompanionRiskDetector( model_name=detector_cfg["model_name"], hidden_size=detector_cfg["hidden_size"], ).to(accelerator.device) ckpt_path = detector_cfg["checkpoint"] if Path(ckpt_path).exists(): detector.load_state_dict( torch.load(ckpt_path, map_location=accelerator.device) ) accelerator.print(f"Detector loaded from {ckpt_path}") else: accelerator.print(f"[WARN] No detector checkpoint at {ckpt_path}. Using random weights.") detector.eval() # ── Distributed preprocessing ──────────────────────────────────────── accelerator.print(f"Loading: {args.train_data}") raw_samples = load_jsonl(args.train_data) accelerator.print(f"Preprocessing {len(raw_samples)} samples across {accelerator.num_processes} GPU(s)...") binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5) if accelerator.num_processes > 1: # Use distributed preprocessing processed = distributed_preprocess( raw_samples, detector, tokenizer, accelerator, binary_threshold ) else: processed = preprocess_samples_with_detector( raw_samples, detector, tokenizer, device=str(accelerator.device), binary_threshold=binary_threshold, ) detector_hidden = detector_cfg["hidden_size"] obs_dim = get_obs_dim(detector_hidden) accelerator.print(f"Observation dim: {obs_dim}") # ── Stage 1: Behavior Cloning (all GPUs) ──────────────────────────── if bc_cfg.get("enabled", True): accelerator.print("\n=== Stage 1: Behavior Cloning Warm-up (all GPUs) ===") # Build BC tensors on main process, broadcast to others if accelerator.is_main_process: obs_tensor, action_tensor = build_bc_tensors(processed, device="cpu") else: obs_tensor = torch.zeros(1, obs_dim) action_tensor = torch.zeros(1, dtype=torch.long) if accelerator.num_processes > 1: # Broadcast tensor sizes from rank 0 size_tensor = torch.tensor([obs_tensor.shape[0]], dtype=torch.long) torch.distributed.broadcast(size_tensor, src=0) n_samples = size_tensor.item() if not accelerator.is_main_process: obs_tensor = torch.zeros(n_samples, obs_dim) action_tensor = torch.zeros(n_samples, dtype=torch.long) # Broadcast data from rank 0 to all processes torch.distributed.broadcast(obs_tensor, src=0) torch.distributed.broadcast(action_tensor, src=0) obs_tensor = obs_tensor.to(accelerator.device) action_tensor = action_tensor.to(accelerator.device) agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=cfg["agent"]["state_hidden"], dropout=cfg["agent"]["dropout"], ) agent, _ = run_bc_warmup(agent, obs_tensor, action_tensor, cfg, accelerator) else: agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=cfg["agent"]["state_hidden"], dropout=cfg["agent"]["dropout"], ) # ── Stage 2: PPO (main process only — inherently sequential) ───────── accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.print("\n=== Stage 2: PPO Fine-tuning (GPU-0 only) ===") # Move agent to GPU-0 device = accelerator.device agent = agent.to(device) ppo_cfg = cfg["ppo"] trainer = PPOTrainer( agent=agent, obs_dim=obs_dim, lr=ppo_cfg["lr"], clip_eps=ppo_cfg["clip_eps"], entropy_coef=ppo_cfg["entropy_coef"], value_coef=ppo_cfg["value_coef"], max_grad_norm=ppo_cfg["max_grad_norm"], gamma=ppo_cfg["gamma"], gae_lambda=ppo_cfg["gae_lambda"], n_epochs=ppo_cfg["n_epochs"], batch_size=ppo_cfg["batch_size"], buffer_size=ppo_cfg["n_rollout_steps"], device=str(device), use_wandb=cfg["logging"]["use_wandb"], ) env_cfg = cfg.get("environment", {}) env = CompanionEnv( samples=processed, detector_hidden=detector_hidden, reward_weights=cfg.get("reward"), max_turns=env_cfg.get("max_turns", 20), ) output_cfg = cfg["output"] Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) trainer.train( env=env, total_timesteps=ppo_cfg["total_timesteps"], n_rollout_steps=ppo_cfg["n_rollout_steps"], checkpoint_dir=output_cfg["checkpoint_dir"], save_interval=output_cfg.get("save_interval", 10_000), ) accelerator.print("Training complete.") if cfg["logging"]["use_wandb"]: accelerator.end_training() if __name__ == "__main__": main()