Files
CompanionGuard-RL/code/scripts/merge_and_split.py
zhangsiyuan bd1f51c496 chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 11:28:42 +08:00

397 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
2026-05-11 数据集合并、过滤、划分脚本
功能:
1. 合并自建中文核心集 + 公开数据集改造结果
2. 质量过滤(结构完整性、最低长度、标签合法性)
3. SHA256 全局去重
4. 按 70/15/15 分层划分 train/dev/test
5. DICESsplit_hint=test_only强制进 test 集
6. 输出统计报告
用法(等 generated_core.jsonl 生成完毕后运行):
python scripts/merge_and_split.py
# 自定义路径
python scripts/merge_and_split.py \\
--core data/raw/generated_core.jsonl \\
--suicide data/raw/adapted_suicide.jsonl \\
--cosafe data/raw/adapted_cosafe.jsonl \\
--dices data/raw/adapted_dices.jsonl \\
--out-dir data/processed/CompanionRisk-Bench
"""
import argparse
import hashlib
import json
import random
from collections import Counter
from pathlib import Path
from typing import List, Dict, Tuple
RANDOM_SEED = 42
TRAIN_RATIO = 0.70
DEV_RATIO = 0.15
TEST_RATIO = 0.15
# 合法标签集合
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
VALID_RISK_LEVELS = {0,1,2,3,4}
VALID_C_FINE = {
"DirectEncouragement","MethodFacilitation","RiskNormalization",
"Romanticization","DependencyReinforcement","IsolationReinforcement",
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
"AgeInappropriateIntimacy",
}
# ── 工具函数 ───────────────────────────────────────────────────────────────────
def load_jsonl(path: Path) -> List[Dict]:
if not path.exists():
print(f" [跳过] 文件不存在: {path}")
return []
samples = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
samples.append(json.loads(line))
except json.JSONDecodeError:
continue
return samples
def save_jsonl(samples: List[Dict], path: Path):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
def fingerprint(s: Dict) -> str:
"""基于 user_input + ai_response 前 100 字计算 SHA256 指纹"""
raw = (
s.get("user_input", "")[:100]
+ "|"
+ s.get("ai_response", "")[:100]
)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
# ── 质量过滤 ───────────────────────────────────────────────────────────────────
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
"""
过滤不合格样本,返回 (合格样本列表, 过滤原因计数)。
"""
reasons: Dict[str, int] = {}
passed = []
for s in samples:
# 必填字段检查
missing = [f for f in ("persona","history","user_input","ai_response",
"y_risk","l_risk","c_primary","c_fine","a_recommend")
if f not in s]
if missing:
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
continue
# 历史轮数
if not isinstance(s["history"], list) or len(s["history"]) < 2:
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
continue
# 文本最短长度
if len(s["user_input"].strip()) < 8:
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
continue
if len(s["ai_response"].strip()) < 20:
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
continue
# 标签合法性
if s["l_risk"] not in VALID_RISK_LEVELS:
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
continue
if s.get("c_primary") not in VALID_C_PRIMARY:
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
continue
if s.get("a_recommend") not in VALID_ACTIONS:
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
continue
# 逻辑一致性y_risk=0 时 c_primary 应为 None/"None"
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
# 宽容处理:修正 c_primary 而非丢弃
s["c_primary"] = "None"
s["c_fine"] = []
# y_risk=1 时 c_primary 不应为空
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
continue
# c_fine 中无效标签过滤(宽容:去掉非法标签,不丢弃整条)
if isinstance(s["c_fine"], list):
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
passed.append(s)
return passed, reasons
# ── 去重 ───────────────────────────────────────────────────────────────────────
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
seen = set()
unique = []
dup_count = 0
for s in samples:
fp = fingerprint(s)
if fp in seen:
dup_count += 1
else:
seen.add(fp)
unique.append(s)
return unique, dup_count
# ── 分层划分 ───────────────────────────────────────────────────────────────────
def stratified_split(
samples: List[Dict],
test_only: List[Dict],
train_ratio: float = TRAIN_RATIO,
dev_ratio: float = DEV_RATIO,
seed: int = RANDOM_SEED,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
按 y_risk 分层划分。test_only 样本直接进 test 集。
"""
random.seed(seed)
risky = [s for s in samples if s.get("y_risk") == 1]
safe = [s for s in samples if s.get("y_risk") == 0]
random.shuffle(risky)
random.shuffle(safe)
def split_list(lst):
n = len(lst)
n_train = int(n * train_ratio)
n_dev = int(n * dev_ratio)
return lst[:n_train], lst[n_train:n_train+n_dev], lst[n_train+n_dev:]
tr_r, dv_r, te_r = split_list(risky)
tr_s, dv_s, te_s = split_list(safe)
train = tr_r + tr_s
dev = dv_r + dv_s
test = te_r + te_s + test_only
random.shuffle(train)
random.shuffle(dev)
random.shuffle(test)
return train, dev, test
# ── 统计报告 ───────────────────────────────────────────────────────────────────
def print_stats(name: str, samples: List[Dict]):
total = len(samples)
if total == 0:
print(f"\n [{name}] 0 条")
return
risky = sum(1 for s in samples if s.get("y_risk") == 1)
safe = total - risky
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
cat_cnt = Counter(
s.get("c_primary", "None") for s in samples if s.get("c_primary") not in (None, "None")
)
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
src_cnt = Counter(s.get("source", "generated") for s in samples)
lang_cnt= Counter(s.get("lang", "zh") for s in samples)
# 细粒度标签统计
fine_cnt = Counter()
for s in samples:
for tag in s.get("c_fine", []):
fine_cnt[tag] += 1
print(f"\n{''*48}")
print(f"{name:<46}")
print(f"{''*48}")
print(f"│ 总数 : {total:<6} (有风险={risky}, 安全={safe}){'':9}")
print(f"│ 语言 : {dict(lang_cnt)}")
print(f"│ 数据来源 : {dict(src_cnt)}")
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
print(f"│ 干预动作 : {dict(act_cnt)}")
if fine_cnt:
print(f"│ 细粒度标签(Top8): {dict(fine_cnt.most_common(8))}")
print(f"{''*48}")
def coverage_check(samples: List[Dict]):
"""检查各类别和细粒度标签的覆盖率"""
all_cats = {f"R{i}" for i in range(1, 11)}
all_fines = {
"DirectEncouragement","MethodFacilitation","RiskNormalization",
"Romanticization","DependencyReinforcement","IsolationReinforcement",
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
"AgeInappropriateIntimacy",
}
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
fine_cnt = Counter()
for s in samples:
for t in s.get("c_fine", []):
fine_cnt[t] += 1
print("\n覆盖率检查:")
print(" 一级类别最少50条")
for cat in sorted(all_cats):
n = cat_cnt.get(cat, 0)
ok = "" if n >= 50 else "✗ 不足"
print(f" {cat}: {n:4d} {ok}")
print(" 细粒度标签最少30条")
for tag in sorted(all_fines):
n = fine_cnt.get(tag, 0)
ok = "" if n >= 30 else "✗ 不足"
print(f" {tag}: {n:3d} {ok}")
# ── 主入口 ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="合并、过滤、划分 CompanionRisk-Bench 数据集")
parser.add_argument("--core", default="data/raw/generated_core.jsonl")
parser.add_argument("--targeted", default="data/raw/generated_targeted.jsonl",
help="弱标签专项补充集FalseReassurance/PseudoTherapy/IsolationReinforcement"
"可选,文件不存在时自动跳过")
parser.add_argument("--suicide", default="data/raw/adapted_suicide.jsonl")
parser.add_argument("--cosafe", default="data/raw/adapted_cosafe.jsonl")
parser.add_argument("--dices", default="data/raw/adapted_dices.jsonl")
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench")
parser.add_argument(
"--min-core", type=int, default=3000,
help="自建核心集最小样本数不足时发出警告默认3000"
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
print(f"\n{'='*52}")
print(f" CompanionRisk-Bench 数据集构建")
print(f"{'='*52}")
# ── 1. 加载各子集 ─────────────────────────────────────────────────────────
print("\n[1/5] 加载原始数据...")
core = load_jsonl(Path(args.core))
targeted = load_jsonl(Path(args.targeted)) # 弱标签专项补充,文件不存在时返回 []
suicide = load_jsonl(Path(args.suicide))
cosafe = load_jsonl(Path(args.cosafe))
dices = load_jsonl(Path(args.dices))
print(f" 自建核心集 : {len(core):5d}")
print(f" 弱标签专项 : {len(targeted):5d} 条 ({'已加载' if targeted else '文件不存在,跳过'})")
print(f" Suicide 改造 : {len(suicide):5d}")
print(f" CoSafe 改造 : {len(cosafe):5d}")
print(f" DICES (test) : {len(dices):5d}")
if len(core) < args.min_core:
print(f"\n ⚠ 警告:自建核心集只有 {len(core)} 条,未达到 {args.min_core} 条目标。")
print(f" 建议等 generate_siliconflow.py 运行完毕后再执行此脚本。")
# ── 2. 标记来源 ───────────────────────────────────────────────────────────
for s in core:
s.setdefault("source", "generated")
s.setdefault("lang", "zh")
for s in targeted:
s.setdefault("source", "generated") # 与核心集同源,合并计数
s.setdefault("lang", "zh")
for s in suicide:
s.setdefault("source", "suicide_risk")
s.setdefault("lang", "en")
for s in cosafe:
s.setdefault("source", "cosafe")
s.setdefault("lang", "en")
for s in dices:
s.setdefault("source", "dices")
s.setdefault("lang", "en")
# DICES 强制进 test 集
test_only = dices
main_pool = core + targeted + suicide + cosafe
print(f"\n 主池总量 : {len(main_pool)}core + targeted + suicide + cosafe")
print(f" Test-only池 : {len(test_only)}DICES直接入 test")
# ── 3. 质量过滤 ───────────────────────────────────────────────────────────
print("\n[2/5] 质量过滤...")
filtered, reasons = quality_filter(main_pool)
filtered_dices, _ = quality_filter(test_only)
total_filtered = len(main_pool) - len(filtered)
print(f" 主池过滤前: {len(main_pool)} → 过滤后: {len(filtered)} (丢弃 {total_filtered} 条)")
if reasons:
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" {k}: {v}")
# ── 4. 全局去重 ───────────────────────────────────────────────────────────
print("\n[3/5] 全局去重...")
unique, dup_count = deduplicate(filtered)
print(f" 去重前: {len(filtered)} → 去重后: {len(unique)} (去除 {dup_count} 条重复)")
# ── 5. 分层划分 ───────────────────────────────────────────────────────────
print("\n[4/5] 分层划分 (train:dev:test = 70:15:15)...")
train, dev, test = stratified_split(unique, filtered_dices)
print(f" train: {len(train)}")
print(f" dev : {len(dev)}")
print(f" test : {len(test)}")
# ── 6. 保存 ───────────────────────────────────────────────────────────────
print(f"\n[5/5] 保存到 {out_dir}/...")
save_jsonl(train, out_dir / "train.jsonl")
save_jsonl(dev, out_dir / "dev.jsonl")
save_jsonl(test, out_dir / "test.jsonl")
# 同时保存完整 gold 集(所有数据,带标注)
all_samples = train + dev + test
for i, s in enumerate(all_samples):
s["final_id"] = f"crb-{i:05d}"
save_jsonl(all_samples, out_dir / "all.jsonl")
print(f" 保存完成train.jsonl / dev.jsonl / test.jsonl / all.jsonl")
# ── 7. 统计报告 ───────────────────────────────────────────────────────────
print("\n" + "="*52)
print(" 数据集统计报告")
print("="*52)
print_stats("ALL", all_samples)
print_stats("TRAIN", train)
print_stats("DEV", dev)
print_stats("TEST", test)
coverage_check(all_samples)
print(f"\n{'='*52}")
print(f" 构建完成!总样本数: {len(all_samples)}")
print(f" 输出目录: {out_dir.resolve()}")
print(f"{'='*52}\n")
if __name__ == "__main__":
main()