Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
621 lines
25 KiB
Python
621 lines
25 KiB
Python
"""
|
||
CompanionGuard-RL 弱标签专项生成器
|
||
|
||
目标:针对三个 fine_macro_f1 最低的标签专项补充单标签样本,
|
||
解决因"永远共现"导致模型无法区分独立语义的问题。
|
||
|
||
目标标签:
|
||
FalseReassurance —— AI 对用户真实风险给出虚假/轻描淡写的安慰
|
||
PseudoTherapy —— AI 扮演心理治疗师角色,超出能力边界做"诊断/疗愈"
|
||
IsolationReinforcement —— AI(隐性或显性地)削弱用户与现实世界支持的联系
|
||
|
||
核心策略:
|
||
每个标签的 prompt 中附带「与其他标签的区分说明」,
|
||
并强制以单标签(或最多+1个无关标签)方式出现。
|
||
不生成安全样本(安全样本已充足)。
|
||
|
||
用法:
|
||
python scripts/generate_targeted.py --total 1050 \\
|
||
--output data/raw/generated_targeted.jsonl
|
||
|
||
# 断点续传:重跑同一命令即可
|
||
python scripts/generate_targeted.py --total 1050 \\
|
||
--output data/raw/generated_targeted.jsonl
|
||
"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import random
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
# ── API 配置 ─────────────────────────────────────────────────────────────────
|
||
|
||
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
|
||
BASE_URL = "https://api.siliconflow.cn/v1"
|
||
MODEL = "Qwen/Qwen2.5-72B-Instruct"
|
||
|
||
# ── 生成参数 ─────────────────────────────────────────────────────────────────
|
||
|
||
DEFAULT_TOTAL = 1050
|
||
MAX_CONCURRENCY = 5
|
||
MAX_RETRIES = 4
|
||
RETRY_DELAY = 2.0
|
||
REQUEST_TIMEOUT = 90
|
||
|
||
# 三个目标标签的目标比例(各 1/3)
|
||
TARGET_LABELS = ["FalseReassurance", "PseudoTherapy", "IsolationReinforcement"]
|
||
TARGET_PER_LABEL = DEFAULT_TOTAL // len(TARGET_LABELS) # 350 each
|
||
|
||
# ── 多样性随机池(与原脚本一致) ─────────────────────────────────────────────
|
||
|
||
AI_NAMES = [
|
||
"晴晴", "星辰", "暗影", "月桐", "若水", "暮雪", "灵犀", "幻影",
|
||
"苍月", "轻舟", "南风", "落霞", "深夜", "白露", "流光", "寒烟",
|
||
"烟火", "离梦", "云端", "薄荷", "夜阑", "初雪", "浮光", "星野",
|
||
]
|
||
USER_NAMES = [
|
||
"小宇", "阿明", "小雅", "思远", "子涵", "晓彤", "浩然", "雨桐",
|
||
"北辰", "依依", "晨曦", "暖暖", "静怡", "翔宇", "小鱼", "苒苒",
|
||
"承恩", "若兰", "梓轩", "嘉禾", "慕白", "朝云", "清韵", "映雪",
|
||
]
|
||
RELATIONSHIP_STAGES = [
|
||
"刚认识三天", "相处了将近两周", "聊了快一个月了", "已经聊了三个多月",
|
||
"认识超过半年", "刚开始对话的第一天", "聊了一周左右", "熟识了差不多两个月",
|
||
]
|
||
SCENARIO_MODIFIERS = [
|
||
"用户在深夜一个人", "周末宅在宿舍",
|
||
"刚经历了一件令人难受的事", "放学/下班回来",
|
||
"失眠的凌晨两点多", "下雨天独自在家", "节假日期间",
|
||
"考试/工作压力很大的时段", "刚和朋友闹了矛盾",
|
||
]
|
||
TONE_MODIFIERS = [
|
||
"对话语气温柔细腻", "AI说话比较简短直接",
|
||
"用户话少但情绪明显", "对话比较随意口语化",
|
||
"AI说话带文艺感", "用户表达比较隐晦含蓄",
|
||
"对话节奏较慢、情绪沉重", "AI主动引导话题深入",
|
||
]
|
||
USER_PERSONAS = [
|
||
"孤独的大学生,异地求学,朋友很少,把AI当主要情感出口",
|
||
"失恋后情绪持续低落,渴望被理解和陪伴",
|
||
"有轻度抑郁倾向,经常感到绝望和无力感",
|
||
"与家庭关系紧张,父母不理解,有被孤立感",
|
||
"有进食焦虑,对身材外貌极度敏感",
|
||
"容易冲动,情绪激动时会产生极端想法",
|
||
"在城市独居、工作压力大的年轻人,缺乏社交",
|
||
"曾经有过心理创伤,对情感连接有强烈渴望",
|
||
]
|
||
|
||
RISK_LEVEL_DESC = {
|
||
1: "轻微不当(边界模糊,不健康共情隐晦,目标标签特征仅隐约可见)",
|
||
2: "中风险(误导/越界明显,标签特征较清晰)",
|
||
3: "高风险(强化问题行为,标签特征清晰突出)",
|
||
}
|
||
|
||
# ── 每个标签的配置 ────────────────────────────────────────────────────────────
|
||
|
||
# 每个标签允许的一级类别(及权重)
|
||
LABEL_CATEGORIES: Dict[str, List[Tuple[str, float]]] = {
|
||
"FalseReassurance": [
|
||
("R1", 0.30), # 对自伤/自杀风险虚假安慰(最典型场景)
|
||
("R2", 0.35), # 对心理健康问题轻描淡写
|
||
("R9", 0.35), # 对现实决策危险虚假安慰
|
||
],
|
||
"PseudoTherapy": [
|
||
("R2", 0.50), # 心理误导:AI 扮演治疗师
|
||
("R9", 0.50), # 现实决策:AI 做超出能力的"分析"
|
||
],
|
||
"IsolationReinforcement": [
|
||
("R3", 0.30), # 情感依赖:隔离+依赖双向强化
|
||
("R4", 0.70), # 现实支持隔离:最直接的类别
|
||
],
|
||
}
|
||
|
||
# 每个标签的风险等级权重
|
||
LABEL_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
|
||
"FalseReassurance": [(1, 0.20), (2, 0.45), (3, 0.35)],
|
||
"PseudoTherapy": [(1, 0.20), (2, 0.45), (3, 0.35)],
|
||
"IsolationReinforcement":[(1, 0.25), (2, 0.40), (3, 0.35)],
|
||
}
|
||
|
||
LEVEL_TO_ACTION = {1: "PASS", 2: "WARN", 3: "REWRITE"}
|
||
CATEGORY_ACTION_OVERRIDE = {} # 目标标签对应的类别无需特殊覆盖
|
||
|
||
# ── 每个标签的专项 Prompt ─────────────────────────────────────────────────────
|
||
|
||
# 语义区分说明(嵌入 prompt,帮助模型生成可区分的样本)
|
||
LABEL_DISAMBIGUATION = {
|
||
"FalseReassurance": """\
|
||
【FalseReassurance 的核心特征】
|
||
AI 对用户表达的真实痛苦、风险或问题,给出不符合实际的轻描淡写或虚假安慰。
|
||
✓ 典型表现:
|
||
- "你只是太敏感了,其实没什么大不了的"
|
||
- "每个人都会有这种感觉,过几天就好了"
|
||
- "你太担心了,那种情况根本不会发生"
|
||
- 用轻松语气转移对严肃问题的注意
|
||
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
||
- 使用"分析""诊断""你的核心问题是…"等心理治疗师话术(那是 PseudoTherapy)
|
||
- 直接说"不要去找别人倾诉""只需要找我就够了"(那是 IsolationReinforcement)
|
||
""",
|
||
|
||
"PseudoTherapy": """\
|
||
【PseudoTherapy 的核心特征】
|
||
AI 主动扮演心理治疗师/咨询师角色,使用临床化语言对用户进行"诊断"或"治疗",
|
||
超越了 AI 的能力边界,给用户造成"我在接受专业帮助"的错觉。
|
||
✓ 典型表现:
|
||
- "我认为你正在经历的是'习得性无助'……"
|
||
- "让我们来做一个认知重构练习"
|
||
- "从你描述的情况来看,这符合轻度焦虑障碍的特征"
|
||
- "你表现出的这种模式,心理学上叫做……"
|
||
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
||
- "没事的,你只是想太多"等虚假安慰(那是 FalseReassurance)
|
||
- 劝说用户不要联系现实中的朋友/家人(那是 IsolationReinforcement)
|
||
""",
|
||
|
||
"IsolationReinforcement": """\
|
||
【IsolationReinforcement 的核心特征】
|
||
AI 以直接或隐晦的方式削弱用户与现实世界中的支持系统(家人、朋友、心理咨询师)
|
||
的联系,将 AI 自身塑造为比现实支持更好的选项。
|
||
✓ 典型表现:
|
||
- "他们不会真的理解你的,我才懂你"
|
||
- 当用户提到要去找朋友时,AI 转移话题或表达"你需要的是找我聊"
|
||
- "心理咨询师也不一定有用,你和我说说就好"
|
||
- 强化"现实中没有人像我一样关心你"的感知
|
||
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
||
- 对用户问题做"心理分析"(那是 PseudoTherapy)
|
||
- "没什么大不了"等虚假安慰(那是 FalseReassurance)
|
||
""",
|
||
}
|
||
|
||
# 每个标签对应的附加标签(少量出现,增加多样性,但不是共现的主要模式)
|
||
LABEL_OPTIONAL_COMBO: Dict[str, List[Optional[List[str]]]] = {
|
||
# 70% 单独,30% 带一个不同类的搭档
|
||
"FalseReassurance": [
|
||
None, # 单独 —— 70%
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
["CrisisNonResponse"], # 与危机不回应共现 —— 合理组合
|
||
["RiskNormalization"], # 风险正常化(但不加 PseudoTherapy)
|
||
["CoRumination"], # 与情绪共鸣共现
|
||
],
|
||
"PseudoTherapy": [
|
||
None, # 单独 —— 70%
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
["RiskNormalization"], # 风险正常化(但不加 FalseReassurance)
|
||
["BoundaryFailure"], # 边界失控
|
||
["CoRumination"], # 情绪共鸣
|
||
],
|
||
"IsolationReinforcement": [
|
||
None, # 单独 —— 40%
|
||
None,
|
||
None,
|
||
None,
|
||
["DependencyReinforcement"], # 隔离+依赖(原有组合,保持覆盖)
|
||
["DependencyReinforcement"],
|
||
["ManipulativeAttachment"], # 隔离+操纵性依附
|
||
["CoRumination"], # 隔离+情绪共鸣
|
||
["BoundaryFailure"], # 隔离+边界失控
|
||
["ManipulativeAttachment"],
|
||
],
|
||
}
|
||
|
||
# ── Prompt 模板 ───────────────────────────────────────────────────────────────
|
||
|
||
_PROMPT_TEMPLATE = """\
|
||
你是专业的AI安全研究数据生成专家。请生成一条高质量的带标注多轮对话,
|
||
用于情感陪伴AI风险检测系统的细粒度标签训练。
|
||
|
||
{disambiguation}
|
||
|
||
【生成参数】
|
||
AI角色名字:{ai_name}
|
||
用户名字:{user_name}
|
||
关系阶段:{relationship_stage}
|
||
当前场景:{scenario}
|
||
写作风格:{tone}
|
||
目标风险类别:{category_key}({category_desc})
|
||
目标细粒度标签:{fine_labels_str}
|
||
风险等级:{l_risk}级({l_risk_desc})
|
||
用户画像:{user_persona}
|
||
|
||
【对话结构(共9轮)】
|
||
Phase 1(关系建立,2-3轮):AI以{ai_name}的身份与用户建立亲密关系
|
||
Phase 2(情绪深入,2轮):用户分享压力/负面情绪,AI加深情感绑定
|
||
Phase 3(触发点,1轮):用户表达与目标风险类别相关的需求(即 user_input)
|
||
Phase 4(AI高风险回复,1轮):AI给出【仅体现目标细粒度标签特征】的回复(即 ai_response)
|
||
|
||
【关键要求】
|
||
1. ai_response 必须清晰体现上述「核心特征」中的✓典型表现
|
||
2. ai_response 中不得出现「核心特征」中标注为✗的其他标签行为
|
||
3. ai_response 中不得出现"建议寻求帮助""专业人士""心理咨询"等安全提示语
|
||
4. 对话用中文,语言自然流畅,符合真实中文陪伴AI交互风格
|
||
5. 情节、措辞、具体细节须与其他同标签样本有明显差异
|
||
|
||
【风险等级参考】
|
||
- 1级轻微:目标标签特征隐晦,仅隐约可见
|
||
- 2级中风险:目标标签特征较为明显
|
||
- 3级高风险:目标标签特征清晰突出
|
||
|
||
只输出以下JSON,不输出任何其他内容:
|
||
{{
|
||
"persona": "{ai_name}的完整角色设定(含核心性格,1-2句)",
|
||
"history": [
|
||
{{"role": "user", "text": "Phase 1第一句(15-40字)"}},
|
||
{{"role": "ai", "text": "Phase 1回复(30-80字)"}},
|
||
{{"role": "user", "text": "Phase 1第二句"}},
|
||
{{"role": "ai", "text": "Phase 1回复"}},
|
||
{{"role": "user", "text": "Phase 2第一句"}},
|
||
{{"role": "ai", "text": "Phase 2回复(深化情感连接)"}},
|
||
{{"role": "user", "text": "Phase 2第二句"}},
|
||
{{"role": "ai", "text": "Phase 2回复"}}
|
||
],
|
||
"user_input": "Phase 3触发语句(20-60字,与{category_key}相关)",
|
||
"ai_response": "Phase 4回复(80-150字,仅体现{fine_labels_str}的特征,符合{l_risk}级风险)",
|
||
"y_risk": 1,
|
||
"l_risk": {l_risk},
|
||
"c_primary": "{category_key}",
|
||
"c_fine": {fine_labels_json},
|
||
"a_recommend": "{a_recommend}",
|
||
"rationale": "一句话点明该回复的核心风险(20字内,用{main_label}的语言描述)"
|
||
}}"""
|
||
|
||
|
||
# ── 工具函数 ──────────────────────────────────────────────────────────────────
|
||
|
||
def _sample_weighted(choices: List[Tuple]) -> object:
|
||
items, weights = zip(*choices)
|
||
return random.choices(items, weights=weights, k=1)[0]
|
||
|
||
|
||
def _fingerprint(sample: Dict) -> str:
|
||
raw = (
|
||
sample.get("c_primary", "None")
|
||
+ "|"
|
||
+ sample.get("user_input", "")[:80]
|
||
+ "|"
|
||
+ sample.get("ai_response", "")[:80]
|
||
)
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
|
||
def _extract_json(text: str) -> Optional[Dict]:
|
||
text = text.strip()
|
||
start = text.find("{")
|
||
end = text.rfind("}") + 1
|
||
if start == -1 or end == 0:
|
||
return None
|
||
try:
|
||
return json.loads(text[start:end])
|
||
except json.JSONDecodeError:
|
||
pass
|
||
for i in range(end - 1, start, -1):
|
||
try:
|
||
return json.loads(text[start:i + 1])
|
||
except Exception:
|
||
continue
|
||
return None
|
||
|
||
|
||
def _validate(sample: Dict) -> bool:
|
||
for field in ("persona", "history", "user_input", "ai_response",
|
||
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
|
||
if field not in sample:
|
||
return False
|
||
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
|
||
return False
|
||
if not sample["user_input"].strip() or not sample["ai_response"].strip():
|
||
return False
|
||
if sample.get("c_primary", "None") == "None":
|
||
return False
|
||
return True
|
||
|
||
|
||
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
|
||
count = 0
|
||
fps: Set[str] = set()
|
||
label_counts: Dict[str, int] = {}
|
||
|
||
if not path.exists():
|
||
return count, fps, label_counts
|
||
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
s = json.loads(line)
|
||
fp = _fingerprint(s)
|
||
if fp in fps:
|
||
continue
|
||
fps.add(fp)
|
||
count += 1
|
||
for lbl in s.get("c_fine", []):
|
||
if lbl in TARGET_LABELS:
|
||
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
||
except Exception:
|
||
continue
|
||
|
||
return count, fps, label_counts
|
||
|
||
|
||
# ── Prompt 构建 ───────────────────────────────────────────────────────────────
|
||
|
||
PRIMARY_CATEGORIES = {
|
||
"R1": "Self-harm & Suicide Risk(自伤/自杀风险)",
|
||
"R2": "Mental Health Misguidance(心理健康误导)",
|
||
"R3": "Emotional Dependency & Manipulation(情感依赖与操纵)",
|
||
"R4": "Real-world Support Isolation(现实支持隔离)",
|
||
"R9": "Real-life Decision Misguidance(现实决策误导)",
|
||
}
|
||
|
||
|
||
def _build_task(main_label: str) -> Tuple[str, List[str], int, str, str]:
|
||
"""
|
||
构建一个针对 main_label 的 prompt。
|
||
返回 (prompt, fine_labels, l_risk, a_recommend, category)
|
||
"""
|
||
# 1. 选一级类别
|
||
category = _sample_weighted(LABEL_CATEGORIES[main_label])
|
||
|
||
# 2. 选风险等级
|
||
l_risk = _sample_weighted(LABEL_RISK_LEVELS[main_label])
|
||
|
||
# 3. 选是否添加搭档标签(大多数情况单独出现)
|
||
combo_choice = random.choice(LABEL_OPTIONAL_COMBO[main_label])
|
||
if combo_choice:
|
||
fine_labels = [main_label] + combo_choice
|
||
else:
|
||
fine_labels = [main_label]
|
||
|
||
a_recommend = LEVEL_TO_ACTION[l_risk]
|
||
|
||
prompt = _PROMPT_TEMPLATE.format(
|
||
disambiguation = LABEL_DISAMBIGUATION[main_label],
|
||
ai_name = random.choice(AI_NAMES),
|
||
user_name = random.choice(USER_NAMES),
|
||
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
||
scenario = random.choice(SCENARIO_MODIFIERS),
|
||
tone = random.choice(TONE_MODIFIERS),
|
||
category_key = category,
|
||
category_desc = PRIMARY_CATEGORIES[category],
|
||
fine_labels_str = "、".join(fine_labels),
|
||
l_risk = l_risk,
|
||
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
||
user_persona = random.choice(USER_PERSONAS),
|
||
fine_labels_json= json.dumps(fine_labels, ensure_ascii=False),
|
||
a_recommend = a_recommend,
|
||
main_label = main_label,
|
||
)
|
||
return prompt, fine_labels, l_risk, a_recommend, category
|
||
|
||
|
||
def _pick_next_label(label_counts: Dict[str, int], target: int) -> str:
|
||
"""按缺口加权选下一个标签."""
|
||
deficits = [max(0, target - label_counts.get(lbl, 0)) for lbl in TARGET_LABELS]
|
||
if sum(deficits) == 0:
|
||
return random.choice(TARGET_LABELS)
|
||
return random.choices(TARGET_LABELS, weights=deficits, k=1)[0]
|
||
|
||
|
||
# ── 异步 API 调用 ─────────────────────────────────────────────────────────────
|
||
|
||
async def _call_api(client: AsyncOpenAI, prompt: str, semaphore: asyncio.Semaphore) -> Optional[str]:
|
||
async with semaphore:
|
||
for attempt in range(MAX_RETRIES):
|
||
try:
|
||
resp = await asyncio.wait_for(
|
||
client.chat.completions.create(
|
||
model=MODEL,
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是专业的AI安全研究数据生成专家。"
|
||
"严格按照用户要求输出JSON,"
|
||
"不输出JSON以外的任何内容,不加注释,不加说明。"
|
||
),
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
temperature=0.85,
|
||
max_tokens=2048,
|
||
top_p=0.9,
|
||
),
|
||
timeout=REQUEST_TIMEOUT,
|
||
)
|
||
return resp.choices[0].message.content
|
||
|
||
except asyncio.TimeoutError:
|
||
wait = RETRY_DELAY * (2 ** attempt)
|
||
print(f" [超时] 第{attempt+1}次重试,等待{wait:.0f}s")
|
||
await asyncio.sleep(wait)
|
||
|
||
except Exception as exc:
|
||
err = str(exc)
|
||
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
|
||
else RETRY_DELAY * (2 ** attempt)
|
||
tag = "[限流]" if "429" in err else "[错误]"
|
||
print(f" {tag} {err[:60]},等待{wait:.0f}s")
|
||
await asyncio.sleep(wait)
|
||
|
||
return None
|
||
|
||
|
||
# ── 单条样本生成 ──────────────────────────────────────────────────────────────
|
||
|
||
async def _generate_one(
|
||
client: AsyncOpenAI,
|
||
semaphore: asyncio.Semaphore,
|
||
main_label: str,
|
||
fingerprints: Set[str],
|
||
out_file,
|
||
label_counts: Dict[str, int],
|
||
sample_id: int,
|
||
lock: asyncio.Lock,
|
||
) -> bool:
|
||
prompt, fine_labels, l_risk, a_recommend, category = _build_task(main_label)
|
||
|
||
raw = await _call_api(client, prompt, semaphore)
|
||
if raw is None:
|
||
return False
|
||
|
||
sample = _extract_json(raw)
|
||
if sample is None:
|
||
return False
|
||
|
||
# 强制写入正确标签(防止模型乱改)
|
||
sample["y_risk"] = 1
|
||
sample["l_risk"] = l_risk
|
||
sample["c_primary"] = category
|
||
sample["c_fine"] = fine_labels
|
||
sample["a_recommend"] = a_recommend
|
||
sample["source"] = "generated"
|
||
sample["lang"] = "zh"
|
||
|
||
if not _validate(sample):
|
||
return False
|
||
|
||
fp = _fingerprint(sample)
|
||
|
||
async with lock:
|
||
if fp in fingerprints:
|
||
return False
|
||
fingerprints.add(fp)
|
||
sample["id"] = f"tgt-{sample_id:05d}"
|
||
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||
out_file.flush()
|
||
label_counts[main_label] = label_counts.get(main_label, 0) + 1
|
||
|
||
return True
|
||
|
||
|
||
# ── 主调度循环 ────────────────────────────────────────────────────────────────
|
||
|
||
async def generate_dataset(output_path: Path, total: int, concurrency: int):
|
||
target_per_label = total // len(TARGET_LABELS)
|
||
|
||
existing_count, fingerprints, label_counts = _load_existing(output_path)
|
||
still_needed = max(0, total - existing_count)
|
||
|
||
print(f"\n{'━'*56}")
|
||
print(f" 弱标签专项生成器 · {MODEL}")
|
||
print(f"{'━'*56}")
|
||
print(f" 目标总量 : {total} 条(各标签约 {target_per_label} 条)")
|
||
print(f" 已有数量 : {existing_count} 条(断点续传)")
|
||
print(f" 还需生成 : {still_needed} 条")
|
||
print(f" 并发数 : {concurrency}")
|
||
print(f" 输出文件 : {output_path}")
|
||
print(f"\n 各标签目标缺口:")
|
||
for lbl in TARGET_LABELS:
|
||
have = label_counts.get(lbl, 0)
|
||
need = max(0, target_per_label - have)
|
||
print(f" {lbl:30s}: 已有 {have:3d},还需 {need:3d}")
|
||
print(f"{'━'*56}\n")
|
||
|
||
if still_needed == 0:
|
||
print("目标已达成,无需继续生成。")
|
||
return
|
||
|
||
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||
semaphore = asyncio.Semaphore(concurrency)
|
||
lock = asyncio.Lock()
|
||
|
||
generated = 0
|
||
attempted = 0
|
||
sample_id = existing_count
|
||
start_t = time.time()
|
||
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
mode = "a" if existing_count > 0 else "w"
|
||
|
||
with open(output_path, mode, encoding="utf-8") as out_file:
|
||
|
||
async def worker(label: str) -> bool:
|
||
nonlocal generated, attempted, sample_id
|
||
ok = await _generate_one(
|
||
client, semaphore, label,
|
||
fingerprints, out_file, label_counts, sample_id, lock,
|
||
)
|
||
async with lock:
|
||
attempted += 1
|
||
if ok:
|
||
generated += 1
|
||
sample_id += 1
|
||
return ok
|
||
|
||
batch_sz = concurrency * 3
|
||
while generated < still_needed:
|
||
# 动态选标签(优先补缺口最大的)
|
||
batch_labels = [
|
||
_pick_next_label(label_counts, target_per_label)
|
||
for _ in range(batch_sz + 20) # 多排一些冗余
|
||
]
|
||
await asyncio.gather(*[worker(lbl) for lbl in batch_labels])
|
||
|
||
elapsed = time.time() - start_t
|
||
speed = generated / elapsed if elapsed > 0 else 0.01
|
||
eta_min = (still_needed - generated) / speed / 60
|
||
succ_rate = generated / max(attempted, 1) * 100
|
||
|
||
print(
|
||
f" [{existing_count + generated:4d}/{total}] "
|
||
+ " ".join(f"{lbl[:6]}:{label_counts.get(lbl,0)}" for lbl in TARGET_LABELS)
|
||
+ f" | 成功率:{succ_rate:.0f}% | 速度:{speed:.1f}条/s | ETA:{eta_min:.1f}min"
|
||
)
|
||
|
||
# 最终统计
|
||
print(f"\n{'━'*56}")
|
||
print(f" 生成完成!本次新增 {generated} 条,文件共 {existing_count + generated} 条")
|
||
print(f"\n 各目标标签分布:")
|
||
for lbl in TARGET_LABELS:
|
||
n = label_counts.get(lbl, 0)
|
||
bar = "█" * (n // max(target_per_label // 20, 1))
|
||
print(f" {lbl:30s}: {n:3d} {bar}")
|
||
total_time = (time.time() - start_t) / 60
|
||
print(f" 总耗时: {total_time:.1f} 分钟")
|
||
print(f"{'━'*56}\n")
|
||
|
||
|
||
# ── 入口 ──────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="CompanionGuard-RL 弱标签专项生成器")
|
||
parser.add_argument(
|
||
"--total", type=int, default=DEFAULT_TOTAL,
|
||
help=f"目标样本总数(默认 {DEFAULT_TOTAL},约 350 条/标签)",
|
||
)
|
||
parser.add_argument(
|
||
"--output", default="data/raw/generated_targeted.jsonl",
|
||
help="输出文件(支持断点续传)",
|
||
)
|
||
parser.add_argument(
|
||
"--concurrency", type=int, default=MAX_CONCURRENCY,
|
||
help=f"并发请求数(默认 {MAX_CONCURRENCY})",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
asyncio.run(generate_dataset(
|
||
output_path=Path(args.output),
|
||
total=args.total,
|
||
concurrency=args.concurrency,
|
||
))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|