Two-module pipeline for AI companion safety: - Module B: context-aware risk detector with CrossAttention fusion - Module C: PPO-based adaptive intervention policy Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels), dataset generation/annotation pipeline, training scripts, and eval suite. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
"""
|
|
Step 3: Train Module B — Context-aware Risk Detector.
|
|
|
|
Usage:
|
|
python scripts/train_detector.py --config configs/detector_config.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import torch
|
|
import wandb
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
|
|
|
from src.data.dataset import CompanionGuardDataset
|
|
from src.models.detector import CompanionRiskDetector
|
|
from src.utils.metrics import detection_metrics
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="configs/detector_config.yaml")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
|
wandb.init(
|
|
project=cfg["logging"]["project"],
|
|
name=cfg["logging"]["run_name"],
|
|
config=cfg,
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
|
|
|
|
train_ds = CompanionGuardDataset(
|
|
cfg["data"]["train_path"], tokenizer,
|
|
max_persona_len=cfg["data"]["max_persona_len"],
|
|
max_context_len=cfg["data"]["max_context_len"],
|
|
max_response_len=cfg["data"]["max_response_len"],
|
|
max_history_turns=cfg["data"]["max_history_turns"],
|
|
)
|
|
val_ds = CompanionGuardDataset(
|
|
cfg["data"]["val_path"], tokenizer,
|
|
max_persona_len=cfg["data"]["max_persona_len"],
|
|
max_context_len=cfg["data"]["max_context_len"],
|
|
max_response_len=cfg["data"]["max_response_len"],
|
|
max_history_turns=cfg["data"]["max_history_turns"],
|
|
)
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=cfg["training"]["batch_size"], shuffle=True)
|
|
val_loader = DataLoader(val_ds, batch_size=cfg["training"]["batch_size"])
|
|
|
|
model = CompanionRiskDetector(
|
|
model_name=cfg["model"]["name"],
|
|
hidden_size=cfg["model"]["hidden_size"],
|
|
num_heads=cfg["model"]["num_heads"],
|
|
dropout=cfg["model"]["dropout"],
|
|
use_lora=cfg["model"]["use_lora"],
|
|
).to(device)
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=cfg["training"]["lr"],
|
|
weight_decay=cfg["training"]["weight_decay"],
|
|
)
|
|
total_steps = len(train_loader) * cfg["training"]["epochs"]
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
optimizer,
|
|
num_warmup_steps=cfg["training"]["warmup_steps"],
|
|
num_training_steps=total_steps,
|
|
)
|
|
|
|
best_val_f1 = 0.0
|
|
global_step = 0
|
|
|
|
for epoch in range(cfg["training"]["epochs"]):
|
|
model.train()
|
|
for batch in train_loader:
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
logits = model(
|
|
batch["persona_input_ids"], batch["persona_attention_mask"],
|
|
batch["context_input_ids"], batch["context_attention_mask"],
|
|
batch["response_input_ids"], batch["response_attention_mask"],
|
|
)
|
|
loss, loss_parts = model.compute_loss(
|
|
logits,
|
|
{"y_risk": batch["y_risk"], "l_risk": batch["l_risk"],
|
|
"c_primary": batch["c_primary"], "c_fine": batch["c_fine"]},
|
|
weights=cfg["loss_weights"],
|
|
)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), cfg["training"]["gradient_clip"]
|
|
)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
global_step += 1
|
|
|
|
if cfg["logging"]["use_wandb"] and global_step % 50 == 0:
|
|
wandb.log({"train/loss": loss.item(), "step": global_step,
|
|
**{f"train/{k}": v.item() for k, v in loss_parts.items()}})
|
|
|
|
if global_step % cfg["training"]["eval_steps"] == 0:
|
|
val_f1 = evaluate(model, val_loader, device, cfg)
|
|
print(f"Step {global_step}: Val binary F1 = {val_f1:.4f}")
|
|
if val_f1 > best_val_f1:
|
|
best_val_f1 = val_f1
|
|
import os
|
|
os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True)
|
|
torch.save(
|
|
model.state_dict(),
|
|
f"{cfg['output']['checkpoint_dir']}/best.pt"
|
|
)
|
|
model.train()
|
|
|
|
print(f"Epoch {epoch + 1}/{cfg['training']['epochs']} done.")
|
|
|
|
print(f"Training complete. Best val binary F1: {best_val_f1:.4f}")
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, loader, device, cfg):
|
|
model.eval()
|
|
all_y_true, all_y_pred = [], []
|
|
|
|
for batch in loader:
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
preds = model.predict(
|
|
batch["persona_input_ids"], batch["persona_attention_mask"],
|
|
batch["context_input_ids"], batch["context_attention_mask"],
|
|
batch["response_input_ids"], batch["response_attention_mask"],
|
|
binary_threshold=cfg["evaluation"]["binary_threshold"],
|
|
)
|
|
all_y_true.extend(batch["y_risk"].int().cpu().tolist())
|
|
all_y_pred.extend(preds["y_risk"].cpu().tolist())
|
|
|
|
from sklearn.metrics import f1_score
|
|
return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|