Two-module pipeline for AI companion safety: - Module B: context-aware risk detector with CrossAttention fusion - Module C: PPO-based adaptive intervention policy Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels), dataset generation/annotation pipeline, training scripts, and eval suite. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
198 lines
6.7 KiB
Python
198 lines
6.7 KiB
Python
"""
|
|
Step 4: Train Module C — RL Intervention Policy (PPO).
|
|
|
|
Two-stage training:
|
|
Stage 1: Behavior cloning warm-up from a_recommend labels
|
|
Stage 2: PPO fine-tuning with multi-objective reward
|
|
|
|
Usage:
|
|
python scripts/train_intervention.py --config configs/intervention_config.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
import wandb
|
|
from pathlib import Path
|
|
|
|
from src.data.dataset import load_jsonl
|
|
from src.models.detector import CompanionRiskDetector
|
|
from src.models.intervention_agent import InterventionAgent
|
|
from src.rl.companion_env import CompanionEnv
|
|
from src.rl.ppo_trainer import PPOTrainer
|
|
from src.utils.taxonomy import (
|
|
ACTION_NAME_TO_ID,
|
|
NUM_RISK_LEVELS,
|
|
NUM_PRIMARY,
|
|
category_to_index,
|
|
)
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device):
|
|
"""Run detector on all samples to extract state vectors for RL env."""
|
|
from src.data.dataset import format_conversation
|
|
|
|
processed = []
|
|
detector.eval()
|
|
|
|
for sample in samples:
|
|
texts = format_conversation(
|
|
sample["persona"],
|
|
sample["history"],
|
|
sample["user_input"],
|
|
sample["ai_response"],
|
|
)
|
|
|
|
def enc(text, max_len):
|
|
return tokenizer(
|
|
text, max_length=max_len, truncation=True,
|
|
padding="max_length", return_tensors="pt",
|
|
)
|
|
|
|
p_enc = enc(texts["persona_text"], 128)
|
|
c_enc = enc(texts["context_text"], 512)
|
|
r_enc = enc(texts["response_text"], 256)
|
|
|
|
with torch.no_grad():
|
|
preds = detector.predict(
|
|
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
|
|
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
|
|
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
|
|
)
|
|
|
|
# Build persona/history pool embeddings (reuse e_fused as approximation)
|
|
e_fused = preds["e_fused"].squeeze(0).cpu().numpy()
|
|
|
|
processed.append({
|
|
**sample,
|
|
"d_score": preds["d_score"].item(),
|
|
"l_risk": preds["l_risk"].item(),
|
|
"c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(),
|
|
"c_primary_idx": preds["c_primary"].item(),
|
|
"e_H_pool": e_fused.tolist(),
|
|
"e_P_pool": e_fused.tolist(),
|
|
"a_recommend": sample.get("a_recommend", "PASS"),
|
|
})
|
|
|
|
return processed
|
|
|
|
|
|
def build_bc_tensors(processed_samples, obs_dim, device):
|
|
"""Build observation and expert action tensors for behavior cloning."""
|
|
obs_list, action_list = [], []
|
|
|
|
for s in processed_samples:
|
|
d_score = np.array([s["d_score"]], dtype=np.float32)
|
|
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
|
|
l_risk_oh[int(s["l_risk"])] = 1.0
|
|
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
|
|
e_H = np.array(s["e_H_pool"], dtype=np.float32)
|
|
e_P = np.array(s["e_P_pool"], dtype=np.float32)
|
|
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
|
|
obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
|
|
obs_list.append(obs)
|
|
action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
|
|
|
|
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
|
|
action_tensor = torch.LongTensor(action_list).to(device)
|
|
return obs_tensor, action_tensor
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="configs/intervention_config.yaml")
|
|
parser.add_argument("--train-data", default="data/processed/train.jsonl")
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
|
|
if cfg["logging"]["use_wandb"]:
|
|
wandb.init(
|
|
project=cfg["logging"]["project"],
|
|
name=cfg["logging"]["run_name"],
|
|
config=cfg,
|
|
)
|
|
|
|
# Load detector
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"])
|
|
detector = CompanionRiskDetector(
|
|
model_name=cfg["detector"]["model_name"],
|
|
hidden_size=cfg["detector"]["hidden_size"],
|
|
).to(device)
|
|
detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device))
|
|
detector.eval()
|
|
print("Detector loaded.")
|
|
|
|
# Load and preprocess training data
|
|
raw_samples = load_jsonl(args.train_data)
|
|
print(f"Preprocessing {len(raw_samples)} samples with detector...")
|
|
processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device)
|
|
|
|
detector_hidden = cfg["detector"]["hidden_size"]
|
|
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
|
|
|
|
# Build RL agent
|
|
agent = InterventionAgent(
|
|
detector_hidden=detector_hidden,
|
|
state_hidden=cfg["agent"]["state_hidden"],
|
|
dropout=cfg["agent"]["dropout"],
|
|
)
|
|
|
|
trainer = PPOTrainer(
|
|
agent=agent,
|
|
obs_dim=obs_dim,
|
|
lr=cfg["ppo"]["lr"],
|
|
clip_eps=cfg["ppo"]["clip_eps"],
|
|
entropy_coef=cfg["ppo"]["entropy_coef"],
|
|
value_coef=cfg["ppo"]["value_coef"],
|
|
max_grad_norm=cfg["ppo"]["max_grad_norm"],
|
|
gamma=cfg["ppo"]["gamma"],
|
|
gae_lambda=cfg["ppo"]["gae_lambda"],
|
|
n_epochs=cfg["ppo"]["n_epochs"],
|
|
batch_size=cfg["ppo"]["batch_size"],
|
|
buffer_size=cfg["ppo"]["n_rollout_steps"],
|
|
device=device,
|
|
use_wandb=cfg["logging"]["use_wandb"],
|
|
)
|
|
|
|
# Stage 1: Behavior cloning warm-up
|
|
if cfg["behavior_cloning"]["enabled"]:
|
|
print("Stage 1: Behavior cloning warm-up...")
|
|
obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device)
|
|
trainer.behavior_cloning_warmup(
|
|
obs_tensor, action_tensor,
|
|
n_epochs=cfg["behavior_cloning"]["epochs"],
|
|
lr=cfg["behavior_cloning"]["lr"],
|
|
)
|
|
|
|
# Stage 2: PPO fine-tuning
|
|
print("Stage 2: PPO fine-tuning...")
|
|
env = CompanionEnv(
|
|
samples=processed,
|
|
detector_hidden=detector_hidden,
|
|
reward_weights=cfg["reward"],
|
|
max_turns=cfg["environment"]["max_turns"],
|
|
)
|
|
|
|
Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
|
|
trainer.train(
|
|
env=env,
|
|
total_timesteps=cfg["ppo"]["total_timesteps"],
|
|
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
|
|
checkpoint_dir=cfg["output"]["checkpoint_dir"],
|
|
save_interval=cfg["output"]["save_interval"],
|
|
)
|
|
|
|
torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt")
|
|
print("Training complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|