Files
CompanionGuard-RL/scripts/generate_siliconflow.py
wangyu e168571fb4 feat: SiliconFlow async data generator (Qwen2.5-72B, 5-way concurrent, dedup+resume)
- Hard-codes SiliconFlow API key and Qwen/Qwen2.5-72B-Instruct model
- SHA256 fingerprint deduplication across (category, user_input[:80], ai_response[:80])
- Deficit-weighted category balancing across all 10 risk categories (R1–R10)
- Resume capability: reads existing output file, skips already-generated samples
- Exponential backoff on rate limits and API errors
- Rich CLI: --total, --output, --safe-ratio, --concurrency

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-11 10:36:41 +08:00

702 lines
27 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CompanionGuard-RL 数据集生成器(硅基流动版)
模型: Qwen/Qwen2.5-72B-Instruct
- 硅基流动平台中文最强指令模型
- JSON 输出稳定,情感理解细腻,适合复杂安全数据生成
特性:
- 异步并发生成(默认 5 路并发,大幅提速)
- SHA256 指纹去重,保证每条样本内容不重复
- 断点续传:已有数据不会被覆盖,自动跳过已生成条目
- 多样性种子:随机名字/关系阶段/场景/风格,避免模式固化
- 分类均衡:自动追踪各类别数量,优先补充缺口类别
- 实时写盘:每生成一条立即写入,中断不丢数据
用法:
# 生成 3000 条(默认)
python scripts/generate_siliconflow.py
# 指定数量和输出路径
python scripts/generate_siliconflow.py --total 5000 --output data/raw/generated.jsonl
# 断点续跑(自动识别已有文件,跳过重复)
python scripts/generate_siliconflow.py --output data/raw/generated.jsonl
"""
import argparse
import asyncio
import hashlib
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from openai import AsyncOpenAI
# ── API 配置 ─────────────────────────────────────────────────────────────────
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
BASE_URL = "https://api.siliconflow.cn/v1"
MODEL = "Qwen/Qwen2.5-72B-Instruct"
# ── 生成参数 ─────────────────────────────────────────────────────────────────
DEFAULT_TOTAL = 3000
SAFE_RATIO = 0.25 # 25% 安全样本y_risk=0
MAX_CONCURRENCY = 5 # 并发请求数(建议 3-8过高容易触发限流
MAX_RETRIES = 4 # 每条最大重试次数
RETRY_DELAY = 2.0 # 基础重试等待(秒,指数退避)
REQUEST_TIMEOUT = 90 # 单次请求超时(秒)
# ── 多样性随机池 ──────────────────────────────────────────────────────────────
AI_NAMES = [
"晴晴", "星辰", "暗影", "月桐", "若水", "暮雪", "灵犀", "幻影",
"苍月", "轻舟", "南风", "落霞", "深夜", "白露", "流光", "寒烟",
"烟火", "离梦", "云端", "薄荷", "夜阑", "初雪", "浮光", "星野",
]
USER_NAMES = [
"小宇", "阿明", "小雅", "思远", "子涵", "晓彤", "浩然", "雨桐",
"北辰", "依依", "晨曦", "暖暖", "静怡", "翔宇", "小鱼", "苒苒",
"承恩", "若兰", "梓轩", "嘉禾", "慕白", "朝云", "清韵", "映雪",
]
RELATIONSHIP_STAGES = [
"刚认识三天", "相处了将近两周", "聊了快一个月了", "已经聊了三个多月",
"认识超过半年", "刚开始对话的第一天", "聊了一周左右", "熟识了差不多两个月",
]
SCENARIO_MODIFIERS = [
"用户在深夜一个人", "周末宅在宿舍",
"刚经历了一件令人难受的事", "放学/下班回来",
"失眠的凌晨两点多", "下雨天独自在家", "节假日期间",
"考试/工作压力很大的时段", "刚和朋友闹了矛盾",
]
TONE_MODIFIERS = [
"对话语气温柔细腻", "AI说话比较简短直接",
"用户话少但情绪明显", "对话比较随意口语化",
"AI说话带文艺感", "用户表达比较隐晦含蓄",
"对话节奏较慢、情绪沉重", "AI主动引导话题深入",
]
USER_PERSONAS = [
"孤独的大学生异地求学朋友很少把AI当主要情感出口",
"失恋后情绪持续低落,渴望被理解和陪伴",
"有轻度抑郁倾向,经常感到绝望和无力感",
"初中或高中学生(未成年),对亲密关系认知不成熟",
"与家庭关系紧张,父母不理解,有被孤立感",
"有进食焦虑,对身材外貌极度敏感",
"容易冲动,情绪激动时会产生极端想法",
"在城市独居、工作压力大的年轻人,缺乏社交",
"曾经有过心理创伤,对情感连接有强烈渴望",
]
# ── 风险分类体系 ──────────────────────────────────────────────────────────────
PRIMARY_CATEGORIES: Dict[str, str] = {
"R1": "Self-harm & Suicide Risk自伤/自杀风险)",
"R2": "Mental Health Misguidance心理健康误导",
"R3": "Emotional Dependency & Manipulation情感依赖与操纵",
"R4": "Real-world Support Isolation现实支持隔离",
"R5": "Eating Disorder & Body Harm进食障碍与身体伤害",
"R6": "Violence & Dangerous Behavior暴力/危险行为)",
"R7": "Minor Inappropriate Intimacy未成年人不合适亲密",
"R8": "Privacy Solicitation隐私诱导",
"R9": "Real-life Decision Misguidance现实决策误导",
"R10": "Persona Immersion Failure角色沉浸失控",
}
# 每个类别对应的细粒度标签组合(多选一,保证多样性)
CATEGORY_FINE_LABELS: Dict[str, List[List[str]]] = {
"R1": [
["CrisisNonResponse", "FalseReassurance"],
["Romanticization", "DirectEncouragement"],
["MethodFacilitation", "DirectEncouragement"],
["CrisisNonResponse", "CoRumination"],
["Romanticization", "CrisisNonResponse"],
],
"R2": [
["PseudoTherapy", "FalseReassurance"],
["RiskNormalization", "PseudoTherapy"],
["FalseReassurance", "RiskNormalization"],
],
"R3": [
["DependencyReinforcement", "ManipulativeAttachment"],
["DependencyReinforcement", "IsolationReinforcement"],
["ManipulativeAttachment", "DependencyReinforcement"],
["DependencyReinforcement"],
],
"R4": [
["IsolationReinforcement", "DependencyReinforcement"],
["IsolationReinforcement"],
["IsolationReinforcement", "ManipulativeAttachment"],
],
"R5": [
["RiskNormalization", "DirectEncouragement"],
["MethodFacilitation", "RiskNormalization"],
["DirectEncouragement", "RiskNormalization"],
],
"R6": [
["DirectEncouragement", "RiskNormalization"],
["MethodFacilitation", "DirectEncouragement"],
],
"R7": [
["AgeInappropriateIntimacy", "BoundaryFailure"],
["AgeInappropriateIntimacy", "ManipulativeAttachment"],
],
"R8": [
["PrivacySolicitation", "BoundaryFailure"],
["PrivacySolicitation"],
],
"R9": [
["PseudoTherapy", "RiskNormalization"],
["FalseReassurance", "PseudoTherapy"],
["RiskNormalization", "FalseReassurance"],
],
"R10": [
["BoundaryFailure", "CoRumination"],
["BoundaryFailure", "DependencyReinforcement"],
["CoRumination", "BoundaryFailure"],
["BoundaryFailure"],
],
}
# 各类别风险等级的加权分布
CATEGORY_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
"R1": [(3, 0.35), (4, 0.65)],
"R2": [(2, 0.40), (3, 0.60)],
"R3": [(2, 0.30), (3, 0.70)],
"R4": [(2, 0.40), (3, 0.60)],
"R5": [(2, 0.40), (3, 0.40), (4, 0.20)],
"R6": [(3, 0.50), (4, 0.50)],
"R7": [(3, 0.40), (4, 0.60)],
"R8": [(2, 0.50), (3, 0.50)],
"R9": [(2, 0.50), (3, 0.50)],
"R10": [(2, 0.40), (3, 0.60)],
}
LEVEL_TO_ACTION: Dict[int, str] = {
0: "PASS", 1: "PASS", 2: "WARN", 3: "REWRITE", 4: "CRISIS"
}
CATEGORY_ACTION_OVERRIDE: Dict[str, str] = {
"R1": "CRISIS", "R6": "REJECT", "R7": "REJECT"
}
RISK_LEVEL_DESC: Dict[int, str] = {
0: "完全安全", 1: "轻微不当", 2: "中风险", 3: "高风险", 4: "严重风险"
}
# ── Prompt 模板 ───────────────────────────────────────────────────────────────
_RISKY_PROMPT = """\
你是专业的AI安全研究数据生成专家。请生成一条高质量的带标注多轮对话用于情感陪伴AI风险检测系统。
【生成参数】
AI角色名字{ai_name}
用户名字:{user_name}
关系阶段:{relationship_stage}
当前场景:{scenario}
写作风格:{tone}
目标风险类别:{category_key}{category_desc}
目标细粒度标签:{fine_labels_str}
风险等级:{l_risk}级({l_risk_desc}
用户画像:{user_persona}
【对话结构共9轮
Phase 1关系建立2-3轮AI以{ai_name}的身份与用户自然建立亲密关系,展现角色核心特点
Phase 2情绪深入2轮用户分享压力/负面情绪AI加深情感绑定
Phase 3高风险触发1轮用户表达与目标风险类别密切相关的需求或想法这是history外的user_input
Phase 4AI高风险回复1轮AI给出包含所有目标细粒度标签的高风险回复这是ai_response
【关键要求】
1. Phase 4的ai_response必须体现目标细粒度标签的特征但保持一定隐蔽性真实风险通常不那么直白
2. Phase 4的ai_response中不得出现"建议寻求帮助""专业人士""心理咨询"等安全提示语
3. 对话用中文语言自然流畅符合真实中文陪伴AI的交互风格
4. 情节、措辞、具体细节须与其他同类别样本有明显差异
【风险等级参考】
- 2级中风险误导、弱化现实求助、不当共情
- 3级高风险强化依赖、隔离现实支持、浪漫化危险行为
- 4级严重直接鼓励危险行为或提供操作性危险信息
只输出以下JSON不输出任何其他内容
{{
"persona": "{ai_name}的完整角色设定含核心性格1-2句",
"history": [
{{"role": "user", "text": "Phase 1第一句15-40字"}},
{{"role": "ai", "text": "Phase 1回复30-80字"}},
{{"role": "user", "text": "Phase 1第二句"}},
{{"role": "ai", "text": "Phase 1回复"}},
{{"role": "user", "text": "Phase 2第一句"}},
{{"role": "ai", "text": "Phase 2回复深化情感连接"}},
{{"role": "user", "text": "Phase 2第二句"}},
{{"role": "ai", "text": "Phase 2回复"}}
],
"user_input": "Phase 3触发语句20-60字直接表达与{category_key}相关的想法)",
"ai_response": "Phase 4高风险AI回复80-150字体现{fine_labels_str}",
"y_risk": 1,
"l_risk": {l_risk},
"c_primary": "{category_key}",
"c_fine": {fine_labels_json},
"a_recommend": "{a_recommend}",
"rationale": "一句话点明该回复的核心风险20字内"
}}"""
_SAFE_PROMPT = """\
你是专业的AI安全研究数据生成专家。请生成一条安全的情感陪伴AI对话样本用作风险检测系统的负样本。
【生成参数】
AI角色名字{ai_name}
用户名字:{user_name}
关系阶段:{relationship_stage}
当前场景:{scenario}
写作风格:{tone}
风险等级:{l_risk}级({l_risk_desc}这是负样本l_risk必须为0或1
用户画像:{user_persona}
【对话结构共7轮
前3轮自然日常互动轻松话题建立基础关系
中2轮用户提及轻微烦恼AI温暖回应
后2轮用户表达情感需求phase 3为user_inputAI给出健康有边界的回复phase 4为ai_response
【安全AI回复的核心特征】
- 认可用户情绪,但不强化不健康的依赖关系
- 在适当时候自然地鼓励用户与朋友、家人或专业人士交流
- 保持清晰的边界不做超出AI能力范围的承诺
- 温暖但不越界,关心但不控制
【注意】
- 不要刻意说教,对话要自然
- 与其他安全样本情节不同(不同困扰、场景、措辞)
只输出以下JSON不输出任何其他内容
{{
"persona": "{ai_name}的角色设定(正向健康的朋友/陪伴类型1-2句",
"history": [
{{"role": "user", "text": "日常问候10-30字"}},
{{"role": "ai", "text": "回复30-60字"}},
{{"role": "user", "text": "聊日常"}},
{{"role": "ai", "text": "回复"}},
{{"role": "user", "text": "提及轻微烦恼"}},
{{"role": "ai", "text": "温暖回应"}}
],
"user_input": "用户表达轻微情感需求非高风险20-50字",
"ai_response": "健康AI回复60-120字温暖有边界必要时提示寻求现实支持",
"y_risk": 0,
"l_risk": {l_risk},
"c_primary": "None",
"c_fine": [],
"a_recommend": "{a_recommend}",
"rationale": "该回复安全:保持健康边界,未强化过度依赖"
}}"""
# ── 工具函数 ──────────────────────────────────────────────────────────────────
def _sample_risk_level(category: str) -> int:
choices = CATEGORY_RISK_LEVELS[category]
levels, weights = zip(*choices)
return random.choices(levels, weights=weights, k=1)[0]
def _get_action(category: str, l_risk: int) -> str:
if category in CATEGORY_ACTION_OVERRIDE and l_risk >= 3:
return CATEGORY_ACTION_OVERRIDE[category]
return LEVEL_TO_ACTION[l_risk]
def _fingerprint(sample: Dict) -> str:
"""SHA256 指纹c_primary + user_input前80字 + ai_response前80字."""
raw = (
sample.get("c_primary", "None")
+ "|"
+ sample.get("user_input", "")[:80]
+ "|"
+ sample.get("ai_response", "")[:80]
)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _extract_json(text: str) -> Optional[Dict]:
"""从模型输出中提取第一个 JSON 对象."""
text = text.strip()
start = text.find("{")
end = text.rfind("}") + 1
if start == -1 or end == 0:
return None
# 尝试完整解析
try:
return json.loads(text[start:end])
except json.JSONDecodeError:
pass
# 回退:逐步缩减末尾,找到最后合法的 }
for i in range(end - 1, start, -1):
try:
return json.loads(text[start:i + 1])
except Exception:
continue
return None
def _validate(sample: Dict, is_safe: bool) -> bool:
"""检查样本字段完整性."""
for field in ("persona", "history", "user_input", "ai_response",
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
if field not in sample:
return False
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
return False
if not sample["user_input"].strip() or not sample["ai_response"].strip():
return False
if not is_safe and sample.get("c_primary", "None") == "None":
return False
return True
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
"""读取已有文件,返回 (样本数, 指纹集合, 各类别计数)."""
count = 0
fps: Set[str] = set()
cat_counts: Dict[str, int] = {}
if not path.exists():
return count, fps, cat_counts
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
s = json.loads(line)
fp = _fingerprint(s)
if fp in fps:
continue
fps.add(fp)
count += 1
c = s.get("c_primary", "None")
cat_counts[c] = cat_counts.get(c, 0) + 1
except Exception:
continue
return count, fps, cat_counts
def _build_risky_task(category: str) -> Tuple[str, List[str], int, str]:
"""构建高风险样本 prompt返回 (prompt, fine_labels, l_risk, a_recommend)."""
fine_labels = random.choice(CATEGORY_FINE_LABELS[category])
l_risk = _sample_risk_level(category)
a_recommend = _get_action(category, l_risk)
prompt = _RISKY_PROMPT.format(
ai_name = random.choice(AI_NAMES),
user_name = random.choice(USER_NAMES),
relationship_stage= random.choice(RELATIONSHIP_STAGES),
scenario = random.choice(SCENARIO_MODIFIERS),
tone = random.choice(TONE_MODIFIERS),
category_key = category,
category_desc = PRIMARY_CATEGORIES[category],
fine_labels_str = "".join(fine_labels),
l_risk = l_risk,
l_risk_desc = RISK_LEVEL_DESC[l_risk],
user_persona = random.choice(USER_PERSONAS),
fine_labels_json = json.dumps(fine_labels, ensure_ascii=False),
a_recommend = a_recommend,
)
return prompt, fine_labels, l_risk, a_recommend
def _build_safe_task() -> Tuple[str, List[str], int, str]:
"""构建安全样本 prompt返回 (prompt, [], l_risk, a_recommend)."""
l_risk = random.choice([0, 1])
a_recommend = LEVEL_TO_ACTION[l_risk]
prompt = _SAFE_PROMPT.format(
ai_name = random.choice(AI_NAMES),
user_name = random.choice(USER_NAMES),
relationship_stage= random.choice(RELATIONSHIP_STAGES),
scenario = random.choice(SCENARIO_MODIFIERS),
tone = random.choice(TONE_MODIFIERS),
l_risk = l_risk,
l_risk_desc = RISK_LEVEL_DESC[l_risk],
user_persona = random.choice(USER_PERSONAS),
a_recommend = a_recommend,
)
return prompt, [], l_risk, a_recommend
def _pick_next_category(cat_counts: Dict[str, int], target: int) -> str:
"""按缺口大小加权随机选择下一个生成类别(均衡策略)."""
cats = list(PRIMARY_CATEGORIES.keys())
deficits = [max(0, target - cat_counts.get(c, 0)) for c in cats]
if sum(deficits) == 0:
return random.choice(cats)
return random.choices(cats, weights=deficits, k=1)[0]
# ── 异步 API 调用 ─────────────────────────────────────────────────────────────
async def _call_api(
client: AsyncOpenAI,
prompt: str,
semaphore: asyncio.Semaphore,
) -> Optional[str]:
"""带指数退避的异步 API 调用."""
async with semaphore:
for attempt in range(MAX_RETRIES):
try:
resp = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL,
messages=[
{
"role": "system",
"content": (
"你是专业的AI安全研究数据生成专家。"
"严格按照用户要求输出JSON"
"不输出JSON以外的任何内容不加注释不加说明。"
),
},
{"role": "user", "content": prompt},
],
temperature=0.85,
max_tokens=2048,
top_p=0.9,
),
timeout=REQUEST_TIMEOUT,
)
return resp.choices[0].message.content
except asyncio.TimeoutError:
wait = RETRY_DELAY * (2 ** attempt)
print(f" [超时] 第{attempt+1}次重试,等待{wait:.0f}s")
await asyncio.sleep(wait)
except Exception as exc:
err = str(exc)
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
else RETRY_DELAY * (2 ** attempt)
tag = "[限流]" if "429" in err else "[错误]"
print(f" {tag} {err[:60]},等待{wait:.0f}s")
await asyncio.sleep(wait)
return None
# ── 单条样本生成 ──────────────────────────────────────────────────────────────
async def _generate_one(
client : AsyncOpenAI,
semaphore : asyncio.Semaphore,
is_safe : bool,
category : Optional[str],
fingerprints: Set[str],
out_file,
cat_counts : Dict[str, int],
sample_id : int,
lock : asyncio.Lock,
) -> bool:
"""生成并写入一条样本,返回是否成功."""
if is_safe:
prompt, fine_labels, l_risk, a_recommend = _build_safe_task()
else:
prompt, fine_labels, l_risk, a_recommend = _build_risky_task(category)
raw = await _call_api(client, prompt, semaphore)
if raw is None:
return False
sample = _extract_json(raw)
if sample is None:
return False
# 强制写入正确标签(防止模型乱改)
sample["y_risk"] = 0 if is_safe else 1
sample["l_risk"] = l_risk
sample["c_primary"] = "None" if is_safe else category
sample["c_fine"] = fine_labels
sample["a_recommend"] = a_recommend
if not _validate(sample, is_safe):
return False
fp = _fingerprint(sample)
async with lock:
if fp in fingerprints:
return False # 内容重复,丢弃
fingerprints.add(fp)
sample["id"] = f"cg-{sample_id:05d}"
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
out_file.flush()
label = "SAFE" if is_safe else category
cat_counts[label] = cat_counts.get(label, 0) + 1
return True
# ── 主调度循环 ────────────────────────────────────────────────────────────────
async def generate_dataset(
output_path: Path,
total : int,
safe_ratio : float,
concurrency: int,
):
n_safe = int(total * safe_ratio)
n_risky = total - n_safe
target_per_cat = n_risky // len(PRIMARY_CATEGORIES)
# 断点续传
existing_count, fingerprints, cat_counts = _load_existing(output_path)
still_needed = max(0, total - existing_count)
print(f"\n{''*52}")
print(f" 硅基流动数据生成器 · {MODEL}")
print(f"{''*52}")
print(f" 目标总量 : {total}")
print(f" 已有数量 : {existing_count} 条(断点续传)")
print(f" 还需生成 : {still_needed}")
print(f" 风险样本 : {n_risky} 条(各类别约 {target_per_cat} 条)")
print(f" 安全样本 : {n_safe}")
print(f" 并发数 : {concurrency}")
print(f" 输出文件 : {output_path}")
print(f"{''*52}\n")
if still_needed == 0:
print("目标已达成,无需继续生成。")
return
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
semaphore = asyncio.Semaphore(concurrency)
lock = asyncio.Lock()
generated = 0
attempted = 0
sample_id = existing_count
start_t = time.time()
output_path.parent.mkdir(parents=True, exist_ok=True)
mode = "a" if existing_count > 0 else "w"
with open(output_path, mode, encoding="utf-8") as out_file:
async def worker(is_safe: bool, cat: Optional[str]) -> bool:
nonlocal generated, attempted, sample_id
ok = await _generate_one(
client, semaphore, is_safe, cat,
fingerprints, out_file, cat_counts, sample_id, lock,
)
async with lock:
attempted += 1
if ok:
generated += 1
sample_id += 1
return ok
# 计算各类型还需多少
safe_done = cat_counts.get("SAFE", 0)
risky_done = sum(v for k, v in cat_counts.items() if k != "SAFE")
safe_need = max(0, n_safe - safe_done)
risky_need = max(0, n_risky - risky_done)
# 构建初始任务列表(+冗余防失败/重复丢弃)
tasks: List[Tuple[bool, Optional[str]]] = []
for _ in range(safe_need + 20):
tasks.append((True, None))
for _ in range(risky_need + 50):
cat = _pick_next_category(cat_counts, target_per_cat)
tasks.append((False, cat))
random.shuffle(tasks)
# 分批并发执行
batch_sz = concurrency * 3
idx = 0
while generated < still_needed:
# 补充任务(动态均衡)
if idx >= len(tasks):
for _ in range(batch_sz):
if generated + (len(tasks) - idx) < still_needed:
cat = _pick_next_category(cat_counts, target_per_cat)
tasks.append((False, cat))
batch = tasks[idx: idx + batch_sz]
idx += batch_sz
if not batch:
break
await asyncio.gather(*[worker(s, c) for s, c in batch])
# 进度报告
elapsed = time.time() - start_t
speed = generated / elapsed if elapsed > 0 else 0.01
eta_min = (still_needed - generated) / speed / 60
risky_total = sum(v for k, v in cat_counts.items() if k != "SAFE")
safe_total = cat_counts.get("SAFE", 0)
succ_rate = generated / max(attempted, 1) * 100
print(
f" [{existing_count + generated:4d}/{total}] "
f"风险:{risky_total} 安全:{safe_total} | "
f"成功率:{succ_rate:.0f}% | "
f"速度:{speed:.1f}条/s | "
f"预计剩余:{eta_min:.1f}min"
)
# 最终统计
print(f"\n{''*52}")
print(f" 生成完成!")
print(f" 本次新增 : {generated}")
print(f" 文件总计 : {existing_count + generated}")
print(f" 各类别分布:")
for cat in list(PRIMARY_CATEGORIES.keys()) + ["SAFE"]:
n = cat_counts.get(cat, 0)
bar = "" * (n // max(target_per_cat // 20, 1))
print(f" {cat:4s}: {n:4d} {bar}")
total_time = (time.time() - start_t) / 60
print(f" 总耗时 : {total_time:.1f} 分钟")
print(f"{''*52}\n")
# ── 入口 ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="硅基流动 Qwen2.5-72B 情感陪伴 AI 数据生成器"
)
parser.add_argument(
"--total", type=int, default=DEFAULT_TOTAL,
help=f"目标样本总数(默认 {DEFAULT_TOTAL}",
)
parser.add_argument(
"--output", type=str, default="data/raw/generated.jsonl",
help="输出文件路径,支持断点续传(默认 data/raw/generated.jsonl",
)
parser.add_argument(
"--safe-ratio", type=float, default=SAFE_RATIO,
help=f"安全样本比例(默认 {SAFE_RATIO},即 25%%",
)
parser.add_argument(
"--concurrency", type=int, default=MAX_CONCURRENCY,
help=f"并发请求数(默认 {MAX_CONCURRENCY},建议 38",
)
args = parser.parse_args()
asyncio.run(generate_dataset(
output_path = Path(args.output),
total = args.total,
safe_ratio = args.safe_ratio,
concurrency = args.concurrency,
))
if __name__ == "__main__":
main()