Files
CompanionGuard-RL/code/scripts/generate_data.py
zhangsiyuan bd1f51c496 chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 11:28:42 +08:00

85 lines
2.7 KiB
Python

"""
Step 1: Generate companion conversation dataset using LLM.
Generates multi-turn conversations covering all 10 risk categories plus
safe (benign) negative samples.
Usage:
python scripts/generate_data.py --config configs/data_generation.yaml
Environment variables required (set before running):
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
"""
import argparse
import os
import yaml
from pathlib import Path
from src.data.data_generator import ConversationGenerator
def main():
parser = argparse.ArgumentParser(
description="Generate CompanionGuard-RL dataset via LLM API"
)
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument(
"--total", type=int, default=None,
help="Override total_samples from config"
)
parser.add_argument(
"--safe-ratio", type=float, default=None,
help="Override safe sample ratio (default 0.25)"
)
parser.add_argument(
"--output", type=str, default=None,
help="Override output file path"
)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
# Apply CLI overrides
total_samples = args.total or cfg["generation"]["total_samples"]
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
output_file = args.output or cfg["output"]["output_file"]
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
# Check API key
api_type = cfg["api"]["type"]
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
raise SystemExit(1)
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
print("[ERROR] OPENAI_API_KEY environment variable not set.")
print("Set it with: export OPENAI_API_KEY=your_api_key")
raise SystemExit(1)
generator = ConversationGenerator(
api_type=api_type,
model=cfg["api"]["model"],
)
print(f"Generating {total_samples} samples "
f"({int(total_samples * (1 - safe_ratio))} risky + "
f"{int(total_samples * safe_ratio)} safe)")
print(f"Output: {output_file}")
count = generator.generate_dataset(
output_path=output_file,
total_samples=total_samples,
safe_ratio=safe_ratio,
delay=cfg["generation"]["delay"],
max_retries=cfg["generation"].get("max_retries", 3),
)
print(f"\nGeneration complete: {count} samples → {output_file}")
if __name__ == "__main__":
main()