2026-05-09 17:21:11 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Step 4: Train Module C — RL Intervention Policy (PPO).
|
|
|
|
|
|
|
|
|
|
|
|
Two-stage training:
|
2026-05-09 17:56:13 +08:00
|
|
|
|
Stage 1 (BC warm-up): behavior cloning on all 4 GPUs via Accelerate DDP
|
|
|
|
|
|
Stage 2 (PPO fine-tuning): single-GPU (GPU-0) offline RL — inherently sequential
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
Preprocessing (detector inference) is distributed across all 4 GPUs.
|
|
|
|
|
|
|
|
|
|
|
|
Usage (4 GPUs):
|
|
|
|
|
|
accelerate launch --num_processes=4 --mixed_precision=bf16 \\
|
|
|
|
|
|
scripts/train_intervention.py --config configs/intervention_config.yaml \\
|
|
|
|
|
|
--train-data data/processed/train.jsonl
|
|
|
|
|
|
|
|
|
|
|
|
Usage (single GPU):
|
|
|
|
|
|
accelerate launch --num_processes=1 \\
|
|
|
|
|
|
scripts/train_intervention.py --config configs/intervention_config.yaml
|
2026-05-09 17:21:11 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
2026-05-09 17:56:13 +08:00
|
|
|
|
import os
|
2026-05-09 17:21:11 +08:00
|
|
|
|
import yaml
|
|
|
|
|
|
import torch
|
2026-05-09 17:56:13 +08:00
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
import numpy as np
|
2026-05-09 17:21:11 +08:00
|
|
|
|
from pathlib import Path
|
2026-05-09 17:56:13 +08:00
|
|
|
|
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
|
2026-05-09 17:50:17 +08:00
|
|
|
|
from transformers import AutoTokenizer
|
2026-05-09 17:56:13 +08:00
|
|
|
|
from accelerate import Accelerator
|
|
|
|
|
|
from accelerate.utils import set_seed
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
from src.data.dataset import load_jsonl
|
|
|
|
|
|
from src.models.detector import CompanionRiskDetector
|
|
|
|
|
|
from src.models.intervention_agent import InterventionAgent
|
|
|
|
|
|
from src.rl.companion_env import CompanionEnv
|
|
|
|
|
|
from src.rl.ppo_trainer import PPOTrainer
|
2026-05-09 17:50:17 +08:00
|
|
|
|
from src.utils.preprocessing import (
|
|
|
|
|
|
preprocess_samples_with_detector,
|
|
|
|
|
|
build_bc_tensors,
|
2026-05-09 17:21:11 +08:00
|
|
|
|
)
|
2026-05-09 17:50:17 +08:00
|
|
|
|
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
|
|
|
|
|
|
import wandb
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
|
def get_obs_dim(detector_hidden: int) -> int:
|
|
|
|
|
|
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
def distributed_preprocess(
|
|
|
|
|
|
raw_samples,
|
|
|
|
|
|
detector,
|
|
|
|
|
|
tokenizer,
|
|
|
|
|
|
accelerator,
|
|
|
|
|
|
binary_threshold: float = 0.5,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Distribute detector inference across all GPUs.
|
|
|
|
|
|
|
|
|
|
|
|
Each process handles its shard of the dataset; results are gathered
|
|
|
|
|
|
on the main process.
|
|
|
|
|
|
"""
|
|
|
|
|
|
n = len(raw_samples)
|
|
|
|
|
|
rank = accelerator.process_index
|
|
|
|
|
|
world = accelerator.num_processes
|
|
|
|
|
|
|
|
|
|
|
|
# Each process takes its contiguous shard
|
|
|
|
|
|
start = (n * rank) // world
|
|
|
|
|
|
end = (n * (rank + 1)) // world
|
|
|
|
|
|
local_samples = raw_samples[start:end]
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.print(
|
|
|
|
|
|
f"Preprocessing: rank {rank} handles samples {start}–{end} "
|
|
|
|
|
|
f"({len(local_samples)} samples)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
local_processed = preprocess_samples_with_detector(
|
|
|
|
|
|
local_samples,
|
|
|
|
|
|
detector,
|
|
|
|
|
|
tokenizer,
|
|
|
|
|
|
device=str(accelerator.device),
|
|
|
|
|
|
binary_threshold=binary_threshold,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Gather on main process via object lists
|
|
|
|
|
|
all_shards = [None] * world
|
|
|
|
|
|
torch.distributed.all_gather_object(all_shards, local_processed)
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
|
processed = []
|
|
|
|
|
|
for shard in all_shards:
|
|
|
|
|
|
processed.extend(shard)
|
|
|
|
|
|
return processed
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_bc_warmup(
|
|
|
|
|
|
agent: InterventionAgent,
|
|
|
|
|
|
obs_tensor: torch.Tensor,
|
|
|
|
|
|
action_tensor: torch.Tensor,
|
|
|
|
|
|
cfg: dict,
|
|
|
|
|
|
accelerator: Accelerator,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Stage 1: Behavior cloning on all GPUs.
|
|
|
|
|
|
Returns the updated agent weights (synced automatically via DDP).
|
|
|
|
|
|
"""
|
|
|
|
|
|
bc_cfg = cfg.get("behavior_cloning", {})
|
|
|
|
|
|
per_gpu_bs = bc_cfg.get("per_gpu_batch_size", 256)
|
|
|
|
|
|
n_epochs = bc_cfg.get("epochs", 5)
|
|
|
|
|
|
lr = bc_cfg.get("lr", 1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
dataset = TensorDataset(obs_tensor, action_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
sampler = None
|
|
|
|
|
|
if accelerator.num_processes > 1:
|
|
|
|
|
|
sampler = DistributedSampler(
|
|
|
|
|
|
dataset,
|
|
|
|
|
|
num_replicas=accelerator.num_processes,
|
|
|
|
|
|
rank=accelerator.process_index,
|
|
|
|
|
|
shuffle=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
loader = DataLoader(
|
|
|
|
|
|
dataset,
|
|
|
|
|
|
batch_size=per_gpu_bs,
|
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
|
shuffle=(sampler is None),
|
|
|
|
|
|
pin_memory=True,
|
|
|
|
|
|
drop_last=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(agent.parameters(), lr=lr)
|
|
|
|
|
|
agent, optimizer, loader = accelerator.prepare(agent, optimizer, loader)
|
|
|
|
|
|
|
|
|
|
|
|
losses = []
|
|
|
|
|
|
for epoch in range(n_epochs):
|
|
|
|
|
|
if accelerator.num_processes > 1:
|
|
|
|
|
|
loader.sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
|
|
|
|
epoch_loss = 0.0
|
|
|
|
|
|
agent.train()
|
|
|
|
|
|
for obs_batch, act_batch in loader:
|
|
|
|
|
|
loss = accelerator.unwrap_model(agent).behavior_clone_loss(
|
|
|
|
|
|
obs_batch, act_batch
|
|
|
|
|
|
)
|
|
|
|
|
|
accelerator.backward(loss)
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
epoch_loss += loss.item()
|
|
|
|
|
|
|
|
|
|
|
|
avg_loss = epoch_loss / max(len(loader), 1)
|
|
|
|
|
|
losses.append(avg_loss)
|
|
|
|
|
|
accelerator.print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
if cfg["logging"]["use_wandb"] and accelerator.is_main_process:
|
|
|
|
|
|
accelerator.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1})
|
|
|
|
|
|
|
|
|
|
|
|
# Return the unwrapped agent (weights are consistent across all processes)
|
|
|
|
|
|
return accelerator.unwrap_model(agent), losses
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-09 17:21:11 +08:00
|
|
|
|
def main():
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument("--config", default="configs/intervention_config.yaml")
|
|
|
|
|
|
parser.add_argument("--train-data", default="data/processed/train.jsonl")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
with open(args.config) as f:
|
|
|
|
|
|
cfg = yaml.safe_load(f)
|
|
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
set_seed(42)
|
|
|
|
|
|
|
|
|
|
|
|
# ── Accelerator for BC stage ─────────────────────────────────────────
|
|
|
|
|
|
bc_cfg = cfg.get("behavior_cloning", {})
|
|
|
|
|
|
accelerator = Accelerator(
|
|
|
|
|
|
mixed_precision=bc_cfg.get("mixed_precision", "bf16"),
|
|
|
|
|
|
gradient_accumulation_steps=1,
|
|
|
|
|
|
log_with="wandb" if cfg["logging"]["use_wandb"] else None,
|
|
|
|
|
|
)
|
|
|
|
|
|
accelerator.print(
|
|
|
|
|
|
f"Running on {accelerator.num_processes} GPU(s), "
|
|
|
|
|
|
f"mixed_precision={accelerator.mixed_precision}"
|
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
2026-05-09 17:56:13 +08:00
|
|
|
|
accelerator.init_trackers(
|
|
|
|
|
|
project_name=cfg["logging"]["project"],
|
2026-05-09 17:21:11 +08:00
|
|
|
|
config=cfg,
|
2026-05-09 17:56:13 +08:00
|
|
|
|
init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}},
|
2026-05-09 17:21:11 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
# ── Load detector (shared weights, each process loads its own copy) ──
|
2026-05-09 17:50:17 +08:00
|
|
|
|
detector_cfg = cfg["detector"]
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
|
2026-05-09 17:21:11 +08:00
|
|
|
|
detector = CompanionRiskDetector(
|
2026-05-09 17:50:17 +08:00
|
|
|
|
model_name=detector_cfg["model_name"],
|
|
|
|
|
|
hidden_size=detector_cfg["hidden_size"],
|
2026-05-09 17:56:13 +08:00
|
|
|
|
).to(accelerator.device)
|
2026-05-09 17:50:17 +08:00
|
|
|
|
|
|
|
|
|
|
ckpt_path = detector_cfg["checkpoint"]
|
|
|
|
|
|
if Path(ckpt_path).exists():
|
2026-05-09 17:56:13 +08:00
|
|
|
|
detector.load_state_dict(
|
|
|
|
|
|
torch.load(ckpt_path, map_location=accelerator.device)
|
|
|
|
|
|
)
|
|
|
|
|
|
accelerator.print(f"Detector loaded from {ckpt_path}")
|
2026-05-09 17:50:17 +08:00
|
|
|
|
else:
|
2026-05-09 17:56:13 +08:00
|
|
|
|
accelerator.print(f"[WARN] No detector checkpoint at {ckpt_path}. Using random weights.")
|
2026-05-09 17:50:17 +08:00
|
|
|
|
|
2026-05-09 17:21:11 +08:00
|
|
|
|
detector.eval()
|
|
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
# ── Distributed preprocessing ────────────────────────────────────────
|
|
|
|
|
|
accelerator.print(f"Loading: {args.train_data}")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
raw_samples = load_jsonl(args.train_data)
|
2026-05-09 17:56:13 +08:00
|
|
|
|
accelerator.print(f"Preprocessing {len(raw_samples)} samples across {accelerator.num_processes} GPU(s)...")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.num_processes > 1:
|
|
|
|
|
|
# Use distributed preprocessing
|
|
|
|
|
|
processed = distributed_preprocess(
|
|
|
|
|
|
raw_samples, detector, tokenizer, accelerator, binary_threshold
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
processed = preprocess_samples_with_detector(
|
|
|
|
|
|
raw_samples, detector, tokenizer,
|
|
|
|
|
|
device=str(accelerator.device),
|
|
|
|
|
|
binary_threshold=binary_threshold,
|
|
|
|
|
|
)
|
2026-05-09 17:50:17 +08:00
|
|
|
|
|
|
|
|
|
|
detector_hidden = detector_cfg["hidden_size"]
|
|
|
|
|
|
obs_dim = get_obs_dim(detector_hidden)
|
2026-05-09 17:56:13 +08:00
|
|
|
|
accelerator.print(f"Observation dim: {obs_dim}")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
# ── Stage 1: Behavior Cloning (all GPUs) ────────────────────────────
|
2026-05-09 17:50:17 +08:00
|
|
|
|
if bc_cfg.get("enabled", True):
|
2026-05-09 17:56:13 +08:00
|
|
|
|
accelerator.print("\n=== Stage 1: Behavior Cloning Warm-up (all GPUs) ===")
|
|
|
|
|
|
|
|
|
|
|
|
# Build BC tensors on main process, broadcast to others
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
|
obs_tensor, action_tensor = build_bc_tensors(processed, device="cpu")
|
|
|
|
|
|
else:
|
|
|
|
|
|
obs_tensor = torch.zeros(1, obs_dim)
|
|
|
|
|
|
action_tensor = torch.zeros(1, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.num_processes > 1:
|
|
|
|
|
|
# Broadcast tensor sizes from rank 0
|
|
|
|
|
|
size_tensor = torch.tensor([obs_tensor.shape[0]], dtype=torch.long)
|
|
|
|
|
|
torch.distributed.broadcast(size_tensor, src=0)
|
|
|
|
|
|
n_samples = size_tensor.item()
|
|
|
|
|
|
|
|
|
|
|
|
if not accelerator.is_main_process:
|
|
|
|
|
|
obs_tensor = torch.zeros(n_samples, obs_dim)
|
|
|
|
|
|
action_tensor = torch.zeros(n_samples, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
# Broadcast data from rank 0 to all processes
|
|
|
|
|
|
torch.distributed.broadcast(obs_tensor, src=0)
|
|
|
|
|
|
torch.distributed.broadcast(action_tensor, src=0)
|
|
|
|
|
|
|
|
|
|
|
|
obs_tensor = obs_tensor.to(accelerator.device)
|
|
|
|
|
|
action_tensor = action_tensor.to(accelerator.device)
|
|
|
|
|
|
|
|
|
|
|
|
agent = InterventionAgent(
|
|
|
|
|
|
detector_hidden=detector_hidden,
|
|
|
|
|
|
state_hidden=cfg["agent"]["state_hidden"],
|
|
|
|
|
|
dropout=cfg["agent"]["dropout"],
|
2026-05-09 17:21:11 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
agent, _ = run_bc_warmup(agent, obs_tensor, action_tensor, cfg, accelerator)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
else:
|
|
|
|
|
|
agent = InterventionAgent(
|
|
|
|
|
|
detector_hidden=detector_hidden,
|
|
|
|
|
|
state_hidden=cfg["agent"]["state_hidden"],
|
|
|
|
|
|
dropout=cfg["agent"]["dropout"],
|
|
|
|
|
|
)
|
2026-05-09 17:50:17 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
# ── Stage 2: PPO (main process only — inherently sequential) ─────────
|
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
|
accelerator.print("\n=== Stage 2: PPO Fine-tuning (GPU-0 only) ===")
|
|
|
|
|
|
|
|
|
|
|
|
# Move agent to GPU-0
|
|
|
|
|
|
device = accelerator.device
|
|
|
|
|
|
agent = agent.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
ppo_cfg = cfg["ppo"]
|
|
|
|
|
|
trainer = PPOTrainer(
|
|
|
|
|
|
agent=agent,
|
|
|
|
|
|
obs_dim=obs_dim,
|
|
|
|
|
|
lr=ppo_cfg["lr"],
|
|
|
|
|
|
clip_eps=ppo_cfg["clip_eps"],
|
|
|
|
|
|
entropy_coef=ppo_cfg["entropy_coef"],
|
|
|
|
|
|
value_coef=ppo_cfg["value_coef"],
|
|
|
|
|
|
max_grad_norm=ppo_cfg["max_grad_norm"],
|
|
|
|
|
|
gamma=ppo_cfg["gamma"],
|
|
|
|
|
|
gae_lambda=ppo_cfg["gae_lambda"],
|
|
|
|
|
|
n_epochs=ppo_cfg["n_epochs"],
|
|
|
|
|
|
batch_size=ppo_cfg["batch_size"],
|
|
|
|
|
|
buffer_size=ppo_cfg["n_rollout_steps"],
|
|
|
|
|
|
device=str(device),
|
|
|
|
|
|
use_wandb=cfg["logging"]["use_wandb"],
|
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
2026-05-09 17:56:13 +08:00
|
|
|
|
env_cfg = cfg.get("environment", {})
|
|
|
|
|
|
env = CompanionEnv(
|
|
|
|
|
|
samples=processed,
|
|
|
|
|
|
detector_hidden=detector_hidden,
|
|
|
|
|
|
reward_weights=cfg.get("reward"),
|
|
|
|
|
|
max_turns=env_cfg.get("max_turns", 20),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
output_cfg = cfg["output"]
|
|
|
|
|
|
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
trainer.train(
|
|
|
|
|
|
env=env,
|
|
|
|
|
|
total_timesteps=ppo_cfg["total_timesteps"],
|
|
|
|
|
|
n_rollout_steps=ppo_cfg["n_rollout_steps"],
|
|
|
|
|
|
checkpoint_dir=output_cfg["checkpoint_dir"],
|
|
|
|
|
|
save_interval=output_cfg.get("save_interval", 10_000),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.print("Training complete.")
|
|
|
|
|
|
|
|
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
|
|
|
|
|
accelerator.end_training()
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|