2026-05-09 17:21:11 +08:00
|
|
|
"""
|
2026-05-09 17:50:17 +08:00
|
|
|
Step 2: LLM judge pre-annotation and dataset split.
|
|
|
|
|
|
|
|
|
|
Reads raw generated JSONL, runs LLM annotation on each sample,
|
|
|
|
|
then splits into train/val/test sets with stratified sampling by y_risk.
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
Usage:
|
2026-05-09 17:50:17 +08:00
|
|
|
# Annotate + split
|
|
|
|
|
python scripts/annotate_data.py \\
|
|
|
|
|
--input data/raw/generated.jsonl \\
|
|
|
|
|
--output data/processed/annotated.jsonl \\
|
|
|
|
|
--config configs/data_generation.yaml
|
|
|
|
|
|
|
|
|
|
# Skip annotation (already annotated), only split
|
|
|
|
|
python scripts/annotate_data.py \\
|
|
|
|
|
--input data/raw/generated.jsonl \\
|
|
|
|
|
--output data/processed/annotated.jsonl \\
|
|
|
|
|
--skip-annotation
|
2026-05-09 17:21:11 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import json
|
|
|
|
|
import random
|
2026-05-09 17:50:17 +08:00
|
|
|
import yaml
|
|
|
|
|
from collections import Counter
|
2026-05-09 17:21:11 +08:00
|
|
|
from pathlib import Path
|
2026-05-09 17:50:17 +08:00
|
|
|
from typing import List, Dict, Tuple
|
|
|
|
|
|
2026-05-09 17:21:11 +08:00
|
|
|
from src.data.llm_judge import LLMJudge
|
2026-05-09 17:50:17 +08:00
|
|
|
from src.data.dataset import load_jsonl, save_jsonl, validate_and_normalize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stratified_split(
|
|
|
|
|
samples: List[Dict],
|
|
|
|
|
train_ratio: float = 0.8,
|
|
|
|
|
val_ratio: float = 0.1,
|
|
|
|
|
seed: int = 42,
|
|
|
|
|
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
|
|
|
"""
|
|
|
|
|
Stratified split by y_risk to ensure each split has balanced risk/safe samples.
|
|
|
|
|
"""
|
|
|
|
|
random.seed(seed)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
risky = [s for s in samples if s.get("y_risk", 0) == 1]
|
|
|
|
|
safe = [s for s in samples if s.get("y_risk", 0) == 0]
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
random.shuffle(risky)
|
|
|
|
|
random.shuffle(safe)
|
|
|
|
|
|
|
|
|
|
def split_list(lst):
|
|
|
|
|
n = len(lst)
|
|
|
|
|
n_train = int(n * train_ratio)
|
|
|
|
|
n_val = int(n * val_ratio)
|
|
|
|
|
return lst[:n_train], lst[n_train:n_train + n_val], lst[n_train + n_val:]
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
train_r, val_r, test_r = split_list(risky)
|
|
|
|
|
train_s, val_s, test_s = split_list(safe)
|
|
|
|
|
|
|
|
|
|
train = train_r + train_s
|
|
|
|
|
val = val_r + val_s
|
|
|
|
|
test = test_r + test_s
|
|
|
|
|
|
|
|
|
|
random.shuffle(train)
|
|
|
|
|
random.shuffle(val)
|
|
|
|
|
random.shuffle(test)
|
|
|
|
|
|
|
|
|
|
return train, val, test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_dataset_stats(name: str, samples: List[Dict]):
|
|
|
|
|
"""Print per-split statistics."""
|
|
|
|
|
total = len(samples)
|
|
|
|
|
n_risky = sum(1 for s in samples if s.get("y_risk", 0) == 1)
|
|
|
|
|
n_safe = total - n_risky
|
|
|
|
|
|
|
|
|
|
level_counts = Counter(s.get("l_risk", 0) for s in samples)
|
|
|
|
|
action_counts = Counter(s.get("a_recommend", "PASS") for s in samples)
|
|
|
|
|
category_counts = Counter(
|
|
|
|
|
s.get("c_primary", "None") for s in samples if s.get("c_primary") != "None"
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
print(f"\n [{name}] {total} samples "
|
|
|
|
|
f"(risky={n_risky} {n_risky/total*100:.0f}%, "
|
|
|
|
|
f"safe={n_safe} {n_safe/total*100:.0f}%)")
|
|
|
|
|
print(f" Risk levels: { {k: level_counts[k] for k in sorted(level_counts)} }")
|
|
|
|
|
print(f" Actions: { dict(action_counts) }")
|
|
|
|
|
if category_counts:
|
|
|
|
|
top5 = category_counts.most_common(5)
|
|
|
|
|
print(f" Top categories: { {k: v for k, v in top5} }")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2026-05-09 17:50:17 +08:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description="Annotate and split CompanionGuard-RL dataset"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--input", required=True, help="Input JSONL file path")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output", default="data/processed/annotated.jsonl",
|
|
|
|
|
help="Output annotated JSONL path"
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
parser.add_argument("--config", default="configs/data_generation.yaml")
|
2026-05-09 17:50:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--skip-annotation", action="store_true",
|
|
|
|
|
help="Skip LLM annotation (use labels already in input file)"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--split-only", action="store_true",
|
|
|
|
|
help="Only perform dataset split (no annotation), reads from --output"
|
|
|
|
|
)
|
2026-05-09 17:21:11 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
with open(args.config) as f:
|
|
|
|
|
cfg = yaml.safe_load(f)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
split_cfg = cfg.get("split", {"train": 0.8, "val": 0.1, "test": 0.1, "seed": 42})
|
|
|
|
|
ann_cfg = cfg.get("annotation", {})
|
|
|
|
|
output_dir = Path(args.output).parent
|
|
|
|
|
|
|
|
|
|
# ── Step 1: Annotation ───────────────────────────────────────────────
|
|
|
|
|
if args.split_only:
|
|
|
|
|
samples = load_jsonl(args.output)
|
|
|
|
|
print(f"Loaded {len(samples)} pre-annotated samples from {args.output}")
|
|
|
|
|
elif args.skip_annotation:
|
|
|
|
|
samples = load_jsonl(args.input)
|
|
|
|
|
samples = [validate_and_normalize(s) for s in samples]
|
|
|
|
|
save_jsonl(samples, args.output)
|
|
|
|
|
print(f"Loaded and normalized {len(samples)} samples (annotation skipped)")
|
|
|
|
|
else:
|
2026-05-09 17:21:11 +08:00
|
|
|
judge = LLMJudge(
|
|
|
|
|
api_type=cfg["api"]["type"],
|
2026-05-09 17:50:17 +08:00
|
|
|
model=ann_cfg.get("judge_model", cfg["api"]["model"]),
|
|
|
|
|
)
|
|
|
|
|
samples = judge.annotate_from_file(
|
|
|
|
|
input_path=args.input,
|
|
|
|
|
output_path=args.output,
|
|
|
|
|
delay=ann_cfg.get("delay", 0.3),
|
|
|
|
|
max_retries=ann_cfg.get("max_retries", 2),
|
2026-05-09 17:21:11 +08:00
|
|
|
)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
if not samples:
|
|
|
|
|
print("[ERROR] No samples to process after annotation. Check logs above.")
|
|
|
|
|
raise SystemExit(1)
|
|
|
|
|
|
|
|
|
|
# ── Step 2: Dataset statistics ───────────────────────────────────────
|
|
|
|
|
print(f"\n{'='*50}")
|
|
|
|
|
print("Dataset statistics:")
|
|
|
|
|
print_dataset_stats("ALL", samples)
|
|
|
|
|
|
|
|
|
|
# ── Step 3: Stratified split ─────────────────────────────────────────
|
|
|
|
|
train, val, test = stratified_split(
|
2026-05-09 17:21:11 +08:00
|
|
|
samples,
|
|
|
|
|
train_ratio=split_cfg["train"],
|
|
|
|
|
val_ratio=split_cfg["val"],
|
|
|
|
|
seed=split_cfg.get("seed", 42),
|
|
|
|
|
)
|
|
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
save_jsonl(train, output_dir / "train.jsonl")
|
|
|
|
|
save_jsonl(val, output_dir / "val.jsonl")
|
|
|
|
|
save_jsonl(test, output_dir / "test.jsonl")
|
|
|
|
|
|
|
|
|
|
print_dataset_stats("train", train)
|
|
|
|
|
print_dataset_stats("val", val)
|
|
|
|
|
print_dataset_stats("test", test)
|
2026-05-09 17:21:11 +08:00
|
|
|
|
2026-05-09 17:50:17 +08:00
|
|
|
print(f"\nSplit complete:")
|
|
|
|
|
print(f" train : {len(train):4d} → {output_dir}/train.jsonl")
|
|
|
|
|
print(f" val : {len(val):4d} → {output_dir}/val.jsonl")
|
|
|
|
|
print(f" test : {len(test):4d} → {output_dir}/test.jsonl")
|
2026-05-09 17:21:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|