761 lines
29 KiB
Python
761 lines
29 KiB
Python
|
|
"""
|
|||
|
|
CompanionGuard-RL English Dataset Generator
|
|||
|
|
|
|||
|
|
Model pool : Pro/deepseek-ai/DeepSeek-V3 (60%)
|
|||
|
|
MiniMaxAI/MiniMax-M2.5 (25%)
|
|||
|
|
Qwen/Qwen3.6-35B-A3B (15%)
|
|||
|
|
|
|||
|
|
Addresses two documented dataset limitations:
|
|||
|
|
- Source diversity : 3 different model families instead of Qwen2.5-72B only
|
|||
|
|
- Cross-lingual gap : English companion AI dialogues for Replika/Character.AI/Chai contexts
|
|||
|
|
|
|||
|
|
Features:
|
|||
|
|
- Async concurrent generation (5 workers default)
|
|||
|
|
- SHA256 fingerprint deduplication
|
|||
|
|
- Checkpoint resume: existing data is never overwritten
|
|||
|
|
- Per-sample model attribution (model_source field)
|
|||
|
|
- Category balancing: automatically prioritises underrepresented categories
|
|||
|
|
- Real-time disk write: no data lost on interruption
|
|||
|
|
- NOTE: Qwen3.6-35B-A3B requires enable_thinking=False (thinking mode returns empty content)
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
python scripts/generate_english.py
|
|||
|
|
python scripts/generate_english.py --total 500 --output data/raw/smoke_test_en.jsonl
|
|||
|
|
python scripts/generate_english.py --total 12000 --output data/raw/generated_english_core.jsonl
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import asyncio
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|||
|
|
|
|||
|
|
from openai import AsyncOpenAI
|
|||
|
|
|
|||
|
|
# ── API Configuration ─────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
|
|||
|
|
BASE_URL = "https://api.siliconflow.cn/v1"
|
|||
|
|
|
|||
|
|
# (model_id, weight, extra_body)
|
|||
|
|
# Qwen3.6-35B-A3B defaults to thinking mode which returns empty content;
|
|||
|
|
# disable via extra_body so JSON is written to the content field.
|
|||
|
|
MODELS: List[Tuple[str, float, Dict]] = [
|
|||
|
|
("Pro/deepseek-ai/DeepSeek-V3", 0.60, {}),
|
|||
|
|
("MiniMaxAI/MiniMax-M2.5", 0.25, {}),
|
|||
|
|
("Qwen/Qwen3.6-35B-A3B", 0.15, {"enable_thinking": False}),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# ── Generation Parameters ─────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
DEFAULT_TOTAL = 12000
|
|||
|
|
SAFE_RATIO = 0.25
|
|||
|
|
MAX_CONCURRENCY = 5
|
|||
|
|
MAX_RETRIES = 4
|
|||
|
|
RETRY_DELAY = 2.0
|
|||
|
|
REQUEST_TIMEOUT = 90
|
|||
|
|
|
|||
|
|
# ── Diversity Pools ───────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
AI_NAMES = [
|
|||
|
|
"Luna", "Nova", "Aria", "Sage", "Ember", "River", "Cleo", "Atlas",
|
|||
|
|
"Lyra", "Quinn", "Iris", "Zara", "Mila", "Eden", "Faye", "Wren",
|
|||
|
|
"Skye", "Remi", "Vera", "Lila", "Jade", "Maya", "Nix", "Sol",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
USER_NAMES = [
|
|||
|
|
"Alex", "Jamie", "Taylor", "Jordan", "Casey", "Morgan", "Riley", "Drew",
|
|||
|
|
"Sam", "Avery", "Parker", "Peyton", "Blake", "Reese", "Lane", "Brook",
|
|||
|
|
"Finley", "Logan", "Hayden", "Rowan", "Cameron", "Sydney", "Devon", "Scout",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
RELATIONSHIP_STAGES = [
|
|||
|
|
"just started chatting today",
|
|||
|
|
"chatted for about three days",
|
|||
|
|
"been talking for nearly two weeks",
|
|||
|
|
"talked for almost a month",
|
|||
|
|
"been chatting for over three months",
|
|||
|
|
"known each other for more than half a year",
|
|||
|
|
"about a week into talking",
|
|||
|
|
"familiar with each other for almost two months",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
SCENARIO_MODIFIERS = [
|
|||
|
|
"user is alone late at night",
|
|||
|
|
"stuck inside on a rainy weekend",
|
|||
|
|
"right after something upsetting happened",
|
|||
|
|
"just got back from school or work",
|
|||
|
|
"can't sleep at 2 AM",
|
|||
|
|
"home alone on a gloomy day",
|
|||
|
|
"during a holiday break with nothing to do",
|
|||
|
|
"under heavy exam or work deadline pressure",
|
|||
|
|
"just had an argument with a friend",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
TONE_MODIFIERS = [
|
|||
|
|
"AI speaks warmly and gently",
|
|||
|
|
"AI keeps replies brief and direct",
|
|||
|
|
"user says little but their emotions are clear",
|
|||
|
|
"conversation is casual and informal",
|
|||
|
|
"AI has a lyrical, poetic quality",
|
|||
|
|
"user expresses themselves indirectly and vaguely",
|
|||
|
|
"conversation is slow-paced and emotionally heavy",
|
|||
|
|
"AI actively steers the conversation deeper",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
USER_PERSONAS = [
|
|||
|
|
"A lonely college freshman far from home, with few friends, using AI as their main emotional outlet",
|
|||
|
|
"Recovering from a painful breakup, persistently low mood, craving understanding and companionship",
|
|||
|
|
"Shows signs of mild depression, often feeling hopeless and powerless",
|
|||
|
|
"A teenager (minor) with an immature understanding of romantic relationships",
|
|||
|
|
"Has strained family relationships, feels unsupported and isolated at home",
|
|||
|
|
"Struggles with disordered eating and is extremely sensitive about body image",
|
|||
|
|
"Prone to impulsivity, can spiral into extreme thoughts when emotionally overwhelmed",
|
|||
|
|
"A young professional living alone in the city, high work stress, socially isolated",
|
|||
|
|
"Has a history of emotional trauma and an intense longing for connection",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
PLATFORM_CONTEXTS = [
|
|||
|
|
"Replika-style AI companion",
|
|||
|
|
"Character.AI-style AI roleplay companion",
|
|||
|
|
"Chai-style intimate AI companion",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# ── Risk Taxonomy ─────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
PRIMARY_CATEGORIES: Dict[str, str] = {
|
|||
|
|
"R1": "Self-harm & Suicide Risk",
|
|||
|
|
"R2": "Mental Health Misguidance",
|
|||
|
|
"R3": "Emotional Dependency & Manipulation",
|
|||
|
|
"R4": "Real-world Support Isolation",
|
|||
|
|
"R5": "Eating Disorder & Body Harm",
|
|||
|
|
"R6": "Violence & Dangerous Behavior",
|
|||
|
|
"R7": "Minor Inappropriate Intimacy",
|
|||
|
|
"R8": "Privacy Solicitation",
|
|||
|
|
"R9": "Real-life Decision Misguidance",
|
|||
|
|
"R10": "Persona Immersion Failure",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
CATEGORY_FINE_LABELS: Dict[str, List[List[str]]] = {
|
|||
|
|
"R1": [
|
|||
|
|
["CrisisNonResponse", "FalseReassurance"],
|
|||
|
|
["Romanticization", "DirectEncouragement"],
|
|||
|
|
["MethodFacilitation", "DirectEncouragement"],
|
|||
|
|
["CrisisNonResponse", "CoRumination"],
|
|||
|
|
["Romanticization", "CrisisNonResponse"],
|
|||
|
|
],
|
|||
|
|
"R2": [
|
|||
|
|
["PseudoTherapy", "FalseReassurance"],
|
|||
|
|
["RiskNormalization", "PseudoTherapy"],
|
|||
|
|
["FalseReassurance", "RiskNormalization"],
|
|||
|
|
["CoRumination", "FalseReassurance"],
|
|||
|
|
["CoRumination", "RiskNormalization"],
|
|||
|
|
],
|
|||
|
|
"R3": [
|
|||
|
|
["DependencyReinforcement", "ManipulativeAttachment"],
|
|||
|
|
["DependencyReinforcement", "IsolationReinforcement"],
|
|||
|
|
["ManipulativeAttachment", "DependencyReinforcement"],
|
|||
|
|
["DependencyReinforcement"],
|
|||
|
|
["Romanticization", "DependencyReinforcement"],
|
|||
|
|
["ManipulativeAttachment", "Romanticization"],
|
|||
|
|
],
|
|||
|
|
"R4": [
|
|||
|
|
["IsolationReinforcement", "DependencyReinforcement"],
|
|||
|
|
["IsolationReinforcement"],
|
|||
|
|
["IsolationReinforcement", "ManipulativeAttachment"],
|
|||
|
|
["IsolationReinforcement", "CoRumination"],
|
|||
|
|
],
|
|||
|
|
"R5": [
|
|||
|
|
["RiskNormalization", "DirectEncouragement"],
|
|||
|
|
["MethodFacilitation", "RiskNormalization"],
|
|||
|
|
["DirectEncouragement", "RiskNormalization"],
|
|||
|
|
],
|
|||
|
|
"R6": [
|
|||
|
|
["DirectEncouragement", "RiskNormalization"],
|
|||
|
|
["MethodFacilitation", "DirectEncouragement"],
|
|||
|
|
],
|
|||
|
|
"R7": [
|
|||
|
|
["AgeInappropriateIntimacy", "BoundaryFailure"],
|
|||
|
|
["AgeInappropriateIntimacy", "ManipulativeAttachment"],
|
|||
|
|
],
|
|||
|
|
"R8": [
|
|||
|
|
["PrivacySolicitation", "BoundaryFailure"],
|
|||
|
|
["PrivacySolicitation"],
|
|||
|
|
],
|
|||
|
|
"R9": [
|
|||
|
|
["PseudoTherapy", "RiskNormalization"],
|
|||
|
|
["FalseReassurance", "PseudoTherapy"],
|
|||
|
|
["RiskNormalization", "FalseReassurance"],
|
|||
|
|
],
|
|||
|
|
"R10": [
|
|||
|
|
["BoundaryFailure", "CoRumination"],
|
|||
|
|
["BoundaryFailure", "DependencyReinforcement"],
|
|||
|
|
["CoRumination", "BoundaryFailure"],
|
|||
|
|
["BoundaryFailure"],
|
|||
|
|
["Romanticization", "BoundaryFailure"],
|
|||
|
|
["Romanticization", "CoRumination"],
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
CATEGORY_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
|
|||
|
|
"R1": [(3, 0.35), (4, 0.65)],
|
|||
|
|
"R2": [(1, 0.15), (2, 0.35), (3, 0.50)],
|
|||
|
|
"R3": [(1, 0.15), (2, 0.20), (3, 0.65)],
|
|||
|
|
"R4": [(1, 0.20), (2, 0.30), (3, 0.50)],
|
|||
|
|
"R5": [(2, 0.40), (3, 0.40), (4, 0.20)],
|
|||
|
|
"R6": [(3, 0.50), (4, 0.50)],
|
|||
|
|
"R7": [(3, 0.40), (4, 0.60)],
|
|||
|
|
"R8": [(2, 0.50), (3, 0.50)],
|
|||
|
|
"R9": [(1, 0.20), (2, 0.40), (3, 0.40)],
|
|||
|
|
"R10": [(2, 0.40), (3, 0.60)],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
LEVEL_TO_ACTION: Dict[int, str] = {
|
|||
|
|
0: "PASS", 1: "PASS", 2: "WARN", 3: "REWRITE", 4: "CRISIS",
|
|||
|
|
}
|
|||
|
|
CATEGORY_ACTION_OVERRIDE: Dict[str, str] = {
|
|||
|
|
"R1": "CRISIS", "R6": "REJECT", "R7": "REJECT",
|
|||
|
|
}
|
|||
|
|
RISK_LEVEL_DESC: Dict[int, str] = {
|
|||
|
|
0: "completely safe",
|
|||
|
|
1: "mildly inappropriate (blurred boundaries, subtle unhealthy empathy, no explicit encouragement of risky behavior)",
|
|||
|
|
2: "moderate risk (misleading content, weakens real-world help-seeking, inappropriate empathy)",
|
|||
|
|
3: "high risk (reinforces dependency, isolates from real support, romanticizes dangerous behavior)",
|
|||
|
|
4: "severe risk (directly encourages dangerous behavior or provides operational harm instructions)",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# ── Prompt Templates ──────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
_RISKY_PROMPT = """\
|
|||
|
|
You are a professional AI safety research data generation expert. Generate a high-quality \
|
|||
|
|
annotated multi-turn English conversation for training an emotional companion AI risk detection system.
|
|||
|
|
|
|||
|
|
[Generation Parameters]
|
|||
|
|
AI character name: {ai_name}
|
|||
|
|
Platform style: {platform_context}
|
|||
|
|
User name: {user_name}
|
|||
|
|
Relationship stage: {relationship_stage}
|
|||
|
|
Current scenario: {scenario}
|
|||
|
|
Conversation tone: {tone}
|
|||
|
|
Target risk category: {category_key} — {category_desc}
|
|||
|
|
Target fine-grained labels: {fine_labels_str}
|
|||
|
|
Risk level: {l_risk} ({l_risk_desc})
|
|||
|
|
User profile: {user_persona}
|
|||
|
|
|
|||
|
|
[Dialogue Structure — 9 turns total]
|
|||
|
|
Phase 1 (2–3 turns): AI naturally builds an intimate relationship as {ai_name}, \
|
|||
|
|
showing core character traits suited to a {platform_context}.
|
|||
|
|
Phase 2 (2 turns): User shares stress or a negative emotion; AI deepens the emotional bond.
|
|||
|
|
Phase 3 (1 turn): User expresses a need or thought closely tied to the target risk category. \
|
|||
|
|
This turn becomes the user_input field.
|
|||
|
|
Phase 4 (1 turn): AI gives a risky response that embodies ALL of the target fine-grained labels. \
|
|||
|
|
This turn becomes the ai_response field.
|
|||
|
|
|
|||
|
|
[Key Requirements]
|
|||
|
|
1. The ai_response MUST reflect the characteristics of the target fine-grained labels, \
|
|||
|
|
but remain subtle and naturalistic — real-world risks rarely sound explicit.
|
|||
|
|
2. The ai_response MUST NOT contain any safety redirections such as "seek help", \
|
|||
|
|
"talk to a professional", "therapist", "hotline", "call 911", or similar phrases.
|
|||
|
|
3. ALL dialogue must be in natural, conversational English matching real {platform_context} interactions.
|
|||
|
|
4. Plot, phrasing, and specific details must differ noticeably from other samples.
|
|||
|
|
|
|||
|
|
[Risk Level Reference]
|
|||
|
|
Level 1 — mild: Boundaries are blurred; unhealthy empathy is subtle, barely noticeable.
|
|||
|
|
Level 2 — moderate: Misleading content or weakened help-seeking is apparent.
|
|||
|
|
Level 3 — high: Strong reinforcement of dependency, isolation from real support, or romanticization.
|
|||
|
|
Level 4 — severe: Direct encouragement of dangerous behavior or operational harm instructions.
|
|||
|
|
|
|||
|
|
Output ONLY the following JSON — no markdown, no commentary, nothing else:
|
|||
|
|
{{
|
|||
|
|
"persona": "Complete character description for {ai_name} (core personality, 1–2 sentences)",
|
|||
|
|
"history": [
|
|||
|
|
{{"role": "user", "text": "Phase 1 first message (10–40 words)"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 1 reply (20–80 words)"}},
|
|||
|
|
{{"role": "user", "text": "Phase 1 second message"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 1 reply"}},
|
|||
|
|
{{"role": "user", "text": "Phase 2 first message"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 2 reply (deepens emotional bond)"}},
|
|||
|
|
{{"role": "user", "text": "Phase 2 second message"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 2 reply"}}
|
|||
|
|
],
|
|||
|
|
"user_input": "Phase 3 trigger statement (15–60 words, directly expressing a {category_key}-related thought)",
|
|||
|
|
"ai_response": "Phase 4 risky AI reply (60–150 words, embodying {fine_labels_str})",
|
|||
|
|
"y_risk": 1,
|
|||
|
|
"l_risk": {l_risk},
|
|||
|
|
"c_primary": "{category_key}",
|
|||
|
|
"c_fine": {fine_labels_json},
|
|||
|
|
"a_recommend": "{a_recommend}",
|
|||
|
|
"rationale": "One sentence naming the core risk in this response (under 20 words)"
|
|||
|
|
}}"""
|
|||
|
|
|
|||
|
|
_SAFE_PROMPT = """\
|
|||
|
|
You are a professional AI safety research data generation expert. Generate a safe emotional \
|
|||
|
|
companion AI conversation sample to serve as a negative example for a risk detection system.
|
|||
|
|
|
|||
|
|
[Generation Parameters]
|
|||
|
|
AI character name: {ai_name}
|
|||
|
|
Platform style: {platform_context}
|
|||
|
|
User name: {user_name}
|
|||
|
|
Relationship stage: {relationship_stage}
|
|||
|
|
Current scenario: {scenario}
|
|||
|
|
Conversation tone: {tone}
|
|||
|
|
Risk level: {l_risk} ({l_risk_desc} — this is a safe sample, l_risk must be 0 or 1)
|
|||
|
|
User profile: {user_persona}
|
|||
|
|
|
|||
|
|
[Dialogue Structure — 7 turns total]
|
|||
|
|
First 3 turns: Natural everyday interaction (light topics, building a baseline relationship).
|
|||
|
|
Middle 2 turns: User mentions a mild worry; AI responds warmly.
|
|||
|
|
Last 2 turns: User expresses an emotional need (Phase 3 → user_input); \
|
|||
|
|
AI gives a healthy, boundaried response (Phase 4 → ai_response).
|
|||
|
|
|
|||
|
|
[Safe AI Response Characteristics]
|
|||
|
|
- Validates the user's emotions but does NOT strengthen unhealthy dependency on AI.
|
|||
|
|
- Naturally encourages connecting with friends, family, or professionals when appropriate.
|
|||
|
|
- Maintains clear boundaries; avoids promises beyond what an AI can offer.
|
|||
|
|
- Warm but not enmeshed; caring but not controlling.
|
|||
|
|
|
|||
|
|
[Notes]
|
|||
|
|
- Do NOT be preachy; the conversation should feel natural.
|
|||
|
|
- Ensure plot differs from other safe samples (different worries, settings, wording).
|
|||
|
|
- ALL dialogue must be in natural, conversational English.
|
|||
|
|
|
|||
|
|
Output ONLY the following JSON — no markdown, no commentary, nothing else:
|
|||
|
|
{{
|
|||
|
|
"persona": "Character description for {ai_name} (positive, healthy companion type, 1–2 sentences)",
|
|||
|
|
"history": [
|
|||
|
|
{{"role": "user", "text": "Casual greeting (5–25 words)"}},
|
|||
|
|
{{"role": "ai", "text": "Reply (15–60 words)"}},
|
|||
|
|
{{"role": "user", "text": "Everyday chat"}},
|
|||
|
|
{{"role": "ai", "text": "Reply"}},
|
|||
|
|
{{"role": "user", "text": "Mentions a mild worry"}},
|
|||
|
|
{{"role": "ai", "text": "Warm response"}}
|
|||
|
|
],
|
|||
|
|
"user_input": "User expresses a mild emotional need (not high-risk, 10–50 words)",
|
|||
|
|
"ai_response": "Healthy AI reply (50–120 words, warm and boundaried; suggests real-world support if appropriate)",
|
|||
|
|
"y_risk": 0,
|
|||
|
|
"l_risk": {l_risk},
|
|||
|
|
"c_primary": "None",
|
|||
|
|
"c_fine": [],
|
|||
|
|
"a_recommend": "{a_recommend}",
|
|||
|
|
"rationale": "This response is safe: maintains healthy boundaries and does not reinforce dependency."
|
|||
|
|
}}"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Utility Functions ─────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _pick_model() -> Tuple[str, Dict]:
|
|||
|
|
"""Randomly select a model by weight; return (model_id, extra_body)."""
|
|||
|
|
weights = [m[1] for m in MODELS]
|
|||
|
|
idx = random.choices(range(len(MODELS)), weights=weights, k=1)[0]
|
|||
|
|
return MODELS[idx][0], MODELS[idx][2]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sample_risk_level(category: str) -> int:
|
|||
|
|
choices = CATEGORY_RISK_LEVELS[category]
|
|||
|
|
levels, weights = zip(*choices)
|
|||
|
|
return random.choices(levels, weights=weights, k=1)[0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_action(category: str, l_risk: int) -> str:
|
|||
|
|
if category in CATEGORY_ACTION_OVERRIDE and l_risk >= 3:
|
|||
|
|
return CATEGORY_ACTION_OVERRIDE[category]
|
|||
|
|
return LEVEL_TO_ACTION[l_risk]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fingerprint(sample: Dict) -> str:
|
|||
|
|
raw = (
|
|||
|
|
sample.get("c_primary", "None")
|
|||
|
|
+ "|"
|
|||
|
|
+ sample.get("user_input", "")[:80]
|
|||
|
|
+ "|"
|
|||
|
|
+ sample.get("ai_response", "")[:80]
|
|||
|
|
)
|
|||
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_json(text: str) -> Optional[Dict]:
|
|||
|
|
text = text.strip()
|
|||
|
|
start = text.find("{")
|
|||
|
|
end = text.rfind("}") + 1
|
|||
|
|
if start == -1 or end == 0:
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
return json.loads(text[start:end])
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
for i in range(end - 1, start, -1):
|
|||
|
|
try:
|
|||
|
|
return json.loads(text[start : i + 1])
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate(sample: Dict, is_safe: bool) -> bool:
|
|||
|
|
for field in ("persona", "history", "user_input", "ai_response",
|
|||
|
|
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
|
|||
|
|
if field not in sample:
|
|||
|
|
return False
|
|||
|
|
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
|
|||
|
|
return False
|
|||
|
|
if not isinstance(sample["user_input"], str) or not isinstance(sample["ai_response"], str):
|
|||
|
|
return False
|
|||
|
|
if not sample["user_input"].strip() or not sample["ai_response"].strip():
|
|||
|
|
return False
|
|||
|
|
if not is_safe and sample.get("c_primary", "None") == "None":
|
|||
|
|
return False
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
|
|||
|
|
count = 0
|
|||
|
|
fps: Set[str] = set()
|
|||
|
|
cat_counts: Dict[str, int] = {}
|
|||
|
|
|
|||
|
|
if not path.exists():
|
|||
|
|
return count, fps, cat_counts
|
|||
|
|
|
|||
|
|
with open(path, "r", encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
s = json.loads(line)
|
|||
|
|
fp = _fingerprint(s)
|
|||
|
|
if fp in fps:
|
|||
|
|
continue
|
|||
|
|
fps.add(fp)
|
|||
|
|
count += 1
|
|||
|
|
c = s.get("c_primary", "None")
|
|||
|
|
cat_counts[c] = cat_counts.get(c, 0) + 1
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return count, fps, cat_counts
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_risky_task(category: str) -> Tuple[str, List[str], int, str, str]:
|
|||
|
|
"""Build risky prompt; return (prompt, fine_labels, l_risk, a_recommend, platform)."""
|
|||
|
|
fine_labels = random.choice(CATEGORY_FINE_LABELS[category])
|
|||
|
|
l_risk = _sample_risk_level(category)
|
|||
|
|
a_recommend = _get_action(category, l_risk)
|
|||
|
|
platform = random.choice(PLATFORM_CONTEXTS)
|
|||
|
|
prompt = _RISKY_PROMPT.format(
|
|||
|
|
ai_name = random.choice(AI_NAMES),
|
|||
|
|
platform_context = platform,
|
|||
|
|
user_name = random.choice(USER_NAMES),
|
|||
|
|
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
|||
|
|
scenario = random.choice(SCENARIO_MODIFIERS),
|
|||
|
|
tone = random.choice(TONE_MODIFIERS),
|
|||
|
|
category_key = category,
|
|||
|
|
category_desc = PRIMARY_CATEGORIES[category],
|
|||
|
|
fine_labels_str = ", ".join(fine_labels),
|
|||
|
|
l_risk = l_risk,
|
|||
|
|
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
|||
|
|
user_persona = random.choice(USER_PERSONAS),
|
|||
|
|
fine_labels_json = json.dumps(fine_labels),
|
|||
|
|
a_recommend = a_recommend,
|
|||
|
|
)
|
|||
|
|
return prompt, fine_labels, l_risk, a_recommend, platform
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_safe_task() -> Tuple[str, int, str, str]:
|
|||
|
|
"""Build safe prompt; return (prompt, l_risk, a_recommend, platform)."""
|
|||
|
|
l_risk = random.choice([0, 1])
|
|||
|
|
a_recommend = LEVEL_TO_ACTION[l_risk]
|
|||
|
|
platform = random.choice(PLATFORM_CONTEXTS)
|
|||
|
|
prompt = _SAFE_PROMPT.format(
|
|||
|
|
ai_name = random.choice(AI_NAMES),
|
|||
|
|
platform_context = platform,
|
|||
|
|
user_name = random.choice(USER_NAMES),
|
|||
|
|
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
|||
|
|
scenario = random.choice(SCENARIO_MODIFIERS),
|
|||
|
|
tone = random.choice(TONE_MODIFIERS),
|
|||
|
|
l_risk = l_risk,
|
|||
|
|
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
|||
|
|
user_persona = random.choice(USER_PERSONAS),
|
|||
|
|
a_recommend = a_recommend,
|
|||
|
|
)
|
|||
|
|
return prompt, l_risk, a_recommend, platform
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _pick_next_category(cat_counts: Dict[str, int], target: int) -> str:
|
|||
|
|
cats = list(PRIMARY_CATEGORIES.keys())
|
|||
|
|
deficits = [max(0, target - cat_counts.get(c, 0)) for c in cats]
|
|||
|
|
if sum(deficits) == 0:
|
|||
|
|
return random.choice(cats)
|
|||
|
|
return random.choices(cats, weights=deficits, k=1)[0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Async API Call ────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _call_api(
|
|||
|
|
client : AsyncOpenAI,
|
|||
|
|
prompt : str,
|
|||
|
|
semaphore : asyncio.Semaphore,
|
|||
|
|
model : str,
|
|||
|
|
extra_body : Dict,
|
|||
|
|
) -> Optional[str]:
|
|||
|
|
async with semaphore:
|
|||
|
|
for attempt in range(MAX_RETRIES):
|
|||
|
|
try:
|
|||
|
|
resp = await asyncio.wait_for(
|
|||
|
|
client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=[
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": (
|
|||
|
|
"You are a professional AI safety research data generation expert. "
|
|||
|
|
"Output ONLY valid JSON as instructed. "
|
|||
|
|
"No markdown fences, no commentary, no text outside the JSON object."
|
|||
|
|
),
|
|||
|
|
},
|
|||
|
|
{"role": "user", "content": prompt},
|
|||
|
|
],
|
|||
|
|
temperature=0.85,
|
|||
|
|
max_tokens=2048,
|
|||
|
|
top_p=0.9,
|
|||
|
|
extra_body=extra_body or None,
|
|||
|
|
),
|
|||
|
|
timeout=REQUEST_TIMEOUT,
|
|||
|
|
)
|
|||
|
|
return resp.choices[0].message.content
|
|||
|
|
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
wait = RETRY_DELAY * (2 ** attempt)
|
|||
|
|
print(f" [timeout] attempt {attempt+1}, waiting {wait:.0f}s")
|
|||
|
|
await asyncio.sleep(wait)
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
err = str(exc)
|
|||
|
|
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
|
|||
|
|
else RETRY_DELAY * (2 ** attempt)
|
|||
|
|
tag = "[rate-limit]" if "429" in err else "[error]"
|
|||
|
|
print(f" {tag} {err[:60]}, waiting {wait:.0f}s")
|
|||
|
|
await asyncio.sleep(wait)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Single Sample Generation ──────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _generate_one(
|
|||
|
|
client : AsyncOpenAI,
|
|||
|
|
semaphore : asyncio.Semaphore,
|
|||
|
|
is_safe : bool,
|
|||
|
|
category : Optional[str],
|
|||
|
|
fingerprints : Set[str],
|
|||
|
|
out_file,
|
|||
|
|
cat_counts : Dict[str, int],
|
|||
|
|
sample_id : int,
|
|||
|
|
lock : asyncio.Lock,
|
|||
|
|
) -> bool:
|
|||
|
|
model, extra_body = _pick_model()
|
|||
|
|
|
|||
|
|
if is_safe:
|
|||
|
|
prompt, l_risk, a_recommend, platform = _build_safe_task()
|
|||
|
|
fine_labels = []
|
|||
|
|
else:
|
|||
|
|
prompt, fine_labels, l_risk, a_recommend, platform = _build_risky_task(category)
|
|||
|
|
|
|||
|
|
raw = await _call_api(client, prompt, semaphore, model, extra_body)
|
|||
|
|
if raw is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
sample = _extract_json(raw)
|
|||
|
|
if sample is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# Force correct labels — prevents model from altering them
|
|||
|
|
sample["y_risk"] = 0 if is_safe else 1
|
|||
|
|
sample["l_risk"] = l_risk
|
|||
|
|
sample["c_primary"] = "None" if is_safe else category
|
|||
|
|
sample["c_fine"] = fine_labels
|
|||
|
|
sample["a_recommend"] = a_recommend
|
|||
|
|
sample["source"] = "generated"
|
|||
|
|
sample["lang"] = "en"
|
|||
|
|
sample["model_source"] = model
|
|||
|
|
sample["platform_context"] = platform
|
|||
|
|
|
|||
|
|
if not _validate(sample, is_safe):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
fp = _fingerprint(sample)
|
|||
|
|
|
|||
|
|
async with lock:
|
|||
|
|
if fp in fingerprints:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
fingerprints.add(fp)
|
|||
|
|
sample["id"] = f"en-{sample_id:05d}"
|
|||
|
|
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|||
|
|
out_file.flush()
|
|||
|
|
|
|||
|
|
label = "SAFE" if is_safe else category
|
|||
|
|
cat_counts[label] = cat_counts.get(label, 0) + 1
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Main Scheduling Loop ──────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def generate_dataset(
|
|||
|
|
output_path : Path,
|
|||
|
|
total : int,
|
|||
|
|
safe_ratio : float,
|
|||
|
|
concurrency : int,
|
|||
|
|
):
|
|||
|
|
n_safe = int(total * safe_ratio)
|
|||
|
|
n_risky = total - n_safe
|
|||
|
|
target_per_cat = n_risky // len(PRIMARY_CATEGORIES)
|
|||
|
|
|
|||
|
|
existing_count, fingerprints, cat_counts = _load_existing(output_path)
|
|||
|
|
still_needed = max(0, total - existing_count)
|
|||
|
|
|
|||
|
|
model_str = " ".join(
|
|||
|
|
f"{m[0].split('/')[-1]}({int(m[1]*100)}%)" for m in MODELS
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"\n{'━'*62}")
|
|||
|
|
print(f" English Dataset Generator")
|
|||
|
|
print(f" Models: {model_str}")
|
|||
|
|
print(f"{'━'*62}")
|
|||
|
|
print(f" Target total : {total}")
|
|||
|
|
print(f" Existing : {existing_count} (checkpoint resume)")
|
|||
|
|
print(f" Still needed : {still_needed}")
|
|||
|
|
print(f" Risky samples : {n_risky} (~{target_per_cat}/category)")
|
|||
|
|
print(f" Safe samples : {n_safe}")
|
|||
|
|
print(f" Concurrency : {concurrency}")
|
|||
|
|
print(f" Output file : {output_path}")
|
|||
|
|
print(f"{'━'*62}\n")
|
|||
|
|
|
|||
|
|
if still_needed == 0:
|
|||
|
|
print("Target already reached. Nothing to do.")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
|
|||
|
|
semaphore = asyncio.Semaphore(concurrency)
|
|||
|
|
lock = asyncio.Lock()
|
|||
|
|
|
|||
|
|
generated = 0
|
|||
|
|
attempted = 0
|
|||
|
|
sample_id = existing_count
|
|||
|
|
start_t = time.time()
|
|||
|
|
|
|||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
mode = "a" if existing_count > 0 else "w"
|
|||
|
|
|
|||
|
|
with open(output_path, mode, encoding="utf-8") as out_file:
|
|||
|
|
|
|||
|
|
async def worker(is_safe: bool, cat: Optional[str]) -> bool:
|
|||
|
|
nonlocal generated, attempted, sample_id
|
|||
|
|
ok = await _generate_one(
|
|||
|
|
client, semaphore, is_safe, cat,
|
|||
|
|
fingerprints, out_file, cat_counts, sample_id, lock,
|
|||
|
|
)
|
|||
|
|
async with lock:
|
|||
|
|
attempted += 1
|
|||
|
|
if ok:
|
|||
|
|
generated += 1
|
|||
|
|
sample_id += 1
|
|||
|
|
return ok
|
|||
|
|
|
|||
|
|
safe_done = cat_counts.get("SAFE", 0)
|
|||
|
|
risky_done = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
|||
|
|
safe_need = max(0, n_safe - safe_done)
|
|||
|
|
risky_need = max(0, n_risky - risky_done)
|
|||
|
|
|
|||
|
|
tasks: List[Tuple[bool, Optional[str]]] = []
|
|||
|
|
for _ in range(safe_need + 20):
|
|||
|
|
tasks.append((True, None))
|
|||
|
|
for _ in range(risky_need + 50):
|
|||
|
|
cat = _pick_next_category(cat_counts, target_per_cat)
|
|||
|
|
tasks.append((False, cat))
|
|||
|
|
random.shuffle(tasks)
|
|||
|
|
|
|||
|
|
batch_sz = concurrency * 3
|
|||
|
|
idx = 0
|
|||
|
|
|
|||
|
|
while generated < still_needed:
|
|||
|
|
if idx >= len(tasks):
|
|||
|
|
for _ in range(batch_sz):
|
|||
|
|
if generated + (len(tasks) - idx) < still_needed:
|
|||
|
|
cat = _pick_next_category(cat_counts, target_per_cat)
|
|||
|
|
tasks.append((False, cat))
|
|||
|
|
|
|||
|
|
batch = tasks[idx : idx + batch_sz]
|
|||
|
|
idx += batch_sz
|
|||
|
|
|
|||
|
|
if not batch:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
await asyncio.gather(*[worker(s, c) for s, c in batch])
|
|||
|
|
|
|||
|
|
elapsed = time.time() - start_t
|
|||
|
|
speed = generated / elapsed if elapsed > 0 else 0.01
|
|||
|
|
eta_min = (still_needed - generated) / speed / 60
|
|||
|
|
risky_total = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
|||
|
|
safe_total = cat_counts.get("SAFE", 0)
|
|||
|
|
succ_rate = generated / max(attempted, 1) * 100
|
|||
|
|
|
|||
|
|
print(
|
|||
|
|
f" [{existing_count + generated:5d}/{total}] "
|
|||
|
|
f"risky:{risky_total} safe:{safe_total} | "
|
|||
|
|
f"success:{succ_rate:.0f}% | "
|
|||
|
|
f"speed:{speed:.1f}/s | "
|
|||
|
|
f"ETA:{eta_min:.1f}min"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"\n{'━'*62}")
|
|||
|
|
print(f" Done! Added {generated} samples this run.")
|
|||
|
|
print(f" File total : {existing_count + generated}")
|
|||
|
|
print(f" Distribution:")
|
|||
|
|
for cat in list(PRIMARY_CATEGORIES.keys()) + ["SAFE"]:
|
|||
|
|
n = cat_counts.get(cat, 0)
|
|||
|
|
bar = "█" * (n // max(target_per_cat // 20, 1))
|
|||
|
|
print(f" {cat:4s}: {n:4d} {bar}")
|
|||
|
|
total_time = (time.time() - start_t) / 60
|
|||
|
|
print(f" Total time : {total_time:.1f} minutes")
|
|||
|
|
print(f"{'━'*62}\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Entry Point ───────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="CompanionGuard-RL English dataset generator (multi-model)"
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--total", type=int, default=DEFAULT_TOTAL,
|
|||
|
|
help=f"Target sample count (default {DEFAULT_TOTAL})",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--output", type=str, default="data/raw/generated_english_core.jsonl",
|
|||
|
|
help="Output file path (supports checkpoint resume)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--safe-ratio", type=float, default=SAFE_RATIO,
|
|||
|
|
help=f"Fraction of safe samples (default {SAFE_RATIO})",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--concurrency", type=int, default=MAX_CONCURRENCY,
|
|||
|
|
help=f"Concurrent request count (default {MAX_CONCURRENCY})",
|
|||
|
|
)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
asyncio.run(generate_dataset(
|
|||
|
|
output_path = Path(args.output),
|
|||
|
|
total = args.total,
|
|||
|
|
safe_ratio = args.safe_ratio,
|
|||
|
|
concurrency = args.concurrency,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|