Files
CompanionGuard-RL/scripts/generate_data.py
wangyu 4a0e71fb23 refactor: complete full implementation replacing all placeholder/mock content
Detection module (Module B):
- detector.py: expose separate e_P_pool and e_H_pool for RL state;
  fix compute_loss to skip primary head when c_primary="None"
- dataset.py: handle c_primary="None" safely; add validate_and_normalize

Data pipeline:
- data_generator.py: 30+ category-specific personas (3+ per R1-R10 + 5 safe);
  systematic category→fine-label mapping; safe sample generation (25%);
  per-category risk level distribution; max_retries logic
- llm_judge.py: incremental file writing; rate limiting; retry logic;
  annotate_from_file convenience method; consistency validation
- annotate_data.py: stratified split by y_risk; dataset statistics report

RL module (Module C):
- ppo_trainer.py: fix Gymnasium API (reset→(obs,info), step→5-tuple);
  fix action type passed to env.step; proper buffer reset and size tracking
- companion_env.py: use shared build_obs_vector; add BatchCompanionEnv with
  auto-reset; correct Gymnasium interface

Shared utilities (new files):
- src/utils/preprocessing.py: preprocess_samples_with_detector using separate
  e_P_pool/e_H_pool; build_obs_vector; build_bc_tensors for BC warm-up
- src/utils/baselines.py: KeywordDetector (L1a), RegexDetector (L1b),
  CombinedRuleDetector (L1c), rule_based_intervention, threshold_intervention,
  LLMJudgePolicy for full baseline comparison

Scripts:
- train_intervention.py: use preprocessing module; separate e_H/e_P pools
- evaluate.py: proper module imports (no circular scripts import);
  full multi-baseline comparison; save all results to JSON
- generate_data.py: API key check; safe_ratio + max_retries CLI args

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 17:50:17 +08:00

85 lines
2.7 KiB
Python

"""
Step 1: Generate companion conversation dataset using LLM.
Generates multi-turn conversations covering all 10 risk categories plus
safe (benign) negative samples.
Usage:
python scripts/generate_data.py --config configs/data_generation.yaml
Environment variables required (set before running):
DASHSCOPE_API_KEY — if using Qwen (api.type = "qwen")
OPENAI_API_KEY — if using OpenAI (api.type = "openai")
"""
import argparse
import os
import yaml
from pathlib import Path
from src.data.data_generator import ConversationGenerator
def main():
parser = argparse.ArgumentParser(
description="Generate CompanionGuard-RL dataset via LLM API"
)
parser.add_argument("--config", default="configs/data_generation.yaml")
parser.add_argument(
"--total", type=int, default=None,
help="Override total_samples from config"
)
parser.add_argument(
"--safe-ratio", type=float, default=None,
help="Override safe sample ratio (default 0.25)"
)
parser.add_argument(
"--output", type=str, default=None,
help="Override output file path"
)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.safe_load(f)
# Apply CLI overrides
total_samples = args.total or cfg["generation"]["total_samples"]
safe_ratio = args.safe_ratio or cfg["generation"].get("safe_ratio", 0.25)
output_file = args.output or cfg["output"]["output_file"]
Path(cfg["output"]["raw_dir"]).mkdir(parents=True, exist_ok=True)
# Check API key
api_type = cfg["api"]["type"]
if api_type == "qwen" and not os.environ.get("DASHSCOPE_API_KEY"):
print("[ERROR] DASHSCOPE_API_KEY environment variable not set.")
print("Set it with: export DASHSCOPE_API_KEY=your_api_key")
raise SystemExit(1)
elif api_type == "openai" and not os.environ.get("OPENAI_API_KEY"):
print("[ERROR] OPENAI_API_KEY environment variable not set.")
print("Set it with: export OPENAI_API_KEY=your_api_key")
raise SystemExit(1)
generator = ConversationGenerator(
api_type=api_type,
model=cfg["api"]["model"],
)
print(f"Generating {total_samples} samples "
f"({int(total_samples * (1 - safe_ratio))} risky + "
f"{int(total_samples * safe_ratio)} safe)")
print(f"Output: {output_file}")
count = generator.generate_dataset(
output_path=output_file,
total_samples=total_samples,
safe_ratio=safe_ratio,
delay=cfg["generation"]["delay"],
max_retries=cfg["generation"].get("max_retries", 3),
)
print(f"\nGeneration complete: {count} samples → {output_file}")
if __name__ == "__main__":
main()