Files
CompanionGuard-RL/scripts/annotate_data.py
wangyu 4a0e71fb23 refactor: complete full implementation replacing all placeholder/mock content
Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 17:50:17 +08:00

173 lines
5.8 KiB
Python

"""
Step 2: LLM judge pre-annotation and dataset split.
Reads raw generated JSONL, runs LLM annotation on each sample,
then splits into train/val/test sets with stratified sampling by y_risk.
Usage:
# Annotate + split
python scripts/annotate_data.py \\
--input data/raw/generated.jsonl \\
--output data/processed/annotated.jsonl \\
--config configs/data_generation.yaml
# Skip annotation (already annotated), only split
python scripts/annotate_data.py \\
--input data/raw/generated.jsonl \\
--output data/processed/annotated.jsonl \\
--skip-annotation
"""
import argparse
import json
import random
import yaml
from collections import Counter
from pathlib import Path
from typing import List, Dict, Tuple
from src.data.llm_judge import LLMJudge
from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize
def stratified_split(
samples: List[Dict],
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
Stratified split by y_risk to ensure each split has balanced risk/safe samples.
"""
random.seed(seed)
risky = [s for s in samples if s.get("y_risk", 0) == 1]
safe = [s for s in samples if s.get("y_risk", 0) == 0]
random.shuffle(risky)
random.shuffle(safe)
def split_list(lst):
n = len(lst)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:]
train_r, val_r, test_r = split_list(risky)
train_s, val_s, test_s = split_list(safe)
train = train_r + train_s
val = val_r + val_s
test = test_r + test_s
random.shuffle(train)
random.shuffle(val)
random.shuffle(test)
return train, val, test
def print_dataset_stats(name: str, samples: List[Dict]):
"""Print per-split statistics."""
total = len(samples)
n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1)
n_safe = total - n_risky
level_counts = Counter(s.get("l_risk", 0) for s in samples)
action_counts = Counter(s.get("a_recommend", "PASS") for s in samples)
category_counts = Counter(
s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None"
)
print(f"\n [{name}] {total} samples "
f"(risky={n_risky} {n_risky/total*100:.0f}%, "
f"safe={n_safe} {n_safe/total*100:.0f}%)")
print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }")
print(f" Actions: { dict(action_counts) }")
if category_counts:
top5 = category_counts.most_common(5)
print(f" Top categories: { {k: v for k, v in top5} }")
def main():
parser = argparse.ArgumentParser(
description="Annotate and split CompanionGuard-RL dataset"
)
parser.add_argument("--input", required=True, help="Input JSONL file path")
parser.add_argument(
"--output", default="data/processed/annotated.jsonl",
help="Output annotated JSONL path"
)
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument(
"--skip-annotation", action="store_true",
help="Skip LLM annotation (use labels already in input file)"
)
parser.add_argument(
"--split-only", action="store_true",
help="Only perform dataset split (no annotation), reads from --output"
)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
ann_cfg = cfg.get("annotation", {})
output_dir = Path(args.output).parent
# ── Step 1: Annotation ───────────────────────────────────────────────
if args.split_only:
samples = load_jsonl(args.output)
print(f"Loaded {len(samples)} pre-annotated samples from {args.output}")
elif args.skip_annotation:
samples = load_jsonl(args.input)
samples = [validate_and_normalize(s) for s in samples]
save_jsonl(samples, args.output)
print(f"Loaded and normalized {len(samples)} samples (annotation skipped)")
else:
judge = LLMJudge(
api_type=cfg["api"]["type"],
model=ann_cfg.get("judge_model", cfg["api"]["model"]),
)
samples = judge.annotate_from_file(
input_path=args.input,
output_path=args.output,
delay=ann_cfg.get("delay", 0.3),
max_retries=ann_cfg.get("max_retries", 2),
)
if not samples:
print("[ERROR] No samples to process after annotation. Check logs above.")
raise SystemExit(1)
# ── Step 2: Dataset statistics ───────────────────────────────────────
print(f"\n{'='*50}")
print("Dataset statistics:")
print_dataset_stats("ALL", samples)
# ── Step 3: Stratified split ─────────────────────────────────────────
train, val, test = stratified_split(
samples,
train_ratio=split_cfg["train"],
val_ratio=split_cfg["val"],
seed=split_cfg.get("seed", 42),
)
save_jsonl(train, output_dir / "train.jsonl")
save_jsonl(val, output_dir / "val.jsonl")
save_jsonl(test, output_dir / "test.jsonl")
print_dataset_stats("train", train)
print_dataset_stats("val", val)
print_dataset_stats("test", test)
print(f"\nSplit complete:")
print(f" train : {len(train):4d}{output_dir}/train.jsonl")
print(f" val : {len(val):4d}{output_dir}/val.jsonl")
print(f" test : {len(test):4d}{output_dir}/test.jsonl")
if __name__ == "__main__":
main()