feat: initial CompanionGuard-RL framework

Two-module pipeline for AI companion safety:
- Module B: context-aware risk detector with CrossAttention fusion
- Module C: PPO-based adaptive intervention policy

Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels),
dataset generation/annotation pipeline, training scripts, and eval suite.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-09 17:21:11 +08:00
commit 7d4345c29d
29 changed files with 3317 additions and 0 deletions

81
scripts/annotate_data.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Step 2: LLM judge pre-annotation.
Usage:
python scripts/annotate_data.py --input data/raw/generated.jsonl \
--output data/processed/annotated.jsonl \
--config configs/data_generation.yaml
"""
import argparse
import json
import yaml
import random
from pathlib import Path
from src.data.llm_judge import LLMJudge
from src.data.dataset import load_jsonl
def split_dataset(samples, train_ratio=0.8, val_ratio=0.1, seed=42):
random.seed(seed)
random.shuffle(samples)
n = len(samples)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
return (
samples[:n_train],
samples[n_train: n_train + n_val],
samples[n_train + n_val:],
)
def save_jsonl(samples, path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples → {path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--output", default="data/processed/annotated.jsonl")
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument("--skip-annotation", action="store_true",
help="Skip LLM annotation (use existing labels)")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
samples = load_jsonl(args.input)
print(f"Loaded {len(samples)} samples from {args.input}")
if not args.skip_annotation:
judge = LLMJudge(
api_type=cfg["api"]["type"],
model=cfg["annotation"]["judge_model"],
)
samples = judge.annotate_batch(samples, output_path=args.output)
else:
save_jsonl(samples, args.output)
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
train, val, test = split_dataset(
samples,
train_ratio=split_cfg["train"],
val_ratio=split_cfg["val"],
seed=split_cfg.get("seed", 42),
)
base = Path(args.output).parent
save_jsonl(train, base / "train.jsonl")
save_jsonl(val, base / "val.jsonl")
save_jsonl(test, base / "test.jsonl")
print(f"Split: train={len(train)}, val={len(val)}, test={len(test)}")
if __name__ == "__main__":
main()

193
scripts/evaluate.py Normal file
View File

@@ -0,0 +1,193 @@
"""
Evaluation script: run detection + intervention baselines and ours.
Usage:
python scripts/evaluate.py --detector-ckpt checkpoints/detector/best.pt \
--agent-ckpt checkpoints/intervention/final.pt \
--test-data data/processed/test.jsonl \
--config configs/detector_config.yaml
"""
import argparse
import yaml
import json
import torch
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer
from src.data.dataset import load_jsonl, format_conversation
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.utils.metrics import detection_metrics, intervention_metrics
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
INTERVENTION_ACTIONS,
NUM_RISK_LEVELS,
NUM_PRIMARY,
DEFAULT_ACTION_BY_LEVEL,
)
# ── Baselines ────────────────────────────────────────────────────────────────
def rule_based_policy(l_risk: int) -> int:
"""Baseline: fixed rule — l_risk >= 3 → REJECT, else PASS."""
return 3 if l_risk >= 3 else 0
def threshold_policy(l_risk: int) -> int:
"""Baseline: threshold mapping per risk level."""
return DEFAULT_ACTION_BY_LEVEL[l_risk]
# ── Main evaluation ──────────────────────────────────────────────────────────
def run_detection_eval(model, tokenizer, samples, cfg, device):
model.eval()
y_true, y_pred = [], []
l_true, l_pred = [], []
for sample in samples:
texts = format_conversation(
sample["persona"], sample["history"],
sample["user_input"], sample["ai_response"],
)
def enc(text, max_len):
return tokenizer(text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt")
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
with torch.no_grad():
preds = model.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
)
y_true.append(sample["y_risk"])
y_pred.append(preds["y_risk"].item())
l_true.append(sample["l_risk"])
l_pred.append(preds["l_risk"].item())
return detection_metrics(y_true, y_pred, l_true, l_pred)
def run_intervention_eval(agent, processed_samples, obs_dim, device):
agent.eval()
y_risk_true, l_risk_true, a_pred, a_recommend = [], [], [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = torch.FloatTensor(
np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
).unsqueeze(0).to(device)
with torch.no_grad():
action, _, _, _ = agent.get_action(obs, deterministic=True)
y_risk_true.append(s["y_risk"])
l_risk_true.append(int(s["l_risk"]))
a_pred.append(action.item())
a_recommend.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
return intervention_metrics(y_risk_true, l_risk_true, a_pred, a_recommend)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--detector-ckpt", required=True)
parser.add_argument("--agent-ckpt", default=None)
parser.add_argument("--test-data", default="data/processed/test.jsonl")
parser.add_argument("--config", default="configs/detector_config.yaml")
parser.add_argument("--intervention-config", default="configs/intervention_config.yaml")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
with open(args.intervention_config) as f:
int_cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
samples = load_jsonl(args.test_data)
print(f"Loaded {len(samples)} test samples.")
# Detection evaluation
detector = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
).to(device)
detector.load_state_dict(torch.load(args.detector_ckpt, map_location=device))
print("\n=== Detection Evaluation ===")
det_metrics = run_detection_eval(detector, tokenizer, samples, cfg, device)
for k, v in det_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
# Intervention evaluation
if args.agent_ckpt:
from scripts.train_intervention import preprocess_samples_with_detector
detector_hidden = cfg["model"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
processed = preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device)
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=int_cfg["agent"]["state_hidden"],
).to(device)
agent.load_state_dict(torch.load(args.agent_ckpt, map_location=device))
print("\n=== Intervention Evaluation: RL Policy (Ours) ===")
int_metrics = run_intervention_eval(agent, processed, obs_dim, device)
for k, v in int_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
elif isinstance(v, list):
print(f" {k}: {[f'{x:.3f}' for x in v]}")
print("\n=== Intervention Evaluation: Rule-based Baseline ===")
rule_preds = [rule_based_policy(s["l_risk"]) for s in processed]
rule_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
rule_preds,
)
for k, v in rule_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
print("\n=== Intervention Evaluation: Threshold Baseline ===")
thr_preds = [threshold_policy(s["l_risk"]) for s in processed]
thr_metrics = intervention_metrics(
[s["y_risk"] for s in processed],
[s["l_risk"] for s in processed],
thr_preds,
)
for k, v in thr_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
# Save results
results = {"detection": det_metrics}
Path("experiments").mkdir(exist_ok=True)
with open("experiments/eval_results.json", "w") as f:
json.dump(results, f, indent=2, default=str)
print("\nResults saved to experiments/eval_results.json")
if __name__ == "__main__":
main()

40
scripts/generate_data.py Normal file
View File

@@ -0,0 +1,40 @@
"""
Step 1: Generate companion conversation dataset using LLM.
Usage:
python scripts/generate_data.py --config configs/data_generation.yaml
"""
import argparse
import yaml
from pathlib import Path
from src.data.data_generator import ConversationGenerator
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="configs/data_generation.yaml")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
generator = ConversationGenerator(
api_type=cfg["api"]["type"],
model=cfg["api"]["model"],
)
count = generator.generate_dataset(
output_path=cfg["output"]["output_file"],
total_samples=cfg["generation"]["total_samples"],
samples_per_category=cfg["generation"]["samples_per_category"],
delay=cfg["generation"]["delay"],
)
print(f"Generated {count} samples → {cfg['output']['output_file']}")
if __name__ == "__main__":
main()

150
scripts/train_detector.py Normal file
View File

@@ -0,0 +1,150 @@
"""
Step 3: Train Module B — Context-aware Risk Detector.
Usage:
python scripts/train_detector.py --config configs/detector_config.yaml
"""
import argparse
import yaml
import torch
import wandb
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from src.data.dataset import CompanionGuardDataset
from src.models.detector import CompanionRiskDetector
from src.utils.metrics import detection_metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="configs/detector_config.yaml")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if cfg["logging"]["use_wandb"]:
wandb.init(
project=cfg["logging"]["project"],
name=cfg["logging"]["run_name"],
config=cfg,
)
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["name"])
train_ds = CompanionGuardDataset(
cfg["data"]["train_path"], tokenizer,
max_persona_len=cfg["data"]["max_persona_len"],
max_context_len=cfg["data"]["max_context_len"],
max_response_len=cfg["data"]["max_response_len"],
max_history_turns=cfg["data"]["max_history_turns"],
)
val_ds = CompanionGuardDataset(
cfg["data"]["val_path"], tokenizer,
max_persona_len=cfg["data"]["max_persona_len"],
max_context_len=cfg["data"]["max_context_len"],
max_response_len=cfg["data"]["max_response_len"],
max_history_turns=cfg["data"]["max_history_turns"],
)
train_loader = DataLoader(train_ds, batch_size=cfg["training"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=cfg["training"]["batch_size"])
model = CompanionRiskDetector(
model_name=cfg["model"]["name"],
hidden_size=cfg["model"]["hidden_size"],
num_heads=cfg["model"]["num_heads"],
dropout=cfg["model"]["dropout"],
use_lora=cfg["model"]["use_lora"],
).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg["training"]["lr"],
weight_decay=cfg["training"]["weight_decay"],
)
total_steps = len(train_loader) * cfg["training"]["epochs"]
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=cfg["training"]["warmup_steps"],
num_training_steps=total_steps,
)
best_val_f1 = 0.0
global_step = 0
for epoch in range(cfg["training"]["epochs"]):
model.train()
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
logits = model(
batch["persona_input_ids"], batch["persona_attention_mask"],
batch["context_input_ids"], batch["context_attention_mask"],
batch["response_input_ids"], batch["response_attention_mask"],
)
loss, loss_parts = model.compute_loss(
logits,
{"y_risk": batch["y_risk"], "l_risk": batch["l_risk"],
"c_primary": batch["c_primary"], "c_fine": batch["c_fine"]},
weights=cfg["loss_weights"],
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(), cfg["training"]["gradient_clip"]
)
optimizer.step()
scheduler.step()
global_step += 1
if cfg["logging"]["use_wandb"] and global_step % 50 == 0:
wandb.log({"train/loss": loss.item(), "step": global_step,
**{f"train/{k}": v.item() for k, v in loss_parts.items()}})
if global_step % cfg["training"]["eval_steps"] == 0:
val_f1 = evaluate(model, val_loader, device, cfg)
print(f"Step {global_step}: Val binary F1 = {val_f1:.4f}")
if val_f1 > best_val_f1:
best_val_f1 = val_f1
import os
os.makedirs(cfg["output"]["checkpoint_dir"], exist_ok=True)
torch.save(
model.state_dict(),
f"{cfg['output']['checkpoint_dir']}/best.pt"
)
model.train()
print(f"Epoch {epoch + 1}/{cfg['training']['epochs']} done.")
print(f"Training complete. Best val binary F1: {best_val_f1:.4f}")
@torch.no_grad()
def evaluate(model, loader, device, cfg):
model.eval()
all_y_true, all_y_pred = [], []
for batch in loader:
batch = {k: v.to(device) for k, v in batch.items()}
preds = model.predict(
batch["persona_input_ids"], batch["persona_attention_mask"],
batch["context_input_ids"], batch["context_attention_mask"],
batch["response_input_ids"], batch["response_attention_mask"],
binary_threshold=cfg["evaluation"]["binary_threshold"],
)
all_y_true.extend(batch["y_risk"].int().cpu().tolist())
all_y_pred.extend(preds["y_risk"].cpu().tolist())
from sklearn.metrics import f1_score
return f1_score(all_y_true, all_y_pred, average="binary", zero_division=0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,197 @@
"""
Step 4: Train Module C — RL Intervention Policy (PPO).
Two-stage training:
Stage 1: Behavior cloning warm-up from a_recommend labels
Stage 2: PPO fine-tuning with multi-objective reward
Usage:
python scripts/train_intervention.py --config configs/intervention_config.yaml
"""
import argparse
import yaml
import torch
import numpy as np
import wandb
from pathlib import Path
from src.data.dataset import load_jsonl
from src.models.detector import CompanionRiskDetector
from src.models.intervention_agent import InterventionAgent
from src.rl.companion_env import CompanionEnv
from src.rl.ppo_trainer import PPOTrainer
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
category_to_index,
)
from transformers import AutoTokenizer
def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device):
"""Run detector on all samples to extract state vectors for RL env."""
from src.data.dataset import format_conversation
processed = []
detector.eval()
for sample in samples:
texts = format_conversation(
sample["persona"],
sample["history"],
sample["user_input"],
sample["ai_response"],
)
def enc(text, max_len):
return tokenizer(
text, max_length=max_len, truncation=True,
padding="max_length", return_tensors="pt",
)
p_enc = enc(texts["persona_text"], 128)
c_enc = enc(texts["context_text"], 512)
r_enc = enc(texts["response_text"], 256)
with torch.no_grad():
preds = detector.predict(
p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device),
)
# Build persona/history pool embeddings (reuse e_fused as approximation)
e_fused = preds["e_fused"].squeeze(0).cpu().numpy()
processed.append({
**sample,
"d_score": preds["d_score"].item(),
"l_risk": preds["l_risk"].item(),
"c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(),
"c_primary_idx": preds["c_primary"].item(),
"e_H_pool": e_fused.tolist(),
"e_P_pool": e_fused.tolist(),
"a_recommend": sample.get("a_recommend", "PASS"),
})
return processed
def build_bc_tensors(processed_samples, obs_dim, device):
"""Build observation and expert action tensors for behavior cloning."""
obs_list, action_list = [], []
for s in processed_samples:
d_score = np.array([s["d_score"]], dtype=np.float32)
l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
l_risk_oh[int(s["l_risk"])] = 1.0
c_probs = np.array(s["c_primary_probs"], dtype=np.float32)
e_H = np.array(s["e_H_pool"], dtype=np.float32)
e_P = np.array(s["e_P_pool"], dtype=np.float32)
t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32)
obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm])
obs_list.append(obs)
action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0))
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
action_tensor = torch.LongTensor(action_list).to(device)
return obs_tensor, action_tensor
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="configs/intervention_config.yaml")
parser.add_argument("--train-data", default="data/processed/train.jsonl")
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if cfg["logging"]["use_wandb"]:
wandb.init(
project=cfg["logging"]["project"],
name=cfg["logging"]["run_name"],
config=cfg,
)
# Load detector
tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"])
detector = CompanionRiskDetector(
model_name=cfg["detector"]["model_name"],
hidden_size=cfg["detector"]["hidden_size"],
).to(device)
detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device))
detector.eval()
print("Detector loaded.")
# Load and preprocess training data
raw_samples = load_jsonl(args.train_data)
print(f"Preprocessing {len(raw_samples)} samples with detector...")
processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device)
detector_hidden = cfg["detector"]["hidden_size"]
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
# Build RL agent
agent = InterventionAgent(
detector_hidden=detector_hidden,
state_hidden=cfg["agent"]["state_hidden"],
dropout=cfg["agent"]["dropout"],
)
trainer = PPOTrainer(
agent=agent,
obs_dim=obs_dim,
lr=cfg["ppo"]["lr"],
clip_eps=cfg["ppo"]["clip_eps"],
entropy_coef=cfg["ppo"]["entropy_coef"],
value_coef=cfg["ppo"]["value_coef"],
max_grad_norm=cfg["ppo"]["max_grad_norm"],
gamma=cfg["ppo"]["gamma"],
gae_lambda=cfg["ppo"]["gae_lambda"],
n_epochs=cfg["ppo"]["n_epochs"],
batch_size=cfg["ppo"]["batch_size"],
buffer_size=cfg["ppo"]["n_rollout_steps"],
device=device,
use_wandb=cfg["logging"]["use_wandb"],
)
# Stage 1: Behavior cloning warm-up
if cfg["behavior_cloning"]["enabled"]:
print("Stage 1: Behavior cloning warm-up...")
obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device)
trainer.behavior_cloning_warmup(
obs_tensor, action_tensor,
n_epochs=cfg["behavior_cloning"]["epochs"],
lr=cfg["behavior_cloning"]["lr"],
)
# Stage 2: PPO fine-tuning
print("Stage 2: PPO fine-tuning...")
env = CompanionEnv(
samples=processed,
detector_hidden=detector_hidden,
reward_weights=cfg["reward"],
max_turns=cfg["environment"]["max_turns"],
)
Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
trainer.train(
env=env,
total_timesteps=cfg["ppo"]["total_timesteps"],
n_rollout_steps=cfg["ppo"]["n_rollout_steps"],
checkpoint_dir=cfg["output"]["checkpoint_dir"],
save_interval=cfg["output"]["save_interval"],
)
torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt")
print("Training complete.")
if __name__ == "__main__":
main()