""" Step 3: Train Module B — Context-aware Risk Detector. Multi-GPU training via HuggingFace Accelerate (DDP, no NVLink required). Mixed precision: BF16 (native on RTX 5090). Usage (4 GPUs): accelerate launch --num_processes=4 --mixed_precision=bf16 \\ scripts/train_detector.py --config configs/detector_config.yaml Usage (single GPU for debugging): accelerate launch --num_processes=1 \\ scripts/train_detector.py --config configs/detector_config.yaml Or with torchrun: torchrun --nproc_per_node=4 scripts/train_detector.py \\ --config configs/detector_config.yaml """ import argparse import os import json import yaml import torch import numpy as np from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.utils import set_seed, DistributedDataParallelKwargs from src.data.dataset import CompanionGuardDataset, load_jsonl from src.models.detector import CompanionRiskDetector from src.utils.metrics import detection_metrics from src.utils.taxonomy import FINE_GRAINED_LABELS, NUM_FINE def compute_fine_pos_weight(train_path: str, device: str) -> torch.Tensor: """ Compute per-label positive weights for BCEWithLogitsLoss on the fine-grained head. pos_weight[i] = (N_total - N_pos_i) / N_pos_i (clipped to [1, 30]) This corrects heavy class imbalance for rare labels like Romanticization/CoRumination which have only ~3.7% positive rate without weighting. """ samples = load_jsonl(train_path) N = len(samples) pw = [] for lbl in FINE_GRAINED_LABELS: pos = sum(1 for s in samples if lbl in s.get("c_fine", [])) w = (N - pos) / max(pos, 1) pw.append(float(np.clip(w, 1.0, 30.0))) # cap at 30 to prevent unstable gradients return torch.tensor(pw, dtype=torch.float32, device=device) def make_loader(dataset, batch_size, accelerator, shuffle=True, num_workers=4): """Plain DataLoader — accelerator.prepare() adds DistributedSampler automatically.""" return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=shuffle, ) @torch.no_grad() def evaluate(model, loader, accelerator, binary_threshold=0.5): """Evaluate on validation set across all processes, aggregate on main.""" model.eval() all_y_true, all_y_pred = [], [] for batch in loader: preds = accelerator.unwrap_model(model).predict( batch["persona_input_ids"], batch["persona_attention_mask"], batch["context_input_ids"], batch["context_attention_mask"], batch["response_input_ids"], batch["response_attention_mask"], binary_threshold=binary_threshold, ) # Gather predictions from all processes y_true_batch = accelerator.gather_for_metrics(batch["y_risk"].int()) y_pred_batch = accelerator.gather_for_metrics(preds["y_risk"]) all_y_true.extend(y_true_batch.cpu().tolist()) all_y_pred.extend(y_pred_batch.cpu().tolist()) if accelerator.is_main_process: from sklearn.metrics import f1_score return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0) return 0.0 def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/detector_config.yaml") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) train_cfg = cfg["training"] set_seed(train_cfg.get("seed", 42)) # ── Accelerator setup ──────────────────────────────────────────────── ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( mixed_precision=train_cfg.get("mixed_precision", "bf16"), gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 1), log_with="wandb" if cfg["logging"]["use_wandb"] else None, kwargs_handlers=[ddp_kwargs], ) accelerator.print( f"Running on {accelerator.num_processes} GPU(s), " f"mixed_precision={accelerator.mixed_precision}, " f"grad_accum={accelerator.gradient_accumulation_steps}" ) # Init wandb only on main process if cfg["logging"]["use_wandb"]: accelerator.init_trackers( project_name=cfg["logging"]["project"], config=cfg, init_kwargs={"wandb": {"name": cfg["logging"]["run_name"]}}, ) # ── Data ───────────────────────────────────────────────────────────── tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"]) data_cfg = cfg["data"] per_gpu_bs = train_cfg["per_gpu_batch_size"] num_workers = data_cfg.get("num_workers", 4) train_ds = CompanionGuardDataset( data_cfg["train_path"], tokenizer, max_persona_len=data_cfg["max_persona_len"], max_context_len=data_cfg["max_context_len"], max_response_len=data_cfg["max_response_len"], max_history_turns=data_cfg["max_history_turns"], ) val_ds = CompanionGuardDataset( data_cfg["val_path"], tokenizer, max_persona_len=data_cfg["max_persona_len"], max_context_len=data_cfg["max_context_len"], max_response_len=data_cfg["max_response_len"], max_history_turns=data_cfg["max_history_turns"], ) train_loader = make_loader(train_ds, per_gpu_bs, accelerator, shuffle=True, num_workers=num_workers) val_loader = make_loader(val_ds, per_gpu_bs, accelerator, shuffle=False, num_workers=num_workers) effective_batch = ( per_gpu_bs * accelerator.num_processes * accelerator.gradient_accumulation_steps ) accelerator.print( f"Dataset: {len(train_ds)} train / {len(val_ds)} val | " f"Effective batch size: {effective_batch}" ) # ── Model ──────────────────────────────────────────────────────────── model = CompanionRiskDetector( model_name=cfg["model"]["name"], hidden_size=cfg["model"]["hidden_size"], num_heads=cfg["model"]["num_heads"], dropout=cfg["model"]["dropout"], use_lora=cfg["model"]["use_lora"], ) optimizer = torch.optim.AdamW( model.parameters(), lr=float(train_cfg["lr"]), weight_decay=float(train_cfg["weight_decay"]), ) # Steps per epoch after accounting for gradient accumulation steps_per_epoch = len(train_loader) // accelerator.gradient_accumulation_steps total_steps = steps_per_epoch * train_cfg["epochs"] scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=train_cfg["warmup_steps"], num_training_steps=total_steps, ) # Prepare: wraps model with DDP, DataLoaders with DistributedSampler model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, scheduler ) # ── Fine-label pos_weight (fixes all-negative bias for rare labels) ── fine_training_cfg = cfg.get("fine_training", {}) use_fine_pos_weight = fine_training_cfg.get("use_pos_weight", False) fine_risky_only = fine_training_cfg.get("risky_only", False) fine_pos_weight = None if use_fine_pos_weight: fine_pos_weight = compute_fine_pos_weight( data_cfg["train_path"], device=accelerator.device ) accelerator.print( f" Fine pos_weight (top-5 rare): " + ", ".join( f"{FINE_GRAINED_LABELS[i]}={fine_pos_weight[i]:.1f}" for i in fine_pos_weight.argsort(descending=True)[:5].tolist() ) ) if fine_risky_only: accelerator.print(" Fine loss: restricted to y_risk=1 samples (fine_risky_only=True)") # ── Training loop ──────────────────────────────────────────────────── best_val_f1 = 0.0 global_step = 0 eval_steps = train_cfg["eval_steps"] binary_threshold = cfg["evaluation"]["binary_threshold"] for epoch in range(train_cfg["epochs"]): model.train() # Update DistributedSampler epoch for proper shuffling if accelerator.num_processes > 1 and hasattr(train_loader.sampler, 'set_epoch'): train_loader.sampler.set_epoch(epoch) for batch in train_loader: with accelerator.accumulate(model): logits = model( batch["persona_input_ids"], batch["persona_attention_mask"], batch["context_input_ids"], batch["context_attention_mask"], batch["response_input_ids"], batch["response_attention_mask"], ) loss, loss_parts = accelerator.unwrap_model(model).compute_loss( logits, { "y_risk": batch["y_risk"], "l_risk": batch["l_risk"], "c_primary": batch["c_primary"], "c_fine": batch["c_fine"], }, weights=cfg["loss_weights"], fine_pos_weight=fine_pos_weight, fine_risky_only=fine_risky_only, ) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_( model.parameters(), train_cfg["gradient_clip"] ) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # Log every 50 global steps (main process only) if cfg["logging"]["use_wandb"] and global_step % 50 == 0: accelerator.log({ "train/loss": loss.item(), "train/lr": scheduler.get_last_lr()[0], "step": global_step, **{f"train/{k}": v.item() for k, v in loss_parts.items()}, }, step=global_step) # Periodic validation if global_step % eval_steps == 0: val_f1 = evaluate(model, val_loader, accelerator, binary_threshold) accelerator.print( f"Step {global_step} | Val binary F1 = {val_f1:.4f}" ) if accelerator.is_main_process: if cfg["logging"]["use_wandb"]: accelerator.log( {"val/binary_f1": val_f1}, step=global_step ) if val_f1 > best_val_f1: best_val_f1 = val_f1 os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True) ckpt_path = os.path.join( cfg["output"]["checkpoint_dir"], "best.pt" ) torch.save( accelerator.unwrap_model(model).state_dict(), ckpt_path, ) accelerator.print(f" → Saved best model: {ckpt_path}") model.train() accelerator.print( f"Epoch {epoch + 1}/{train_cfg['epochs']} done. " f"Best Val F1 so far: {best_val_f1:.4f}" ) # Save final model if accelerator.is_main_process: final_path = os.path.join(cfg["output"]["checkpoint_dir"], "final.pt") torch.save(accelerator.unwrap_model(model).state_dict(), final_path) accelerator.print( f"\nTraining complete. Best val binary F1: {best_val_f1:.4f}\n" f"Final model saved to {final_path}" ) if cfg["logging"]["use_wandb"]: accelerator.end_training() if __name__ == "__main__": main()