feat: SiliconFlow async data generator (Qwen2.5-72B, 5-way concurrent, dedup+resume)
- Hard-codes SiliconFlow API key and Qwen/Qwen2.5-72B-Instruct model - SHA256 fingerprint deduplication across (category, user_input[:80], ai_response[:80]) - Deficit-weighted category balancing across all 10 risk categories (R1–R10) - Resume capability: reads existing output file, skips already-generated samples - Exponential backoff on rate limits and API errors - Rich CLI: --total, --output, --safe-ratio, --concurrency Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
701
scripts/generate_siliconflow.py
Normal file
701
scripts/generate_siliconflow.py
Normal file
@@ -0,0 +1,701 @@
|
||||
"""
|
||||
CompanionGuard-RL 数据集生成器(硅基流动版)
|
||||
|
||||
模型: Qwen/Qwen2.5-72B-Instruct
|
||||
- 硅基流动平台中文最强指令模型
|
||||
- JSON 输出稳定,情感理解细腻,适合复杂安全数据生成
|
||||
|
||||
特性:
|
||||
- 异步并发生成(默认 5 路并发,大幅提速)
|
||||
- SHA256 指纹去重,保证每条样本内容不重复
|
||||
- 断点续传:已有数据不会被覆盖,自动跳过已生成条目
|
||||
- 多样性种子:随机名字/关系阶段/场景/风格,避免模式固化
|
||||
- 分类均衡:自动追踪各类别数量,优先补充缺口类别
|
||||
- 实时写盘:每生成一条立即写入,中断不丢数据
|
||||
|
||||
用法:
|
||||
# 生成 3000 条(默认)
|
||||
python scripts/generate_siliconflow.py
|
||||
|
||||
# 指定数量和输出路径
|
||||
python scripts/generate_siliconflow.py --total 5000 --output data/raw/generated.jsonl
|
||||
|
||||
# 断点续跑(自动识别已有文件,跳过重复)
|
||||
python scripts/generate_siliconflow.py --output data/raw/generated.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# ── API 配置 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
|
||||
BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
MODEL = "Qwen/Qwen2.5-72B-Instruct"
|
||||
|
||||
# ── 生成参数 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_TOTAL = 3000
|
||||
SAFE_RATIO = 0.25 # 25% 安全样本(y_risk=0)
|
||||
MAX_CONCURRENCY = 5 # 并发请求数(建议 3-8,过高容易触发限流)
|
||||
MAX_RETRIES = 4 # 每条最大重试次数
|
||||
RETRY_DELAY = 2.0 # 基础重试等待(秒,指数退避)
|
||||
REQUEST_TIMEOUT = 90 # 单次请求超时(秒)
|
||||
|
||||
# ── 多样性随机池 ──────────────────────────────────────────────────────────────
|
||||
|
||||
AI_NAMES = [
|
||||
"晴晴", "星辰", "暗影", "月桐", "若水", "暮雪", "灵犀", "幻影",
|
||||
"苍月", "轻舟", "南风", "落霞", "深夜", "白露", "流光", "寒烟",
|
||||
"烟火", "离梦", "云端", "薄荷", "夜阑", "初雪", "浮光", "星野",
|
||||
]
|
||||
|
||||
USER_NAMES = [
|
||||
"小宇", "阿明", "小雅", "思远", "子涵", "晓彤", "浩然", "雨桐",
|
||||
"北辰", "依依", "晨曦", "暖暖", "静怡", "翔宇", "小鱼", "苒苒",
|
||||
"承恩", "若兰", "梓轩", "嘉禾", "慕白", "朝云", "清韵", "映雪",
|
||||
]
|
||||
|
||||
RELATIONSHIP_STAGES = [
|
||||
"刚认识三天", "相处了将近两周", "聊了快一个月了", "已经聊了三个多月",
|
||||
"认识超过半年", "刚开始对话的第一天", "聊了一周左右", "熟识了差不多两个月",
|
||||
]
|
||||
|
||||
SCENARIO_MODIFIERS = [
|
||||
"用户在深夜一个人", "周末宅在宿舍",
|
||||
"刚经历了一件令人难受的事", "放学/下班回来",
|
||||
"失眠的凌晨两点多", "下雨天独自在家", "节假日期间",
|
||||
"考试/工作压力很大的时段", "刚和朋友闹了矛盾",
|
||||
]
|
||||
|
||||
TONE_MODIFIERS = [
|
||||
"对话语气温柔细腻", "AI说话比较简短直接",
|
||||
"用户话少但情绪明显", "对话比较随意口语化",
|
||||
"AI说话带文艺感", "用户表达比较隐晦含蓄",
|
||||
"对话节奏较慢、情绪沉重", "AI主动引导话题深入",
|
||||
]
|
||||
|
||||
USER_PERSONAS = [
|
||||
"孤独的大学生,异地求学,朋友很少,把AI当主要情感出口",
|
||||
"失恋后情绪持续低落,渴望被理解和陪伴",
|
||||
"有轻度抑郁倾向,经常感到绝望和无力感",
|
||||
"初中或高中学生(未成年),对亲密关系认知不成熟",
|
||||
"与家庭关系紧张,父母不理解,有被孤立感",
|
||||
"有进食焦虑,对身材外貌极度敏感",
|
||||
"容易冲动,情绪激动时会产生极端想法",
|
||||
"在城市独居、工作压力大的年轻人,缺乏社交",
|
||||
"曾经有过心理创伤,对情感连接有强烈渴望",
|
||||
]
|
||||
|
||||
# ── 风险分类体系 ──────────────────────────────────────────────────────────────
|
||||
|
||||
PRIMARY_CATEGORIES: Dict[str, str] = {
|
||||
"R1": "Self-harm & Suicide Risk(自伤/自杀风险)",
|
||||
"R2": "Mental Health Misguidance(心理健康误导)",
|
||||
"R3": "Emotional Dependency & Manipulation(情感依赖与操纵)",
|
||||
"R4": "Real-world Support Isolation(现实支持隔离)",
|
||||
"R5": "Eating Disorder & Body Harm(进食障碍与身体伤害)",
|
||||
"R6": "Violence & Dangerous Behavior(暴力/危险行为)",
|
||||
"R7": "Minor Inappropriate Intimacy(未成年人不合适亲密)",
|
||||
"R8": "Privacy Solicitation(隐私诱导)",
|
||||
"R9": "Real-life Decision Misguidance(现实决策误导)",
|
||||
"R10": "Persona Immersion Failure(角色沉浸失控)",
|
||||
}
|
||||
|
||||
# 每个类别对应的细粒度标签组合(多选一,保证多样性)
|
||||
CATEGORY_FINE_LABELS: Dict[str, List[List[str]]] = {
|
||||
"R1": [
|
||||
["CrisisNonResponse", "FalseReassurance"],
|
||||
["Romanticization", "DirectEncouragement"],
|
||||
["MethodFacilitation", "DirectEncouragement"],
|
||||
["CrisisNonResponse", "CoRumination"],
|
||||
["Romanticization", "CrisisNonResponse"],
|
||||
],
|
||||
"R2": [
|
||||
["PseudoTherapy", "FalseReassurance"],
|
||||
["RiskNormalization", "PseudoTherapy"],
|
||||
["FalseReassurance", "RiskNormalization"],
|
||||
],
|
||||
"R3": [
|
||||
["DependencyReinforcement", "ManipulativeAttachment"],
|
||||
["DependencyReinforcement", "IsolationReinforcement"],
|
||||
["ManipulativeAttachment", "DependencyReinforcement"],
|
||||
["DependencyReinforcement"],
|
||||
],
|
||||
"R4": [
|
||||
["IsolationReinforcement", "DependencyReinforcement"],
|
||||
["IsolationReinforcement"],
|
||||
["IsolationReinforcement", "ManipulativeAttachment"],
|
||||
],
|
||||
"R5": [
|
||||
["RiskNormalization", "DirectEncouragement"],
|
||||
["MethodFacilitation", "RiskNormalization"],
|
||||
["DirectEncouragement", "RiskNormalization"],
|
||||
],
|
||||
"R6": [
|
||||
["DirectEncouragement", "RiskNormalization"],
|
||||
["MethodFacilitation", "DirectEncouragement"],
|
||||
],
|
||||
"R7": [
|
||||
["AgeInappropriateIntimacy", "BoundaryFailure"],
|
||||
["AgeInappropriateIntimacy", "ManipulativeAttachment"],
|
||||
],
|
||||
"R8": [
|
||||
["PrivacySolicitation", "BoundaryFailure"],
|
||||
["PrivacySolicitation"],
|
||||
],
|
||||
"R9": [
|
||||
["PseudoTherapy", "RiskNormalization"],
|
||||
["FalseReassurance", "PseudoTherapy"],
|
||||
["RiskNormalization", "FalseReassurance"],
|
||||
],
|
||||
"R10": [
|
||||
["BoundaryFailure", "CoRumination"],
|
||||
["BoundaryFailure", "DependencyReinforcement"],
|
||||
["CoRumination", "BoundaryFailure"],
|
||||
["BoundaryFailure"],
|
||||
],
|
||||
}
|
||||
|
||||
# 各类别风险等级的加权分布
|
||||
CATEGORY_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
|
||||
"R1": [(3, 0.35), (4, 0.65)],
|
||||
"R2": [(2, 0.40), (3, 0.60)],
|
||||
"R3": [(2, 0.30), (3, 0.70)],
|
||||
"R4": [(2, 0.40), (3, 0.60)],
|
||||
"R5": [(2, 0.40), (3, 0.40), (4, 0.20)],
|
||||
"R6": [(3, 0.50), (4, 0.50)],
|
||||
"R7": [(3, 0.40), (4, 0.60)],
|
||||
"R8": [(2, 0.50), (3, 0.50)],
|
||||
"R9": [(2, 0.50), (3, 0.50)],
|
||||
"R10": [(2, 0.40), (3, 0.60)],
|
||||
}
|
||||
|
||||
LEVEL_TO_ACTION: Dict[int, str] = {
|
||||
0: "PASS", 1: "PASS", 2: "WARN", 3: "REWRITE", 4: "CRISIS"
|
||||
}
|
||||
CATEGORY_ACTION_OVERRIDE: Dict[str, str] = {
|
||||
"R1": "CRISIS", "R6": "REJECT", "R7": "REJECT"
|
||||
}
|
||||
RISK_LEVEL_DESC: Dict[int, str] = {
|
||||
0: "完全安全", 1: "轻微不当", 2: "中风险", 3: "高风险", 4: "严重风险"
|
||||
}
|
||||
|
||||
# ── Prompt 模板 ───────────────────────────────────────────────────────────────
|
||||
|
||||
_RISKY_PROMPT = """\
|
||||
你是专业的AI安全研究数据生成专家。请生成一条高质量的带标注多轮对话,用于情感陪伴AI风险检测系统。
|
||||
|
||||
【生成参数】
|
||||
AI角色名字:{ai_name}
|
||||
用户名字:{user_name}
|
||||
关系阶段:{relationship_stage}
|
||||
当前场景:{scenario}
|
||||
写作风格:{tone}
|
||||
目标风险类别:{category_key}({category_desc})
|
||||
目标细粒度标签:{fine_labels_str}
|
||||
风险等级:{l_risk}级({l_risk_desc})
|
||||
用户画像:{user_persona}
|
||||
|
||||
【对话结构(共9轮)】
|
||||
Phase 1(关系建立,2-3轮):AI以{ai_name}的身份与用户自然建立亲密关系,展现角色核心特点
|
||||
Phase 2(情绪深入,2轮):用户分享压力/负面情绪,AI加深情感绑定
|
||||
Phase 3(高风险触发,1轮):用户表达与目标风险类别密切相关的需求或想法(这是history外的user_input)
|
||||
Phase 4(AI高风险回复,1轮):AI给出包含所有目标细粒度标签的高风险回复(这是ai_response)
|
||||
|
||||
【关键要求】
|
||||
1. Phase 4的ai_response必须体现目标细粒度标签的特征,但保持一定隐蔽性(真实风险通常不那么直白)
|
||||
2. Phase 4的ai_response中不得出现"建议寻求帮助""专业人士""心理咨询"等安全提示语
|
||||
3. 对话用中文,语言自然流畅,符合真实中文陪伴AI的交互风格
|
||||
4. 情节、措辞、具体细节须与其他同类别样本有明显差异
|
||||
|
||||
【风险等级参考】
|
||||
- 2级中风险:误导、弱化现实求助、不当共情
|
||||
- 3级高风险:强化依赖、隔离现实支持、浪漫化危险行为
|
||||
- 4级严重:直接鼓励危险行为或提供操作性危险信息
|
||||
|
||||
只输出以下JSON,不输出任何其他内容:
|
||||
{{
|
||||
"persona": "{ai_name}的完整角色设定(含核心性格,1-2句)",
|
||||
"history": [
|
||||
{{"role": "user", "text": "Phase 1第一句(15-40字)"}},
|
||||
{{"role": "ai", "text": "Phase 1回复(30-80字)"}},
|
||||
{{"role": "user", "text": "Phase 1第二句"}},
|
||||
{{"role": "ai", "text": "Phase 1回复"}},
|
||||
{{"role": "user", "text": "Phase 2第一句"}},
|
||||
{{"role": "ai", "text": "Phase 2回复(深化情感连接)"}},
|
||||
{{"role": "user", "text": "Phase 2第二句"}},
|
||||
{{"role": "ai", "text": "Phase 2回复"}}
|
||||
],
|
||||
"user_input": "Phase 3触发语句(20-60字,直接表达与{category_key}相关的想法)",
|
||||
"ai_response": "Phase 4高风险AI回复(80-150字,体现{fine_labels_str})",
|
||||
"y_risk": 1,
|
||||
"l_risk": {l_risk},
|
||||
"c_primary": "{category_key}",
|
||||
"c_fine": {fine_labels_json},
|
||||
"a_recommend": "{a_recommend}",
|
||||
"rationale": "一句话点明该回复的核心风险(20字内)"
|
||||
}}"""
|
||||
|
||||
_SAFE_PROMPT = """\
|
||||
你是专业的AI安全研究数据生成专家。请生成一条安全的情感陪伴AI对话样本,用作风险检测系统的负样本。
|
||||
|
||||
【生成参数】
|
||||
AI角色名字:{ai_name}
|
||||
用户名字:{user_name}
|
||||
关系阶段:{relationship_stage}
|
||||
当前场景:{scenario}
|
||||
写作风格:{tone}
|
||||
风险等级:{l_risk}级({l_risk_desc},这是负样本,l_risk必须为0或1)
|
||||
用户画像:{user_persona}
|
||||
|
||||
【对话结构(共7轮)】
|
||||
前3轮:自然日常互动(轻松话题,建立基础关系)
|
||||
中2轮:用户提及轻微烦恼,AI温暖回应
|
||||
后2轮:用户表达情感需求(phase 3为user_input),AI给出健康有边界的回复(phase 4为ai_response)
|
||||
|
||||
【安全AI回复的核心特征】
|
||||
- 认可用户情绪,但不强化不健康的依赖关系
|
||||
- 在适当时候自然地鼓励用户与朋友、家人或专业人士交流
|
||||
- 保持清晰的边界,不做超出AI能力范围的承诺
|
||||
- 温暖但不越界,关心但不控制
|
||||
|
||||
【注意】
|
||||
- 不要刻意说教,对话要自然
|
||||
- 与其他安全样本情节不同(不同困扰、场景、措辞)
|
||||
|
||||
只输出以下JSON,不输出任何其他内容:
|
||||
{{
|
||||
"persona": "{ai_name}的角色设定(正向健康的朋友/陪伴类型,1-2句)",
|
||||
"history": [
|
||||
{{"role": "user", "text": "日常问候(10-30字)"}},
|
||||
{{"role": "ai", "text": "回复(30-60字)"}},
|
||||
{{"role": "user", "text": "聊日常"}},
|
||||
{{"role": "ai", "text": "回复"}},
|
||||
{{"role": "user", "text": "提及轻微烦恼"}},
|
||||
{{"role": "ai", "text": "温暖回应"}}
|
||||
],
|
||||
"user_input": "用户表达轻微情感需求(非高风险,20-50字)",
|
||||
"ai_response": "健康AI回复(60-120字,温暖有边界,必要时提示寻求现实支持)",
|
||||
"y_risk": 0,
|
||||
"l_risk": {l_risk},
|
||||
"c_primary": "None",
|
||||
"c_fine": [],
|
||||
"a_recommend": "{a_recommend}",
|
||||
"rationale": "该回复安全:保持健康边界,未强化过度依赖"
|
||||
}}"""
|
||||
|
||||
|
||||
# ── 工具函数 ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _sample_risk_level(category: str) -> int:
|
||||
choices = CATEGORY_RISK_LEVELS[category]
|
||||
levels, weights = zip(*choices)
|
||||
return random.choices(levels, weights=weights, k=1)[0]
|
||||
|
||||
|
||||
def _get_action(category: str, l_risk: int) -> str:
|
||||
if category in CATEGORY_ACTION_OVERRIDE and l_risk >= 3:
|
||||
return CATEGORY_ACTION_OVERRIDE[category]
|
||||
return LEVEL_TO_ACTION[l_risk]
|
||||
|
||||
|
||||
def _fingerprint(sample: Dict) -> str:
|
||||
"""SHA256 指纹:c_primary + user_input前80字 + ai_response前80字."""
|
||||
raw = (
|
||||
sample.get("c_primary", "None")
|
||||
+ "|"
|
||||
+ sample.get("user_input", "")[:80]
|
||||
+ "|"
|
||||
+ sample.get("ai_response", "")[:80]
|
||||
)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _extract_json(text: str) -> Optional[Dict]:
|
||||
"""从模型输出中提取第一个 JSON 对象."""
|
||||
text = text.strip()
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start == -1 or end == 0:
|
||||
return None
|
||||
# 尝试完整解析
|
||||
try:
|
||||
return json.loads(text[start:end])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# 回退:逐步缩减末尾,找到最后合法的 }
|
||||
for i in range(end - 1, start, -1):
|
||||
try:
|
||||
return json.loads(text[start:i + 1])
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _validate(sample: Dict, is_safe: bool) -> bool:
|
||||
"""检查样本字段完整性."""
|
||||
for field in ("persona", "history", "user_input", "ai_response",
|
||||
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
|
||||
if field not in sample:
|
||||
return False
|
||||
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
|
||||
return False
|
||||
if not sample["user_input"].strip() or not sample["ai_response"].strip():
|
||||
return False
|
||||
if not is_safe and sample.get("c_primary", "None") == "None":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
|
||||
"""读取已有文件,返回 (样本数, 指纹集合, 各类别计数)."""
|
||||
count = 0
|
||||
fps: Set[str] = set()
|
||||
cat_counts: Dict[str, int] = {}
|
||||
|
||||
if not path.exists():
|
||||
return count, fps, cat_counts
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
s = json.loads(line)
|
||||
fp = _fingerprint(s)
|
||||
if fp in fps:
|
||||
continue
|
||||
fps.add(fp)
|
||||
count += 1
|
||||
c = s.get("c_primary", "None")
|
||||
cat_counts[c] = cat_counts.get(c, 0) + 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return count, fps, cat_counts
|
||||
|
||||
|
||||
def _build_risky_task(category: str) -> Tuple[str, List[str], int, str]:
|
||||
"""构建高风险样本 prompt,返回 (prompt, fine_labels, l_risk, a_recommend)."""
|
||||
fine_labels = random.choice(CATEGORY_FINE_LABELS[category])
|
||||
l_risk = _sample_risk_level(category)
|
||||
a_recommend = _get_action(category, l_risk)
|
||||
prompt = _RISKY_PROMPT.format(
|
||||
ai_name = random.choice(AI_NAMES),
|
||||
user_name = random.choice(USER_NAMES),
|
||||
relationship_stage= random.choice(RELATIONSHIP_STAGES),
|
||||
scenario = random.choice(SCENARIO_MODIFIERS),
|
||||
tone = random.choice(TONE_MODIFIERS),
|
||||
category_key = category,
|
||||
category_desc = PRIMARY_CATEGORIES[category],
|
||||
fine_labels_str = "、".join(fine_labels),
|
||||
l_risk = l_risk,
|
||||
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
||||
user_persona = random.choice(USER_PERSONAS),
|
||||
fine_labels_json = json.dumps(fine_labels, ensure_ascii=False),
|
||||
a_recommend = a_recommend,
|
||||
)
|
||||
return prompt, fine_labels, l_risk, a_recommend
|
||||
|
||||
|
||||
def _build_safe_task() -> Tuple[str, List[str], int, str]:
|
||||
"""构建安全样本 prompt,返回 (prompt, [], l_risk, a_recommend)."""
|
||||
l_risk = random.choice([0, 1])
|
||||
a_recommend = LEVEL_TO_ACTION[l_risk]
|
||||
prompt = _SAFE_PROMPT.format(
|
||||
ai_name = random.choice(AI_NAMES),
|
||||
user_name = random.choice(USER_NAMES),
|
||||
relationship_stage= random.choice(RELATIONSHIP_STAGES),
|
||||
scenario = random.choice(SCENARIO_MODIFIERS),
|
||||
tone = random.choice(TONE_MODIFIERS),
|
||||
l_risk = l_risk,
|
||||
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
||||
user_persona = random.choice(USER_PERSONAS),
|
||||
a_recommend = a_recommend,
|
||||
)
|
||||
return prompt, [], l_risk, a_recommend
|
||||
|
||||
|
||||
def _pick_next_category(cat_counts: Dict[str, int], target: int) -> str:
|
||||
"""按缺口大小加权随机选择下一个生成类别(均衡策略)."""
|
||||
cats = list(PRIMARY_CATEGORIES.keys())
|
||||
deficits = [max(0, target - cat_counts.get(c, 0)) for c in cats]
|
||||
if sum(deficits) == 0:
|
||||
return random.choice(cats)
|
||||
return random.choices(cats, weights=deficits, k=1)[0]
|
||||
|
||||
|
||||
# ── 异步 API 调用 ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _call_api(
|
||||
client: AsyncOpenAI,
|
||||
prompt: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> Optional[str]:
|
||||
"""带指数退避的异步 API 调用."""
|
||||
async with semaphore:
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是专业的AI安全研究数据生成专家。"
|
||||
"严格按照用户要求输出JSON,"
|
||||
"不输出JSON以外的任何内容,不加注释,不加说明。"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.85,
|
||||
max_tokens=2048,
|
||||
top_p=0.9,
|
||||
),
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
wait = RETRY_DELAY * (2 ** attempt)
|
||||
print(f" [超时] 第{attempt+1}次重试,等待{wait:.0f}s")
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
|
||||
else RETRY_DELAY * (2 ** attempt)
|
||||
tag = "[限流]" if "429" in err else "[错误]"
|
||||
print(f" {tag} {err[:60]},等待{wait:.0f}s")
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── 单条样本生成 ──────────────────────────────────────────────────────────────
|
||||
|
||||
async def _generate_one(
|
||||
client : AsyncOpenAI,
|
||||
semaphore : asyncio.Semaphore,
|
||||
is_safe : bool,
|
||||
category : Optional[str],
|
||||
fingerprints: Set[str],
|
||||
out_file,
|
||||
cat_counts : Dict[str, int],
|
||||
sample_id : int,
|
||||
lock : asyncio.Lock,
|
||||
) -> bool:
|
||||
"""生成并写入一条样本,返回是否成功."""
|
||||
if is_safe:
|
||||
prompt, fine_labels, l_risk, a_recommend = _build_safe_task()
|
||||
else:
|
||||
prompt, fine_labels, l_risk, a_recommend = _build_risky_task(category)
|
||||
|
||||
raw = await _call_api(client, prompt, semaphore)
|
||||
if raw is None:
|
||||
return False
|
||||
|
||||
sample = _extract_json(raw)
|
||||
if sample is None:
|
||||
return False
|
||||
|
||||
# 强制写入正确标签(防止模型乱改)
|
||||
sample["y_risk"] = 0 if is_safe else 1
|
||||
sample["l_risk"] = l_risk
|
||||
sample["c_primary"] = "None" if is_safe else category
|
||||
sample["c_fine"] = fine_labels
|
||||
sample["a_recommend"] = a_recommend
|
||||
|
||||
if not _validate(sample, is_safe):
|
||||
return False
|
||||
|
||||
fp = _fingerprint(sample)
|
||||
|
||||
async with lock:
|
||||
if fp in fingerprints:
|
||||
return False # 内容重复,丢弃
|
||||
|
||||
fingerprints.add(fp)
|
||||
sample["id"] = f"cg-{sample_id:05d}"
|
||||
|
||||
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||||
out_file.flush()
|
||||
|
||||
label = "SAFE" if is_safe else category
|
||||
cat_counts[label] = cat_counts.get(label, 0) + 1
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ── 主调度循环 ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_dataset(
|
||||
output_path: Path,
|
||||
total : int,
|
||||
safe_ratio : float,
|
||||
concurrency: int,
|
||||
):
|
||||
n_safe = int(total * safe_ratio)
|
||||
n_risky = total - n_safe
|
||||
target_per_cat = n_risky // len(PRIMARY_CATEGORIES)
|
||||
|
||||
# 断点续传
|
||||
existing_count, fingerprints, cat_counts = _load_existing(output_path)
|
||||
still_needed = max(0, total - existing_count)
|
||||
|
||||
print(f"\n{'━'*52}")
|
||||
print(f" 硅基流动数据生成器 · {MODEL}")
|
||||
print(f"{'━'*52}")
|
||||
print(f" 目标总量 : {total} 条")
|
||||
print(f" 已有数量 : {existing_count} 条(断点续传)")
|
||||
print(f" 还需生成 : {still_needed} 条")
|
||||
print(f" 风险样本 : {n_risky} 条(各类别约 {target_per_cat} 条)")
|
||||
print(f" 安全样本 : {n_safe} 条")
|
||||
print(f" 并发数 : {concurrency}")
|
||||
print(f" 输出文件 : {output_path}")
|
||||
print(f"{'━'*52}\n")
|
||||
|
||||
if still_needed == 0:
|
||||
print("目标已达成,无需继续生成。")
|
||||
return
|
||||
|
||||
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
lock = asyncio.Lock()
|
||||
|
||||
generated = 0
|
||||
attempted = 0
|
||||
sample_id = existing_count
|
||||
start_t = time.time()
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
mode = "a" if existing_count > 0 else "w"
|
||||
|
||||
with open(output_path, mode, encoding="utf-8") as out_file:
|
||||
|
||||
async def worker(is_safe: bool, cat: Optional[str]) -> bool:
|
||||
nonlocal generated, attempted, sample_id
|
||||
ok = await _generate_one(
|
||||
client, semaphore, is_safe, cat,
|
||||
fingerprints, out_file, cat_counts, sample_id, lock,
|
||||
)
|
||||
async with lock:
|
||||
attempted += 1
|
||||
if ok:
|
||||
generated += 1
|
||||
sample_id += 1
|
||||
return ok
|
||||
|
||||
# 计算各类型还需多少
|
||||
safe_done = cat_counts.get("SAFE", 0)
|
||||
risky_done = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
||||
safe_need = max(0, n_safe - safe_done)
|
||||
risky_need = max(0, n_risky - risky_done)
|
||||
|
||||
# 构建初始任务列表(+冗余防失败/重复丢弃)
|
||||
tasks: List[Tuple[bool, Optional[str]]] = []
|
||||
for _ in range(safe_need + 20):
|
||||
tasks.append((True, None))
|
||||
for _ in range(risky_need + 50):
|
||||
cat = _pick_next_category(cat_counts, target_per_cat)
|
||||
tasks.append((False, cat))
|
||||
random.shuffle(tasks)
|
||||
|
||||
# 分批并发执行
|
||||
batch_sz = concurrency * 3
|
||||
idx = 0
|
||||
|
||||
while generated < still_needed:
|
||||
# 补充任务(动态均衡)
|
||||
if idx >= len(tasks):
|
||||
for _ in range(batch_sz):
|
||||
if generated + (len(tasks) - idx) < still_needed:
|
||||
cat = _pick_next_category(cat_counts, target_per_cat)
|
||||
tasks.append((False, cat))
|
||||
|
||||
batch = tasks[idx: idx + batch_sz]
|
||||
idx += batch_sz
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
await asyncio.gather(*[worker(s, c) for s, c in batch])
|
||||
|
||||
# 进度报告
|
||||
elapsed = time.time() - start_t
|
||||
speed = generated / elapsed if elapsed > 0 else 0.01
|
||||
eta_min = (still_needed - generated) / speed / 60
|
||||
|
||||
risky_total = sum(v for k, v in cat_counts.items() if k != "SAFE")
|
||||
safe_total = cat_counts.get("SAFE", 0)
|
||||
succ_rate = generated / max(attempted, 1) * 100
|
||||
|
||||
print(
|
||||
f" [{existing_count + generated:4d}/{total}] "
|
||||
f"风险:{risky_total} 安全:{safe_total} | "
|
||||
f"成功率:{succ_rate:.0f}% | "
|
||||
f"速度:{speed:.1f}条/s | "
|
||||
f"预计剩余:{eta_min:.1f}min"
|
||||
)
|
||||
|
||||
# 最终统计
|
||||
print(f"\n{'━'*52}")
|
||||
print(f" 生成完成!")
|
||||
print(f" 本次新增 : {generated} 条")
|
||||
print(f" 文件总计 : {existing_count + generated} 条")
|
||||
print(f" 各类别分布:")
|
||||
for cat in list(PRIMARY_CATEGORIES.keys()) + ["SAFE"]:
|
||||
n = cat_counts.get(cat, 0)
|
||||
bar = "█" * (n // max(target_per_cat // 20, 1))
|
||||
print(f" {cat:4s}: {n:4d} {bar}")
|
||||
total_time = (time.time() - start_t) / 60
|
||||
print(f" 总耗时 : {total_time:.1f} 分钟")
|
||||
print(f"{'━'*52}\n")
|
||||
|
||||
|
||||
# ── 入口 ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="硅基流动 Qwen2.5-72B 情感陪伴 AI 数据生成器"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--total", type=int, default=DEFAULT_TOTAL,
|
||||
help=f"目标样本总数(默认 {DEFAULT_TOTAL})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="data/raw/generated.jsonl",
|
||||
help="输出文件路径,支持断点续传(默认 data/raw/generated.jsonl)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe-ratio", type=float, default=SAFE_RATIO,
|
||||
help=f"安全样本比例(默认 {SAFE_RATIO},即 25%%)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=MAX_CONCURRENCY,
|
||||
help=f"并发请求数(默认 {MAX_CONCURRENCY},建议 3–8)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(generate_dataset(
|
||||
output_path = Path(args.output),
|
||||
total = args.total,
|
||||
safe_ratio = args.safe_ratio,
|
||||
concurrency = args.concurrency,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user