refactor: complete full implementation replacing all placeholder/mock content
Detection module (Module B): - detector.py: expose separate e_P_pool and e_H_pool for RL state; fix compute_loss to skip primary head when c_primary="None" - dataset.py: handle c_primary="None" safely; add validate_and_normalize Data pipeline: - data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe); systematic category→fine-label mapping; safe sample generation (25%); per-category risk level distribution; max_retries logic - llm_judge.py: incremental file writing; rate limiting; retry logic; annotate_from_file convenience method; consistency validation - annotate_data.py: stratified split by y_risk; dataset statistics report RL module (Module C): - ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple); fix action type passed to env.step; proper buffer reset and size tracking - companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with auto-reset; correct Gymnasium interface Shared utilities (new files): - src/utils/preprocessing.py: preprocess_samples_with_detector using separate e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up - src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b), CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention, LLMJudgePolicy for full baseline comparison Scripts: - train_intervention.py: use preprocessing module; separate e_H/e_P pools - evaluate.py: proper module imports (no circular scripts import); full multi-baseline comparison; save all results to JSON - generate_data.py: API key check; safe_ratio + max_retries CLI args Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,80 +1,171 @@
|
||||
"""
|
||||
Step 2: LLM judge pre-annotation.
|
||||
Step 2: LLM judge pre-annotation and dataset split.
|
||||
|
||||
Reads raw generated JSONL, runs LLM annotation on each sample,
|
||||
then splits into train/val/test sets with stratified sampling by y_risk.
|
||||
|
||||
Usage:
|
||||
python scripts/annotate_data.py --input data/raw/generated.jsonl \
|
||||
--output data/processed/annotated.jsonl \
|
||||
--config configs/data_generation.yaml
|
||||
# Annotate + split
|
||||
python scripts/annotate_data.py \\
|
||||
--input data/raw/generated.jsonl \\
|
||||
--output data/processed/annotated.jsonl \\
|
||||
--config configs/data_generation.yaml
|
||||
|
||||
# Skip annotation (already annotated), only split
|
||||
python scripts/annotate_data.py \\
|
||||
--input data/raw/generated.jsonl \\
|
||||
--output data/processed/annotated.jsonl \\
|
||||
--skip-annotation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
import random
|
||||
import yaml
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from src.data.llm_judge import LLMJudge
|
||||
from src.data.dataset import load_jsonl
|
||||
from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize
|
||||
|
||||
|
||||
def split_dataset(samples, train_ratio=0.8, val_ratio=0.1, seed=42):
|
||||
def stratified_split(
|
||||
samples: List[Dict],
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||||
"""
|
||||
Stratified split by y_risk to ensure each split has balanced risk/safe samples.
|
||||
"""
|
||||
random.seed(seed)
|
||||
random.shuffle(samples)
|
||||
n = len(samples)
|
||||
n_train = int(n * train_ratio)
|
||||
n_val = int(n * val_ratio)
|
||||
return (
|
||||
samples[:n_train],
|
||||
samples[n_train: n_train + n_val],
|
||||
samples[n_train + n_val:],
|
||||
|
||||
risky = [s for s in samples if s.get("y_risk", 0) == 1]
|
||||
safe = [s for s in samples if s.get("y_risk", 0) == 0]
|
||||
|
||||
random.shuffle(risky)
|
||||
random.shuffle(safe)
|
||||
|
||||
def split_list(lst):
|
||||
n = len(lst)
|
||||
n_train = int(n * train_ratio)
|
||||
n_val = int(n * val_ratio)
|
||||
return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:]
|
||||
|
||||
train_r, val_r, test_r = split_list(risky)
|
||||
train_s, val_s, test_s = split_list(safe)
|
||||
|
||||
train = train_r + train_s
|
||||
val = val_r + val_s
|
||||
test = test_r + test_s
|
||||
|
||||
random.shuffle(train)
|
||||
random.shuffle(val)
|
||||
random.shuffle(test)
|
||||
|
||||
return train, val, test
|
||||
|
||||
|
||||
def print_dataset_stats(name: str, samples: List[Dict]):
|
||||
"""Print per-split statistics."""
|
||||
total = len(samples)
|
||||
n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1)
|
||||
n_safe = total - n_risky
|
||||
|
||||
level_counts = Counter(s.get("l_risk", 0) for s in samples)
|
||||
action_counts = Counter(s.get("a_recommend", "PASS") for s in samples)
|
||||
category_counts = Counter(
|
||||
s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None"
|
||||
)
|
||||
|
||||
|
||||
def save_jsonl(samples, path):
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for s in samples:
|
||||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||||
print(f"Saved {len(samples)} samples → {path}")
|
||||
print(f"\n [{name}] {total} samples "
|
||||
f"(risky={n_risky} {n_risky/total*100:.0f}%, "
|
||||
f"safe={n_safe} {n_safe/total*100:.0f}%)")
|
||||
print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }")
|
||||
print(f" Actions: { dict(action_counts) }")
|
||||
if category_counts:
|
||||
top5 = category_counts.most_common(5)
|
||||
print(f" Top categories: { {k: v for k, v in top5} }")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True)
|
||||
parser.add_argument("--output", default="data/processed/annotated.jsonl")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Annotate and split CompanionGuard-RL dataset"
|
||||
)
|
||||
parser.add_argument("--input", required=True, help="Input JSONL file path")
|
||||
parser.add_argument(
|
||||
"--output", default="data/processed/annotated.jsonl",
|
||||
help="Output annotated JSONL path"
|
||||
)
|
||||
parser.add_argument("--config", default="configs/data_generation.yaml")
|
||||
parser.add_argument("--skip-annotation", action="store_true",
|
||||
help="Skip LLM annotation (use existing labels)")
|
||||
parser.add_argument(
|
||||
"--skip-annotation", action="store_true",
|
||||
help="Skip LLM annotation (use labels already in input file)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-only", action="store_true",
|
||||
help="Only perform dataset split (no annotation), reads from --output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
samples = load_jsonl(args.input)
|
||||
print(f"Loaded {len(samples)} samples from {args.input}")
|
||||
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
|
||||
ann_cfg = cfg.get("annotation", {})
|
||||
output_dir = Path(args.output).parent
|
||||
|
||||
if not args.skip_annotation:
|
||||
# ── Step 1: Annotation ───────────────────────────────────────────────
|
||||
if args.split_only:
|
||||
samples = load_jsonl(args.output)
|
||||
print(f"Loaded {len(samples)} pre-annotated samples from {args.output}")
|
||||
elif args.skip_annotation:
|
||||
samples = load_jsonl(args.input)
|
||||
samples = [validate_and_normalize(s) for s in samples]
|
||||
save_jsonl(samples, args.output)
|
||||
print(f"Loaded and normalized {len(samples)} samples (annotation skipped)")
|
||||
else:
|
||||
judge = LLMJudge(
|
||||
api_type=cfg["api"]["type"],
|
||||
model=cfg["annotation"]["judge_model"],
|
||||
model=ann_cfg.get("judge_model", cfg["api"]["model"]),
|
||||
)
|
||||
samples = judge.annotate_from_file(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
delay=ann_cfg.get("delay", 0.3),
|
||||
max_retries=ann_cfg.get("max_retries", 2),
|
||||
)
|
||||
samples = judge.annotate_batch(samples, output_path=args.output)
|
||||
else:
|
||||
save_jsonl(samples, args.output)
|
||||
|
||||
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
|
||||
train, val, test = split_dataset(
|
||||
if not samples:
|
||||
print("[ERROR] No samples to process after annotation. Check logs above.")
|
||||
raise SystemExit(1)
|
||||
|
||||
# ── Step 2: Dataset statistics ───────────────────────────────────────
|
||||
print(f"\n{'='*50}")
|
||||
print("Dataset statistics:")
|
||||
print_dataset_stats("ALL", samples)
|
||||
|
||||
# ── Step 3: Stratified split ─────────────────────────────────────────
|
||||
train, val, test = stratified_split(
|
||||
samples,
|
||||
train_ratio=split_cfg["train"],
|
||||
val_ratio=split_cfg["val"],
|
||||
seed=split_cfg.get("seed", 42),
|
||||
)
|
||||
|
||||
base = Path(args.output).parent
|
||||
save_jsonl(train, base / "train.jsonl")
|
||||
save_jsonl(val, base / "val.jsonl")
|
||||
save_jsonl(test, base / "test.jsonl")
|
||||
save_jsonl(train, output_dir / "train.jsonl")
|
||||
save_jsonl(val, output_dir / "val.jsonl")
|
||||
save_jsonl(test, output_dir / "test.jsonl")
|
||||
|
||||
print(f"Split: train={len(train)}, val={len(val)}, test={len(test)}")
|
||||
print_dataset_stats("train", train)
|
||||
print_dataset_stats("val", val)
|
||||
print_dataset_stats("test", test)
|
||||
|
||||
print(f"\nSplit complete:")
|
||||
print(f" train : {len(train):4d} → {output_dir}/train.jsonl")
|
||||
print(f" val : {len(val):4d} → {output_dir}/val.jsonl")
|
||||
print(f" test : {len(test):4d} → {output_dir}/test.jsonl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user