85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
|
|
"""
|
||
|
|
Step 1: Generate companion conversation dataset using LLM.
|
||
|
|
|
||
|
|
Generates multi-turn conversations covering all 10 risk categories plus
|
||
|
|
safe (benign) negative samples.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python scripts/generate_data.py --config configs/data_generation.yaml
|
||
|
|
|
||
|
|
Environment variables required (set before running):
|
||
|
|
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
|
||
|
|
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
import yaml
|
||
|
|
from pathlib import Path
|
||
|
|
from src.data.data_generator import ConversationGenerator
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Generate CompanionGuard-RL dataset via LLM API"
|
||
|
|
)
|
||
|
|
parser.add_argument("--config", default="configs/data_generation.yaml")
|
||
|
|
parser.add_argument(
|
||
|
|
"--total", type=int, default=None,
|
||
|
|
help="Override total_samples from config"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--safe-ratio", type=float, default=None,
|
||
|
|
help="Override safe sample ratio (default 0.25)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--output", type=str, default=None,
|
||
|
|
help="Override output file path"
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
with open(args.config) as f:
|
||
|
|
cfg = yaml.safe_load(f)
|
||
|
|
|
||
|
|
# Apply CLI overrides
|
||
|
|
total_samples = args.total or cfg["generation"]["total_samples"]
|
||
|
|
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
|
||
|
|
output_file = args.output or cfg["output"]["output_file"]
|
||
|
|
|
||
|
|
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Check API key
|
||
|
|
api_type = cfg["api"]["type"]
|
||
|
|
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
|
||
|
|
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
|
||
|
|
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
|
||
|
|
raise SystemExit(1)
|
||
|
|
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
|
||
|
|
print("[ERROR] OPENAI_API_KEY environment variable not set.")
|
||
|
|
print("Set it with: export OPENAI_API_KEY=your_api_key")
|
||
|
|
raise SystemExit(1)
|
||
|
|
|
||
|
|
generator = ConversationGenerator(
|
||
|
|
api_type=api_type,
|
||
|
|
model=cfg["api"]["model"],
|
||
|
|
)
|
||
|
|
|
||
|
|
print(f"Generating {total_samples} samples "
|
||
|
|
f"({int(total_samples * (1 - safe_ratio))} risky + "
|
||
|
|
f"{int(total_samples * safe_ratio)} safe)")
|
||
|
|
print(f"Output: {output_file}")
|
||
|
|
|
||
|
|
count = generator.generate_dataset(
|
||
|
|
output_path=output_file,
|
||
|
|
total_samples=total_samples,
|
||
|
|
safe_ratio=safe_ratio,
|
||
|
|
delay=cfg["generation"]["delay"],
|
||
|
|
max_retries=cfg["generation"].get("max_retries", 3),
|
||
|
|
)
|
||
|
|
|
||
|
|
print(f"\nGeneration complete: {count} samples → {output_file}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|