""" Step 4: Train Module C — RL Intervention Policy (PPO). Two-stage training: Stage 1 (BC warm-up): behavior cloning on all 4 GPUs via Accelerate DDP Stage 2 (PPO fine-tuning): single-GPU (GPU-0) offline RL — inherently sequential Preprocessing (detector inference) is distributed across all 4 GPUs. Usage (4 GPUs): accelerate launch --num_processes=4 --mixed_precision=bf16 \\ scripts/train_intervention.py --config configs/intervention_config.yaml \\ --train-data data/processed/CompanionRisk-Bench/train.jsonl Usage (single GPU): accelerate launch --num_processes=1 \\ scripts/train_intervention.py --config configs/intervention_config.yaml """ import argparse import os import pickle import yaml import torch import torch.nn as nn import torch.optim as optim import numpy as np from pathlib import Path from torch.utils.data import DataLoader, TensorDataset, DistributedSampler from transformers import AutoTokenizer from accelerate import Accelerator from accelerate.utils import set_seed from src.data.dataset import load_jsonl from src.models.detector import CompanionRiskDetector from src.models.intervention_agent import InterventionAgent from src.rl.companion_env import CompanionEnv from src.rl.ppo_trainer import PPOTrainer from src.utils.preprocessing import ( preprocess_samples_with_detector, build_bc_tensors, ) from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY try: import wandb _WANDB_AVAILABLE = True except ImportError: wandb = None _WANDB_AVAILABLE = False def get_obs_dim(detector_hidden: int) -> int: return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 def distributed_preprocess( raw_samples, detector, tokenizer, accelerator, binary_threshold: float = 0.5, ): """ Distribute detector inference across all GPUs. Each process handles its shard of the dataset; results are gathered on the main process. """ n = len(raw_samples) rank = accelerator.process_index world = accelerator.num_processes # Each process takes its contiguous shard start = (n * rank) // world end = (n * (rank + 1)) // world local_samples = raw_samples[start:end] accelerator.print( f"Preprocessing: rank {rank} handles samples {start}–{end} " f"({len(local_samples)} samples)" ) local_processed = preprocess_samples_with_detector( local_samples, detector, tokenizer, device=str(accelerator.device), binary_threshold=binary_threshold, ) # Gather on main process via object lists all_shards = [None] * world torch.distributed.all_gather_object(all_shards, local_processed) if accelerator.is_main_process: processed = [] for shard in all_shards: processed.extend(shard) return processed return [] def run_bc_warmup( agent: InterventionAgent, obs_tensor: torch.Tensor, action_tensor: torch.Tensor, cfg: dict, accelerator: Accelerator, use_wandb: bool = False, ): """ Stage 1: Behavior cloning on all GPUs. Returns the updated agent weights (synced automatically via DDP). """ bc_cfg = cfg.get("behavior_cloning", {}) per_gpu_bs = bc_cfg.get("per_gpu_batch_size", 256) n_epochs = bc_cfg.get("epochs", 5) lr = bc_cfg.get("lr", 1e-3) dataset = TensorDataset(obs_tensor, action_tensor) sampler = None if accelerator.num_processes > 1: sampler = DistributedSampler( dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True, ) loader = DataLoader( dataset, batch_size=per_gpu_bs, sampler=sampler, shuffle=(sampler is None), pin_memory=True, drop_last=False, ) optimizer = optim.Adam(agent.parameters(), lr=lr) agent, optimizer, loader = accelerator.prepare(agent, optimizer, loader) losses = [] for epoch in range(n_epochs): if accelerator.num_processes > 1: loader.sampler.set_epoch(epoch) epoch_loss = 0.0 agent.train() for obs_batch, act_batch in loader: loss = accelerator.unwrap_model(agent).behavior_clone_loss( obs_batch, act_batch ) accelerator.backward(loss) optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() avg_loss = epoch_loss / max(len(loader), 1) losses.append(avg_loss) accelerator.print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}") if use_wandb and accelerator.is_main_process: accelerator.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1}) # Return the unwrapped agent (weights are consistent across all processes) return accelerator.unwrap_model(agent), losses def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/intervention_config.yaml") parser.add_argument("--train-data", default="data/processed/CompanionRisk-Bench/train.jsonl") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) set_seed(42) # ── Accelerator for BC stage ───────────────────────────────────────── bc_cfg = cfg.get("behavior_cloning", {}) use_wandb = cfg["logging"]["use_wandb"] and _WANDB_AVAILABLE accelerator = Accelerator( mixed_precision=bc_cfg.get("mixed_precision", "bf16"), gradient_accumulation_steps=1, log_with="wandb" if use_wandb else None, ) accelerator.print( f"Running on {accelerator.num_processes} GPU(s), " f"mixed_precision={accelerator.mixed_precision}" ) if use_wandb: accelerator.init_trackers( project_name=cfg["logging"]["project"], config=cfg, init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}}, ) # ── Load detector (shared weights, each process loads its own copy) ── detector_cfg = cfg["detector"] tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"]) detector = CompanionRiskDetector( model_name=detector_cfg["model_name"], hidden_size=detector_cfg["hidden_size"], ).to(accelerator.device) ckpt_path = detector_cfg["checkpoint"] if Path(ckpt_path).exists(): detector.load_state_dict( torch.load(ckpt_path, map_location=accelerator.device) ) accelerator.print(f"Detector loaded from {ckpt_path}") else: accelerator.print(f"[WARN] No detector checkpoint at {ckpt_path}. Using random weights.") detector.eval() # ── Distributed preprocessing (with disk cache) ────────────────────── accelerator.print(f"Loading: {args.train_data}") raw_samples = load_jsonl(args.train_data) accelerator.print(f"Preprocessing {len(raw_samples)} samples across {accelerator.num_processes} GPU(s)...") binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5) # Cache key based on train data path and sample count cache_path = Path("/tmp/companionguard_processed_cache.pkl") cache_meta_path = Path("/tmp/companionguard_processed_cache.meta") cache_key = f"{args.train_data}:{len(raw_samples)}" if cache_path.exists() and cache_meta_path.exists(): cached_key = cache_meta_path.read_text().strip() if cached_key == cache_key: accelerator.print(f"[CACHE HIT] Loading preprocessed data from {cache_path}") with open(cache_path, "rb") as f: processed = pickle.load(f) accelerator.print(f"Loaded {len(processed)} cached samples.") else: accelerator.print("[CACHE MISS] Cache key mismatch, re-preprocessing...") cache_path.unlink(missing_ok=True) cache_meta_path.unlink(missing_ok=True) processed = None else: processed = None if processed is None: if accelerator.num_processes > 1: # Use distributed preprocessing processed = distributed_preprocess( raw_samples, detector, tokenizer, accelerator, binary_threshold ) else: processed = preprocess_samples_with_detector( raw_samples, detector, tokenizer, device=str(accelerator.device), binary_threshold=binary_threshold, ) # Save cache if accelerator.is_main_process and processed: accelerator.print(f"Saving preprocessed cache to {cache_path} ...") with open(cache_path, "wb") as f: pickle.dump(processed, f) cache_meta_path.write_text(cache_key) accelerator.print("Cache saved.") detector_hidden = detector_cfg["hidden_size"] obs_dim = get_obs_dim(detector_hidden) accelerator.print(f"Observation dim: {obs_dim}") # ── Stage 1: Behavior Cloning (all GPUs) ──────────────────────────── if bc_cfg.get("enabled", True): accelerator.print("\n=== Stage 1: Behavior Cloning Warm-up (all GPUs) ===") # Build BC tensors on main process, broadcast to others if accelerator.is_main_process: obs_tensor, action_tensor = build_bc_tensors(processed, device="cpu") else: obs_tensor = torch.zeros(1, obs_dim) action_tensor = torch.zeros(1, dtype=torch.long) if accelerator.num_processes > 1: # Broadcast tensor sizes from rank 0 size_tensor = torch.tensor([obs_tensor.shape[0]], dtype=torch.long) torch.distributed.broadcast(size_tensor, src=0) n_samples = size_tensor.item() if not accelerator.is_main_process: obs_tensor = torch.zeros(n_samples, obs_dim) action_tensor = torch.zeros(n_samples, dtype=torch.long) # Broadcast data from rank 0 to all processes torch.distributed.broadcast(obs_tensor, src=0) torch.distributed.broadcast(action_tensor, src=0) # Keep tensors on CPU: DataLoader(pin_memory=True) requires CPU tensors. # accelerator.prepare() moves batches to the correct device during training. obs_tensor = obs_tensor.cpu() action_tensor = action_tensor.cpu() agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=cfg["agent"]["state_hidden"], dropout=cfg["agent"]["dropout"], ) agent, _ = run_bc_warmup(agent, obs_tensor, action_tensor, cfg, accelerator, use_wandb=use_wandb) # Save BC-only checkpoint for ablation comparison if accelerator.is_main_process: output_cfg = cfg["output"] Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) bc_ckpt_path = str(Path(output_cfg["checkpoint_dir"]) / "bc_only_v5.pt") torch.save(agent.state_dict(), bc_ckpt_path) accelerator.print(f"BC-only checkpoint saved: {bc_ckpt_path}") else: agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=cfg["agent"]["state_hidden"], dropout=cfg["agent"]["dropout"], ) # ── Stage 2: PPO (main process only — inherently sequential) ───────── accelerator.wait_for_everyone() if accelerator.is_main_process: accelerator.print("\n=== Stage 2: PPO Fine-tuning (GPU-0 only) ===") # Move agent to GPU-0 device = accelerator.device agent = agent.to(device) ppo_cfg = cfg["ppo"] trainer = PPOTrainer( agent=agent, obs_dim=obs_dim, lr=ppo_cfg["lr"], clip_eps=ppo_cfg["clip_eps"], entropy_coef=ppo_cfg["entropy_coef"], value_coef=ppo_cfg["value_coef"], max_grad_norm=ppo_cfg["max_grad_norm"], gamma=ppo_cfg["gamma"], gae_lambda=ppo_cfg["gae_lambda"], n_epochs=ppo_cfg["n_epochs"], batch_size=ppo_cfg["batch_size"], buffer_size=ppo_cfg["n_rollout_steps"], device=str(device), use_wandb=use_wandb, ) env_cfg = cfg.get("environment", {}) env = CompanionEnv( samples=processed, detector_hidden=detector_hidden, reward_weights=cfg.get("reward"), max_turns=env_cfg.get("max_turns", 20), enable_category_reward=cfg.get("reward", {}).get("enable_category_reward", True), ) output_cfg = cfg["output"] Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) trainer.train( env=env, total_timesteps=ppo_cfg["total_timesteps"], n_rollout_steps=ppo_cfg["n_rollout_steps"], checkpoint_dir=output_cfg["checkpoint_dir"], save_interval=output_cfg.get("save_interval", 10_000), ) accelerator.print("Training complete.") if use_wandb: accelerator.end_training() if __name__ == "__main__": main()