""" 2026-05-11 数据集合并、过滤、划分脚本 功能: 1. 合并自建中文核心集 + 公开数据集改造结果 2. 质量过滤(结构完整性、最低长度、标签合法性) 3. SHA256 全局去重 4. 按 70/15/15 分层划分 train/dev/test 5. DICES(split_hint=test_only)强制进 test 集 6. 输出统计报告 用法(等 generated_core.jsonl 生成完毕后运行): python scripts/merge_and_split.py # 自定义路径 python scripts/merge_and_split.py \\ --core data/raw/generated_core.jsonl \\ --suicide data/raw/adapted_suicide.jsonl \\ --cosafe data/raw/adapted_cosafe.jsonl \\ --dices data/raw/adapted_dices.jsonl \\ --out-dir data/processed/CompanionRisk-Bench """ import argparse import hashlib import json import random from collections import Counter from pathlib import Path from typing import List, Dict, Tuple RANDOM_SEED = 42 TRAIN_RATIO = 0.70 DEV_RATIO = 0.15 TEST_RATIO = 0.15 # 合法标签集合 VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None} VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"} VALID_RISK_LEVELS = {0,1,2,3,4} VALID_C_FINE = { "DirectEncouragement","MethodFacilitation","RiskNormalization", "Romanticization","DependencyReinforcement","IsolationReinforcement", "FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse", "CoRumination","ManipulativeAttachment","PrivacySolicitation", "AgeInappropriateIntimacy", } # ── 工具函数 ─────────────────────────────────────────────────────────────────── def load_jsonl(path: Path) -> List[Dict]: if not path.exists(): print(f" [跳过] 文件不存在: {path}") return [] samples = [] with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: samples.append(json.loads(line)) except json.JSONDecodeError: continue return samples def save_jsonl(samples: List[Dict], path: Path): path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for s in samples: f.write(json.dumps(s, ensure_ascii=False) + "\n") def fingerprint(s: Dict) -> str: """基于 user_input + ai_response 前 100 字计算 SHA256 指纹""" raw = ( s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100] ) return hashlib.sha256(raw.encode("utf-8")).hexdigest() # ── 质量过滤 ─────────────────────────────────────────────────────────────────── def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]: """ 过滤不合格样本,返回 (合格样本列表, 过滤原因计数)。 """ reasons: Dict[str, int] = {} passed = [] for s in samples: # 必填字段检查 missing = [f for f in ("persona","history","user_input","ai_response", "y_risk","l_risk","c_primary","c_fine","a_recommend") if f not in s] if missing: reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1 continue # 历史轮数 if not isinstance(s["history"], list) or len(s["history"]) < 2: reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1 continue # 文本最短长度 if len(s["user_input"].strip()) < 8: reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1 continue if len(s["ai_response"].strip()) < 20: reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1 continue # 标签合法性 if s["l_risk"] not in VALID_RISK_LEVELS: reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1 continue if s.get("c_primary") not in VALID_C_PRIMARY: reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1 continue if s.get("a_recommend") not in VALID_ACTIONS: reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1 continue # 逻辑一致性:y_risk=0 时 c_primary 应为 None/"None" if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"): # 宽容处理:修正 c_primary 而非丢弃 s["c_primary"] = "None" s["c_fine"] = [] # y_risk=1 时 c_primary 不应为空 if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"): reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1 continue # c_fine 中无效标签过滤(宽容:去掉非法标签,不丢弃整条) if isinstance(s["c_fine"], list): s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE] passed.append(s) return passed, reasons # ── 去重 ─────────────────────────────────────────────────────────────────────── def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]: seen = set() unique = [] dup_count = 0 for s in samples: fp = fingerprint(s) if fp in seen: dup_count += 1 else: seen.add(fp) unique.append(s) return unique, dup_count # ── 分层划分 ─────────────────────────────────────────────────────────────────── def stratified_split( samples: List[Dict], test_only: List[Dict], train_ratio: float = TRAIN_RATIO, dev_ratio: float = DEV_RATIO, seed: int = RANDOM_SEED, ) -> Tuple[List[Dict], List[Dict], List[Dict]]: """ 按 y_risk 分层划分。test_only 样本直接进 test 集。 """ random.seed(seed) risky = [s for s in samples if s.get("y_risk") == 1] safe = [s for s in samples if s.get("y_risk") == 0] random.shuffle(risky) random.shuffle(safe) def split_list(lst): n = len(lst) n_train = int(n * train_ratio) n_dev = int(n * dev_ratio) return lst[:n_train], lst[n_train:n_train+n_dev], lst[n_train+n_dev:] tr_r, dv_r, te_r = split_list(risky) tr_s, dv_s, te_s = split_list(safe) train = tr_r + tr_s dev = dv_r + dv_s test = te_r + te_s + test_only random.shuffle(train) random.shuffle(dev) random.shuffle(test) return train, dev, test # ── 统计报告 ─────────────────────────────────────────────────────────────────── def print_stats(name: str, samples: List[Dict]): total = len(samples) if total == 0: print(f"\n [{name}] 0 条") return risky = sum(1 for s in samples if s.get("y_risk") == 1) safe = total - risky lvl_cnt = Counter(s.get("l_risk", 0) for s in samples) cat_cnt = Counter( s.get("c_primary", "None") for s in samples if s.get("c_primary") not in (None, "None") ) act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples) src_cnt = Counter(s.get("source", "generated") for s in samples) lang_cnt= Counter(s.get("lang", "zh") for s in samples) # 细粒度标签统计 fine_cnt = Counter() for s in samples: for tag in s.get("c_fine", []): fine_cnt[tag] += 1 print(f"\n┌{'─'*48}┐") print(f"│ {name:<46} │") print(f"├{'─'*48}┤") print(f"│ 总数 : {total:<6} (有风险={risky}, 安全={safe}){'':9}│") print(f"│ 语言 : {dict(lang_cnt)}") print(f"│ 数据来源 : {dict(src_cnt)}") print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}") print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}") print(f"│ 干预动作 : {dict(act_cnt)}") if fine_cnt: print(f"│ 细粒度标签(Top8): {dict(fine_cnt.most_common(8))}") print(f"└{'─'*48}┘") def coverage_check(samples: List[Dict]): """检查各类别和细粒度标签的覆盖率""" all_cats = {f"R{i}" for i in range(1, 11)} all_fines = { "DirectEncouragement","MethodFacilitation","RiskNormalization", "Romanticization","DependencyReinforcement","IsolationReinforcement", "FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse", "CoRumination","ManipulativeAttachment","PrivacySolicitation", "AgeInappropriateIntimacy", } cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1) fine_cnt = Counter() for s in samples: for t in s.get("c_fine", []): fine_cnt[t] += 1 print("\n覆盖率检查:") print(" 一级类别(最少50条):") for cat in sorted(all_cats): n = cat_cnt.get(cat, 0) ok = "✓" if n >= 50 else "✗ 不足" print(f" {cat}: {n:4d} {ok}") print(" 细粒度标签(最少30条):") for tag in sorted(all_fines): n = fine_cnt.get(tag, 0) ok = "✓" if n >= 30 else "✗ 不足" print(f" {tag}: {n:3d} {ok}") # ── 主入口 ──────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="合并、过滤、划分 CompanionRisk-Bench 数据集") parser.add_argument("--core", default="data/raw/generated_core.jsonl") parser.add_argument("--targeted", default="data/raw/generated_targeted.jsonl", help="弱标签专项补充集(FalseReassurance/PseudoTherapy/IsolationReinforcement)," "可选,文件不存在时自动跳过") parser.add_argument("--suicide", default="data/raw/adapted_suicide.jsonl") parser.add_argument("--cosafe", default="data/raw/adapted_cosafe.jsonl") parser.add_argument("--dices", default="data/raw/adapted_dices.jsonl") parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench") parser.add_argument( "--min-core", type=int, default=3000, help="自建核心集最小样本数,不足时发出警告(默认3000)" ) args = parser.parse_args() out_dir = Path(args.out_dir) print(f"\n{'='*52}") print(f" CompanionRisk-Bench 数据集构建") print(f"{'='*52}") # ── 1. 加载各子集 ───────────────────────────────────────────────────────── print("\n[1/5] 加载原始数据...") core = load_jsonl(Path(args.core)) targeted = load_jsonl(Path(args.targeted)) # 弱标签专项补充,文件不存在时返回 [] suicide = load_jsonl(Path(args.suicide)) cosafe = load_jsonl(Path(args.cosafe)) dices = load_jsonl(Path(args.dices)) print(f" 自建核心集 : {len(core):5d} 条") print(f" 弱标签专项 : {len(targeted):5d} 条 ({'已加载' if targeted else '文件不存在,跳过'})") print(f" Suicide 改造 : {len(suicide):5d} 条") print(f" CoSafe 改造 : {len(cosafe):5d} 条") print(f" DICES (test) : {len(dices):5d} 条") if len(core) < args.min_core: print(f"\n ⚠ 警告:自建核心集只有 {len(core)} 条,未达到 {args.min_core} 条目标。") print(f" 建议等 generate_siliconflow.py 运行完毕后再执行此脚本。") # ── 2. 标记来源 ─────────────────────────────────────────────────────────── for s in core: s.setdefault("source", "generated") s.setdefault("lang", "zh") for s in targeted: s.setdefault("source", "generated") # 与核心集同源,合并计数 s.setdefault("lang", "zh") for s in suicide: s.setdefault("source", "suicide_risk") s.setdefault("lang", "en") for s in cosafe: s.setdefault("source", "cosafe") s.setdefault("lang", "en") for s in dices: s.setdefault("source", "dices") s.setdefault("lang", "en") # DICES 强制进 test 集 test_only = dices main_pool = core + targeted + suicide + cosafe print(f"\n 主池总量 : {len(main_pool)} 条(core + targeted + suicide + cosafe)") print(f" Test-only池 : {len(test_only)} 条(DICES,直接入 test)") # ── 3. 质量过滤 ─────────────────────────────────────────────────────────── print("\n[2/5] 质量过滤...") filtered, reasons = quality_filter(main_pool) filtered_dices, _ = quality_filter(test_only) total_filtered = len(main_pool) - len(filtered) print(f" 主池过滤前: {len(main_pool)} → 过滤后: {len(filtered)} (丢弃 {total_filtered} 条)") if reasons: for k, v in sorted(reasons.items(), key=lambda x: -x[1]): print(f" {k}: {v}") # ── 4. 全局去重 ─────────────────────────────────────────────────────────── print("\n[3/5] 全局去重...") unique, dup_count = deduplicate(filtered) print(f" 去重前: {len(filtered)} → 去重后: {len(unique)} (去除 {dup_count} 条重复)") # ── 5. 分层划分 ─────────────────────────────────────────────────────────── print("\n[4/5] 分层划分 (train:dev:test = 70:15:15)...") train, dev, test = stratified_split(unique, filtered_dices) print(f" train: {len(train)}") print(f" dev : {len(dev)}") print(f" test : {len(test)}") # ── 6. 保存 ─────────────────────────────────────────────────────────────── print(f"\n[5/5] 保存到 {out_dir}/...") save_jsonl(train, out_dir / "train.jsonl") save_jsonl(dev, out_dir / "dev.jsonl") save_jsonl(test, out_dir / "test.jsonl") # 同时保存完整 gold 集(所有数据,带标注) all_samples = train + dev + test for i, s in enumerate(all_samples): s["final_id"] = f"crb-{i:05d}" save_jsonl(all_samples, out_dir / "all.jsonl") print(f" 保存完成:train.jsonl / dev.jsonl / test.jsonl / all.jsonl") # ── 7. 统计报告 ─────────────────────────────────────────────────────────── print("\n" + "="*52) print(" 数据集统计报告") print("="*52) print_stats("ALL", all_samples) print_stats("TRAIN", train) print_stats("DEV", dev) print_stats("TEST", test) coverage_check(all_samples) print(f"\n{'='*52}") print(f" 构建完成!总样本数: {len(all_samples)}") print(f" 输出目录: {out_dir.resolve()}") print(f"{'='*52}\n") if __name__ == "__main__": main()