Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
397 lines
16 KiB
Python
397 lines
16 KiB
Python
"""
|
||
2026-05-11 数据集合并、过滤、划分脚本
|
||
|
||
功能:
|
||
1. 合并自建中文核心集 + 公开数据集改造结果
|
||
2. 质量过滤(结构完整性、最低长度、标签合法性)
|
||
3. SHA256 全局去重
|
||
4. 按 70/15/15 分层划分 train/dev/test
|
||
5. DICES(split_hint=test_only)强制进 test 集
|
||
6. 输出统计报告
|
||
|
||
用法(等 generated_core.jsonl 生成完毕后运行):
|
||
python scripts/merge_and_split.py
|
||
|
||
# 自定义路径
|
||
python scripts/merge_and_split.py \\
|
||
--core data/raw/generated_core.jsonl \\
|
||
--suicide data/raw/adapted_suicide.jsonl \\
|
||
--cosafe data/raw/adapted_cosafe.jsonl \\
|
||
--dices data/raw/adapted_dices.jsonl \\
|
||
--out-dir data/processed/CompanionRisk-Bench
|
||
"""
|
||
|
||
import argparse
|
||
import hashlib
|
||
import json
|
||
import random
|
||
from collections import Counter
|
||
from pathlib import Path
|
||
from typing import List, Dict, Tuple
|
||
|
||
|
||
RANDOM_SEED = 42
|
||
TRAIN_RATIO = 0.70
|
||
DEV_RATIO = 0.15
|
||
TEST_RATIO = 0.15
|
||
|
||
# 合法标签集合
|
||
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
|
||
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
|
||
VALID_RISK_LEVELS = {0,1,2,3,4}
|
||
VALID_C_FINE = {
|
||
"DirectEncouragement","MethodFacilitation","RiskNormalization",
|
||
"Romanticization","DependencyReinforcement","IsolationReinforcement",
|
||
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
|
||
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
|
||
"AgeInappropriateIntimacy",
|
||
}
|
||
|
||
|
||
# ── 工具函数 ───────────────────────────────────────────────────────────────────
|
||
|
||
def load_jsonl(path: Path) -> List[Dict]:
|
||
if not path.exists():
|
||
print(f" [跳过] 文件不存在: {path}")
|
||
return []
|
||
samples = []
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
samples.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
return samples
|
||
|
||
|
||
def save_jsonl(samples: List[Dict], path: Path):
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
for s in samples:
|
||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||
|
||
|
||
def fingerprint(s: Dict) -> str:
|
||
"""基于 user_input + ai_response 前 100 字计算 SHA256 指纹"""
|
||
raw = (
|
||
s.get("user_input", "")[:100]
|
||
+ "|"
|
||
+ s.get("ai_response", "")[:100]
|
||
)
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
|
||
# ── 质量过滤 ───────────────────────────────────────────────────────────────────
|
||
|
||
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
|
||
"""
|
||
过滤不合格样本,返回 (合格样本列表, 过滤原因计数)。
|
||
"""
|
||
reasons: Dict[str, int] = {}
|
||
passed = []
|
||
|
||
for s in samples:
|
||
# 必填字段检查
|
||
missing = [f for f in ("persona","history","user_input","ai_response",
|
||
"y_risk","l_risk","c_primary","c_fine","a_recommend")
|
||
if f not in s]
|
||
if missing:
|
||
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
|
||
continue
|
||
|
||
# 历史轮数
|
||
if not isinstance(s["history"], list) or len(s["history"]) < 2:
|
||
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
|
||
continue
|
||
|
||
# 文本最短长度
|
||
if len(s["user_input"].strip()) < 8:
|
||
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
|
||
continue
|
||
if len(s["ai_response"].strip()) < 20:
|
||
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
|
||
continue
|
||
|
||
# 标签合法性
|
||
if s["l_risk"] not in VALID_RISK_LEVELS:
|
||
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
|
||
continue
|
||
if s.get("c_primary") not in VALID_C_PRIMARY:
|
||
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
|
||
continue
|
||
if s.get("a_recommend") not in VALID_ACTIONS:
|
||
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
|
||
continue
|
||
|
||
# 逻辑一致性:y_risk=0 时 c_primary 应为 None/"None"
|
||
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
|
||
# 宽容处理:修正 c_primary 而非丢弃
|
||
s["c_primary"] = "None"
|
||
s["c_fine"] = []
|
||
|
||
# y_risk=1 时 c_primary 不应为空
|
||
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
|
||
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
|
||
continue
|
||
|
||
# c_fine 中无效标签过滤(宽容:去掉非法标签,不丢弃整条)
|
||
if isinstance(s["c_fine"], list):
|
||
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
|
||
|
||
passed.append(s)
|
||
|
||
return passed, reasons
|
||
|
||
|
||
# ── 去重 ───────────────────────────────────────────────────────────────────────
|
||
|
||
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
|
||
seen = set()
|
||
unique = []
|
||
dup_count = 0
|
||
for s in samples:
|
||
fp = fingerprint(s)
|
||
if fp in seen:
|
||
dup_count += 1
|
||
else:
|
||
seen.add(fp)
|
||
unique.append(s)
|
||
return unique, dup_count
|
||
|
||
|
||
# ── 分层划分 ───────────────────────────────────────────────────────────────────
|
||
|
||
def stratified_split(
|
||
samples: List[Dict],
|
||
test_only: List[Dict],
|
||
train_ratio: float = TRAIN_RATIO,
|
||
dev_ratio: float = DEV_RATIO,
|
||
seed: int = RANDOM_SEED,
|
||
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||
"""
|
||
按 y_risk 分层划分。test_only 样本直接进 test 集。
|
||
"""
|
||
random.seed(seed)
|
||
|
||
risky = [s for s in samples if s.get("y_risk") == 1]
|
||
safe = [s for s in samples if s.get("y_risk") == 0]
|
||
random.shuffle(risky)
|
||
random.shuffle(safe)
|
||
|
||
def split_list(lst):
|
||
n = len(lst)
|
||
n_train = int(n * train_ratio)
|
||
n_dev = int(n * dev_ratio)
|
||
return lst[:n_train], lst[n_train:n_train+n_dev], lst[n_train+n_dev:]
|
||
|
||
tr_r, dv_r, te_r = split_list(risky)
|
||
tr_s, dv_s, te_s = split_list(safe)
|
||
|
||
train = tr_r + tr_s
|
||
dev = dv_r + dv_s
|
||
test = te_r + te_s + test_only
|
||
|
||
random.shuffle(train)
|
||
random.shuffle(dev)
|
||
random.shuffle(test)
|
||
|
||
return train, dev, test
|
||
|
||
|
||
# ── 统计报告 ───────────────────────────────────────────────────────────────────
|
||
|
||
def print_stats(name: str, samples: List[Dict]):
|
||
total = len(samples)
|
||
if total == 0:
|
||
print(f"\n [{name}] 0 条")
|
||
return
|
||
|
||
risky = sum(1 for s in samples if s.get("y_risk") == 1)
|
||
safe = total - risky
|
||
|
||
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
|
||
cat_cnt = Counter(
|
||
s.get("c_primary", "None") for s in samples if s.get("c_primary") not in (None, "None")
|
||
)
|
||
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
|
||
src_cnt = Counter(s.get("source", "generated") for s in samples)
|
||
lang_cnt= Counter(s.get("lang", "zh") for s in samples)
|
||
|
||
# 细粒度标签统计
|
||
fine_cnt = Counter()
|
||
for s in samples:
|
||
for tag in s.get("c_fine", []):
|
||
fine_cnt[tag] += 1
|
||
|
||
print(f"\n┌{'─'*48}┐")
|
||
print(f"│ {name:<46} │")
|
||
print(f"├{'─'*48}┤")
|
||
print(f"│ 总数 : {total:<6} (有风险={risky}, 安全={safe}){'':9}│")
|
||
print(f"│ 语言 : {dict(lang_cnt)}")
|
||
print(f"│ 数据来源 : {dict(src_cnt)}")
|
||
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
|
||
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
|
||
print(f"│ 干预动作 : {dict(act_cnt)}")
|
||
if fine_cnt:
|
||
print(f"│ 细粒度标签(Top8): {dict(fine_cnt.most_common(8))}")
|
||
print(f"└{'─'*48}┘")
|
||
|
||
|
||
def coverage_check(samples: List[Dict]):
|
||
"""检查各类别和细粒度标签的覆盖率"""
|
||
all_cats = {f"R{i}" for i in range(1, 11)}
|
||
all_fines = {
|
||
"DirectEncouragement","MethodFacilitation","RiskNormalization",
|
||
"Romanticization","DependencyReinforcement","IsolationReinforcement",
|
||
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
|
||
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
|
||
"AgeInappropriateIntimacy",
|
||
}
|
||
|
||
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
|
||
fine_cnt = Counter()
|
||
for s in samples:
|
||
for t in s.get("c_fine", []):
|
||
fine_cnt[t] += 1
|
||
|
||
print("\n覆盖率检查:")
|
||
print(" 一级类别(最少50条):")
|
||
for cat in sorted(all_cats):
|
||
n = cat_cnt.get(cat, 0)
|
||
ok = "✓" if n >= 50 else "✗ 不足"
|
||
print(f" {cat}: {n:4d} {ok}")
|
||
|
||
print(" 细粒度标签(最少30条):")
|
||
for tag in sorted(all_fines):
|
||
n = fine_cnt.get(tag, 0)
|
||
ok = "✓" if n >= 30 else "✗ 不足"
|
||
print(f" {tag}: {n:3d} {ok}")
|
||
|
||
|
||
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="合并、过滤、划分 CompanionRisk-Bench 数据集")
|
||
parser.add_argument("--core", default="data/raw/generated_core.jsonl")
|
||
parser.add_argument("--targeted", default="data/raw/generated_targeted.jsonl",
|
||
help="弱标签专项补充集(FalseReassurance/PseudoTherapy/IsolationReinforcement),"
|
||
"可选,文件不存在时自动跳过")
|
||
parser.add_argument("--suicide", default="data/raw/adapted_suicide.jsonl")
|
||
parser.add_argument("--cosafe", default="data/raw/adapted_cosafe.jsonl")
|
||
parser.add_argument("--dices", default="data/raw/adapted_dices.jsonl")
|
||
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench")
|
||
parser.add_argument(
|
||
"--min-core", type=int, default=3000,
|
||
help="自建核心集最小样本数,不足时发出警告(默认3000)"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
out_dir = Path(args.out_dir)
|
||
|
||
print(f"\n{'='*52}")
|
||
print(f" CompanionRisk-Bench 数据集构建")
|
||
print(f"{'='*52}")
|
||
|
||
# ── 1. 加载各子集 ─────────────────────────────────────────────────────────
|
||
print("\n[1/5] 加载原始数据...")
|
||
|
||
core = load_jsonl(Path(args.core))
|
||
targeted = load_jsonl(Path(args.targeted)) # 弱标签专项补充,文件不存在时返回 []
|
||
suicide = load_jsonl(Path(args.suicide))
|
||
cosafe = load_jsonl(Path(args.cosafe))
|
||
dices = load_jsonl(Path(args.dices))
|
||
|
||
print(f" 自建核心集 : {len(core):5d} 条")
|
||
print(f" 弱标签专项 : {len(targeted):5d} 条 ({'已加载' if targeted else '文件不存在,跳过'})")
|
||
print(f" Suicide 改造 : {len(suicide):5d} 条")
|
||
print(f" CoSafe 改造 : {len(cosafe):5d} 条")
|
||
print(f" DICES (test) : {len(dices):5d} 条")
|
||
|
||
if len(core) < args.min_core:
|
||
print(f"\n ⚠ 警告:自建核心集只有 {len(core)} 条,未达到 {args.min_core} 条目标。")
|
||
print(f" 建议等 generate_siliconflow.py 运行完毕后再执行此脚本。")
|
||
|
||
# ── 2. 标记来源 ───────────────────────────────────────────────────────────
|
||
for s in core:
|
||
s.setdefault("source", "generated")
|
||
s.setdefault("lang", "zh")
|
||
for s in targeted:
|
||
s.setdefault("source", "generated") # 与核心集同源,合并计数
|
||
s.setdefault("lang", "zh")
|
||
for s in suicide:
|
||
s.setdefault("source", "suicide_risk")
|
||
s.setdefault("lang", "en")
|
||
for s in cosafe:
|
||
s.setdefault("source", "cosafe")
|
||
s.setdefault("lang", "en")
|
||
for s in dices:
|
||
s.setdefault("source", "dices")
|
||
s.setdefault("lang", "en")
|
||
|
||
# DICES 强制进 test 集
|
||
test_only = dices
|
||
main_pool = core + targeted + suicide + cosafe
|
||
|
||
print(f"\n 主池总量 : {len(main_pool)} 条(core + targeted + suicide + cosafe)")
|
||
print(f" Test-only池 : {len(test_only)} 条(DICES,直接入 test)")
|
||
|
||
# ── 3. 质量过滤 ───────────────────────────────────────────────────────────
|
||
print("\n[2/5] 质量过滤...")
|
||
filtered, reasons = quality_filter(main_pool)
|
||
filtered_dices, _ = quality_filter(test_only)
|
||
|
||
total_filtered = len(main_pool) - len(filtered)
|
||
print(f" 主池过滤前: {len(main_pool)} → 过滤后: {len(filtered)} (丢弃 {total_filtered} 条)")
|
||
if reasons:
|
||
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
|
||
print(f" {k}: {v}")
|
||
|
||
# ── 4. 全局去重 ───────────────────────────────────────────────────────────
|
||
print("\n[3/5] 全局去重...")
|
||
unique, dup_count = deduplicate(filtered)
|
||
print(f" 去重前: {len(filtered)} → 去重后: {len(unique)} (去除 {dup_count} 条重复)")
|
||
|
||
# ── 5. 分层划分 ───────────────────────────────────────────────────────────
|
||
print("\n[4/5] 分层划分 (train:dev:test = 70:15:15)...")
|
||
train, dev, test = stratified_split(unique, filtered_dices)
|
||
print(f" train: {len(train)}")
|
||
print(f" dev : {len(dev)}")
|
||
print(f" test : {len(test)}")
|
||
|
||
# ── 6. 保存 ───────────────────────────────────────────────────────────────
|
||
print(f"\n[5/5] 保存到 {out_dir}/...")
|
||
save_jsonl(train, out_dir / "train.jsonl")
|
||
save_jsonl(dev, out_dir / "dev.jsonl")
|
||
save_jsonl(test, out_dir / "test.jsonl")
|
||
|
||
# 同时保存完整 gold 集(所有数据,带标注)
|
||
all_samples = train + dev + test
|
||
for i, s in enumerate(all_samples):
|
||
s["final_id"] = f"crb-{i:05d}"
|
||
save_jsonl(all_samples, out_dir / "all.jsonl")
|
||
|
||
print(f" 保存完成:train.jsonl / dev.jsonl / test.jsonl / all.jsonl")
|
||
|
||
# ── 7. 统计报告 ───────────────────────────────────────────────────────────
|
||
print("\n" + "="*52)
|
||
print(" 数据集统计报告")
|
||
print("="*52)
|
||
print_stats("ALL", all_samples)
|
||
print_stats("TRAIN", train)
|
||
print_stats("DEV", dev)
|
||
print_stats("TEST", test)
|
||
|
||
coverage_check(all_samples)
|
||
|
||
print(f"\n{'='*52}")
|
||
print(f" 构建完成!总样本数: {len(all_samples)}")
|
||
print(f" 输出目录: {out_dir.resolve()}")
|
||
print(f"{'='*52}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|