chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
84
code/scripts/generate_data.py
Normal file
84
code/scripts/generate_data.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Step 1: Generate companion conversation dataset using LLM.
|
||||
|
||||
Generates multi-turn conversations covering all 10 risk categories plus
|
||||
safe (benign) negative samples.
|
||||
|
||||
Usage:
|
||||
python scripts/generate_data.py --config configs/data_generation.yaml
|
||||
|
||||
Environment variables required (set before running):
|
||||
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
|
||||
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from src.data.data_generator import ConversationGenerator
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate CompanionGuard-RL dataset via LLM API"
|
||||
)
|
||||
parser.add_argument("--config", default="configs/data_generation.yaml")
|
||||
parser.add_argument(
|
||||
"--total", type=int, default=None,
|
||||
help="Override total_samples from config"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe-ratio", type=float, default=None,
|
||||
help="Override safe sample ratio (default 0.25)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default=None,
|
||||
help="Override output file path"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# Apply CLI overrides
|
||||
total_samples = args.total or cfg["generation"]["total_samples"]
|
||||
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
|
||||
output_file = args.output or cfg["output"]["output_file"]
|
||||
|
||||
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check API key
|
||||
api_type = cfg["api"]["type"]
|
||||
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
|
||||
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
|
||||
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
|
||||
raise SystemExit(1)
|
||||
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
|
||||
print("[ERROR] OPENAI_API_KEY environment variable not set.")
|
||||
print("Set it with: export OPENAI_API_KEY=your_api_key")
|
||||
raise SystemExit(1)
|
||||
|
||||
generator = ConversationGenerator(
|
||||
api_type=api_type,
|
||||
model=cfg["api"]["model"],
|
||||
)
|
||||
|
||||
print(f"Generating {total_samples} samples "
|
||||
f"({int(total_samples * (1 - safe_ratio))} risky + "
|
||||
f"{int(total_samples * safe_ratio)} safe)")
|
||||
print(f"Output: {output_file}")
|
||||
|
||||
count = generator.generate_dataset(
|
||||
output_path=output_file,
|
||||
total_samples=total_samples,
|
||||
safe_ratio=safe_ratio,
|
||||
delay=cfg["generation"]["delay"],
|
||||
max_retries=cfg["generation"].get("max_retries", 3),
|
||||
)
|
||||
|
||||
print(f"\nGeneration complete: {count} samples → {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user