Files
CompanionGuard-RL/code/scripts/train_intervention.py
zhangsiyuan bd1f51c496 chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 11:28:42 +08:00

379 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Step 4: Train Module C — RL Intervention Policy (PPO).
Two-stage training:
Stage 1 (BC warm-up): behavior cloning on all 4 GPUs via Accelerate DDP
Stage 2 (PPO fine-tuning): single-GPU (GPU-0) offline RL — inherently sequential
Preprocessing (detector inference) is distributed across all 4 GPUs.
Usage (4 GPUs):
accelerate launch --num_processes=4 --mixed_precision=bf16 \\
scripts/train_intervention.py --config configs/intervention_config.yaml \\
--train-data data/processed/CompanionRisk-Bench/train.jsonl
Usage (single GPU):
accelerate launch --num_processes=1 \\
scripts/train_intervention.py --config configs/intervention_config.yaml
"""
import argparse
import os
import pickle
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
from transformers import AutoTokenizer
from accelerate import Accelerator
from accelerate.utils import set_seed
from src.data.dataset import load_jsonl
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.rl.companion_env import CompanionEnv
from src.rl.ppo_trainer import PPOTrainer
from src.utils.preprocessing import (
preprocess_samples_with_detector,
build_bc_tensors,
)
from src.utils.taxonomy import NUM_RISK_LEVELS, NUM_PRIMARY
try:
import wandb
_WANDB_AVAILABLE = True
except ImportError:
wandb = None
_WANDB_AVAILABLE = False
def get_obs_dim(detector_hidden: int) -> int:
return 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
def distributed_preprocess(
raw_samples,
detector,
tokenizer,
accelerator,
binary_threshold: float = 0.5,
):
"""
Distribute detector inference across all GPUs.
Each process handles its shard of the dataset; results are gathered
on the main process.
"""
n = len(raw_samples)
rank = accelerator.process_index
world = accelerator.num_processes
# Each process takes its contiguous shard
start = (n * rank) // world
end = (n * (rank + 1)) // world
local_samples = raw_samples[start:end]
accelerator.print(
f"Preprocessing: rank {rank} handles samples {start}{end} "
f"({len(local_samples)} samples)"
)
local_processed = preprocess_samples_with_detector(
local_samples,
detector,
tokenizer,
device=str(accelerator.device),
binary_threshold=binary_threshold,
)
# Gather on main process via object lists
all_shards = [None] * world
torch.distributed.all_gather_object(all_shards, local_processed)
if accelerator.is_main_process:
processed = []
for shard in all_shards:
processed.extend(shard)
return processed
return []
def run_bc_warmup(
agent: InterventionAgent,
obs_tensor: torch.Tensor,
action_tensor: torch.Tensor,
cfg: dict,
accelerator: Accelerator,
use_wandb: bool = False,
):
"""
Stage 1: Behavior cloning on all GPUs.
Returns the updated agent weights (synced automatically via DDP).
"""
bc_cfg = cfg.get("behavior_cloning", {})
per_gpu_bs = bc_cfg.get("per_gpu_batch_size", 256)
n_epochs = bc_cfg.get("epochs", 5)
lr = bc_cfg.get("lr", 1e-3)
dataset = TensorDataset(obs_tensor, action_tensor)
sampler = None
if accelerator.num_processes > 1:
sampler = DistributedSampler(
dataset,
num_replicas=accelerator.num_processes,
rank=accelerator.process_index,
shuffle=True,
)
loader = DataLoader(
dataset,
batch_size=per_gpu_bs,
sampler=sampler,
shuffle=(sampler is None),
pin_memory=True,
drop_last=False,
)
optimizer = optim.Adam(agent.parameters(), lr=lr)
agent, optimizer, loader = accelerator.prepare(agent, optimizer, loader)
losses = []
for epoch in range(n_epochs):
if accelerator.num_processes > 1:
loader.sampler.set_epoch(epoch)
epoch_loss = 0.0
agent.train()
for obs_batch, act_batch in loader:
loss = accelerator.unwrap_model(agent).behavior_clone_loss(
obs_batch, act_batch
)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
avg_loss = epoch_loss / max(len(loader), 1)
losses.append(avg_loss)
accelerator.print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")
if use_wandb and accelerator.is_main_process:
accelerator.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1})
# Return the unwrapped agent (weights are consistent across all processes)
return accelerator.unwrap_model(agent), losses
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="configs/intervention_config.yaml")
parser.add_argument("--train-data",
default="data/processed/CompanionRisk-Bench/train.jsonl")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
set_seed(42)
# ── Accelerator for BC stage ─────────────────────────────────────────
bc_cfg = cfg.get("behavior_cloning", {})
use_wandb = cfg["logging"]["use_wandb"] and _WANDB_AVAILABLE
accelerator = Accelerator(
mixed_precision=bc_cfg.get("mixed_precision", "bf16"),
gradient_accumulation_steps=1,
log_with="wandb" if use_wandb else None,
)
accelerator.print(
f"Running on {accelerator.num_processes} GPU(s), "
f"mixed_precision={accelerator.mixed_precision}"
)
if use_wandb:
accelerator.init_trackers(
project_name=cfg["logging"]["project"],
config=cfg,
init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}},
)
# ── Load detector (shared weights, each process loads its own copy) ──
detector_cfg = cfg["detector"]
tokenizer = AutoTokenizer.from_pretrained(detector_cfg["model_name"])
detector = CompanionRiskDetector(
model_name=detector_cfg["model_name"],
hidden_size=detector_cfg["hidden_size"],
).to(accelerator.device)
ckpt_path = detector_cfg["checkpoint"]
if Path(ckpt_path).exists():
detector.load_state_dict(
torch.load(ckpt_path, map_location=accelerator.device)
)
accelerator.print(f"Detector loaded from {ckpt_path}")
else:
accelerator.print(f"[WARN] No detector checkpoint at {ckpt_path}. Using random weights.")
detector.eval()
# ── Distributed preprocessing (with disk cache) ──────────────────────
accelerator.print(f"Loading: {args.train_data}")
raw_samples = load_jsonl(args.train_data)
accelerator.print(f"Preprocessing {len(raw_samples)} samples across {accelerator.num_processes} GPU(s)...")
binary_threshold = cfg.get("evaluation", {}).get("binary_threshold", 0.5)
# Cache key based on train data path and sample count
cache_path = Path("/tmp/companionguard_processed_cache.pkl")
cache_meta_path = Path("/tmp/companionguard_processed_cache.meta")
cache_key = f"{args.train_data}:{len(raw_samples)}"
if cache_path.exists() and cache_meta_path.exists():
cached_key = cache_meta_path.read_text().strip()
if cached_key == cache_key:
accelerator.print(f"[CACHE HIT] Loading preprocessed data from {cache_path}")
with open(cache_path, "rb") as f:
processed = pickle.load(f)
accelerator.print(f"Loaded {len(processed)} cached samples.")
else:
accelerator.print("[CACHE MISS] Cache key mismatch, re-preprocessing...")
cache_path.unlink(missing_ok=True)
cache_meta_path.unlink(missing_ok=True)
processed = None
else:
processed = None
if processed is None:
if accelerator.num_processes > 1:
# Use distributed preprocessing
processed = distributed_preprocess(
raw_samples, detector, tokenizer, accelerator, binary_threshold
)
else:
processed = preprocess_samples_with_detector(
raw_samples, detector, tokenizer,
device=str(accelerator.device),
binary_threshold=binary_threshold,
)
# Save cache
if accelerator.is_main_process and processed:
accelerator.print(f"Saving preprocessed cache to {cache_path} ...")
with open(cache_path, "wb") as f:
pickle.dump(processed, f)
cache_meta_path.write_text(cache_key)
accelerator.print("Cache saved.")
detector_hidden = detector_cfg["hidden_size"]
obs_dim = get_obs_dim(detector_hidden)
accelerator.print(f"Observation dim: {obs_dim}")
# ── Stage 1: Behavior Cloning (all GPUs) ────────────────────────────
if bc_cfg.get("enabled", True):
accelerator.print("\n=== Stage 1: Behavior Cloning Warm-up (all GPUs) ===")
# Build BC tensors on main process, broadcast to others
if accelerator.is_main_process:
obs_tensor, action_tensor = build_bc_tensors(processed, device="cpu")
else:
obs_tensor = torch.zeros(1, obs_dim)
action_tensor = torch.zeros(1, dtype=torch.long)
if accelerator.num_processes > 1:
# Broadcast tensor sizes from rank 0
size_tensor = torch.tensor([obs_tensor.shape[0]], dtype=torch.long)
torch.distributed.broadcast(size_tensor, src=0)
n_samples = size_tensor.item()
if not accelerator.is_main_process:
obs_tensor = torch.zeros(n_samples, obs_dim)
action_tensor = torch.zeros(n_samples, dtype=torch.long)
# Broadcast data from rank 0 to all processes
torch.distributed.broadcast(obs_tensor, src=0)
torch.distributed.broadcast(action_tensor, src=0)
obs_tensor = obs_tensor.to(accelerator.device)
action_tensor = action_tensor.to(accelerator.device)
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=cfg["agent"]["state_hidden"],
dropout=cfg["agent"]["dropout"],
)
agent, _ = run_bc_warmup(agent, obs_tensor, action_tensor, cfg, accelerator, use_wandb=use_wandb)
# Save BC-only checkpoint for ablation comparison
if accelerator.is_main_process:
output_cfg = cfg["output"]
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
bc_ckpt_path = str(Path(output_cfg["checkpoint_dir"]) / "bc_only_v5.pt")
torch.save(agent.state_dict(), bc_ckpt_path)
accelerator.print(f"BC-only checkpoint saved: {bc_ckpt_path}")
else:
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=cfg["agent"]["state_hidden"],
dropout=cfg["agent"]["dropout"],
)
# ── Stage 2: PPO (main process only — inherently sequential) ─────────
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print("\n=== Stage 2: PPO Fine-tuning (GPU-0 only) ===")
# Move agent to GPU-0
device = accelerator.device
agent = agent.to(device)
ppo_cfg = cfg["ppo"]
trainer = PPOTrainer(
agent=agent,
obs_dim=obs_dim,
lr=ppo_cfg["lr"],
clip_eps=ppo_cfg["clip_eps"],
entropy_coef=ppo_cfg["entropy_coef"],
value_coef=ppo_cfg["value_coef"],
max_grad_norm=ppo_cfg["max_grad_norm"],
gamma=ppo_cfg["gamma"],
gae_lambda=ppo_cfg["gae_lambda"],
n_epochs=ppo_cfg["n_epochs"],
batch_size=ppo_cfg["batch_size"],
buffer_size=ppo_cfg["n_rollout_steps"],
device=str(device),
use_wandb=use_wandb,
)
env_cfg = cfg.get("environment", {})
env = CompanionEnv(
samples=processed,
detector_hidden=detector_hidden,
reward_weights=cfg.get("reward"),
max_turns=env_cfg.get("max_turns", 20),
)
output_cfg = cfg["output"]
Path(output_cfg["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
trainer.train(
env=env,
total_timesteps=ppo_cfg["total_timesteps"],
n_rollout_steps=ppo_cfg["n_rollout_steps"],
checkpoint_dir=output_cfg["checkpoint_dir"],
save_interval=output_cfg.get("save_interval", 10_000),
)
accelerator.print("Training complete.")
if use_wandb:
accelerator.end_training()
if __name__ == "__main__":
main()