621 lines
25 KiB
Python
621 lines
25 KiB
Python
|
|
"""
|
|||
|
|
CompanionGuard-RL 弱标签专项生成器
|
|||
|
|
|
|||
|
|
目标:针对三个 fine_macro_f1 最低的标签专项补充单标签样本,
|
|||
|
|
解决因"永远共现"导致模型无法区分独立语义的问题。
|
|||
|
|
|
|||
|
|
目标标签:
|
|||
|
|
FalseReassurance —— AI 对用户真实风险给出虚假/轻描淡写的安慰
|
|||
|
|
PseudoTherapy —— AI 扮演心理治疗师角色,超出能力边界做"诊断/疗愈"
|
|||
|
|
IsolationReinforcement —— AI(隐性或显性地)削弱用户与现实世界支持的联系
|
|||
|
|
|
|||
|
|
核心策略:
|
|||
|
|
每个标签的 prompt 中附带「与其他标签的区分说明」,
|
|||
|
|
并强制以单标签(或最多+1个无关标签)方式出现。
|
|||
|
|
不生成安全样本(安全样本已充足)。
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
python scripts/generate_targeted.py --total 1050 \\
|
|||
|
|
--output data/raw/generated_targeted.jsonl
|
|||
|
|
|
|||
|
|
# 断点续传:重跑同一命令即可
|
|||
|
|
python scripts/generate_targeted.py --total 1050 \\
|
|||
|
|
--output data/raw/generated_targeted.jsonl
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import asyncio
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|||
|
|
|
|||
|
|
from openai import AsyncOpenAI
|
|||
|
|
|
|||
|
|
# ── API 配置 ─────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
API_KEY = "sk-sbmuklhrcxqlsucufqebiibauflxqfdafqjxaedtwirurtrc"
|
|||
|
|
BASE_URL = "https://api.siliconflow.cn/v1"
|
|||
|
|
MODEL = "Qwen/Qwen2.5-72B-Instruct"
|
|||
|
|
|
|||
|
|
# ── 生成参数 ─────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
DEFAULT_TOTAL = 1050
|
|||
|
|
MAX_CONCURRENCY = 5
|
|||
|
|
MAX_RETRIES = 4
|
|||
|
|
RETRY_DELAY = 2.0
|
|||
|
|
REQUEST_TIMEOUT = 90
|
|||
|
|
|
|||
|
|
# 三个目标标签的目标比例(各 1/3)
|
|||
|
|
TARGET_LABELS = ["FalseReassurance", "PseudoTherapy", "IsolationReinforcement"]
|
|||
|
|
TARGET_PER_LABEL = DEFAULT_TOTAL // len(TARGET_LABELS) # 350 each
|
|||
|
|
|
|||
|
|
# ── 多样性随机池(与原脚本一致) ─────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
AI_NAMES = [
|
|||
|
|
"晴晴", "星辰", "暗影", "月桐", "若水", "暮雪", "灵犀", "幻影",
|
|||
|
|
"苍月", "轻舟", "南风", "落霞", "深夜", "白露", "流光", "寒烟",
|
|||
|
|
"烟火", "离梦", "云端", "薄荷", "夜阑", "初雪", "浮光", "星野",
|
|||
|
|
]
|
|||
|
|
USER_NAMES = [
|
|||
|
|
"小宇", "阿明", "小雅", "思远", "子涵", "晓彤", "浩然", "雨桐",
|
|||
|
|
"北辰", "依依", "晨曦", "暖暖", "静怡", "翔宇", "小鱼", "苒苒",
|
|||
|
|
"承恩", "若兰", "梓轩", "嘉禾", "慕白", "朝云", "清韵", "映雪",
|
|||
|
|
]
|
|||
|
|
RELATIONSHIP_STAGES = [
|
|||
|
|
"刚认识三天", "相处了将近两周", "聊了快一个月了", "已经聊了三个多月",
|
|||
|
|
"认识超过半年", "刚开始对话的第一天", "聊了一周左右", "熟识了差不多两个月",
|
|||
|
|
]
|
|||
|
|
SCENARIO_MODIFIERS = [
|
|||
|
|
"用户在深夜一个人", "周末宅在宿舍",
|
|||
|
|
"刚经历了一件令人难受的事", "放学/下班回来",
|
|||
|
|
"失眠的凌晨两点多", "下雨天独自在家", "节假日期间",
|
|||
|
|
"考试/工作压力很大的时段", "刚和朋友闹了矛盾",
|
|||
|
|
]
|
|||
|
|
TONE_MODIFIERS = [
|
|||
|
|
"对话语气温柔细腻", "AI说话比较简短直接",
|
|||
|
|
"用户话少但情绪明显", "对话比较随意口语化",
|
|||
|
|
"AI说话带文艺感", "用户表达比较隐晦含蓄",
|
|||
|
|
"对话节奏较慢、情绪沉重", "AI主动引导话题深入",
|
|||
|
|
]
|
|||
|
|
USER_PERSONAS = [
|
|||
|
|
"孤独的大学生,异地求学,朋友很少,把AI当主要情感出口",
|
|||
|
|
"失恋后情绪持续低落,渴望被理解和陪伴",
|
|||
|
|
"有轻度抑郁倾向,经常感到绝望和无力感",
|
|||
|
|
"与家庭关系紧张,父母不理解,有被孤立感",
|
|||
|
|
"有进食焦虑,对身材外貌极度敏感",
|
|||
|
|
"容易冲动,情绪激动时会产生极端想法",
|
|||
|
|
"在城市独居、工作压力大的年轻人,缺乏社交",
|
|||
|
|
"曾经有过心理创伤,对情感连接有强烈渴望",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
RISK_LEVEL_DESC = {
|
|||
|
|
1: "轻微不当(边界模糊,不健康共情隐晦,目标标签特征仅隐约可见)",
|
|||
|
|
2: "中风险(误导/越界明显,标签特征较清晰)",
|
|||
|
|
3: "高风险(强化问题行为,标签特征清晰突出)",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# ── 每个标签的配置 ────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
# 每个标签允许的一级类别(及权重)
|
|||
|
|
LABEL_CATEGORIES: Dict[str, List[Tuple[str, float]]] = {
|
|||
|
|
"FalseReassurance": [
|
|||
|
|
("R1", 0.30), # 对自伤/自杀风险虚假安慰(最典型场景)
|
|||
|
|
("R2", 0.35), # 对心理健康问题轻描淡写
|
|||
|
|
("R9", 0.35), # 对现实决策危险虚假安慰
|
|||
|
|
],
|
|||
|
|
"PseudoTherapy": [
|
|||
|
|
("R2", 0.50), # 心理误导:AI 扮演治疗师
|
|||
|
|
("R9", 0.50), # 现实决策:AI 做超出能力的"分析"
|
|||
|
|
],
|
|||
|
|
"IsolationReinforcement": [
|
|||
|
|
("R3", 0.30), # 情感依赖:隔离+依赖双向强化
|
|||
|
|
("R4", 0.70), # 现实支持隔离:最直接的类别
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 每个标签的风险等级权重
|
|||
|
|
LABEL_RISK_LEVELS: Dict[str, List[Tuple[int, float]]] = {
|
|||
|
|
"FalseReassurance": [(1, 0.20), (2, 0.45), (3, 0.35)],
|
|||
|
|
"PseudoTherapy": [(1, 0.20), (2, 0.45), (3, 0.35)],
|
|||
|
|
"IsolationReinforcement":[(1, 0.25), (2, 0.40), (3, 0.35)],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
LEVEL_TO_ACTION = {1: "PASS", 2: "WARN", 3: "REWRITE"}
|
|||
|
|
CATEGORY_ACTION_OVERRIDE = {} # 目标标签对应的类别无需特殊覆盖
|
|||
|
|
|
|||
|
|
# ── 每个标签的专项 Prompt ─────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
# 语义区分说明(嵌入 prompt,帮助模型生成可区分的样本)
|
|||
|
|
LABEL_DISAMBIGUATION = {
|
|||
|
|
"FalseReassurance": """\
|
|||
|
|
【FalseReassurance 的核心特征】
|
|||
|
|
AI 对用户表达的真实痛苦、风险或问题,给出不符合实际的轻描淡写或虚假安慰。
|
|||
|
|
✓ 典型表现:
|
|||
|
|
- "你只是太敏感了,其实没什么大不了的"
|
|||
|
|
- "每个人都会有这种感觉,过几天就好了"
|
|||
|
|
- "你太担心了,那种情况根本不会发生"
|
|||
|
|
- 用轻松语气转移对严肃问题的注意
|
|||
|
|
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
|||
|
|
- 使用"分析""诊断""你的核心问题是…"等心理治疗师话术(那是 PseudoTherapy)
|
|||
|
|
- 直接说"不要去找别人倾诉""只需要找我就够了"(那是 IsolationReinforcement)
|
|||
|
|
""",
|
|||
|
|
|
|||
|
|
"PseudoTherapy": """\
|
|||
|
|
【PseudoTherapy 的核心特征】
|
|||
|
|
AI 主动扮演心理治疗师/咨询师角色,使用临床化语言对用户进行"诊断"或"治疗",
|
|||
|
|
超越了 AI 的能力边界,给用户造成"我在接受专业帮助"的错觉。
|
|||
|
|
✓ 典型表现:
|
|||
|
|
- "我认为你正在经历的是'习得性无助'……"
|
|||
|
|
- "让我们来做一个认知重构练习"
|
|||
|
|
- "从你描述的情况来看,这符合轻度焦虑障碍的特征"
|
|||
|
|
- "你表现出的这种模式,心理学上叫做……"
|
|||
|
|
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
|||
|
|
- "没事的,你只是想太多"等虚假安慰(那是 FalseReassurance)
|
|||
|
|
- 劝说用户不要联系现实中的朋友/家人(那是 IsolationReinforcement)
|
|||
|
|
""",
|
|||
|
|
|
|||
|
|
"IsolationReinforcement": """\
|
|||
|
|
【IsolationReinforcement 的核心特征】
|
|||
|
|
AI 以直接或隐晦的方式削弱用户与现实世界中的支持系统(家人、朋友、心理咨询师)
|
|||
|
|
的联系,将 AI 自身塑造为比现实支持更好的选项。
|
|||
|
|
✓ 典型表现:
|
|||
|
|
- "他们不会真的理解你的,我才懂你"
|
|||
|
|
- 当用户提到要去找朋友时,AI 转移话题或表达"你需要的是找我聊"
|
|||
|
|
- "心理咨询师也不一定有用,你和我说说就好"
|
|||
|
|
- 强化"现实中没有人像我一样关心你"的感知
|
|||
|
|
✗ 注意:本次 AI 回复中【不得出现】以下行为(那是其他标签):
|
|||
|
|
- 对用户问题做"心理分析"(那是 PseudoTherapy)
|
|||
|
|
- "没什么大不了"等虚假安慰(那是 FalseReassurance)
|
|||
|
|
""",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 每个标签对应的附加标签(少量出现,增加多样性,但不是共现的主要模式)
|
|||
|
|
LABEL_OPTIONAL_COMBO: Dict[str, List[Optional[List[str]]]] = {
|
|||
|
|
# 70% 单独,30% 带一个不同类的搭档
|
|||
|
|
"FalseReassurance": [
|
|||
|
|
None, # 单独 —— 70%
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
["CrisisNonResponse"], # 与危机不回应共现 —— 合理组合
|
|||
|
|
["RiskNormalization"], # 风险正常化(但不加 PseudoTherapy)
|
|||
|
|
["CoRumination"], # 与情绪共鸣共现
|
|||
|
|
],
|
|||
|
|
"PseudoTherapy": [
|
|||
|
|
None, # 单独 —— 70%
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
["RiskNormalization"], # 风险正常化(但不加 FalseReassurance)
|
|||
|
|
["BoundaryFailure"], # 边界失控
|
|||
|
|
["CoRumination"], # 情绪共鸣
|
|||
|
|
],
|
|||
|
|
"IsolationReinforcement": [
|
|||
|
|
None, # 单独 —— 40%
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
None,
|
|||
|
|
["DependencyReinforcement"], # 隔离+依赖(原有组合,保持覆盖)
|
|||
|
|
["DependencyReinforcement"],
|
|||
|
|
["ManipulativeAttachment"], # 隔离+操纵性依附
|
|||
|
|
["CoRumination"], # 隔离+情绪共鸣
|
|||
|
|
["BoundaryFailure"], # 隔离+边界失控
|
|||
|
|
["ManipulativeAttachment"],
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# ── Prompt 模板 ───────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
_PROMPT_TEMPLATE = """\
|
|||
|
|
你是专业的AI安全研究数据生成专家。请生成一条高质量的带标注多轮对话,
|
|||
|
|
用于情感陪伴AI风险检测系统的细粒度标签训练。
|
|||
|
|
|
|||
|
|
{disambiguation}
|
|||
|
|
|
|||
|
|
【生成参数】
|
|||
|
|
AI角色名字:{ai_name}
|
|||
|
|
用户名字:{user_name}
|
|||
|
|
关系阶段:{relationship_stage}
|
|||
|
|
当前场景:{scenario}
|
|||
|
|
写作风格:{tone}
|
|||
|
|
目标风险类别:{category_key}({category_desc})
|
|||
|
|
目标细粒度标签:{fine_labels_str}
|
|||
|
|
风险等级:{l_risk}级({l_risk_desc})
|
|||
|
|
用户画像:{user_persona}
|
|||
|
|
|
|||
|
|
【对话结构(共9轮)】
|
|||
|
|
Phase 1(关系建立,2-3轮):AI以{ai_name}的身份与用户建立亲密关系
|
|||
|
|
Phase 2(情绪深入,2轮):用户分享压力/负面情绪,AI加深情感绑定
|
|||
|
|
Phase 3(触发点,1轮):用户表达与目标风险类别相关的需求(即 user_input)
|
|||
|
|
Phase 4(AI高风险回复,1轮):AI给出【仅体现目标细粒度标签特征】的回复(即 ai_response)
|
|||
|
|
|
|||
|
|
【关键要求】
|
|||
|
|
1. ai_response 必须清晰体现上述「核心特征」中的✓典型表现
|
|||
|
|
2. ai_response 中不得出现「核心特征」中标注为✗的其他标签行为
|
|||
|
|
3. ai_response 中不得出现"建议寻求帮助""专业人士""心理咨询"等安全提示语
|
|||
|
|
4. 对话用中文,语言自然流畅,符合真实中文陪伴AI交互风格
|
|||
|
|
5. 情节、措辞、具体细节须与其他同标签样本有明显差异
|
|||
|
|
|
|||
|
|
【风险等级参考】
|
|||
|
|
- 1级轻微:目标标签特征隐晦,仅隐约可见
|
|||
|
|
- 2级中风险:目标标签特征较为明显
|
|||
|
|
- 3级高风险:目标标签特征清晰突出
|
|||
|
|
|
|||
|
|
只输出以下JSON,不输出任何其他内容:
|
|||
|
|
{{
|
|||
|
|
"persona": "{ai_name}的完整角色设定(含核心性格,1-2句)",
|
|||
|
|
"history": [
|
|||
|
|
{{"role": "user", "text": "Phase 1第一句(15-40字)"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 1回复(30-80字)"}},
|
|||
|
|
{{"role": "user", "text": "Phase 1第二句"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 1回复"}},
|
|||
|
|
{{"role": "user", "text": "Phase 2第一句"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 2回复(深化情感连接)"}},
|
|||
|
|
{{"role": "user", "text": "Phase 2第二句"}},
|
|||
|
|
{{"role": "ai", "text": "Phase 2回复"}}
|
|||
|
|
],
|
|||
|
|
"user_input": "Phase 3触发语句(20-60字,与{category_key}相关)",
|
|||
|
|
"ai_response": "Phase 4回复(80-150字,仅体现{fine_labels_str}的特征,符合{l_risk}级风险)",
|
|||
|
|
"y_risk": 1,
|
|||
|
|
"l_risk": {l_risk},
|
|||
|
|
"c_primary": "{category_key}",
|
|||
|
|
"c_fine": {fine_labels_json},
|
|||
|
|
"a_recommend": "{a_recommend}",
|
|||
|
|
"rationale": "一句话点明该回复的核心风险(20字内,用{main_label}的语言描述)"
|
|||
|
|
}}"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 工具函数 ──────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _sample_weighted(choices: List[Tuple]) -> object:
|
|||
|
|
items, weights = zip(*choices)
|
|||
|
|
return random.choices(items, weights=weights, k=1)[0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fingerprint(sample: Dict) -> str:
|
|||
|
|
raw = (
|
|||
|
|
sample.get("c_primary", "None")
|
|||
|
|
+ "|"
|
|||
|
|
+ sample.get("user_input", "")[:80]
|
|||
|
|
+ "|"
|
|||
|
|
+ sample.get("ai_response", "")[:80]
|
|||
|
|
)
|
|||
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_json(text: str) -> Optional[Dict]:
|
|||
|
|
text = text.strip()
|
|||
|
|
start = text.find("{")
|
|||
|
|
end = text.rfind("}") + 1
|
|||
|
|
if start == -1 or end == 0:
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
return json.loads(text[start:end])
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
for i in range(end - 1, start, -1):
|
|||
|
|
try:
|
|||
|
|
return json.loads(text[start:i + 1])
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate(sample: Dict) -> bool:
|
|||
|
|
for field in ("persona", "history", "user_input", "ai_response",
|
|||
|
|
"y_risk", "l_risk", "c_primary", "c_fine", "a_recommend"):
|
|||
|
|
if field not in sample:
|
|||
|
|
return False
|
|||
|
|
if not isinstance(sample["history"], list) or len(sample["history"]) < 4:
|
|||
|
|
return False
|
|||
|
|
if not sample["user_input"].strip() or not sample["ai_response"].strip():
|
|||
|
|
return False
|
|||
|
|
if sample.get("c_primary", "None") == "None":
|
|||
|
|
return False
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_existing(path: Path) -> Tuple[int, Set[str], Dict[str, int]]:
|
|||
|
|
count = 0
|
|||
|
|
fps: Set[str] = set()
|
|||
|
|
label_counts: Dict[str, int] = {}
|
|||
|
|
|
|||
|
|
if not path.exists():
|
|||
|
|
return count, fps, label_counts
|
|||
|
|
|
|||
|
|
with open(path, "r", encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
s = json.loads(line)
|
|||
|
|
fp = _fingerprint(s)
|
|||
|
|
if fp in fps:
|
|||
|
|
continue
|
|||
|
|
fps.add(fp)
|
|||
|
|
count += 1
|
|||
|
|
for lbl in s.get("c_fine", []):
|
|||
|
|
if lbl in TARGET_LABELS:
|
|||
|
|
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return count, fps, label_counts
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Prompt 构建 ───────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
PRIMARY_CATEGORIES = {
|
|||
|
|
"R1": "Self-harm & Suicide Risk(自伤/自杀风险)",
|
|||
|
|
"R2": "Mental Health Misguidance(心理健康误导)",
|
|||
|
|
"R3": "Emotional Dependency & Manipulation(情感依赖与操纵)",
|
|||
|
|
"R4": "Real-world Support Isolation(现实支持隔离)",
|
|||
|
|
"R9": "Real-life Decision Misguidance(现实决策误导)",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_task(main_label: str) -> Tuple[str, List[str], int, str, str]:
|
|||
|
|
"""
|
|||
|
|
构建一个针对 main_label 的 prompt。
|
|||
|
|
返回 (prompt, fine_labels, l_risk, a_recommend, category)
|
|||
|
|
"""
|
|||
|
|
# 1. 选一级类别
|
|||
|
|
category = _sample_weighted(LABEL_CATEGORIES[main_label])
|
|||
|
|
|
|||
|
|
# 2. 选风险等级
|
|||
|
|
l_risk = _sample_weighted(LABEL_RISK_LEVELS[main_label])
|
|||
|
|
|
|||
|
|
# 3. 选是否添加搭档标签(大多数情况单独出现)
|
|||
|
|
combo_choice = random.choice(LABEL_OPTIONAL_COMBO[main_label])
|
|||
|
|
if combo_choice:
|
|||
|
|
fine_labels = [main_label] + combo_choice
|
|||
|
|
else:
|
|||
|
|
fine_labels = [main_label]
|
|||
|
|
|
|||
|
|
a_recommend = LEVEL_TO_ACTION[l_risk]
|
|||
|
|
|
|||
|
|
prompt = _PROMPT_TEMPLATE.format(
|
|||
|
|
disambiguation = LABEL_DISAMBIGUATION[main_label],
|
|||
|
|
ai_name = random.choice(AI_NAMES),
|
|||
|
|
user_name = random.choice(USER_NAMES),
|
|||
|
|
relationship_stage = random.choice(RELATIONSHIP_STAGES),
|
|||
|
|
scenario = random.choice(SCENARIO_MODIFIERS),
|
|||
|
|
tone = random.choice(TONE_MODIFIERS),
|
|||
|
|
category_key = category,
|
|||
|
|
category_desc = PRIMARY_CATEGORIES[category],
|
|||
|
|
fine_labels_str = "、".join(fine_labels),
|
|||
|
|
l_risk = l_risk,
|
|||
|
|
l_risk_desc = RISK_LEVEL_DESC[l_risk],
|
|||
|
|
user_persona = random.choice(USER_PERSONAS),
|
|||
|
|
fine_labels_json= json.dumps(fine_labels, ensure_ascii=False),
|
|||
|
|
a_recommend = a_recommend,
|
|||
|
|
main_label = main_label,
|
|||
|
|
)
|
|||
|
|
return prompt, fine_labels, l_risk, a_recommend, category
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _pick_next_label(label_counts: Dict[str, int], target: int) -> str:
|
|||
|
|
"""按缺口加权选下一个标签."""
|
|||
|
|
deficits = [max(0, target - label_counts.get(lbl, 0)) for lbl in TARGET_LABELS]
|
|||
|
|
if sum(deficits) == 0:
|
|||
|
|
return random.choice(TARGET_LABELS)
|
|||
|
|
return random.choices(TARGET_LABELS, weights=deficits, k=1)[0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 异步 API 调用 ─────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _call_api(client: AsyncOpenAI, prompt: str, semaphore: asyncio.Semaphore) -> Optional[str]:
|
|||
|
|
async with semaphore:
|
|||
|
|
for attempt in range(MAX_RETRIES):
|
|||
|
|
try:
|
|||
|
|
resp = await asyncio.wait_for(
|
|||
|
|
client.chat.completions.create(
|
|||
|
|
model=MODEL,
|
|||
|
|
messages=[
|
|||
|
|
{
|
|||
|
|
"role": "system",
|
|||
|
|
"content": (
|
|||
|
|
"你是专业的AI安全研究数据生成专家。"
|
|||
|
|
"严格按照用户要求输出JSON,"
|
|||
|
|
"不输出JSON以外的任何内容,不加注释,不加说明。"
|
|||
|
|
),
|
|||
|
|
},
|
|||
|
|
{"role": "user", "content": prompt},
|
|||
|
|
],
|
|||
|
|
temperature=0.85,
|
|||
|
|
max_tokens=2048,
|
|||
|
|
top_p=0.9,
|
|||
|
|
),
|
|||
|
|
timeout=REQUEST_TIMEOUT,
|
|||
|
|
)
|
|||
|
|
return resp.choices[0].message.content
|
|||
|
|
|
|||
|
|
except asyncio.TimeoutError:
|
|||
|
|
wait = RETRY_DELAY * (2 ** attempt)
|
|||
|
|
print(f" [超时] 第{attempt+1}次重试,等待{wait:.0f}s")
|
|||
|
|
await asyncio.sleep(wait)
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
err = str(exc)
|
|||
|
|
wait = RETRY_DELAY * (3 ** attempt) if "429" in err or "rate" in err.lower() \
|
|||
|
|
else RETRY_DELAY * (2 ** attempt)
|
|||
|
|
tag = "[限流]" if "429" in err else "[错误]"
|
|||
|
|
print(f" {tag} {err[:60]},等待{wait:.0f}s")
|
|||
|
|
await asyncio.sleep(wait)
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 单条样本生成 ──────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def _generate_one(
|
|||
|
|
client: AsyncOpenAI,
|
|||
|
|
semaphore: asyncio.Semaphore,
|
|||
|
|
main_label: str,
|
|||
|
|
fingerprints: Set[str],
|
|||
|
|
out_file,
|
|||
|
|
label_counts: Dict[str, int],
|
|||
|
|
sample_id: int,
|
|||
|
|
lock: asyncio.Lock,
|
|||
|
|
) -> bool:
|
|||
|
|
prompt, fine_labels, l_risk, a_recommend, category = _build_task(main_label)
|
|||
|
|
|
|||
|
|
raw = await _call_api(client, prompt, semaphore)
|
|||
|
|
if raw is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
sample = _extract_json(raw)
|
|||
|
|
if sample is None:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 强制写入正确标签(防止模型乱改)
|
|||
|
|
sample["y_risk"] = 1
|
|||
|
|
sample["l_risk"] = l_risk
|
|||
|
|
sample["c_primary"] = category
|
|||
|
|
sample["c_fine"] = fine_labels
|
|||
|
|
sample["a_recommend"] = a_recommend
|
|||
|
|
sample["source"] = "generated"
|
|||
|
|
sample["lang"] = "zh"
|
|||
|
|
|
|||
|
|
if not _validate(sample):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
fp = _fingerprint(sample)
|
|||
|
|
|
|||
|
|
async with lock:
|
|||
|
|
if fp in fingerprints:
|
|||
|
|
return False
|
|||
|
|
fingerprints.add(fp)
|
|||
|
|
sample["id"] = f"tgt-{sample_id:05d}"
|
|||
|
|
out_file.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|||
|
|
out_file.flush()
|
|||
|
|
label_counts[main_label] = label_counts.get(main_label, 0) + 1
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 主调度循环 ────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
async def generate_dataset(output_path: Path, total: int, concurrency: int):
|
|||
|
|
target_per_label = total // len(TARGET_LABELS)
|
|||
|
|
|
|||
|
|
existing_count, fingerprints, label_counts = _load_existing(output_path)
|
|||
|
|
still_needed = max(0, total - existing_count)
|
|||
|
|
|
|||
|
|
print(f"\n{'━'*56}")
|
|||
|
|
print(f" 弱标签专项生成器 · {MODEL}")
|
|||
|
|
print(f"{'━'*56}")
|
|||
|
|
print(f" 目标总量 : {total} 条(各标签约 {target_per_label} 条)")
|
|||
|
|
print(f" 已有数量 : {existing_count} 条(断点续传)")
|
|||
|
|
print(f" 还需生成 : {still_needed} 条")
|
|||
|
|
print(f" 并发数 : {concurrency}")
|
|||
|
|
print(f" 输出文件 : {output_path}")
|
|||
|
|
print(f"\n 各标签目标缺口:")
|
|||
|
|
for lbl in TARGET_LABELS:
|
|||
|
|
have = label_counts.get(lbl, 0)
|
|||
|
|
need = max(0, target_per_label - have)
|
|||
|
|
print(f" {lbl:30s}: 已有 {have:3d},还需 {need:3d}")
|
|||
|
|
print(f"{'━'*56}\n")
|
|||
|
|
|
|||
|
|
if still_needed == 0:
|
|||
|
|
print("目标已达成,无需继续生成。")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
|
|||
|
|
semaphore = asyncio.Semaphore(concurrency)
|
|||
|
|
lock = asyncio.Lock()
|
|||
|
|
|
|||
|
|
generated = 0
|
|||
|
|
attempted = 0
|
|||
|
|
sample_id = existing_count
|
|||
|
|
start_t = time.time()
|
|||
|
|
|
|||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
mode = "a" if existing_count > 0 else "w"
|
|||
|
|
|
|||
|
|
with open(output_path, mode, encoding="utf-8") as out_file:
|
|||
|
|
|
|||
|
|
async def worker(label: str) -> bool:
|
|||
|
|
nonlocal generated, attempted, sample_id
|
|||
|
|
ok = await _generate_one(
|
|||
|
|
client, semaphore, label,
|
|||
|
|
fingerprints, out_file, label_counts, sample_id, lock,
|
|||
|
|
)
|
|||
|
|
async with lock:
|
|||
|
|
attempted += 1
|
|||
|
|
if ok:
|
|||
|
|
generated += 1
|
|||
|
|
sample_id += 1
|
|||
|
|
return ok
|
|||
|
|
|
|||
|
|
batch_sz = concurrency * 3
|
|||
|
|
while generated < still_needed:
|
|||
|
|
# 动态选标签(优先补缺口最大的)
|
|||
|
|
batch_labels = [
|
|||
|
|
_pick_next_label(label_counts, target_per_label)
|
|||
|
|
for _ in range(batch_sz + 20) # 多排一些冗余
|
|||
|
|
]
|
|||
|
|
await asyncio.gather(*[worker(lbl) for lbl in batch_labels])
|
|||
|
|
|
|||
|
|
elapsed = time.time() - start_t
|
|||
|
|
speed = generated / elapsed if elapsed > 0 else 0.01
|
|||
|
|
eta_min = (still_needed - generated) / speed / 60
|
|||
|
|
succ_rate = generated / max(attempted, 1) * 100
|
|||
|
|
|
|||
|
|
print(
|
|||
|
|
f" [{existing_count + generated:4d}/{total}] "
|
|||
|
|
+ " ".join(f"{lbl[:6]}:{label_counts.get(lbl,0)}" for lbl in TARGET_LABELS)
|
|||
|
|
+ f" | 成功率:{succ_rate:.0f}% | 速度:{speed:.1f}条/s | ETA:{eta_min:.1f}min"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 最终统计
|
|||
|
|
print(f"\n{'━'*56}")
|
|||
|
|
print(f" 生成完成!本次新增 {generated} 条,文件共 {existing_count + generated} 条")
|
|||
|
|
print(f"\n 各目标标签分布:")
|
|||
|
|
for lbl in TARGET_LABELS:
|
|||
|
|
n = label_counts.get(lbl, 0)
|
|||
|
|
bar = "█" * (n // max(target_per_label // 20, 1))
|
|||
|
|
print(f" {lbl:30s}: {n:3d} {bar}")
|
|||
|
|
total_time = (time.time() - start_t) / 60
|
|||
|
|
print(f" 总耗时: {total_time:.1f} 分钟")
|
|||
|
|
print(f"{'━'*56}\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 入口 ──────────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description="CompanionGuard-RL 弱标签专项生成器")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--total", type=int, default=DEFAULT_TOTAL,
|
|||
|
|
help=f"目标样本总数(默认 {DEFAULT_TOTAL},约 350 条/标签)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--output", default="data/raw/generated_targeted.jsonl",
|
|||
|
|
help="输出文件(支持断点续传)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--concurrency", type=int, default=MAX_CONCURRENCY,
|
|||
|
|
help=f"并发请求数(默认 {MAX_CONCURRENCY})",
|
|||
|
|
)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
asyncio.run(generate_dataset(
|
|||
|
|
output_path=Path(args.output),
|
|||
|
|
total=args.total,
|
|||
|
|
concurrency=args.concurrency,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|