Detection module (Module B): - detector.py: expose separate e_P_pool and e_H_pool for RL state; fix compute_loss to skip primary head when c_primary="None" - dataset.py: handle c_primary="None" safely; add validate_and_normalize Data pipeline: - data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe); systematic category→fine-label mapping; safe sample generation (25%); per-category risk level distribution; max_retries logic - llm_judge.py: incremental file writing; rate limiting; retry logic; annotate_from_file convenience method; consistency validation - annotate_data.py: stratified split by y_risk; dataset statistics report RL module (Module C): - ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple); fix action type passed to env.step; proper buffer reset and size tracking - companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with auto-reset; correct Gymnasium interface Shared utilities (new files): - src/utils/preprocessing.py: preprocess_samples_with_detector using separate e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up - src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b), CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention, LLMJudgePolicy for full baseline comparison Scripts: - train_intervention.py: use preprocessing module; separate e_H/e_P pools - evaluate.py: proper module imports (no circular scripts import); full multi-baseline comparison; save all results to JSON - generate_data.py: API key check; safe_ratio + max_retries CLI args Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
"""
|
|
Step 1: Generate companion conversation dataset using LLM.
|
|
|
|
Generates multi-turn conversations covering all 10 risk categories plus
|
|
safe (benign) negative samples.
|
|
|
|
Usage:
|
|
python scripts/generate_data.py --config configs/data_generation.yaml
|
|
|
|
Environment variables required (set before running):
|
|
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
|
|
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import yaml
|
|
from pathlib import Path
|
|
from src.data.data_generator import ConversationGenerator
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate CompanionGuard-RL dataset via LLM API"
|
|
)
|
|
parser.add_argument("--config", default="configs/data_generation.yaml")
|
|
parser.add_argument(
|
|
"--total", type=int, default=None,
|
|
help="Override total_samples from config"
|
|
)
|
|
parser.add_argument(
|
|
"--safe-ratio", type=float, default=None,
|
|
help="Override safe sample ratio (default 0.25)"
|
|
)
|
|
parser.add_argument(
|
|
"--output", type=str, default=None,
|
|
help="Override output file path"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
# Apply CLI overrides
|
|
total_samples = args.total or cfg["generation"]["total_samples"]
|
|
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
|
|
output_file = args.output or cfg["output"]["output_file"]
|
|
|
|
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Check API key
|
|
api_type = cfg["api"]["type"]
|
|
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
|
|
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
|
|
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
|
|
raise SystemExit(1)
|
|
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
|
|
print("[ERROR] OPENAI_API_KEY environment variable not set.")
|
|
print("Set it with: export OPENAI_API_KEY=your_api_key")
|
|
raise SystemExit(1)
|
|
|
|
generator = ConversationGenerator(
|
|
api_type=api_type,
|
|
model=cfg["api"]["model"],
|
|
)
|
|
|
|
print(f"Generating {total_samples} samples "
|
|
f"({int(total_samples * (1 - safe_ratio))} risky + "
|
|
f"{int(total_samples * safe_ratio)} safe)")
|
|
print(f"Output: {output_file}")
|
|
|
|
count = generator.generate_dataset(
|
|
output_path=output_file,
|
|
total_samples=total_samples,
|
|
safe_ratio=safe_ratio,
|
|
delay=cfg["generation"]["delay"],
|
|
max_retries=cfg["generation"].get("max_retries", 3),
|
|
)
|
|
|
|
print(f"\nGeneration complete: {count} samples → {output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|