""" Step 2: LLM judge pre-annotation and dataset split. Reads raw generated JSONL, runs LLM annotation on each sample, then splits into train/val/test sets with stratified sampling by y_risk. Usage: # Annotate + split python scripts/annotate_data.py \\ --input data/raw/generated.jsonl \\ --output data/processed/annotated.jsonl \\ --config configs/data_generation.yaml # Skip annotation (already annotated), only split python scripts/annotate_data.py \\ --input data/raw/generated.jsonl \\ --output data/processed/annotated.jsonl \\ --skip-annotation """ import argparse import json import random import yaml from collections import Counter from pathlib import Path from typing import List, Dict, Tuple from src.data.llm_judge import LLMJudge from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize def stratified_split( samples: List[Dict], train_ratio: float = 0.8, val_ratio: float = 0.1, seed: int = 42, ) -> Tuple[List[Dict], List[Dict], List[Dict]]: """ Stratified split by y_risk to ensure each split has balanced risk/safe samples. """ random.seed(seed) risky = [s for s in samples if s.get("y_risk", 0) == 1] safe = [s for s in samples if s.get("y_risk", 0) == 0] random.shuffle(risky) random.shuffle(safe) def split_list(lst): n = len(lst) n_train = int(n * train_ratio) n_val = int(n * val_ratio) return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:] train_r, val_r, test_r = split_list(risky) train_s, val_s, test_s = split_list(safe) train = train_r + train_s val = val_r + val_s test = test_r + test_s random.shuffle(train) random.shuffle(val) random.shuffle(test) return train, val, test def print_dataset_stats(name: str, samples: List[Dict]): """Print per-split statistics.""" total = len(samples) n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1) n_safe = total - n_risky level_counts = Counter(s.get("l_risk", 0) for s in samples) action_counts = Counter(s.get("a_recommend", "PASS") for s in samples) category_counts = Counter( s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None" ) print(f"\n [{name}] {total} samples " f"(risky={n_risky} {n_risky/total*100:.0f}%, " f"safe={n_safe} {n_safe/total*100:.0f}%)") print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }") print(f" Actions: { dict(action_counts) }") if category_counts: top5 = category_counts.most_common(5) print(f" Top categories: { {k: v for k, v in top5} }") def main(): parser = argparse.ArgumentParser( description="Annotate and split CompanionGuard-RL dataset" ) parser.add_argument("--input", required=True, help="Input JSONL file path") parser.add_argument( "--output", default="data/processed/annotated.jsonl", help="Output annotated JSONL path" ) parser.add_argument("--config", default="configs/data_generation.yaml") parser.add_argument( "--skip-annotation", action="store_true", help="Skip LLM annotation (use labels already in input file)" ) parser.add_argument( "--split-only", action="store_true", help="Only perform dataset split (no annotation), reads from --output" ) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42}) ann_cfg = cfg.get("annotation", {}) output_dir = Path(args.output).parent # ── Step 1: Annotation ─────────────────────────────────────────────── if args.split_only: samples = load_jsonl(args.output) print(f"Loaded {len(samples)} pre-annotated samples from {args.output}") elif args.skip_annotation: samples = load_jsonl(args.input) samples = [validate_and_normalize(s) for s in samples] save_jsonl(samples, args.output) print(f"Loaded and normalized {len(samples)} samples (annotation skipped)") else: judge = LLMJudge( api_type=cfg["api"]["type"], model=ann_cfg.get("judge_model", cfg["api"]["model"]), ) samples = judge.annotate_from_file( input_path=args.input, output_path=args.output, delay=ann_cfg.get("delay", 0.3), max_retries=ann_cfg.get("max_retries", 2), ) if not samples: print("[ERROR] No samples to process after annotation. Check logs above.") raise SystemExit(1) # ── Step 2: Dataset statistics ─────────────────────────────────────── print(f"\n{'='*50}") print("Dataset statistics:") print_dataset_stats("ALL", samples) # ── Step 3: Stratified split ───────────────────────────────────────── train, val, test = stratified_split( samples, train_ratio=split_cfg["train"], val_ratio=split_cfg["val"], seed=split_cfg.get("seed", 42), ) save_jsonl(train, output_dir / "train.jsonl") save_jsonl(val, output_dir / "val.jsonl") save_jsonl(test, output_dir / "test.jsonl") print_dataset_stats("train", train) print_dataset_stats("val", val) print_dataset_stats("test", test) print(f"\nSplit complete:") print(f" train : {len(train):4d} → {output_dir}/train.jsonl") print(f" val : {len(val):4d} → {output_dir}/val.jsonl") print(f" test : {len(test):4d} → {output_dir}/test.jsonl") if __name__ == "__main__": main()