""" 合并 CompanionRisk-Bench v5 来源: - 中文 v4 (data/processed/CompanionRisk-Bench-v4-backup/all.jsonl) 9,896 条(已质检) - 英文 core (data/raw/generated_english_core.jsonl) 12,000 条 - 英文 targeted (data/raw/generated_english_targeted.jsonl) 3,000 条 输出:data/processed/CompanionRisk-Bench-v5/{train,dev,test,all}.jsonl 用法: cd code/ python scripts/merge_v5.py """ import argparse import hashlib import json import random from collections import Counter from pathlib import Path from typing import List, Dict, Tuple RANDOM_SEED = 42 TRAIN_RATIO = 0.70 DEV_RATIO = 0.15 VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None} VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"} VALID_RISK_LEVELS = {0, 1, 2, 3, 4} VALID_C_FINE = { "DirectEncouragement","MethodFacilitation","RiskNormalization", "Romanticization","DependencyReinforcement","IsolationReinforcement", "FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse", "CoRumination","ManipulativeAttachment","PrivacySolicitation", "AgeInappropriateIntimacy", } # ── I/O ─────────────────────────────────────────────────────────────────────── def load_jsonl(path: Path) -> List[Dict]: if not path.exists(): print(f" [跳过] 文件不存在: {path}") return [] samples = [] with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: samples.append(json.loads(line)) except json.JSONDecodeError: continue return samples def save_jsonl(samples: List[Dict], path: Path): path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for s in samples: f.write(json.dumps(s, ensure_ascii=False) + "\n") # ── 去重 ────────────────────────────────────────────────────────────────────── def fingerprint(s: Dict) -> str: raw = s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100] return hashlib.sha256(raw.encode("utf-8")).hexdigest() def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]: seen: set = set() unique = [] dups = 0 for s in samples: fp = fingerprint(s) if fp in seen: dups += 1 else: seen.add(fp) unique.append(s) return unique, dups # ── 质量过滤(仅用于新增英文数据;中文 v4 已质检) ─────────────────────────── def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]: reasons: Dict[str, int] = {} passed = [] for s in samples: # 必填字段 missing = [f for f in ("persona","history","user_input","ai_response", "y_risk","l_risk","c_primary","c_fine","a_recommend") if f not in s] if missing: reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1 continue # 类型检查(防止 LLM 返回 list 而非 str) if not isinstance(s["user_input"], str) or not isinstance(s["ai_response"], str): reasons["wrong_type"] = reasons.get("wrong_type", 0) + 1 continue # 历史轮数 if not isinstance(s["history"], list) or len(s["history"]) < 2: reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1 continue # 最短文本长度 if len(s["user_input"].strip()) < 8: reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1 continue if len(s["ai_response"].strip()) < 20: reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1 continue # 标签合法性 if s["l_risk"] not in VALID_RISK_LEVELS: reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1 continue if s.get("c_primary") not in VALID_C_PRIMARY: reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1 continue if s.get("a_recommend") not in VALID_ACTIONS: reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1 continue # 逻辑一致性:y_risk=0 时修正 c_primary if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"): s["c_primary"] = "None" s["c_fine"] = [] # y_risk=1 时 c_primary 不能为空 if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"): reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1 continue # 过滤 c_fine 中的非法标签(宽容处理,不丢整条) if isinstance(s["c_fine"], list): s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE] passed.append(s) return passed, reasons # ── 分层划分(按 y_risk × lang 双维度分层) ────────────────────────────────── def stratified_split( samples: List[Dict], seed: int = RANDOM_SEED, ) -> Tuple[List[Dict], List[Dict], List[Dict]]: random.seed(seed) # 按 (y_risk, lang) 分桶 buckets: Dict[Tuple, List[Dict]] = {} for s in samples: key = (s.get("y_risk", 1), s.get("lang", "zh")) buckets.setdefault(key, []).append(s) train, dev, test = [], [], [] for key, bucket in buckets.items(): random.shuffle(bucket) n = len(bucket) n_train = int(n * TRAIN_RATIO) n_dev = int(n * DEV_RATIO) train += bucket[:n_train] dev += bucket[n_train:n_train + n_dev] test += bucket[n_train + n_dev:] random.shuffle(train) random.shuffle(dev) random.shuffle(test) return train, dev, test # ── 统计报告 ────────────────────────────────────────────────────────────────── def print_stats(name: str, samples: List[Dict]): total = len(samples) if total == 0: print(f"\n [{name}] 0 条") return risky = sum(1 for s in samples if s.get("y_risk") == 1) lang_cnt = Counter(s.get("lang", "zh") for s in samples) lvl_cnt = Counter(s.get("l_risk", 0) for s in samples) cat_cnt = Counter( s.get("c_primary") for s in samples if s.get("c_primary") not in (None, "None") ) act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples) fine_cnt = Counter(t for s in samples for t in s.get("c_fine", [])) print(f"\n┌{'─'*52}┐") print(f"│ {name:<50} │") print(f"├{'─'*52}┤") print(f"│ 总数 : {total} (有风险={risky}, 安全={total-risky})") print(f"│ 语言 : zh={lang_cnt.get('zh',0)} en={lang_cnt.get('en',0)}") print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}") print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}") print(f"│ 干预动作 : {dict(act_cnt)}") print(f"│ 细粒度(Top8): {dict(fine_cnt.most_common(8))}") print(f"└{'─'*52}┘") def coverage_check(samples: List[Dict]): all_cats = {f"R{i}" for i in range(1, 11)} all_fines = VALID_C_FINE cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1) fine_cnt = Counter(t for s in samples for t in s.get("c_fine", [])) print("\n覆盖率检查(合并后全集):") print(" 一级类别(≥50条):") for cat in sorted(all_cats): n = cat_cnt.get(cat, 0) ok = "✓" if n >= 50 else "✗" print(f" {cat}: {n:5d} {ok}") print(" 细粒度标签(≥30条):") for tag in sorted(all_fines): n = fine_cnt.get(tag, 0) ok = "✓" if n >= 30 else "✗" print(f" {tag}: {n:5d} {ok}") # ── 主入口 ──────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--v4", default="data/processed/CompanionRisk-Bench-v4-backup/all.jsonl") parser.add_argument("--en-core", default="data/raw/generated_english_core.jsonl") parser.add_argument("--en-targeted", default="data/raw/generated_english_targeted.jsonl") parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench-v5") args = parser.parse_args() out_dir = Path(args.out_dir) print(f"\n{'='*56}") print(f" CompanionRisk-Bench v5 构建") print(f"{'='*56}") # 1. 加载 print("\n[1/5] 加载数据...") zh_v4 = load_jsonl(Path(args.v4)) en_core = load_jsonl(Path(args.en_core)) en_tgt = load_jsonl(Path(args.en_targeted)) print(f" 中文 v4 (已质检) : {len(zh_v4):6d} 条") print(f" 英文 core : {len(en_core):6d} 条") print(f" 英文 targeted : {len(en_tgt):6d} 条") print(f" 合计(过滤前) : {len(zh_v4)+len(en_core)+len(en_tgt):6d} 条") # 2. 标记语言字段(确保一致) for s in zh_v4: s.setdefault("lang", "zh") for s in en_core + en_tgt: s.setdefault("lang", "en") # 3. 质量过滤(仅对新英文数据) print("\n[2/5] 质量过滤(英文数据)...") en_all = en_core + en_tgt en_filtered, reasons = quality_filter(en_all) dropped = len(en_all) - len(en_filtered) print(f" 英文过滤前: {len(en_all)} → 过滤后: {len(en_filtered)} (丢弃 {dropped} 条)") if reasons: for k, v in sorted(reasons.items(), key=lambda x: -x[1]): print(f" {k}: {v}") # 4. 合并 + 全局去重 print("\n[3/5] 合并 + 全局去重...") merged = zh_v4 + en_filtered unique, dups = deduplicate(merged) print(f" 合并后: {len(merged)} → 去重后: {len(unique)} (去除 {dups} 条重复)") # 5. 分层划分(按 y_risk × lang) print("\n[4/5] 分层划分 (train:dev:test ≈ 70:15:15)...") train, dev, test = stratified_split(unique) print(f" train: {len(train)}") print(f" dev : {len(dev)}") print(f" test : {len(test)}") # 6. 保存 print(f"\n[5/5] 保存到 {out_dir}/...") save_jsonl(train, out_dir / "train.jsonl") save_jsonl(dev, out_dir / "dev.jsonl") save_jsonl(test, out_dir / "test.jsonl") all_samples = train + dev + test for i, s in enumerate(all_samples): s["final_id"] = f"crb-v5-{i:05d}" save_jsonl(all_samples, out_dir / "all.jsonl") print(f" 保存完成:train / dev / test / all") # 7. 统计报告 print(f"\n{'='*56}") print(f" 数据集统计报告") print(f"{'='*56}") print_stats("ALL (v5)", all_samples) print_stats("TRAIN", train) print_stats("DEV", dev) print_stats("TEST", test) coverage_check(all_samples) # 8. 语言 × 分割矩阵 print("\n 语言 × 分割分布:") print(f" {'':12} {'train':>8} {'dev':>8} {'test':>8} {'total':>8}") for lang in ("zh", "en"): row = [sum(1 for s in split if s.get("lang") == lang) for split in (train, dev, test)] print(f" {lang:12} {row[0]:>8} {row[1]:>8} {row[2]:>8} {sum(row):>8}") print(f"\n{'='*56}") print(f" 构建完成!总样本数: {len(all_samples)}") print(f" 输出目录: {out_dir.resolve()}") print(f"{'='*56}\n") if __name__ == "__main__": main()