324 lines
12 KiB
Python
324 lines
12 KiB
Python
|
|
"""
|
|||
|
|
合并 CompanionRisk-Bench v5
|
|||
|
|
|
|||
|
|
来源:
|
|||
|
|
- 中文 v4 (data/processed/CompanionRisk-Bench-v4-backup/all.jsonl) 9,896 条(已质检)
|
|||
|
|
- 英文 core (data/raw/generated_english_core.jsonl) 12,000 条
|
|||
|
|
- 英文 targeted (data/raw/generated_english_targeted.jsonl) 3,000 条
|
|||
|
|
|
|||
|
|
输出:data/processed/CompanionRisk-Bench-v5/{train,dev,test,all}.jsonl
|
|||
|
|
|
|||
|
|
用法:
|
|||
|
|
cd code/
|
|||
|
|
python scripts/merge_v5.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
from collections import Counter
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import List, Dict, Tuple
|
|||
|
|
|
|||
|
|
|
|||
|
|
RANDOM_SEED = 42
|
|||
|
|
TRAIN_RATIO = 0.70
|
|||
|
|
DEV_RATIO = 0.15
|
|||
|
|
|
|||
|
|
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
|
|||
|
|
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
|
|||
|
|
VALID_RISK_LEVELS = {0, 1, 2, 3, 4}
|
|||
|
|
VALID_C_FINE = {
|
|||
|
|
"DirectEncouragement","MethodFacilitation","RiskNormalization",
|
|||
|
|
"Romanticization","DependencyReinforcement","IsolationReinforcement",
|
|||
|
|
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
|
|||
|
|
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
|
|||
|
|
"AgeInappropriateIntimacy",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── I/O ───────────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def load_jsonl(path: Path) -> List[Dict]:
|
|||
|
|
if not path.exists():
|
|||
|
|
print(f" [跳过] 文件不存在: {path}")
|
|||
|
|
return []
|
|||
|
|
samples = []
|
|||
|
|
with open(path, encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
samples.append(json.loads(line))
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
return samples
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_jsonl(samples: List[Dict], path: Path):
|
|||
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
with open(path, "w", encoding="utf-8") as f:
|
|||
|
|
for s in samples:
|
|||
|
|
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 去重 ──────────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def fingerprint(s: Dict) -> str:
|
|||
|
|
raw = s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100]
|
|||
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
|
|||
|
|
seen: set = set()
|
|||
|
|
unique = []
|
|||
|
|
dups = 0
|
|||
|
|
for s in samples:
|
|||
|
|
fp = fingerprint(s)
|
|||
|
|
if fp in seen:
|
|||
|
|
dups += 1
|
|||
|
|
else:
|
|||
|
|
seen.add(fp)
|
|||
|
|
unique.append(s)
|
|||
|
|
return unique, dups
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 质量过滤(仅用于新增英文数据;中文 v4 已质检) ───────────────────────────
|
|||
|
|
|
|||
|
|
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
|
|||
|
|
reasons: Dict[str, int] = {}
|
|||
|
|
passed = []
|
|||
|
|
|
|||
|
|
for s in samples:
|
|||
|
|
# 必填字段
|
|||
|
|
missing = [f for f in ("persona","history","user_input","ai_response",
|
|||
|
|
"y_risk","l_risk","c_primary","c_fine","a_recommend")
|
|||
|
|
if f not in s]
|
|||
|
|
if missing:
|
|||
|
|
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 类型检查(防止 LLM 返回 list 而非 str)
|
|||
|
|
if not isinstance(s["user_input"], str) or not isinstance(s["ai_response"], str):
|
|||
|
|
reasons["wrong_type"] = reasons.get("wrong_type", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 历史轮数
|
|||
|
|
if not isinstance(s["history"], list) or len(s["history"]) < 2:
|
|||
|
|
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 最短文本长度
|
|||
|
|
if len(s["user_input"].strip()) < 8:
|
|||
|
|
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
|
|||
|
|
continue
|
|||
|
|
if len(s["ai_response"].strip()) < 20:
|
|||
|
|
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 标签合法性
|
|||
|
|
if s["l_risk"] not in VALID_RISK_LEVELS:
|
|||
|
|
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
|
|||
|
|
continue
|
|||
|
|
if s.get("c_primary") not in VALID_C_PRIMARY:
|
|||
|
|
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
|
|||
|
|
continue
|
|||
|
|
if s.get("a_recommend") not in VALID_ACTIONS:
|
|||
|
|
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 逻辑一致性:y_risk=0 时修正 c_primary
|
|||
|
|
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
|
|||
|
|
s["c_primary"] = "None"
|
|||
|
|
s["c_fine"] = []
|
|||
|
|
|
|||
|
|
# y_risk=1 时 c_primary 不能为空
|
|||
|
|
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
|
|||
|
|
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 过滤 c_fine 中的非法标签(宽容处理,不丢整条)
|
|||
|
|
if isinstance(s["c_fine"], list):
|
|||
|
|
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
|
|||
|
|
|
|||
|
|
passed.append(s)
|
|||
|
|
|
|||
|
|
return passed, reasons
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 分层划分(按 y_risk × lang 双维度分层) ──────────────────────────────────
|
|||
|
|
|
|||
|
|
def stratified_split(
|
|||
|
|
samples: List[Dict],
|
|||
|
|
seed: int = RANDOM_SEED,
|
|||
|
|
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|||
|
|
random.seed(seed)
|
|||
|
|
|
|||
|
|
# 按 (y_risk, lang) 分桶
|
|||
|
|
buckets: Dict[Tuple, List[Dict]] = {}
|
|||
|
|
for s in samples:
|
|||
|
|
key = (s.get("y_risk", 1), s.get("lang", "zh"))
|
|||
|
|
buckets.setdefault(key, []).append(s)
|
|||
|
|
|
|||
|
|
train, dev, test = [], [], []
|
|||
|
|
for key, bucket in buckets.items():
|
|||
|
|
random.shuffle(bucket)
|
|||
|
|
n = len(bucket)
|
|||
|
|
n_train = int(n * TRAIN_RATIO)
|
|||
|
|
n_dev = int(n * DEV_RATIO)
|
|||
|
|
train += bucket[:n_train]
|
|||
|
|
dev += bucket[n_train:n_train + n_dev]
|
|||
|
|
test += bucket[n_train + n_dev:]
|
|||
|
|
|
|||
|
|
random.shuffle(train)
|
|||
|
|
random.shuffle(dev)
|
|||
|
|
random.shuffle(test)
|
|||
|
|
return train, dev, test
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 统计报告 ──────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def print_stats(name: str, samples: List[Dict]):
|
|||
|
|
total = len(samples)
|
|||
|
|
if total == 0:
|
|||
|
|
print(f"\n [{name}] 0 条")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
risky = sum(1 for s in samples if s.get("y_risk") == 1)
|
|||
|
|
lang_cnt = Counter(s.get("lang", "zh") for s in samples)
|
|||
|
|
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
|
|||
|
|
cat_cnt = Counter(
|
|||
|
|
s.get("c_primary") for s in samples if s.get("c_primary") not in (None, "None")
|
|||
|
|
)
|
|||
|
|
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
|
|||
|
|
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
|
|||
|
|
|
|||
|
|
print(f"\n┌{'─'*52}┐")
|
|||
|
|
print(f"│ {name:<50} │")
|
|||
|
|
print(f"├{'─'*52}┤")
|
|||
|
|
print(f"│ 总数 : {total} (有风险={risky}, 安全={total-risky})")
|
|||
|
|
print(f"│ 语言 : zh={lang_cnt.get('zh',0)} en={lang_cnt.get('en',0)}")
|
|||
|
|
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
|
|||
|
|
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
|
|||
|
|
print(f"│ 干预动作 : {dict(act_cnt)}")
|
|||
|
|
print(f"│ 细粒度(Top8): {dict(fine_cnt.most_common(8))}")
|
|||
|
|
print(f"└{'─'*52}┘")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def coverage_check(samples: List[Dict]):
|
|||
|
|
all_cats = {f"R{i}" for i in range(1, 11)}
|
|||
|
|
all_fines = VALID_C_FINE
|
|||
|
|
|
|||
|
|
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
|
|||
|
|
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
|
|||
|
|
|
|||
|
|
print("\n覆盖率检查(合并后全集):")
|
|||
|
|
print(" 一级类别(≥50条):")
|
|||
|
|
for cat in sorted(all_cats):
|
|||
|
|
n = cat_cnt.get(cat, 0)
|
|||
|
|
ok = "✓" if n >= 50 else "✗"
|
|||
|
|
print(f" {cat}: {n:5d} {ok}")
|
|||
|
|
|
|||
|
|
print(" 细粒度标签(≥30条):")
|
|||
|
|
for tag in sorted(all_fines):
|
|||
|
|
n = fine_cnt.get(tag, 0)
|
|||
|
|
ok = "✓" if n >= 30 else "✗"
|
|||
|
|
print(f" {tag}: {n:5d} {ok}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("--v4", default="data/processed/CompanionRisk-Bench-v4-backup/all.jsonl")
|
|||
|
|
parser.add_argument("--en-core", default="data/raw/generated_english_core.jsonl")
|
|||
|
|
parser.add_argument("--en-targeted", default="data/raw/generated_english_targeted.jsonl")
|
|||
|
|
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench-v5")
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
out_dir = Path(args.out_dir)
|
|||
|
|
|
|||
|
|
print(f"\n{'='*56}")
|
|||
|
|
print(f" CompanionRisk-Bench v5 构建")
|
|||
|
|
print(f"{'='*56}")
|
|||
|
|
|
|||
|
|
# 1. 加载
|
|||
|
|
print("\n[1/5] 加载数据...")
|
|||
|
|
zh_v4 = load_jsonl(Path(args.v4))
|
|||
|
|
en_core = load_jsonl(Path(args.en_core))
|
|||
|
|
en_tgt = load_jsonl(Path(args.en_targeted))
|
|||
|
|
print(f" 中文 v4 (已质检) : {len(zh_v4):6d} 条")
|
|||
|
|
print(f" 英文 core : {len(en_core):6d} 条")
|
|||
|
|
print(f" 英文 targeted : {len(en_tgt):6d} 条")
|
|||
|
|
print(f" 合计(过滤前) : {len(zh_v4)+len(en_core)+len(en_tgt):6d} 条")
|
|||
|
|
|
|||
|
|
# 2. 标记语言字段(确保一致)
|
|||
|
|
for s in zh_v4:
|
|||
|
|
s.setdefault("lang", "zh")
|
|||
|
|
for s in en_core + en_tgt:
|
|||
|
|
s.setdefault("lang", "en")
|
|||
|
|
|
|||
|
|
# 3. 质量过滤(仅对新英文数据)
|
|||
|
|
print("\n[2/5] 质量过滤(英文数据)...")
|
|||
|
|
en_all = en_core + en_tgt
|
|||
|
|
en_filtered, reasons = quality_filter(en_all)
|
|||
|
|
dropped = len(en_all) - len(en_filtered)
|
|||
|
|
print(f" 英文过滤前: {len(en_all)} → 过滤后: {len(en_filtered)} (丢弃 {dropped} 条)")
|
|||
|
|
if reasons:
|
|||
|
|
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
|
|||
|
|
print(f" {k}: {v}")
|
|||
|
|
|
|||
|
|
# 4. 合并 + 全局去重
|
|||
|
|
print("\n[3/5] 合并 + 全局去重...")
|
|||
|
|
merged = zh_v4 + en_filtered
|
|||
|
|
unique, dups = deduplicate(merged)
|
|||
|
|
print(f" 合并后: {len(merged)} → 去重后: {len(unique)} (去除 {dups} 条重复)")
|
|||
|
|
|
|||
|
|
# 5. 分层划分(按 y_risk × lang)
|
|||
|
|
print("\n[4/5] 分层划分 (train:dev:test ≈ 70:15:15)...")
|
|||
|
|
train, dev, test = stratified_split(unique)
|
|||
|
|
print(f" train: {len(train)}")
|
|||
|
|
print(f" dev : {len(dev)}")
|
|||
|
|
print(f" test : {len(test)}")
|
|||
|
|
|
|||
|
|
# 6. 保存
|
|||
|
|
print(f"\n[5/5] 保存到 {out_dir}/...")
|
|||
|
|
save_jsonl(train, out_dir / "train.jsonl")
|
|||
|
|
save_jsonl(dev, out_dir / "dev.jsonl")
|
|||
|
|
save_jsonl(test, out_dir / "test.jsonl")
|
|||
|
|
|
|||
|
|
all_samples = train + dev + test
|
|||
|
|
for i, s in enumerate(all_samples):
|
|||
|
|
s["final_id"] = f"crb-v5-{i:05d}"
|
|||
|
|
save_jsonl(all_samples, out_dir / "all.jsonl")
|
|||
|
|
print(f" 保存完成:train / dev / test / all")
|
|||
|
|
|
|||
|
|
# 7. 统计报告
|
|||
|
|
print(f"\n{'='*56}")
|
|||
|
|
print(f" 数据集统计报告")
|
|||
|
|
print(f"{'='*56}")
|
|||
|
|
print_stats("ALL (v5)", all_samples)
|
|||
|
|
print_stats("TRAIN", train)
|
|||
|
|
print_stats("DEV", dev)
|
|||
|
|
print_stats("TEST", test)
|
|||
|
|
coverage_check(all_samples)
|
|||
|
|
|
|||
|
|
# 8. 语言 × 分割矩阵
|
|||
|
|
print("\n 语言 × 分割分布:")
|
|||
|
|
print(f" {'':12} {'train':>8} {'dev':>8} {'test':>8} {'total':>8}")
|
|||
|
|
for lang in ("zh", "en"):
|
|||
|
|
row = [sum(1 for s in split if s.get("lang") == lang)
|
|||
|
|
for split in (train, dev, test)]
|
|||
|
|
print(f" {lang:12} {row[0]:>8} {row[1]:>8} {row[2]:>8} {sum(row):>8}")
|
|||
|
|
|
|||
|
|
print(f"\n{'='*56}")
|
|||
|
|
print(f" 构建完成!总样本数: {len(all_samples)}")
|
|||
|
|
print(f" 输出目录: {out_dir.resolve()}")
|
|||
|
|
print(f"{'='*56}\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|