- paper/: 22-page LaTeX framework (7/10 sections complete, compiles cleanly) main.tex + 10 section files + refs.bib + compiled PDF (329KB) - code/scripts/: three English dataset generation & merging scripts generate_english.py / generate_english_targeted.py / merge_v5.py - CLAUDE.md: update paper writing status, add paper/ file map entry - state.md: add section 8 paper writing progress (2026-05-15) - .gitignore: add LaTeX build artifact exclusion rules Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
324 lines
12 KiB
Python
324 lines
12 KiB
Python
"""
|
||
合并 CompanionRisk-Bench v5
|
||
|
||
来源:
|
||
- 中文 v4 (data/processed/CompanionRisk-Bench-v4-backup/all.jsonl) 9,896 条(已质检)
|
||
- 英文 core (data/raw/generated_english_core.jsonl) 12,000 条
|
||
- 英文 targeted (data/raw/generated_english_targeted.jsonl) 3,000 条
|
||
|
||
输出:data/processed/CompanionRisk-Bench-v5/{train,dev,test,all}.jsonl
|
||
|
||
用法:
|
||
cd code/
|
||
python scripts/merge_v5.py
|
||
"""
|
||
|
||
import argparse
|
||
import hashlib
|
||
import json
|
||
import random
|
||
from collections import Counter
|
||
from pathlib import Path
|
||
from typing import List, Dict, Tuple
|
||
|
||
|
||
RANDOM_SEED = 42
|
||
TRAIN_RATIO = 0.70
|
||
DEV_RATIO = 0.15
|
||
|
||
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
|
||
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
|
||
VALID_RISK_LEVELS = {0, 1, 2, 3, 4}
|
||
VALID_C_FINE = {
|
||
"DirectEncouragement","MethodFacilitation","RiskNormalization",
|
||
"Romanticization","DependencyReinforcement","IsolationReinforcement",
|
||
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
|
||
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
|
||
"AgeInappropriateIntimacy",
|
||
}
|
||
|
||
|
||
# ── I/O ───────────────────────────────────────────────────────────────────────
|
||
|
||
def load_jsonl(path: Path) -> List[Dict]:
|
||
if not path.exists():
|
||
print(f" [跳过] 文件不存在: {path}")
|
||
return []
|
||
samples = []
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
samples.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
return samples
|
||
|
||
|
||
def save_jsonl(samples: List[Dict], path: Path):
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
for s in samples:
|
||
f.write(json.dumps(s, ensure_ascii=False) + "\n")
|
||
|
||
|
||
# ── 去重 ──────────────────────────────────────────────────────────────────────
|
||
|
||
def fingerprint(s: Dict) -> str:
|
||
raw = s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100]
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
|
||
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
|
||
seen: set = set()
|
||
unique = []
|
||
dups = 0
|
||
for s in samples:
|
||
fp = fingerprint(s)
|
||
if fp in seen:
|
||
dups += 1
|
||
else:
|
||
seen.add(fp)
|
||
unique.append(s)
|
||
return unique, dups
|
||
|
||
|
||
# ── 质量过滤(仅用于新增英文数据;中文 v4 已质检) ───────────────────────────
|
||
|
||
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
|
||
reasons: Dict[str, int] = {}
|
||
passed = []
|
||
|
||
for s in samples:
|
||
# 必填字段
|
||
missing = [f for f in ("persona","history","user_input","ai_response",
|
||
"y_risk","l_risk","c_primary","c_fine","a_recommend")
|
||
if f not in s]
|
||
if missing:
|
||
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
|
||
continue
|
||
|
||
# 类型检查(防止 LLM 返回 list 而非 str)
|
||
if not isinstance(s["user_input"], str) or not isinstance(s["ai_response"], str):
|
||
reasons["wrong_type"] = reasons.get("wrong_type", 0) + 1
|
||
continue
|
||
|
||
# 历史轮数
|
||
if not isinstance(s["history"], list) or len(s["history"]) < 2:
|
||
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
|
||
continue
|
||
|
||
# 最短文本长度
|
||
if len(s["user_input"].strip()) < 8:
|
||
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
|
||
continue
|
||
if len(s["ai_response"].strip()) < 20:
|
||
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
|
||
continue
|
||
|
||
# 标签合法性
|
||
if s["l_risk"] not in VALID_RISK_LEVELS:
|
||
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
|
||
continue
|
||
if s.get("c_primary") not in VALID_C_PRIMARY:
|
||
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
|
||
continue
|
||
if s.get("a_recommend") not in VALID_ACTIONS:
|
||
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
|
||
continue
|
||
|
||
# 逻辑一致性:y_risk=0 时修正 c_primary
|
||
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
|
||
s["c_primary"] = "None"
|
||
s["c_fine"] = []
|
||
|
||
# y_risk=1 时 c_primary 不能为空
|
||
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
|
||
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
|
||
continue
|
||
|
||
# 过滤 c_fine 中的非法标签(宽容处理,不丢整条)
|
||
if isinstance(s["c_fine"], list):
|
||
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
|
||
|
||
passed.append(s)
|
||
|
||
return passed, reasons
|
||
|
||
|
||
# ── 分层划分(按 y_risk × lang 双维度分层) ──────────────────────────────────
|
||
|
||
def stratified_split(
|
||
samples: List[Dict],
|
||
seed: int = RANDOM_SEED,
|
||
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||
random.seed(seed)
|
||
|
||
# 按 (y_risk, lang) 分桶
|
||
buckets: Dict[Tuple, List[Dict]] = {}
|
||
for s in samples:
|
||
key = (s.get("y_risk", 1), s.get("lang", "zh"))
|
||
buckets.setdefault(key, []).append(s)
|
||
|
||
train, dev, test = [], [], []
|
||
for key, bucket in buckets.items():
|
||
random.shuffle(bucket)
|
||
n = len(bucket)
|
||
n_train = int(n * TRAIN_RATIO)
|
||
n_dev = int(n * DEV_RATIO)
|
||
train += bucket[:n_train]
|
||
dev += bucket[n_train:n_train + n_dev]
|
||
test += bucket[n_train + n_dev:]
|
||
|
||
random.shuffle(train)
|
||
random.shuffle(dev)
|
||
random.shuffle(test)
|
||
return train, dev, test
|
||
|
||
|
||
# ── 统计报告 ──────────────────────────────────────────────────────────────────
|
||
|
||
def print_stats(name: str, samples: List[Dict]):
|
||
total = len(samples)
|
||
if total == 0:
|
||
print(f"\n [{name}] 0 条")
|
||
return
|
||
|
||
risky = sum(1 for s in samples if s.get("y_risk") == 1)
|
||
lang_cnt = Counter(s.get("lang", "zh") for s in samples)
|
||
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
|
||
cat_cnt = Counter(
|
||
s.get("c_primary") for s in samples if s.get("c_primary") not in (None, "None")
|
||
)
|
||
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
|
||
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
|
||
|
||
print(f"\n┌{'─'*52}┐")
|
||
print(f"│ {name:<50} │")
|
||
print(f"├{'─'*52}┤")
|
||
print(f"│ 总数 : {total} (有风险={risky}, 安全={total-risky})")
|
||
print(f"│ 语言 : zh={lang_cnt.get('zh',0)} en={lang_cnt.get('en',0)}")
|
||
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
|
||
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
|
||
print(f"│ 干预动作 : {dict(act_cnt)}")
|
||
print(f"│ 细粒度(Top8): {dict(fine_cnt.most_common(8))}")
|
||
print(f"└{'─'*52}┘")
|
||
|
||
|
||
def coverage_check(samples: List[Dict]):
|
||
all_cats = {f"R{i}" for i in range(1, 11)}
|
||
all_fines = VALID_C_FINE
|
||
|
||
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
|
||
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
|
||
|
||
print("\n覆盖率检查(合并后全集):")
|
||
print(" 一级类别(≥50条):")
|
||
for cat in sorted(all_cats):
|
||
n = cat_cnt.get(cat, 0)
|
||
ok = "✓" if n >= 50 else "✗"
|
||
print(f" {cat}: {n:5d} {ok}")
|
||
|
||
print(" 细粒度标签(≥30条):")
|
||
for tag in sorted(all_fines):
|
||
n = fine_cnt.get(tag, 0)
|
||
ok = "✓" if n >= 30 else "✗"
|
||
print(f" {tag}: {n:5d} {ok}")
|
||
|
||
|
||
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--v4", default="data/processed/CompanionRisk-Bench-v4-backup/all.jsonl")
|
||
parser.add_argument("--en-core", default="data/raw/generated_english_core.jsonl")
|
||
parser.add_argument("--en-targeted", default="data/raw/generated_english_targeted.jsonl")
|
||
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench-v5")
|
||
args = parser.parse_args()
|
||
|
||
out_dir = Path(args.out_dir)
|
||
|
||
print(f"\n{'='*56}")
|
||
print(f" CompanionRisk-Bench v5 构建")
|
||
print(f"{'='*56}")
|
||
|
||
# 1. 加载
|
||
print("\n[1/5] 加载数据...")
|
||
zh_v4 = load_jsonl(Path(args.v4))
|
||
en_core = load_jsonl(Path(args.en_core))
|
||
en_tgt = load_jsonl(Path(args.en_targeted))
|
||
print(f" 中文 v4 (已质检) : {len(zh_v4):6d} 条")
|
||
print(f" 英文 core : {len(en_core):6d} 条")
|
||
print(f" 英文 targeted : {len(en_tgt):6d} 条")
|
||
print(f" 合计(过滤前) : {len(zh_v4)+len(en_core)+len(en_tgt):6d} 条")
|
||
|
||
# 2. 标记语言字段(确保一致)
|
||
for s in zh_v4:
|
||
s.setdefault("lang", "zh")
|
||
for s in en_core + en_tgt:
|
||
s.setdefault("lang", "en")
|
||
|
||
# 3. 质量过滤(仅对新英文数据)
|
||
print("\n[2/5] 质量过滤(英文数据)...")
|
||
en_all = en_core + en_tgt
|
||
en_filtered, reasons = quality_filter(en_all)
|
||
dropped = len(en_all) - len(en_filtered)
|
||
print(f" 英文过滤前: {len(en_all)} → 过滤后: {len(en_filtered)} (丢弃 {dropped} 条)")
|
||
if reasons:
|
||
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
|
||
print(f" {k}: {v}")
|
||
|
||
# 4. 合并 + 全局去重
|
||
print("\n[3/5] 合并 + 全局去重...")
|
||
merged = zh_v4 + en_filtered
|
||
unique, dups = deduplicate(merged)
|
||
print(f" 合并后: {len(merged)} → 去重后: {len(unique)} (去除 {dups} 条重复)")
|
||
|
||
# 5. 分层划分(按 y_risk × lang)
|
||
print("\n[4/5] 分层划分 (train:dev:test ≈ 70:15:15)...")
|
||
train, dev, test = stratified_split(unique)
|
||
print(f" train: {len(train)}")
|
||
print(f" dev : {len(dev)}")
|
||
print(f" test : {len(test)}")
|
||
|
||
# 6. 保存
|
||
print(f"\n[5/5] 保存到 {out_dir}/...")
|
||
save_jsonl(train, out_dir / "train.jsonl")
|
||
save_jsonl(dev, out_dir / "dev.jsonl")
|
||
save_jsonl(test, out_dir / "test.jsonl")
|
||
|
||
all_samples = train + dev + test
|
||
for i, s in enumerate(all_samples):
|
||
s["final_id"] = f"crb-v5-{i:05d}"
|
||
save_jsonl(all_samples, out_dir / "all.jsonl")
|
||
print(f" 保存完成:train / dev / test / all")
|
||
|
||
# 7. 统计报告
|
||
print(f"\n{'='*56}")
|
||
print(f" 数据集统计报告")
|
||
print(f"{'='*56}")
|
||
print_stats("ALL (v5)", all_samples)
|
||
print_stats("TRAIN", train)
|
||
print_stats("DEV", dev)
|
||
print_stats("TEST", test)
|
||
coverage_check(all_samples)
|
||
|
||
# 8. 语言 × 分割矩阵
|
||
print("\n 语言 × 分割分布:")
|
||
print(f" {'':12} {'train':>8} {'dev':>8} {'test':>8} {'total':>8}")
|
||
for lang in ("zh", "en"):
|
||
row = [sum(1 for s in split if s.get("lang") == lang)
|
||
for split in (train, dev, test)]
|
||
print(f" {lang:12} {row[0]:>8} {row[1]:>8} {row[2]:>8} {sum(row):>8}")
|
||
|
||
print(f"\n{'='*56}")
|
||
print(f" 构建完成!总样本数: {len(all_samples)}")
|
||
print(f" 输出目录: {out_dir.resolve()}")
|
||
print(f"{'='*56}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|