Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
309 lines
12 KiB
Python
309 lines
12 KiB
Python
"""
|
|
Step 3: Train Module B — Context-aware Risk Detector.
|
|
|
|
Multi-GPU training via HuggingFace Accelerate (DDP, no NVLink required).
|
|
Mixed precision: BF16 (native on RTX 5090).
|
|
|
|
Usage (4 GPUs):
|
|
accelerate launch --num_processes=4 --mixed_precision=bf16 \\
|
|
scripts/train_detector.py --config configs/detector_config.yaml
|
|
|
|
Usage (single GPU for debugging):
|
|
accelerate launch --num_processes=1 \\
|
|
scripts/train_detector.py --config configs/detector_config.yaml
|
|
|
|
Or with torchrun:
|
|
torchrun --nproc_per_node=4 scripts/train_detector.py \\
|
|
--config configs/detector_config.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import json
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
|
from accelerate import Accelerator
|
|
from accelerate.utils import set_seed, DistributedDataParallelKwargs
|
|
|
|
from src.data.dataset import CompanionGuardDataset, load_jsonl
|
|
from src.models.detector import CompanionRiskDetector
|
|
from src.utils.metrics import detection_metrics
|
|
from src.utils.taxonomy import FINE_GRAINED_LABELS, NUM_FINE
|
|
|
|
|
|
def compute_fine_pos_weight(train_path: str, device: str) -> torch.Tensor:
|
|
"""
|
|
Compute per-label positive weights for BCEWithLogitsLoss on the fine-grained head.
|
|
pos_weight[i] = (N_total - N_pos_i) / N_pos_i (clipped to [1, 30])
|
|
|
|
This corrects heavy class imbalance for rare labels like Romanticization/CoRumination
|
|
which have only ~3.7% positive rate without weighting.
|
|
"""
|
|
samples = load_jsonl(train_path)
|
|
N = len(samples)
|
|
pw = []
|
|
for lbl in FINE_GRAINED_LABELS:
|
|
pos = sum(1 for s in samples if lbl in s.get("c_fine", []))
|
|
w = (N - pos) / max(pos, 1)
|
|
pw.append(float(np.clip(w, 1.0, 30.0))) # cap at 30 to prevent unstable gradients
|
|
return torch.tensor(pw, dtype=torch.float32, device=device)
|
|
|
|
|
|
def make_loader(dataset, batch_size, accelerator, shuffle=True, num_workers=4):
|
|
"""Plain DataLoader — accelerator.prepare() adds DistributedSampler automatically."""
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
drop_last=shuffle,
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, loader, accelerator, binary_threshold=0.5):
|
|
"""Evaluate on validation set across all processes, aggregate on main."""
|
|
model.eval()
|
|
all_y_true, all_y_pred = [], []
|
|
|
|
for batch in loader:
|
|
preds = accelerator.unwrap_model(model).predict(
|
|
batch["persona_input_ids"], batch["persona_attention_mask"],
|
|
batch["context_input_ids"], batch["context_attention_mask"],
|
|
batch["response_input_ids"], batch["response_attention_mask"],
|
|
binary_threshold=binary_threshold,
|
|
)
|
|
# Gather predictions from all processes
|
|
y_true_batch = accelerator.gather_for_metrics(batch["y_risk"].int())
|
|
y_pred_batch = accelerator.gather_for_metrics(preds["y_risk"])
|
|
|
|
all_y_true.extend(y_true_batch.cpu().tolist())
|
|
all_y_pred.extend(y_pred_batch.cpu().tolist())
|
|
|
|
if accelerator.is_main_process:
|
|
from sklearn.metrics import f1_score
|
|
return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0)
|
|
return 0.0
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="configs/detector_config.yaml")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
train_cfg = cfg["training"]
|
|
set_seed(train_cfg.get("seed", 42))
|
|
|
|
# ── Accelerator setup ────────────────────────────────────────────────
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
accelerator = Accelerator(
|
|
mixed_precision=train_cfg.get("mixed_precision", "bf16"),
|
|
gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 1),
|
|
log_with="wandb" if cfg["logging"]["use_wandb"] else None,
|
|
kwargs_handlers=[ddp_kwargs],
|
|
)
|
|
|
|
accelerator.print(
|
|
f"Running on {accelerator.num_processes} GPU(s), "
|
|
f"mixed_precision={accelerator.mixed_precision}, "
|
|
f"grad_accum={accelerator.gradient_accumulation_steps}"
|
|
)
|
|
|
|
# Init wandb only on main process
|
|
if cfg["logging"]["use_wandb"]:
|
|
accelerator.init_trackers(
|
|
project_name=cfg["logging"]["project"],
|
|
config=cfg,
|
|
init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}},
|
|
)
|
|
|
|
# ── Data ─────────────────────────────────────────────────────────────
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
|
data_cfg = cfg["data"]
|
|
per_gpu_bs = train_cfg["per_gpu_batch_size"]
|
|
num_workers = data_cfg.get("num_workers", 4)
|
|
|
|
train_ds = CompanionGuardDataset(
|
|
data_cfg["train_path"], tokenizer,
|
|
max_persona_len=data_cfg["max_persona_len"],
|
|
max_context_len=data_cfg["max_context_len"],
|
|
max_response_len=data_cfg["max_response_len"],
|
|
max_history_turns=data_cfg["max_history_turns"],
|
|
)
|
|
val_ds = CompanionGuardDataset(
|
|
data_cfg["val_path"], tokenizer,
|
|
max_persona_len=data_cfg["max_persona_len"],
|
|
max_context_len=data_cfg["max_context_len"],
|
|
max_response_len=data_cfg["max_response_len"],
|
|
max_history_turns=data_cfg["max_history_turns"],
|
|
)
|
|
|
|
train_loader = make_loader(train_ds, per_gpu_bs, accelerator, shuffle=True, num_workers=num_workers)
|
|
val_loader = make_loader(val_ds, per_gpu_bs, accelerator, shuffle=False, num_workers=num_workers)
|
|
|
|
effective_batch = (
|
|
per_gpu_bs
|
|
* accelerator.num_processes
|
|
* accelerator.gradient_accumulation_steps
|
|
)
|
|
accelerator.print(
|
|
f"Dataset: {len(train_ds)} train / {len(val_ds)} val | "
|
|
f"Effective batch size: {effective_batch}"
|
|
)
|
|
|
|
# ── Model ────────────────────────────────────────────────────────────
|
|
model = CompanionRiskDetector(
|
|
model_name=cfg["model"]["name"],
|
|
hidden_size=cfg["model"]["hidden_size"],
|
|
num_heads=cfg["model"]["num_heads"],
|
|
dropout=cfg["model"]["dropout"],
|
|
use_lora=cfg["model"]["use_lora"],
|
|
)
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=float(train_cfg["lr"]),
|
|
weight_decay=float(train_cfg["weight_decay"]),
|
|
)
|
|
|
|
# Steps per epoch after accounting for gradient accumulation
|
|
steps_per_epoch = len(train_loader) // accelerator.gradient_accumulation_steps
|
|
total_steps = steps_per_epoch * train_cfg["epochs"]
|
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
optimizer,
|
|
num_warmup_steps=train_cfg["warmup_steps"],
|
|
num_training_steps=total_steps,
|
|
)
|
|
|
|
# Prepare: wraps model with DDP, DataLoaders with DistributedSampler
|
|
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
|
|
model, optimizer, train_loader, val_loader, scheduler
|
|
)
|
|
|
|
# ── Fine-label pos_weight (fixes all-negative bias for rare labels) ──
|
|
fine_training_cfg = cfg.get("fine_training", {})
|
|
use_fine_pos_weight = fine_training_cfg.get("use_pos_weight", False)
|
|
fine_risky_only = fine_training_cfg.get("risky_only", False)
|
|
fine_pos_weight = None
|
|
if use_fine_pos_weight:
|
|
fine_pos_weight = compute_fine_pos_weight(
|
|
data_cfg["train_path"], device=accelerator.device
|
|
)
|
|
accelerator.print(
|
|
f" Fine pos_weight (top-5 rare): "
|
|
+ ", ".join(
|
|
f"{FINE_GRAINED_LABELS[i]}={fine_pos_weight[i]:.1f}"
|
|
for i in fine_pos_weight.argsort(descending=True)[:5].tolist()
|
|
)
|
|
)
|
|
if fine_risky_only:
|
|
accelerator.print(" Fine loss: restricted to y_risk=1 samples (fine_risky_only=True)")
|
|
|
|
# ── Training loop ────────────────────────────────────────────────────
|
|
best_val_f1 = 0.0
|
|
global_step = 0
|
|
eval_steps = train_cfg["eval_steps"]
|
|
binary_threshold = cfg["evaluation"]["binary_threshold"]
|
|
|
|
for epoch in range(train_cfg["epochs"]):
|
|
model.train()
|
|
|
|
# Update DistributedSampler epoch for proper shuffling
|
|
if accelerator.num_processes > 1 and hasattr(train_loader.sampler, 'set_epoch'):
|
|
train_loader.sampler.set_epoch(epoch)
|
|
|
|
for batch in train_loader:
|
|
with accelerator.accumulate(model):
|
|
logits = model(
|
|
batch["persona_input_ids"], batch["persona_attention_mask"],
|
|
batch["context_input_ids"], batch["context_attention_mask"],
|
|
batch["response_input_ids"], batch["response_attention_mask"],
|
|
)
|
|
loss, loss_parts = accelerator.unwrap_model(model).compute_loss(
|
|
logits,
|
|
{
|
|
"y_risk": batch["y_risk"],
|
|
"l_risk": batch["l_risk"],
|
|
"c_primary": batch["c_primary"],
|
|
"c_fine": batch["c_fine"],
|
|
},
|
|
weights=cfg["loss_weights"],
|
|
fine_pos_weight=fine_pos_weight,
|
|
fine_risky_only=fine_risky_only,
|
|
)
|
|
|
|
accelerator.backward(loss)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(
|
|
model.parameters(), train_cfg["gradient_clip"]
|
|
)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
global_step += 1
|
|
|
|
# Log every 50 global steps (main process only)
|
|
if cfg["logging"]["use_wandb"] and global_step % 50 == 0:
|
|
accelerator.log({
|
|
"train/loss": loss.item(),
|
|
"train/lr": scheduler.get_last_lr()[0],
|
|
"step": global_step,
|
|
**{f"train/{k}": v.item() for k, v in loss_parts.items()},
|
|
}, step=global_step)
|
|
|
|
# Periodic validation
|
|
if global_step % eval_steps == 0:
|
|
val_f1 = evaluate(model, val_loader, accelerator, binary_threshold)
|
|
accelerator.print(
|
|
f"Step {global_step} | Val binary F1 = {val_f1:.4f}"
|
|
)
|
|
|
|
if accelerator.is_main_process:
|
|
if cfg["logging"]["use_wandb"]:
|
|
accelerator.log(
|
|
{"val/binary_f1": val_f1}, step=global_step
|
|
)
|
|
if val_f1 > best_val_f1:
|
|
best_val_f1 = val_f1
|
|
os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True)
|
|
ckpt_path = os.path.join(
|
|
cfg["output"]["checkpoint_dir"], "best.pt"
|
|
)
|
|
torch.save(
|
|
accelerator.unwrap_model(model).state_dict(),
|
|
ckpt_path,
|
|
)
|
|
accelerator.print(f" → Saved best model: {ckpt_path}")
|
|
|
|
model.train()
|
|
|
|
accelerator.print(
|
|
f"Epoch {epoch + 1}/{train_cfg['epochs']} done. "
|
|
f"Best Val F1 so far: {best_val_f1:.4f}"
|
|
)
|
|
|
|
# Save final model
|
|
if accelerator.is_main_process:
|
|
final_path = os.path.join(cfg["output"]["checkpoint_dir"], "final.pt")
|
|
torch.save(accelerator.unwrap_model(model).state_dict(), final_path)
|
|
accelerator.print(
|
|
f"\nTraining complete. Best val binary F1: {best_val_f1:.4f}\n"
|
|
f"Final model saved to {final_path}"
|
|
)
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
|
accelerator.end_training()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|