diff --git a/configs/detector_config.yaml b/configs/detector_config.yaml index a06f5c2..87b4f21 100644 --- a/configs/detector_config.yaml +++ b/configs/detector_config.yaml @@ -7,36 +7,39 @@ model: data: train_path: "data/processed/train.jsonl" - val_path: "data/processed/val.jsonl" - test_path: "data/processed/test.jsonl" - max_persona_len: 128 - max_context_len: 512 - max_response_len: 256 + val_path: "data/processed/val.jsonl" + test_path: "data/processed/test.jsonl" + max_persona_len: 128 + max_context_len: 512 + max_response_len: 256 max_history_turns: 5 + num_workers: 4 # DataLoader worker processes per GPU training: epochs: 10 - batch_size: 16 + per_gpu_batch_size: 32 # 4 GPUs × 32 = 128 effective batch per step + gradient_accumulation_steps: 1 # effective_batch = per_gpu × n_gpu × accum lr: 2e-5 warmup_steps: 200 weight_decay: 0.01 gradient_clip: 1.0 - eval_steps: 200 - save_steps: 500 + eval_steps: 200 # global steps between validation runs + mixed_precision: "bf16" # RTX 5090 has native BF16; use "fp16" for older GPUs + seed: 42 loss_weights: - binary: 1.0 - level: 1.0 + binary: 1.0 + level: 1.0 primary: 1.0 - fine: 1.0 + fine: 1.0 evaluation: binary_threshold: 0.5 - fine_threshold: 0.4 + fine_threshold: 0.4 logging: - project: "CompanionGuard-RL" - run_name: "detector-macbert" + project: "CompanionGuard-RL" + run_name: "detector-macbert-4gpu" use_wandb: true output: diff --git a/configs/intervention_config.yaml b/configs/intervention_config.yaml index 1a2fb6f..adb14d7 100644 --- a/configs/intervention_config.yaml +++ b/configs/intervention_config.yaml @@ -8,22 +8,26 @@ agent: dropout: 0.1 reward: - w1: 2.0 # safety gain for correct intervention - w2: 3.0 # false negative penalty - w3: 4.0 # crisis bonus for R1 - w4: 1.5 # over-refusal penalty - w5: 0.5 # UX cost + w1: 2.0 # safety gain for correct intervention (REWRITE/REJECT/CRISIS on risky) + w2: 3.0 # false negative penalty (PASS on high-risk) + w3: 4.0 # crisis bonus for R1 (self-harm/suicide) + w4: 1.5 # over-refusal penalty (intervention on safe content) + w5: 0.5 # UX cost per REJECT/CRISIS action +# Stage 1: Behavior cloning warm-up runs on all 4 GPUs behavior_cloning: enabled: true epochs: 5 + per_gpu_batch_size: 256 # BC is lightweight MLP training; large batch is fine lr: 1e-3 + mixed_precision: "bf16" +# Stage 2: PPO runs on GPU-0 only (inherently sequential env-agent loop) ppo: total_timesteps: 200000 n_rollout_steps: 2048 n_epochs: 4 - batch_size: 64 + batch_size: 256 # PPO mini-batch; large since obs vectors are small lr: 3e-4 clip_eps: 0.2 entropy_coef: 0.01 @@ -33,14 +37,17 @@ ppo: gae_lambda: 0.95 environment: - n_envs: 1 max_turns: 20 +# Preprocessing: detector inference distributed across 4 GPUs +preprocessing: + per_gpu_batch_size: 64 # inference batch for converting dataset → RL states + logging: - project: "CompanionGuard-RL" - run_name: "intervention-ppo" + project: "CompanionGuard-RL" + run_name: "intervention-ppo-4gpu" use_wandb: true output: checkpoint_dir: "checkpoints/intervention" - save_interval: 10000 + save_interval: 10000 diff --git a/scripts/run_detector.sh b/scripts/run_detector.sh new file mode 100755 index 0000000..17d1fca --- /dev/null +++ b/scripts/run_detector.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Train Module B (Risk Detector) on 4x RTX 5090. +# +# Usage: +# bash scripts/run_detector.sh +# bash scripts/run_detector.sh --config configs/detector_config.yaml +# +# NVLink not required: DDP communicates via PCIe (sufficient for MacBERT-large). +# Mixed precision: BF16 (native on RTX 5090, ~2x throughput vs FP32). + +set -e + +CONFIG="${1:---config configs/detector_config.yaml}" +NUM_GPUS=4 + +echo "==============================================" +echo " CompanionGuard-RL — Module B: Detector" +echo " GPUs : ${NUM_GPUS}x RTX 5090 (PCIe DDP)" +echo " Precision : BF16" +echo " Config : ${CONFIG}" +echo "==============================================" + +# Verify GPU count +ACTUAL_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +if [ "$ACTUAL_GPUS" -lt "$NUM_GPUS" ]; then + echo "[WARN] Expected ${NUM_GPUS} GPUs, found ${ACTUAL_GPUS}. Adjusting." + NUM_GPUS=$ACTUAL_GPUS +fi + +accelerate launch \ + --num_processes=${NUM_GPUS} \ + --mixed_precision=bf16 \ + --multi_gpu \ + scripts/train_detector.py ${CONFIG} diff --git a/scripts/run_full_pipeline.sh b/scripts/run_full_pipeline.sh new file mode 100755 index 0000000..4a27846 --- /dev/null +++ b/scripts/run_full_pipeline.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Full CompanionGuard-RL pipeline on 4x RTX 5090. +# +# Step 1: Generate data (calls LLM API, single process) +# Step 2: Annotate + split (calls LLM API, single process) +# Step 3: Train detector (4 GPU DDP, BF16) +# Step 4: Train intervention (4 GPU BC + 1 GPU PPO) +# Step 5: Evaluate (single GPU) +# +# Usage: +# export DASHSCOPE_API_KEY=your_key # for Qwen +# bash scripts/run_full_pipeline.sh + +set -e + +NUM_GPUS=4 +echo "======================================================" +echo " CompanionGuard-RL Full Pipeline — 4x RTX 5090" +echo "======================================================" + +# ── Step 1: Data generation ──────────────────────────────────────────── +echo "" +echo "[1/5] Generating dataset..." +python scripts/generate_data.py --config configs/data_generation.yaml + +# ── Step 2: LLM annotation + split ───────────────────────────────────── +echo "" +echo "[2/5] Annotating and splitting dataset..." +python scripts/annotate_data.py \ + --input data/raw/generated.jsonl \ + --output data/processed/annotated.jsonl \ + --config configs/data_generation.yaml + +# ── Step 3: Train detector ────────────────────────────────────────────── +echo "" +echo "[3/5] Training risk detector (4 GPU DDP, BF16)..." +accelerate launch \ + --num_processes=${NUM_GPUS} \ + --mixed_precision=bf16 \ + --multi_gpu \ + scripts/train_detector.py \ + --config configs/detector_config.yaml + +# ── Step 4: Train intervention policy ────────────────────────────────── +echo "" +echo "[4/5] Training intervention policy (BC: 4 GPU, PPO: 1 GPU)..." +accelerate launch \ + --num_processes=${NUM_GPUS} \ + --mixed_precision=bf16 \ + --multi_gpu \ + scripts/train_intervention.py \ + --config configs/intervention_config.yaml \ + --train-data data/processed/train.jsonl + +# ── Step 5: Evaluate ──────────────────────────────────────────────────── +echo "" +echo "[5/5] Evaluating..." +python scripts/evaluate.py \ + --detector-ckpt checkpoints/detector/best.pt \ + --agent-ckpt checkpoints/intervention/final.pt \ + --test-data data/processed/test.jsonl \ + --config configs/detector_config.yaml \ + --intervention-config configs/intervention_config.yaml \ + --output experiments/eval_results.json + +echo "" +echo "======================================================" +echo " Pipeline complete. Results: experiments/eval_results.json" +echo "======================================================" diff --git a/scripts/run_intervention.sh b/scripts/run_intervention.sh new file mode 100755 index 0000000..c0db6a1 --- /dev/null +++ b/scripts/run_intervention.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Train Module C (Intervention Policy) on 4x RTX 5090. +# +# Stage 1 — Behavior Cloning: all 4 GPUs (DDP, BF16) +# Stage 2 — PPO fine-tuning: GPU-0 only (inherently sequential) +# +# Usage: +# bash scripts/run_intervention.sh +# bash scripts/run_intervention.sh data/processed/train.jsonl + +set -e + +TRAIN_DATA="${1:-data/processed/train.jsonl}" +CONFIG="configs/intervention_config.yaml" +NUM_GPUS=4 + +echo "==============================================" +echo " CompanionGuard-RL — Module C: Intervention" +echo " Stage 1 (BC) : ${NUM_GPUS}x GPU (DDP, BF16)" +echo " Stage 2 (PPO) : GPU-0 only" +echo " Config : ${CONFIG}" +echo " Train data : ${TRAIN_DATA}" +echo "==============================================" + +ACTUAL_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +if [ "$ACTUAL_GPUS" -lt "$NUM_GPUS" ]; then + echo "[WARN] Expected ${NUM_GPUS} GPUs, found ${ACTUAL_GPUS}. Adjusting." + NUM_GPUS=$ACTUAL_GPUS +fi + +accelerate launch \ + --num_processes=${NUM_GPUS} \ + --mixed_precision=bf16 \ + --multi_gpu \ + scripts/train_intervention.py \ + --config ${CONFIG} \ + --train-data ${TRAIN_DATA} diff --git a/scripts/train_detector.py b/scripts/train_detector.py index 59bba0e..7ed0f98 100644 --- a/scripts/train_detector.py +++ b/scripts/train_detector.py @@ -1,22 +1,83 @@ """ Step 3: Train Module B — Context-aware Risk Detector. -Usage: - python scripts/train_detector.py --config configs/detector_config.yaml +Multi-GPU training via HuggingFace Accelerate (DDP, no NVLink required). +Mixed precision: BF16 (native on RTX 5090). + +Usage (4 GPUs): + accelerate launch --num_processes=4 --mixed_precision=bf16 \\ + scripts/train_detector.py --config configs/detector_config.yaml + +Usage (single GPU for debugging): + accelerate launch --num_processes=1 \\ + scripts/train_detector.py --config configs/detector_config.yaml + +Or with torchrun: + torchrun --nproc_per_node=4 scripts/train_detector.py \\ + --config configs/detector_config.yaml """ import argparse +import os import yaml import torch -import wandb -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, get_linear_schedule_with_warmup +from accelerate import Accelerator +from accelerate.utils import set_seed from src.data.dataset import CompanionGuardDataset from src.models.detector import CompanionRiskDetector from src.utils.metrics import detection_metrics +def make_loader(dataset, batch_size, accelerator, shuffle=True, num_workers=4): + """Create a DataLoader with DistributedSampler when running multi-GPU.""" + sampler = None + if accelerator.num_processes > 1: + sampler = DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=shuffle, + ) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + shuffle=(shuffle and sampler is None), + num_workers=num_workers, + pin_memory=True, + drop_last=shuffle, + ) + + +@torch.no_grad() +def evaluate(model, loader, accelerator, binary_threshold=0.5): + """Evaluate on validation set across all processes, aggregate on main.""" + model.eval() + all_y_true, all_y_pred = [], [] + + for batch in loader: + preds = accelerator.unwrap_model(model).predict( + batch["persona_input_ids"], batch["persona_attention_mask"], + batch["context_input_ids"], batch["context_attention_mask"], + batch["response_input_ids"], batch["response_attention_mask"], + binary_threshold=binary_threshold, + ) + # Gather predictions from all processes + y_true_batch = accelerator.gather_for_metrics(batch["y_risk"].int()) + y_pred_batch = accelerator.gather_for_metrics(preds["y_risk"]) + + all_y_true.extend(y_true_batch.cpu().tolist()) + all_y_pred.extend(y_pred_batch.cpu().tolist()) + + if accelerator.is_main_process: + from sklearn.metrics import f1_score + return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0) + return 0.0 + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/detector_config.yaml") @@ -25,125 +86,187 @@ def main(): with open(args.config) as f: cfg = yaml.safe_load(f) - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") + train_cfg = cfg["training"] + set_seed(train_cfg.get("seed", 42)) + # ── Accelerator setup ──────────────────────────────────────────────── + accelerator = Accelerator( + mixed_precision=train_cfg.get("mixed_precision", "bf16"), + gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 1), + log_with="wandb" if cfg["logging"]["use_wandb"] else None, + ) + + accelerator.print( + f"Running on {accelerator.num_processes} GPU(s), " + f"mixed_precision={accelerator.mixed_precision}, " + f"grad_accum={accelerator.gradient_accumulation_steps}" + ) + + # Init wandb only on main process if cfg["logging"]["use_wandb"]: - wandb.init( - project=cfg["logging"]["project"], - name=cfg["logging"]["run_name"], + accelerator.init_trackers( + project_name=cfg["logging"]["project"], config=cfg, + init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}}, ) + # ── Data ───────────────────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"]) + data_cfg = cfg["data"] + per_gpu_bs = train_cfg["per_gpu_batch_size"] + num_workers = data_cfg.get("num_workers", 4) train_ds = CompanionGuardDataset( - cfg["data"]["train_path"], tokenizer, - max_persona_len=cfg["data"]["max_persona_len"], - max_context_len=cfg["data"]["max_context_len"], - max_response_len=cfg["data"]["max_response_len"], - max_history_turns=cfg["data"]["max_history_turns"], + data_cfg["train_path"], tokenizer, + max_persona_len=data_cfg["max_persona_len"], + max_context_len=data_cfg["max_context_len"], + max_response_len=data_cfg["max_response_len"], + max_history_turns=data_cfg["max_history_turns"], ) val_ds = CompanionGuardDataset( - cfg["data"]["val_path"], tokenizer, - max_persona_len=cfg["data"]["max_persona_len"], - max_context_len=cfg["data"]["max_context_len"], - max_response_len=cfg["data"]["max_response_len"], - max_history_turns=cfg["data"]["max_history_turns"], + data_cfg["val_path"], tokenizer, + max_persona_len=data_cfg["max_persona_len"], + max_context_len=data_cfg["max_context_len"], + max_response_len=data_cfg["max_response_len"], + max_history_turns=data_cfg["max_history_turns"], ) - train_loader = DataLoader(train_ds, batch_size=cfg["training"]["batch_size"], shuffle=True) - val_loader = DataLoader(val_ds, batch_size=cfg["training"]["batch_size"]) + train_loader = make_loader(train_ds, per_gpu_bs, accelerator, shuffle=True, num_workers=num_workers) + val_loader = make_loader(val_ds, per_gpu_bs, accelerator, shuffle=False, num_workers=num_workers) + effective_batch = ( + per_gpu_bs + * accelerator.num_processes + * accelerator.gradient_accumulation_steps + ) + accelerator.print( + f"Dataset: {len(train_ds)} train / {len(val_ds)} val | " + f"Effective batch size: {effective_batch}" + ) + + # ── Model ──────────────────────────────────────────────────────────── model = CompanionRiskDetector( model_name=cfg["model"]["name"], hidden_size=cfg["model"]["hidden_size"], num_heads=cfg["model"]["num_heads"], dropout=cfg["model"]["dropout"], use_lora=cfg["model"]["use_lora"], - ).to(device) + ) optimizer = torch.optim.AdamW( model.parameters(), - lr=cfg["training"]["lr"], - weight_decay=cfg["training"]["weight_decay"], + lr=train_cfg["lr"], + weight_decay=train_cfg["weight_decay"], ) - total_steps = len(train_loader) * cfg["training"]["epochs"] + + # Steps per epoch after accounting for gradient accumulation + steps_per_epoch = len(train_loader) // accelerator.gradient_accumulation_steps + total_steps = steps_per_epoch * train_cfg["epochs"] + scheduler = get_linear_schedule_with_warmup( optimizer, - num_warmup_steps=cfg["training"]["warmup_steps"], + num_warmup_steps=train_cfg["warmup_steps"], num_training_steps=total_steps, ) + # Prepare: wraps model with DDP, DataLoaders with DistributedSampler + model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, val_loader, scheduler + ) + + # ── Training loop ──────────────────────────────────────────────────── best_val_f1 = 0.0 global_step = 0 + eval_steps = train_cfg["eval_steps"] + binary_threshold = cfg["evaluation"]["binary_threshold"] - for epoch in range(cfg["training"]["epochs"]): + for epoch in range(train_cfg["epochs"]): model.train() + + # Update DistributedSampler epoch for proper shuffling + if accelerator.num_processes > 1: + train_loader.sampler.set_epoch(epoch) + for batch in train_loader: - batch = {k: v.to(device) for k, v in batch.items()} + with accelerator.accumulate(model): + logits = model( + batch["persona_input_ids"], batch["persona_attention_mask"], + batch["context_input_ids"], batch["context_attention_mask"], + batch["response_input_ids"], batch["response_attention_mask"], + ) + loss, loss_parts = accelerator.unwrap_model(model).compute_loss( + logits, + { + "y_risk": batch["y_risk"], + "l_risk": batch["l_risk"], + "c_primary": batch["c_primary"], + "c_fine": batch["c_fine"], + }, + weights=cfg["loss_weights"], + ) - logits = model( - batch["persona_input_ids"], batch["persona_attention_mask"], - batch["context_input_ids"], batch["context_attention_mask"], - batch["response_input_ids"], batch["response_attention_mask"], - ) - loss, loss_parts = model.compute_loss( - logits, - {"y_risk": batch["y_risk"], "l_risk": batch["l_risk"], - "c_primary": batch["c_primary"], "c_fine": batch["c_fine"]}, - weights=cfg["loss_weights"], - ) + accelerator.backward(loss) - optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_( - model.parameters(), cfg["training"]["gradient_clip"] - ) - optimizer.step() - scheduler.step() - global_step += 1 - - if cfg["logging"]["use_wandb"] and global_step % 50 == 0: - wandb.log({"train/loss": loss.item(), "step": global_step, - **{f"train/{k}": v.item() for k, v in loss_parts.items()}}) - - if global_step % cfg["training"]["eval_steps"] == 0: - val_f1 = evaluate(model, val_loader, device, cfg) - print(f"Step {global_step}: Val binary F1 = {val_f1:.4f}") - if val_f1 > best_val_f1: - best_val_f1 = val_f1 - import os - os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True) - torch.save( - model.state_dict(), - f"{cfg['output']['checkpoint_dir']}/best.pt" + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + model.parameters(), train_cfg["gradient_clip"] ) - model.train() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + global_step += 1 - print(f"Epoch {epoch + 1}/{cfg['training']['epochs']} done.") + # Log every 50 global steps (main process only) + if cfg["logging"]["use_wandb"] and global_step % 50 == 0: + accelerator.log({ + "train/loss": loss.item(), + "train/lr": scheduler.get_last_lr()[0], + "step": global_step, + **{f"train/{k}": v.item() for k, v in loss_parts.items()}, + }, step=global_step) - print(f"Training complete. Best val binary F1: {best_val_f1:.4f}") + # Periodic validation + if global_step % eval_steps == 0: + val_f1 = evaluate(model, val_loader, accelerator, binary_threshold) + accelerator.print( + f"Step {global_step} | Val binary F1 = {val_f1:.4f}" + ) + if accelerator.is_main_process: + if cfg["logging"]["use_wandb"]: + accelerator.log( + {"val/binary_f1": val_f1}, step=global_step + ) + if val_f1 > best_val_f1: + best_val_f1 = val_f1 + os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True) + ckpt_path = os.path.join( + cfg["output"]["checkpoint_dir"], "best.pt" + ) + torch.save( + accelerator.unwrap_model(model).state_dict(), + ckpt_path, + ) + accelerator.print(f" → Saved best model: {ckpt_path}") -@torch.no_grad() -def evaluate(model, loader, device, cfg): - model.eval() - all_y_true, all_y_pred = [], [] + model.train() - for batch in loader: - batch = {k: v.to(device) for k, v in batch.items()} - preds = model.predict( - batch["persona_input_ids"], batch["persona_attention_mask"], - batch["context_input_ids"], batch["context_attention_mask"], - batch["response_input_ids"], batch["response_attention_mask"], - binary_threshold=cfg["evaluation"]["binary_threshold"], + accelerator.print( + f"Epoch {epoch + 1}/{train_cfg['epochs']} done. " + f"Best Val F1 so far: {best_val_f1:.4f}" ) - all_y_true.extend(batch["y_risk"].int().cpu().tolist()) - all_y_pred.extend(preds["y_risk"].cpu().tolist()) - from sklearn.metrics import f1_score - return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0) + # Save final model + if accelerator.is_main_process: + final_path = os.path.join(cfg["output"]["checkpoint_dir"], "final.pt") + torch.save(accelerator.unwrap_model(model).state_dict(), final_path) + accelerator.print( + f"\nTraining complete. Best val binary F1: {best_val_f1:.4f}\n" + f"Final model saved to {final_path}" + ) + + if cfg["logging"]["use_wandb"]: + accelerator.end_training() if __name__ == "__main__": diff --git a/scripts/train_intervention.py b/scripts/train_intervention.py index 98ebb3d..530fed8 100644 --- a/scripts/train_intervention.py +++ b/scripts/train_intervention.py @@ -2,19 +2,33 @@ Step 4: Train Module C — RL Intervention Policy (PPO). Two-stage training: - Stage 1: Behavior cloning warm-up from a_recommend labels - Stage 2: PPO fine-tuning with multi-objective reward + Stage 1 (BC warm-up): behavior cloning on all 4 GPUs via Accelerate DDP + Stage 2 (PPO fine-tuning): single-GPU (GPU-0) offline RL — inherently sequential -Usage: - python scripts/train_intervention.py --config configs/intervention_config.yaml \ - --train-data data/processed/train.jsonl +Preprocessing (detector inference) is distributed across all 4 GPUs. + +Usage (4 GPUs): + accelerate launch --num_processes=4 --mixed_precision=bf16 \\ + scripts/train_intervention.py --config configs/intervention_config.yaml \\ + --train-data data/processed/train.jsonl + +Usage (single GPU): + accelerate launch --num_processes=1 \\ + scripts/train_intervention.py --config configs/intervention_config.yaml """ import argparse +import os import yaml import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np from pathlib import Path +from torch.utils.data import DataLoader, TensorDataset, DistributedSampler from transformers import AutoTokenizer +from accelerate import Accelerator +from accelerate.utils import set_seed from src.data.dataset import load_jsonl from src.models.detector import CompanionRiskDetector @@ -30,10 +44,122 @@ import wandb def get_obs_dim(detector_hidden: int) -> int: - """Compute observation vector dimension.""" return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 +def distributed_preprocess( + raw_samples, + detector, + tokenizer, + accelerator, + binary_threshold: float = 0.5, +): + """ + Distribute detector inference across all GPUs. + + Each process handles its shard of the dataset; results are gathered + on the main process. + """ + n = len(raw_samples) + rank = accelerator.process_index + world = accelerator.num_processes + + # Each process takes its contiguous shard + start = (n * rank) // world + end = (n * (rank + 1)) // world + local_samples = raw_samples[start:end] + + accelerator.print( + f"Preprocessing: rank {rank} handles samples {start}–{end} " + f"({len(local_samples)} samples)" + ) + + local_processed = preprocess_samples_with_detector( + local_samples, + detector, + tokenizer, + device=str(accelerator.device), + binary_threshold=binary_threshold, + ) + + # Gather on main process via object lists + all_shards = [None] * world + torch.distributed.all_gather_object(all_shards, local_processed) + + if accelerator.is_main_process: + processed = [] + for shard in all_shards: + processed.extend(shard) + return processed + return [] + + +def run_bc_warmup( + agent: InterventionAgent, + obs_tensor: torch.Tensor, + action_tensor: torch.Tensor, + cfg: dict, + accelerator: Accelerator, +): + """ + Stage 1: Behavior cloning on all GPUs. + Returns the updated agent weights (synced automatically via DDP). + """ + bc_cfg = cfg.get("behavior_cloning", {}) + per_gpu_bs = bc_cfg.get("per_gpu_batch_size", 256) + n_epochs = bc_cfg.get("epochs", 5) + lr = bc_cfg.get("lr", 1e-3) + + dataset = TensorDataset(obs_tensor, action_tensor) + + sampler = None + if accelerator.num_processes > 1: + sampler = DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + + loader = DataLoader( + dataset, + batch_size=per_gpu_bs, + sampler=sampler, + shuffle=(sampler is None), + pin_memory=True, + drop_last=False, + ) + + optimizer = optim.Adam(agent.parameters(), lr=lr) + agent, optimizer, loader = accelerator.prepare(agent, optimizer, loader) + + losses = [] + for epoch in range(n_epochs): + if accelerator.num_processes > 1: + loader.sampler.set_epoch(epoch) + + epoch_loss = 0.0 + agent.train() + for obs_batch, act_batch in loader: + loss = accelerator.unwrap_model(agent).behavior_clone_loss( + obs_batch, act_batch + ) + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + epoch_loss += loss.item() + + avg_loss = epoch_loss / max(len(loader), 1) + losses.append(avg_loss) + accelerator.print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}") + + if cfg["logging"]["use_wandb"] and accelerator.is_main_process: + accelerator.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1}) + + # Return the unwrapped agent (weights are consistent across all processes) + return accelerator.unwrap_model(agent), losses + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/intervention_config.yaml") @@ -43,109 +169,163 @@ def main(): with open(args.config) as f: cfg = yaml.safe_load(f) - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Device: {device}") + set_seed(42) + + # ── Accelerator for BC stage ───────────────────────────────────────── + bc_cfg = cfg.get("behavior_cloning", {}) + accelerator = Accelerator( + mixed_precision=bc_cfg.get("mixed_precision", "bf16"), + gradient_accumulation_steps=1, + log_with="wandb" if cfg["logging"]["use_wandb"] else None, + ) + accelerator.print( + f"Running on {accelerator.num_processes} GPU(s), " + f"mixed_precision={accelerator.mixed_precision}" + ) if cfg["logging"]["use_wandb"]: - wandb.init( - project=cfg["logging"]["project"], - name=cfg["logging"]["run_name"], + accelerator.init_trackers( + project_name=cfg["logging"]["project"], config=cfg, + init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}}, ) - # Load detector + # ── Load detector (shared weights, each process loads its own copy) ── detector_cfg = cfg["detector"] tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"]) detector = CompanionRiskDetector( model_name=detector_cfg["model_name"], hidden_size=detector_cfg["hidden_size"], - ).to(device) + ).to(accelerator.device) ckpt_path = detector_cfg["checkpoint"] if Path(ckpt_path).exists(): - detector.load_state_dict(torch.load(ckpt_path, map_location=device)) - print(f"Detector loaded from {ckpt_path}") + detector.load_state_dict( + torch.load(ckpt_path, map_location=accelerator.device) + ) + accelerator.print(f"Detector loaded from {ckpt_path}") else: - print(f"[WARN] Detector checkpoint not found at {ckpt_path}. Using random weights.") + accelerator.print(f"[WARN] No detector checkpoint at {ckpt_path}. Using random weights.") detector.eval() - # Pre-process training data through the detector - print(f"Loading training data: {args.train_data}") + # ── Distributed preprocessing ──────────────────────────────────────── + accelerator.print(f"Loading: {args.train_data}") raw_samples = load_jsonl(args.train_data) - print(f"Preprocessing {len(raw_samples)} samples with detector...") + accelerator.print(f"Preprocessing {len(raw_samples)} samples across {accelerator.num_processes} GPU(s)...") - processed = preprocess_samples_with_detector( - raw_samples, - detector, - tokenizer, - device=device, - binary_threshold=cfg.get("evaluation", {}).get("binary_threshold", 0.5), - ) + binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5) + + if accelerator.num_processes > 1: + # Use distributed preprocessing + processed = distributed_preprocess( + raw_samples, detector, tokenizer, accelerator, binary_threshold + ) + else: + processed = preprocess_samples_with_detector( + raw_samples, detector, tokenizer, + device=str(accelerator.device), + binary_threshold=binary_threshold, + ) detector_hidden = detector_cfg["hidden_size"] obs_dim = get_obs_dim(detector_hidden) - print(f"Observation dimension: {obs_dim}") + accelerator.print(f"Observation dim: {obs_dim}") - # Build the RL agent - agent_cfg = cfg["agent"] - agent = InterventionAgent( - detector_hidden=detector_hidden, - state_hidden=agent_cfg["state_hidden"], - dropout=agent_cfg["dropout"], - ).to(device) - - trainer = PPOTrainer( - agent=agent, - obs_dim=obs_dim, - lr=cfg["ppo"]["lr"], - clip_eps=cfg["ppo"]["clip_eps"], - entropy_coef=cfg["ppo"]["entropy_coef"], - value_coef=cfg["ppo"]["value_coef"], - max_grad_norm=cfg["ppo"]["max_grad_norm"], - gamma=cfg["ppo"]["gamma"], - gae_lambda=cfg["ppo"]["gae_lambda"], - n_epochs=cfg["ppo"]["n_epochs"], - batch_size=cfg["ppo"]["batch_size"], - buffer_size=cfg["ppo"]["n_rollout_steps"], - device=device, - use_wandb=cfg["logging"]["use_wandb"], - ) - - # Stage 1: Behavior cloning warm-up - bc_cfg = cfg.get("behavior_cloning", {}) + # ── Stage 1: Behavior Cloning (all GPUs) ──────────────────────────── if bc_cfg.get("enabled", True): - print("\n=== Stage 1: Behavior Cloning Warm-up ===") - obs_tensor, action_tensor = build_bc_tensors(processed, device=device) - trainer.behavior_cloning_warmup( - obs_tensor, - action_tensor, - n_epochs=bc_cfg.get("epochs", 5), - lr=bc_cfg.get("lr", 1e-3), + accelerator.print("\n=== Stage 1: Behavior Cloning Warm-up (all GPUs) ===") + + # Build BC tensors on main process, broadcast to others + if accelerator.is_main_process: + obs_tensor, action_tensor = build_bc_tensors(processed, device="cpu") + else: + obs_tensor = torch.zeros(1, obs_dim) + action_tensor = torch.zeros(1, dtype=torch.long) + + if accelerator.num_processes > 1: + # Broadcast tensor sizes from rank 0 + size_tensor = torch.tensor([obs_tensor.shape[0]], dtype=torch.long) + torch.distributed.broadcast(size_tensor, src=0) + n_samples = size_tensor.item() + + if not accelerator.is_main_process: + obs_tensor = torch.zeros(n_samples, obs_dim) + action_tensor = torch.zeros(n_samples, dtype=torch.long) + + # Broadcast data from rank 0 to all processes + torch.distributed.broadcast(obs_tensor, src=0) + torch.distributed.broadcast(action_tensor, src=0) + + obs_tensor = obs_tensor.to(accelerator.device) + action_tensor = action_tensor.to(accelerator.device) + + agent = InterventionAgent( + detector_hidden=detector_hidden, + state_hidden=cfg["agent"]["state_hidden"], + dropout=cfg["agent"]["dropout"], ) - # Stage 2: PPO fine-tuning - print("\n=== Stage 2: PPO Fine-tuning ===") - env_cfg = cfg.get("environment", {}) - env = CompanionEnv( - samples=processed, - detector_hidden=detector_hidden, - reward_weights=cfg.get("reward"), - max_turns=env_cfg.get("max_turns", 20), - ) + agent, _ = run_bc_warmup(agent, obs_tensor, action_tensor, cfg, accelerator) - output_cfg = cfg["output"] - Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) + else: + agent = InterventionAgent( + detector_hidden=detector_hidden, + state_hidden=cfg["agent"]["state_hidden"], + dropout=cfg["agent"]["dropout"], + ) - trainer.train( - env=env, - total_timesteps=cfg["ppo"]["total_timesteps"], - n_rollout_steps=cfg["ppo"]["n_rollout_steps"], - checkpoint_dir=output_cfg["checkpoint_dir"], - save_interval=output_cfg.get("save_interval", 10_000), - ) + # ── Stage 2: PPO (main process only — inherently sequential) ───────── + accelerator.wait_for_everyone() - print("Training complete.") + if accelerator.is_main_process: + accelerator.print("\n=== Stage 2: PPO Fine-tuning (GPU-0 only) ===") + + # Move agent to GPU-0 + device = accelerator.device + agent = agent.to(device) + + ppo_cfg = cfg["ppo"] + trainer = PPOTrainer( + agent=agent, + obs_dim=obs_dim, + lr=ppo_cfg["lr"], + clip_eps=ppo_cfg["clip_eps"], + entropy_coef=ppo_cfg["entropy_coef"], + value_coef=ppo_cfg["value_coef"], + max_grad_norm=ppo_cfg["max_grad_norm"], + gamma=ppo_cfg["gamma"], + gae_lambda=ppo_cfg["gae_lambda"], + n_epochs=ppo_cfg["n_epochs"], + batch_size=ppo_cfg["batch_size"], + buffer_size=ppo_cfg["n_rollout_steps"], + device=str(device), + use_wandb=cfg["logging"]["use_wandb"], + ) + + env_cfg = cfg.get("environment", {}) + env = CompanionEnv( + samples=processed, + detector_hidden=detector_hidden, + reward_weights=cfg.get("reward"), + max_turns=env_cfg.get("max_turns", 20), + ) + + output_cfg = cfg["output"] + Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) + + trainer.train( + env=env, + total_timesteps=ppo_cfg["total_timesteps"], + n_rollout_steps=ppo_cfg["n_rollout_steps"], + checkpoint_dir=output_cfg["checkpoint_dir"], + save_interval=output_cfg.get("save_interval", 10_000), + ) + + accelerator.print("Training complete.") + + if cfg["logging"]["use_wandb"]: + accelerator.end_training() if __name__ == "__main__":