""" Step 3: Train Module B — Context-aware Risk Detector. Usage: python scripts/train_detector.py --config configs/detector_config.yaml """ import argparse import yaml import torch import wandb from torch.utils.data import DataLoader from transformers import AutoTokenizer, get_linear_schedule_with_warmup from src.data.dataset import CompanionGuardDataset from src.models.detector import CompanionRiskDetector from src.utils.metrics import detection_metrics def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/detector_config.yaml") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") if cfg["logging"]["use_wandb"]: wandb.init( project=cfg["logging"]["project"], name=cfg["logging"]["run_name"], config=cfg, ) tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"]) train_ds = CompanionGuardDataset( cfg["data"]["train_path"], tokenizer, max_persona_len=cfg["data"]["max_persona_len"], max_context_len=cfg["data"]["max_context_len"], max_response_len=cfg["data"]["max_response_len"], max_history_turns=cfg["data"]["max_history_turns"], ) val_ds = CompanionGuardDataset( cfg["data"]["val_path"], tokenizer, max_persona_len=cfg["data"]["max_persona_len"], max_context_len=cfg["data"]["max_context_len"], max_response_len=cfg["data"]["max_response_len"], max_history_turns=cfg["data"]["max_history_turns"], ) train_loader = DataLoader(train_ds, batch_size=cfg["training"]["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=cfg["training"]["batch_size"]) model = CompanionRiskDetector( model_name=cfg["model"]["name"], hidden_size=cfg["model"]["hidden_size"], num_heads=cfg["model"]["num_heads"], dropout=cfg["model"]["dropout"], use_lora=cfg["model"]["use_lora"], ).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["training"]["lr"], weight_decay=cfg["training"]["weight_decay"], ) total_steps = len(train_loader) * cfg["training"]["epochs"] scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=cfg["training"]["warmup_steps"], num_training_steps=total_steps, ) best_val_f1 = 0.0 global_step = 0 for epoch in range(cfg["training"]["epochs"]): model.train() for batch in train_loader: batch = {k: v.to(device) for k, v in batch.items()} logits = model( batch["persona_input_ids"], batch["persona_attention_mask"], batch["context_input_ids"], batch["context_attention_mask"], batch["response_input_ids"], batch["response_attention_mask"], ) loss, loss_parts = model.compute_loss( logits, {"y_risk": batch["y_risk"], "l_risk": batch["l_risk"], "c_primary": batch["c_primary"], "c_fine": batch["c_fine"]}, weights=cfg["loss_weights"], ) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), cfg["training"]["gradient_clip"] ) optimizer.step() scheduler.step() global_step += 1 if cfg["logging"]["use_wandb"] and global_step % 50 == 0: wandb.log({"train/loss": loss.item(), "step": global_step, **{f"train/{k}": v.item() for k, v in loss_parts.items()}}) if global_step % cfg["training"]["eval_steps"] == 0: val_f1 = evaluate(model, val_loader, device, cfg) print(f"Step {global_step}: Val binary F1 = {val_f1:.4f}") if val_f1 > best_val_f1: best_val_f1 = val_f1 import os os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True) torch.save( model.state_dict(), f"{cfg['output']['checkpoint_dir']}/best.pt" ) model.train() print(f"Epoch {epoch + 1}/{cfg['training']['epochs']} done.") print(f"Training complete. Best val binary F1: {best_val_f1:.4f}") @torch.no_grad() def evaluate(model, loader, device, cfg): model.eval() all_y_true, all_y_pred = [], [] for batch in loader: batch = {k: v.to(device) for k, v in batch.items()} preds = model.predict( batch["persona_input_ids"], batch["persona_attention_mask"], batch["context_input_ids"], batch["context_attention_mask"], batch["response_input_ids"], batch["response_attention_mask"], binary_threshold=cfg["evaluation"]["binary_threshold"], ) all_y_true.extend(batch["y_risk"].int().cpu().tolist()) all_y_pred.extend(preds["y_risk"].cpu().tolist()) from sklearn.metrics import f1_score return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0) if __name__ == "__main__": main()