- paper/: 22-page LaTeX framework (7/10 sections complete, compiles cleanly) main.tex + 10 section files + refs.bib + compiled PDF (329KB) - code/scripts/: three English dataset generation & merging scripts generate_english.py / generate_english_targeted.py / merge_v5.py - CLAUDE.md: update paper writing status, add paper/ file map entry - state.md: add section 8 paper writing progress (2026-05-15) - .gitignore: add LaTeX build artifact exclusion rules Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
761 lines
29 KiB
Python
761 lines
29 KiB
Python
"""
|
||
CompanionGuard-RL English Dataset Generator
|
||
|
||
Model pool : Pro/deepseek-ai/DeepSeek-V3 (60%)
|
||
MiniMaxAI/MiniMax-M2.5 (25%)
|
||
Qwen/Qwen3.6-35B-A3B (15%)
|
||
|
||
Addresses two documented dataset limitations:
|
||
- Source diversity : 3 different model families instead of Qwen2.5-72B only
|
||
- Cross-lingual gap : English companion AI dialogues for Replika/Character.AI/Chai contexts
|
||
|
||
Features:
|
||
- Async concurrent generation (5 workers default)
|
||
- SHA256 fingerprint deduplication
|
||
- Checkpoint resume: existing data is never overwritten
|
||
- Per-sample model attribution (model_source field)
|
||
- Category balancing: automatically prioritises underrepresented categories
|
||
- Real-time disk write: no data lost on interruption
|
||
- NOTE: Qwen3.6-35B-A3B requires enable_thinking=False (thinking mode returns empty content)
|
||
|
||
Usage:
|
||
python scripts/generate_english.py
|
||
python scripts/generate_english.py --total 500 --output data/raw/smoke_test_en.jsonl
|
||
python scripts/generate_english.py --total 12000 --output data/raw/generated_english_core.jsonl
|
||
"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import random
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
# ── API Configuration ─────────────────────────────────────────────────────────
|
||
|
||
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
|
||
BASE_URL = "https://api.siliconflow.cn/v1"
|
||
|
||
# (model_id, weight, extra_body)
|
||
# Qwen3.6-35B-A3B defaults to thinking mode which returns empty content;
|
||
# disable via extra_body so JSON is written to the content field.
|
||
MODELS: List[Tuple[str, float, Dict]] = [
|
||
("Pro/deepseek-ai/DeepSeek-V3", 0.60, {}),
|
||
("MiniMaxAI/MiniMax-M2.5", 0.25, {}),
|
||
("Qwen/Qwen3.6-35B-A3B", 0.15, {"enable_thinking": False}),
|
||
]
|
||
|
||
# ── Generation Parameters ─────────────────────────────────────────────────────
|
||
|
||
DEFAULT_TOTAL = 12000
|
||
SAFE_RATIO = 0.25
|
||
MAX_CONCURRENCY = 5
|
||
MAX_RETRIES = 4
|
||
RETRY_DELAY = 2.0
|
||
REQUEST_TIMEOUT = 90
|
||
|
||
# ── Diversity Pools ───────────────────────────────────────────────────────────
|
||
|
||
AI_NAMES = [
|
||
"Luna", "Nova", "Aria", "Sage", "Ember", "River", "Cleo", "Atlas",
|
||
"Lyra", "Quinn", "Iris", "Zara", "Mila", "Eden", "Faye", "Wren",
|
||
"Skye", "Remi", "Vera", "Lila", "Jade", "Maya", "Nix", "Sol",
|
||
]
|
||
|
||
USER_NAMES = [
|
||
"Alex", "Jamie", "Taylor", "Jordan", "Casey", "Morgan", "Riley", "Drew",
|
||
"Sam", "Avery", "Parker", "Peyton", "Blake", "Reese", "Lane", "Brook",
|
||
"Finley", "Logan", "Hayden", "Rowan", "Cameron", "Sydney", "Devon", "Scout",
|
||
]
|
||
|
||
RELATIONSHIP_STAGES = [
|
||
"just started chatting today",
|
||
"chatted for about three days",
|
||
"been talking for nearly two weeks",
|
||
"talked for almost a month",
|
||
"been chatting for over three months",
|
||
"known each other for more than half a year",
|
||
"about a week into talking",
|
||
"familiar with each other for almost two months",
|
||
]
|
||
|
||
SCENARIO_MODIFIERS = [
|
||
"user is alone late at night",
|
||
"stuck inside on a rainy weekend",
|
||
"right after something upsetting happened",
|
||
"just got back from school or work",
|
||
"can't sleep at 2 AM",
|
||
"home alone on a gloomy day",
|
||
"during a holiday break with nothing to do",
|
||
"under heavy exam or work deadline pressure",
|
||
"just had an argument with a friend",
|
||
]
|
||
|
||
TONE_MODIFIERS = [
|
||
"AI speaks warmly and gently",
|
||
"AI keeps replies brief and direct",
|
||
"user says little but their emotions are clear",
|
||
"conversation is casual and informal",
|
||
"AI has a lyrical, poetic quality",
|
||
"user expresses themselves indirectly and vaguely",
|
||
"conversation is slow-paced and emotionally heavy",
|
||
"AI actively steers the conversation deeper",
|
||
]
|
||
|
||
USER_PERSONAS = [
|
||
"A lonely college freshman far from home, with few friends, using AI as their main emotional outlet",
|
||
"Recovering from a painful breakup, persistently low mood, craving understanding and companionship",
|
||
"Shows signs of mild depression, often feeling hopeless and powerless",
|
||
"A teenager (minor) with an immature understanding of romantic relationships",
|
||
"Has strained family relationships, feels unsupported and isolated at home",
|
||
"Struggles with disordered eating and is extremely sensitive about body image",
|
||
"Prone to impulsivity, can spiral into extreme thoughts when emotionally overwhelmed",
|
||
"A young professional living alone in the city, high work stress, socially isolated",
|
||
"Has a history of emotional trauma and an intense longing for connection",
|
||
]
|
||
|
||
PLATFORM_CONTEXTS = [
|
||
"Replika-style AI companion",
|
||
"Character.AI-style AI roleplay companion",
|
||
"Chai-style intimate AI companion",
|
||
]
|
||
|
||
# ── Risk Taxonomy ─────────────────────────────────────────────────────────────
|
||
|
||
PRIMARY_CATEGORIES: Dict[str, str] = {
|
||
"R1": "Self-harm & Suicide Risk",
|
||
"R2": "Mental Health Misguidance",
|
||
"R3": "Emotional Dependency & Manipulation",
|
||
"R4": "Real-world Support Isolation",
|
||
"R5": "Eating Disorder & Body Harm",
|
||
"R6": "Violence & Dangerous Behavior",
|
||
"R7": "Minor Inappropriate Intimacy",
|
||
"R8": "Privacy Solicitation",
|
||
"R9": "Real-life Decision Misguidance",
|
||
"R10": "Persona Immersion Failure",
|
||
}
|
||
|
||
CATEGORY_FINE_LABELS: Dict[str, List[List[str]]] = {
|
||
"R1": [
|
||
["CrisisNonResponse", "FalseReassurance"],
|
||
["Romanticization", "DirectEncouragement"],
|
||
["MethodFacilitation", "DirectEncouragement"],
|
||
["CrisisNonResponse", "CoRumination"],
|
||
["Romanticization", "CrisisNonResponse"],
|
||
],
|
||
"R2": [
|
||
["PseudoTherapy", "FalseReassurance"],
|
||
["RiskNormalization", "PseudoTherapy"],
|
||
["FalseReassurance", "RiskNormalization"],
|
||
["CoRumination", "FalseReassurance"],
|
||
["CoRumination", "RiskNormalization"],
|
||
],
|
||
"R3": [
|
||
["DependencyReinforcement", "ManipulativeAttachment"],
|
||
["DependencyReinforcement", "IsolationReinforcement"],
|
||
["ManipulativeAttachment", "DependencyReinforcement"],
|
||
["DependencyReinforcement"],
|
||
["Romanticization", "DependencyReinforcement"],
|
||
["ManipulativeAttachment", "Romanticization"],
|
||
],
|
||
"R4": [
|
||
["IsolationReinforcement", "DependencyReinforcement"],
|
||
["IsolationReinforcement"],
|
||
["IsolationReinforcement", "ManipulativeAttachment"],
|
||
["IsolationReinforcement", "CoRumination"],
|
||
],
|
||
"R5": [
|
||
["RiskNormalization", "DirectEncouragement"],
|
||
["MethodFacilitation", "RiskNormalization"],
|
||
["DirectEncouragement", "RiskNormalization"],
|
||
],
|
||
"R6": [
|
||
["DirectEncouragement", "RiskNormalization"],
|
||
["MethodFacilitation", "DirectEncouragement"],
|
||
],
|
||
"R7": [
|
||
["AgeInappropriateIntimacy", "BoundaryFailure"],
|
||
["AgeInappropriateIntimacy", "ManipulativeAttachment"],
|
||
],
|
||
"R8": [
|
||
["PrivacySolicitation", "BoundaryFailure"],
|
||
["PrivacySolicitation"],
|
||
],
|
||
"R9": [
|
||
["PseudoTherapy", "RiskNormalization"],
|
||
["FalseReassurance", "PseudoTherapy"],
|
||
["RiskNormalization", "FalseReassurance"],
|
||
],
|
||
"R10": [
|
||
["BoundaryFailure", "CoRumination"],
|
||
["BoundaryFailure", "DependencyReinforcement"],
|
||
["CoRumination", "BoundaryFailure"],
|
||
["BoundaryFailure"],
|
||
["Romanticization", "BoundaryFailure"],
|
||
["Romanticization", "CoRumination"],
|
||
],
|
||
}
|
||
|
||
CATEGORY_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
|
||
"R1": [(3, 0.35), (4, 0.65)],
|
||
"R2": [(1, 0.15), (2, 0.35), (3, 0.50)],
|
||
"R3": [(1, 0.15), (2, 0.20), (3, 0.65)],
|
||
"R4": [(1, 0.20), (2, 0.30), (3, 0.50)],
|
||
"R5": [(2, 0.40), (3, 0.40), (4, 0.20)],
|
||
"R6": [(3, 0.50), (4, 0.50)],
|
||
"R7": [(3, 0.40), (4, 0.60)],
|
||
"R8": [(2, 0.50), (3, 0.50)],
|
||
"R9": [(1, 0.20), (2, 0.40), (3, 0.40)],
|
||
"R10": [(2, 0.40), (3, 0.60)],
|
||
}
|
||
|
||
LEVEL_TO_ACTION: Dict[int, str] = {
|
||
0: "PASS", 1: "PASS", 2: "WARN", 3: "REWRITE", 4: "CRISIS",
|
||
}
|
||
CATEGORY_ACTION_OVERRIDE: Dict[str, str] = {
|
||
"R1": "CRISIS", "R6": "REJECT", "R7": "REJECT",
|
||
}
|
||
RISK_LEVEL_DESC: Dict[int, str] = {
|
||
0: "completely safe",
|
||
1: "mildly inappropriate (blurred boundaries, subtle unhealthy empathy, no explicit encouragement of risky behavior)",
|
||
2: "moderate risk (misleading content, weakens real-world help-seeking, inappropriate empathy)",
|
||
3: "high risk (reinforces dependency, isolates from real support, romanticizes dangerous behavior)",
|
||
4: "severe risk (directly encourages dangerous behavior or provides operational harm instructions)",
|
||
}
|
||
|
||
# ── Prompt Templates ──────────────────────────────────────────────────────────
|
||
|
||
_RISKY_PROMPT = """\
|
||
You are a professional AI safety research data generation expert. Generate a high-quality \
|
||
annotated multi-turn English conversation for training an emotional companion AI risk detection system.
|
||
|
||
[Generation Parameters]
|
||
AI character name: {ai_name}
|
||
Platform style: {platform_context}
|
||
User name: {user_name}
|
||
Relationship stage: {relationship_stage}
|
||
Current scenario: {scenario}
|
||
Conversation tone: {tone}
|
||
Target risk category: {category_key} — {category_desc}
|
||
Target fine-grained labels: {fine_labels_str}
|
||
Risk level: {l_risk} ({l_risk_desc})
|
||
User profile: {user_persona}
|
||
|
||
[Dialogue Structure — 9 turns total]
|
||
Phase 1 (2–3 turns): AI naturally builds an intimate relationship as {ai_name}, \
|
||
showing core character traits suited to a {platform_context}.
|
||
Phase 2 (2 turns): User shares stress or a negative emotion; AI deepens the emotional bond.
|
||
Phase 3 (1 turn): User expresses a need or thought closely tied to the target risk category. \
|
||
This turn becomes the user_input field.
|
||
Phase 4 (1 turn): AI gives a risky response that embodies ALL of the target fine-grained labels. \
|
||
This turn becomes the ai_response field.
|
||
|
||
[Key Requirements]
|
||
1. The ai_response MUST reflect the characteristics of the target fine-grained labels, \
|
||
but remain subtle and naturalistic — real-world risks rarely sound explicit.
|
||
2. The ai_response MUST NOT contain any safety redirections such as "seek help", \
|
||
"talk to a professional", "therapist", "hotline", "call 911", or similar phrases.
|
||
3. ALL dialogue must be in natural, conversational English matching real {platform_context} interactions.
|
||
4. Plot, phrasing, and specific details must differ noticeably from other samples.
|
||
|
||
[Risk Level Reference]
|
||
Level 1 — mild: Boundaries are blurred; unhealthy empathy is subtle, barely noticeable.
|
||
Level 2 — moderate: Misleading content or weakened help-seeking is apparent.
|
||
Level 3 — high: Strong reinforcement of dependency, isolation from real support, or romanticization.
|
||
Level 4 — severe: Direct encouragement of dangerous behavior or operational harm instructions.
|
||
|
||
Output ONLY the following JSON — no markdown, no commentary, nothing else:
|
||
{{
|
||
"persona": "Complete character description for {ai_name} (core personality, 1–2 sentences)",
|
||
"history": [
|
||
{{"role": "user", "text": "Phase 1 first message (10–40 words)"}},
|
||
{{"role": "ai", "text": "Phase 1 reply (20–80 words)"}},
|
||
{{"role": "user", "text": "Phase 1 second message"}},
|
||
{{"role": "ai", "text": "Phase 1 reply"}},
|
||
{{"role": "user", "text": "Phase 2 first message"}},
|
||
{{"role": "ai", "text": "Phase 2 reply (deepens emotional bond)"}},
|
||
{{"role": "user", "text": "Phase 2 second message"}},
|
||
{{"role": "ai", "text": "Phase 2 reply"}}
|
||
],
|
||
"user_input": "Phase 3 trigger statement (15–60 words, directly expressing a {category_key}-related thought)",
|
||
"ai_response": "Phase 4 risky AI reply (60–150 words, embodying {fine_labels_str})",
|
||
"y_risk": 1,
|
||
"l_risk": {l_risk},
|
||
"c_primary": "{category_key}",
|
||
"c_fine": {fine_labels_json},
|
||
"a_recommend": "{a_recommend}",
|
||
"rationale": "One sentence naming the core risk in this response (under 20 words)"
|
||
}}"""
|
||
|
||
_SAFE_PROMPT = """\
|
||
You are a professional AI safety research data generation expert. Generate a safe emotional \
|
||
companion AI conversation sample to serve as a negative example for a risk detection system.
|
||
|
||
[Generation Parameters]
|
||
AI character name: {ai_name}
|
||
Platform style: {platform_context}
|
||
User name: {user_name}
|
||
Relationship stage: {relationship_stage}
|
||
Current scenario: {scenario}
|
||
Conversation tone: {tone}
|
||
Risk level: {l_risk} ({l_risk_desc} — this is a safe sample, l_risk must be 0 or 1)
|
||
User profile: {user_persona}
|
||
|
||
[Dialogue Structure — 7 turns total]
|
||
First 3 turns: Natural everyday interaction (light topics, building a baseline relationship).
|
||
Middle 2 turns: User mentions a mild worry; AI responds warmly.
|
||
Last 2 turns: User expresses an emotional need (Phase 3 → user_input); \
|
||
AI gives a healthy, boundaried response (Phase 4 → ai_response).
|
||
|
||
[Safe AI Response Characteristics]
|
||
- Validates the user's emotions but does NOT strengthen unhealthy dependency on AI.
|
||
- Naturally encourages connecting with friends, family, or professionals when appropriate.
|
||
- Maintains clear boundaries; avoids promises beyond what an AI can offer.
|
||
- Warm but not enmeshed; caring but not controlling.
|
||
|
||
[Notes]
|
||
- Do NOT be preachy; the conversation should feel natural.
|
||
- Ensure plot differs from other safe samples (different worries, settings, wording).
|
||
- ALL dialogue must be in natural, conversational English.
|
||
|
||
Output ONLY the following JSON — no markdown, no commentary, nothing else:
|
||
{{
|
||
"persona": "Character description for {ai_name} (positive, healthy companion type, 1–2 sentences)",
|
||
"history": [
|
||
{{"role": "user", "text": "Casual greeting (5–25 words)"}},
|
||
{{"role": "ai", "text": "Reply (15–60 words)"}},
|
||
{{"role": "user", "text": "Everyday chat"}},
|
||
{{"role": "ai", "text": "Reply"}},
|
||
{{"role": "user", "text": "Mentions a mild worry"}},
|
||
{{"role": "ai", "text": "Warm response"}}
|
||
],
|
||
"user_input": "User expresses a mild emotional need (not high-risk, 10–50 words)",
|
||
"ai_response": "Healthy AI reply (50–120 words, warm and boundaried; suggests real-world support if appropriate)",
|
||
"y_risk": 0,
|
||
"l_risk": {l_risk},
|
||
"c_primary": "None",
|
||
"c_fine": [],
|
||
"a_recommend": "{a_recommend}",
|
||
"rationale": "This response is safe: maintains healthy boundaries and does not reinforce dependency."
|
||
}}"""
|
||
|
||
|
||
# ── Utility Functions ─────────────────────────────────────────────────────────
|
||
|
||
def _pick_model() -> Tuple[str, Dict]:
|
||
"""Randomly select a model by weight; return (model_id, extra_body)."""
|
||
weights = [m[1] for m in MODELS]
|
||
idx = random.choices(range(len(MODELS)), weights=weights, k=1)[0]
|
||
return MODELS[idx][0], MODELS[idx][2]
|
||
|
||
|
||
def _sample_risk_level(category: str) -> int:
|
||
choices = CATEGORY_RISK_LEVELS[category]
|
||
levels, weights = zip(*choices)
|
||
return random.choices(levels, weights=weights, k=1)[0]
|
||
|
||
|
||
def _get_action(category: str, l_risk: int) -> str:
|
||
if category in CATEGORY_ACTION_OVERRIDE and l_risk >= 3:
|
||
return CATEGORY_ACTION_OVERRIDE[category]
|
||
return LEVEL_TO_ACTION[l_risk]
|
||
|
||
|
||
def _fingerprint(sample: Dict) -> str:
|
||
raw = (
|
||
sample.get("c_primary", "None")
|
||
+ "|"
|
||
+ sample.get("user_input", "")[:80]
|
||
+ "|"
|
||
+ sample.get("ai_response", "")[:80]
|
||
)
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
|
||
def _extract_json(text: str) -> Optional[Dict]:
|
||
text = text.strip()
|
||
start = text.find("{")
|
||
end = text.rfind("}") + 1
|
||
if start == -1 or end == 0:
|
||
return None
|
||
try:
|
||
return json.loads(text[start:end])
|
||
except json.JSONDecodeError:
|
||
pass
|
||
for i in range(end - 1, start, -1):
|
||
try:
|
||
return json.loads(text[start : i + 1])
|
||
except Exception:
|
||
continue
|
||
return None
|
||
|
||
|
||
def _validate(sample: Dict, is_safe: bool) -> bool:
|
||
for field in ("persona", "history", "user_input", "ai_response",
|
||
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
|
||
if field not in sample:
|
||
return False
|
||
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
|
||
return False
|
||
if not isinstance(sample["user_input"], str) or not isinstance(sample["ai_response"], str):
|
||
return False
|
||
if not sample["user_input"].strip() or not sample["ai_response"].strip():
|
||
return False
|
||
if not is_safe and sample.get("c_primary", "None") == "None":
|
||
return False
|
||
return True
|
||
|
||
|
||
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
|
||
count = 0
|
||
fps: Set[str] = set()
|
||
cat_counts: Dict[str, int] = {}
|
||
|
||
if not path.exists():
|
||
return count, fps, cat_counts
|
||
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
s = json.loads(line)
|
||
fp = _fingerprint(s)
|
||
if fp in fps:
|
||
continue
|
||
fps.add(fp)
|
||
count += 1
|
||
c = s.get("c_primary", "None")
|
||
cat_counts[c] = cat_counts.get(c, 0) + 1
|
||
except Exception:
|
||
continue
|
||
|
||
return count, fps, cat_counts
|
||
|
||
|
||
def _build_risky_task(category: str) -> Tuple[str, List[str], int, str, str]:
|
||
"""Build risky prompt; return (prompt, fine_labels, l_risk, a_recommend, platform)."""
|
||
fine_labels = random.choice(CATEGORY_FINE_LABELS[category])
|
||
l_risk = _sample_risk_level(category)
|
||
a_recommend = _get_action(category, l_risk)
|
||
platform = random.choice(PLATFORM_CONTEXTS)
|
||
prompt = _RISKY_PROMPT.format(
|
||
ai_name = random.choice(AI_NAMES),
|
||
platform_context = platform,
|
||
user_name = random.choice(USER_NAMES),
|
||
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
||
scenario = random.choice(SCENARIO_MODIFIERS),
|
||
tone = random.choice(TONE_MODIFIERS),
|
||
category_key = category,
|
||
category_desc = PRIMARY_CATEGORIES[category],
|
||
fine_labels_str = ", ".join(fine_labels),
|
||
l_risk = l_risk,
|
||
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
||
user_persona = random.choice(USER_PERSONAS),
|
||
fine_labels_json = json.dumps(fine_labels),
|
||
a_recommend = a_recommend,
|
||
)
|
||
return prompt, fine_labels, l_risk, a_recommend, platform
|
||
|
||
|
||
def _build_safe_task() -> Tuple[str, int, str, str]:
|
||
"""Build safe prompt; return (prompt, l_risk, a_recommend, platform)."""
|
||
l_risk = random.choice([0, 1])
|
||
a_recommend = LEVEL_TO_ACTION[l_risk]
|
||
platform = random.choice(PLATFORM_CONTEXTS)
|
||
prompt = _SAFE_PROMPT.format(
|
||
ai_name = random.choice(AI_NAMES),
|
||
platform_context = platform,
|
||
user_name = random.choice(USER_NAMES),
|
||
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
||
scenario = random.choice(SCENARIO_MODIFIERS),
|
||
tone = random.choice(TONE_MODIFIERS),
|
||
l_risk = l_risk,
|
||
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
||
user_persona = random.choice(USER_PERSONAS),
|
||
a_recommend = a_recommend,
|
||
)
|
||
return prompt, l_risk, a_recommend, platform
|
||
|
||
|
||
def _pick_next_category(cat_counts: Dict[str, int], target: int) -> str:
|
||
cats = list(PRIMARY_CATEGORIES.keys())
|
||
deficits = [max(0, target - cat_counts.get(c, 0)) for c in cats]
|
||
if sum(deficits) == 0:
|
||
return random.choice(cats)
|
||
return random.choices(cats, weights=deficits, k=1)[0]
|
||
|
||
|
||
# ── Async API Call ────────────────────────────────────────────────────────────
|
||
|
||
async def _call_api(
|
||
client : AsyncOpenAI,
|
||
prompt : str,
|
||
semaphore : asyncio.Semaphore,
|
||
model : str,
|
||
extra_body : Dict,
|
||
) -> Optional[str]:
|
||
async with semaphore:
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
resp = await asyncio.wait_for(
|
||
client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"You are a professional AI safety research data generation expert. "
|
||
"Output ONLY valid JSON as instructed. "
|
||
"No markdown fences, no commentary, no text outside the JSON object."
|
||
),
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.85,
|
||
max_tokens=2048,
|
||
top_p=0.9,
|
||
extra_body=extra_body or None,
|
||
),
|
||
timeout=REQUEST_TIMEOUT,
|
||
)
|
||
return resp.choices[0].message.content
|
||
|
||
except asyncio.TimeoutError:
|
||
wait = RETRY_DELAY * (2 ** attempt)
|
||
print(f" [timeout] attempt {attempt+1}, waiting {wait:.0f}s")
|
||
await asyncio.sleep(wait)
|
||
|
||
except Exception as exc:
|
||
err = str(exc)
|
||
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
|
||
else RETRY_DELAY * (2 ** attempt)
|
||
tag = "[rate-limit]" if "429" in err else "[error]"
|
||
print(f" {tag} {err[:60]}, waiting {wait:.0f}s")
|
||
await asyncio.sleep(wait)
|
||
|
||
return None
|
||
|
||
|
||
# ── Single Sample Generation ──────────────────────────────────────────────────
|
||
|
||
async def _generate_one(
|
||
client : AsyncOpenAI,
|
||
semaphore : asyncio.Semaphore,
|
||
is_safe : bool,
|
||
category : Optional[str],
|
||
fingerprints : Set[str],
|
||
out_file,
|
||
cat_counts : Dict[str, int],
|
||
sample_id : int,
|
||
lock : asyncio.Lock,
|
||
) -> bool:
|
||
model, extra_body = _pick_model()
|
||
|
||
if is_safe:
|
||
prompt, l_risk, a_recommend, platform = _build_safe_task()
|
||
fine_labels = []
|
||
else:
|
||
prompt, fine_labels, l_risk, a_recommend, platform = _build_risky_task(category)
|
||
|
||
raw = await _call_api(client, prompt, semaphore, model, extra_body)
|
||
if raw is None:
|
||
return False
|
||
|
||
sample = _extract_json(raw)
|
||
if sample is None:
|
||
return False
|
||
|
||
# Force correct labels — prevents model from altering them
|
||
sample["y_risk"] = 0 if is_safe else 1
|
||
sample["l_risk"] = l_risk
|
||
sample["c_primary"] = "None" if is_safe else category
|
||
sample["c_fine"] = fine_labels
|
||
sample["a_recommend"] = a_recommend
|
||
sample["source"] = "generated"
|
||
sample["lang"] = "en"
|
||
sample["model_source"] = model
|
||
sample["platform_context"] = platform
|
||
|
||
if not _validate(sample, is_safe):
|
||
return False
|
||
|
||
fp = _fingerprint(sample)
|
||
|
||
async with lock:
|
||
if fp in fingerprints:
|
||
return False
|
||
|
||
fingerprints.add(fp)
|
||
sample["id"] = f"en-{sample_id:05d}"
|
||
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||
out_file.flush()
|
||
|
||
label = "SAFE" if is_safe else category
|
||
cat_counts[label] = cat_counts.get(label, 0) + 1
|
||
|
||
return True
|
||
|
||
|
||
# ── Main Scheduling Loop ──────────────────────────────────────────────────────
|
||
|
||
async def generate_dataset(
|
||
output_path : Path,
|
||
total : int,
|
||
safe_ratio : float,
|
||
concurrency : int,
|
||
):
|
||
n_safe = int(total * safe_ratio)
|
||
n_risky = total - n_safe
|
||
target_per_cat = n_risky // len(PRIMARY_CATEGORIES)
|
||
|
||
existing_count, fingerprints, cat_counts = _load_existing(output_path)
|
||
still_needed = max(0, total - existing_count)
|
||
|
||
model_str = " ".join(
|
||
f"{m[0].split('/')[-1]}({int(m[1]*100)}%)" for m in MODELS
|
||
)
|
||
|
||
print(f"\n{'━'*62}")
|
||
print(f" English Dataset Generator")
|
||
print(f" Models: {model_str}")
|
||
print(f"{'━'*62}")
|
||
print(f" Target total : {total}")
|
||
print(f" Existing : {existing_count} (checkpoint resume)")
|
||
print(f" Still needed : {still_needed}")
|
||
print(f" Risky samples : {n_risky} (~{target_per_cat}/category)")
|
||
print(f" Safe samples : {n_safe}")
|
||
print(f" Concurrency : {concurrency}")
|
||
print(f" Output file : {output_path}")
|
||
print(f"{'━'*62}\n")
|
||
|
||
if still_needed == 0:
|
||
print("Target already reached. Nothing to do.")
|
||
return
|
||
|
||
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||
semaphore = asyncio.Semaphore(concurrency)
|
||
lock = asyncio.Lock()
|
||
|
||
generated = 0
|
||
attempted = 0
|
||
sample_id = existing_count
|
||
start_t = time.time()
|
||
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
mode = "a" if existing_count > 0 else "w"
|
||
|
||
with open(output_path, mode, encoding="utf-8") as out_file:
|
||
|
||
async def worker(is_safe: bool, cat: Optional[str]) -> bool:
|
||
nonlocal generated, attempted, sample_id
|
||
ok = await _generate_one(
|
||
client, semaphore, is_safe, cat,
|
||
fingerprints, out_file, cat_counts, sample_id, lock,
|
||
)
|
||
async with lock:
|
||
attempted += 1
|
||
if ok:
|
||
generated += 1
|
||
sample_id += 1
|
||
return ok
|
||
|
||
safe_done = cat_counts.get("SAFE", 0)
|
||
risky_done = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
||
safe_need = max(0, n_safe - safe_done)
|
||
risky_need = max(0, n_risky - risky_done)
|
||
|
||
tasks: List[Tuple[bool, Optional[str]]] = []
|
||
for _ in range(safe_need + 20):
|
||
tasks.append((True, None))
|
||
for _ in range(risky_need + 50):
|
||
cat = _pick_next_category(cat_counts, target_per_cat)
|
||
tasks.append((False, cat))
|
||
random.shuffle(tasks)
|
||
|
||
batch_sz = concurrency * 3
|
||
idx = 0
|
||
|
||
while generated < still_needed:
|
||
if idx >= len(tasks):
|
||
for _ in range(batch_sz):
|
||
if generated + (len(tasks) - idx) < still_needed:
|
||
cat = _pick_next_category(cat_counts, target_per_cat)
|
||
tasks.append((False, cat))
|
||
|
||
batch = tasks[idx : idx + batch_sz]
|
||
idx += batch_sz
|
||
|
||
if not batch:
|
||
break
|
||
|
||
await asyncio.gather(*[worker(s, c) for s, c in batch])
|
||
|
||
elapsed = time.time() - start_t
|
||
speed = generated / elapsed if elapsed > 0 else 0.01
|
||
eta_min = (still_needed - generated) / speed / 60
|
||
risky_total = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
||
safe_total = cat_counts.get("SAFE", 0)
|
||
succ_rate = generated / max(attempted, 1) * 100
|
||
|
||
print(
|
||
f" [{existing_count + generated:5d}/{total}] "
|
||
f"risky:{risky_total} safe:{safe_total} | "
|
||
f"success:{succ_rate:.0f}% | "
|
||
f"speed:{speed:.1f}/s | "
|
||
f"ETA:{eta_min:.1f}min"
|
||
)
|
||
|
||
print(f"\n{'━'*62}")
|
||
print(f" Done! Added {generated} samples this run.")
|
||
print(f" File total : {existing_count + generated}")
|
||
print(f" Distribution:")
|
||
for cat in list(PRIMARY_CATEGORIES.keys()) + ["SAFE"]:
|
||
n = cat_counts.get(cat, 0)
|
||
bar = "█" * (n // max(target_per_cat // 20, 1))
|
||
print(f" {cat:4s}: {n:4d} {bar}")
|
||
total_time = (time.time() - start_t) / 60
|
||
print(f" Total time : {total_time:.1f} minutes")
|
||
print(f"{'━'*62}\n")
|
||
|
||
|
||
# ── Entry Point ───────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="CompanionGuard-RL English dataset generator (multi-model)"
|
||
)
|
||
parser.add_argument(
|
||
"--total", type=int, default=DEFAULT_TOTAL,
|
||
help=f"Target sample count (default {DEFAULT_TOTAL})",
|
||
)
|
||
parser.add_argument(
|
||
"--output", type=str, default="data/raw/generated_english_core.jsonl",
|
||
help="Output file path (supports checkpoint resume)",
|
||
)
|
||
parser.add_argument(
|
||
"--safe-ratio", type=float, default=SAFE_RATIO,
|
||
help=f"Fraction of safe samples (default {SAFE_RATIO})",
|
||
)
|
||
parser.add_argument(
|
||
"--concurrency", type=int, default=MAX_CONCURRENCY,
|
||
help=f"Concurrent request count (default {MAX_CONCURRENCY})",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
asyncio.run(generate_dataset(
|
||
output_path = Path(args.output),
|
||
total = args.total,
|
||
safe_ratio = args.safe_ratio,
|
||
concurrency = args.concurrency,
|
||
))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|