""" Step 1: Generate companion conversation dataset using LLM. Generates multi-turn conversations covering all 10 risk categories plus safe (benign) negative samples. Usage: python scripts/generate_data.py --config configs/data_generation.yaml Environment variables required (set before running): DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen") OPENAI_API_KEY — if using OpenAI (api.type = "openai") """ import argparse import os import yaml from pathlib import Path from src.data.data_generator import ConversationGenerator def main(): parser = argparse.ArgumentParser( description="Generate CompanionGuard-RL dataset via LLM API" ) parser.add_argument("--config", default="configs/data_generation.yaml") parser.add_argument( "--total", type=int, default=None, help="Override total_samples from config" ) parser.add_argument( "--safe-ratio", type=float, default=None, help="Override safe sample ratio (default 0.25)" ) parser.add_argument( "--output", type=str, default=None, help="Override output file path" ) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # Apply CLI overrides total_samples = args.total or cfg["generation"]["total_samples"] safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25) output_file = args.output or cfg["output"]["output_file"] Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True) # Check API key api_type = cfg["api"]["type"] if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"): print("[ERROR] DASHSCOPE_API_KEY environment variable not set.") print("Set it with: export DASHSCOPE_API_KEY=your_api_key") raise SystemExit(1) elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"): print("[ERROR] OPENAI_API_KEY environment variable not set.") print("Set it with: export OPENAI_API_KEY=your_api_key") raise SystemExit(1) generator = ConversationGenerator( api_type=api_type, model=cfg["api"]["model"], ) print(f"Generating {total_samples} samples " f"({int(total_samples * (1 - safe_ratio))} risky + " f"{int(total_samples * safe_ratio)} safe)") print(f"Output: {output_file}") count = generator.generate_dataset( output_path=output_file, total_samples=total_samples, safe_ratio=safe_ratio, delay=cfg["generation"]["delay"], max_retries=cfg["generation"].get("max_retries", 3), ) print(f"\nGeneration complete: {count} samples → {output_file}") if __name__ == "__main__": main()