Files
CompanionGuard-RL/code/scripts/merge_v5.py
zhangsiyuan 804ebd2f77 feat: add paper/ LaTeX draft, English data scripts, update progress docs
- paper/: 22-page LaTeX framework (7/10 sections complete, compiles cleanly)
  main.tex + 10 section files + refs.bib + compiled PDF (329KB)
- code/scripts/: three English dataset generation & merging scripts
  generate_english.py / generate_english_targeted.py / merge_v5.py
- CLAUDE.md: update paper writing status, add paper/ file map entry
- state.md: add section 8 paper writing progress (2026-05-15)
- .gitignore: add LaTeX build artifact exclusion rules

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-18 11:19:39 +08:00

324 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
合并 CompanionRisk-Bench v5
来源:
- 中文 v4 (data/processed/CompanionRisk-Bench-v4-backup/all.jsonl) 9,896 条(已质检)
- 英文 core (data/raw/generated_english_core.jsonl) 12,000 条
- 英文 targeted (data/raw/generated_english_targeted.jsonl) 3,000 条
输出data/processed/CompanionRisk-Bench-v5/{train,dev,test,all}.jsonl
用法:
cd code/
python scripts/merge_v5.py
"""
import argparse
import hashlib
import json
import random
from collections import Counter
from pathlib import Path
from typing import List, Dict, Tuple
RANDOM_SEED = 42
TRAIN_RATIO = 0.70
DEV_RATIO = 0.15
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
VALID_RISK_LEVELS = {0, 1, 2, 3, 4}
VALID_C_FINE = {
"DirectEncouragement","MethodFacilitation","RiskNormalization",
"Romanticization","DependencyReinforcement","IsolationReinforcement",
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
"AgeInappropriateIntimacy",
}
# ── I/O ───────────────────────────────────────────────────────────────────────
def load_jsonl(path: Path) -> List[Dict]:
if not path.exists():
print(f" [跳过] 文件不存在: {path}")
return []
samples = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
samples.append(json.loads(line))
except json.JSONDecodeError:
continue
return samples
def save_jsonl(samples: List[Dict], path: Path):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
# ── 去重 ──────────────────────────────────────────────────────────────────────
def fingerprint(s: Dict) -> str:
raw = s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100]
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
seen: set = set()
unique = []
dups = 0
for s in samples:
fp = fingerprint(s)
if fp in seen:
dups += 1
else:
seen.add(fp)
unique.append(s)
return unique, dups
# ── 质量过滤(仅用于新增英文数据;中文 v4 已质检) ───────────────────────────
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
reasons: Dict[str, int] = {}
passed = []
for s in samples:
# 必填字段
missing = [f for f in ("persona","history","user_input","ai_response",
"y_risk","l_risk","c_primary","c_fine","a_recommend")
if f not in s]
if missing:
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
continue
# 类型检查(防止 LLM 返回 list 而非 str
if not isinstance(s["user_input"], str) or not isinstance(s["ai_response"], str):
reasons["wrong_type"] = reasons.get("wrong_type", 0) + 1
continue
# 历史轮数
if not isinstance(s["history"], list) or len(s["history"]) < 2:
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
continue
# 最短文本长度
if len(s["user_input"].strip()) < 8:
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
continue
if len(s["ai_response"].strip()) < 20:
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
continue
# 标签合法性
if s["l_risk"] not in VALID_RISK_LEVELS:
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
continue
if s.get("c_primary") not in VALID_C_PRIMARY:
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
continue
if s.get("a_recommend") not in VALID_ACTIONS:
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
continue
# 逻辑一致性y_risk=0 时修正 c_primary
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
s["c_primary"] = "None"
s["c_fine"] = []
# y_risk=1 时 c_primary 不能为空
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
continue
# 过滤 c_fine 中的非法标签(宽容处理,不丢整条)
if isinstance(s["c_fine"], list):
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
passed.append(s)
return passed, reasons
# ── 分层划分(按 y_risk × lang 双维度分层) ──────────────────────────────────
def stratified_split(
samples: List[Dict],
seed: int = RANDOM_SEED,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
random.seed(seed)
# 按 (y_risk, lang) 分桶
buckets: Dict[Tuple, List[Dict]] = {}
for s in samples:
key = (s.get("y_risk", 1), s.get("lang", "zh"))
buckets.setdefault(key, []).append(s)
train, dev, test = [], [], []
for key, bucket in buckets.items():
random.shuffle(bucket)
n = len(bucket)
n_train = int(n * TRAIN_RATIO)
n_dev = int(n * DEV_RATIO)
train += bucket[:n_train]
dev += bucket[n_train:n_train + n_dev]
test += bucket[n_train + n_dev:]
random.shuffle(train)
random.shuffle(dev)
random.shuffle(test)
return train, dev, test
# ── 统计报告 ──────────────────────────────────────────────────────────────────
def print_stats(name: str, samples: List[Dict]):
total = len(samples)
if total == 0:
print(f"\n [{name}] 0 条")
return
risky = sum(1 for s in samples if s.get("y_risk") == 1)
lang_cnt = Counter(s.get("lang", "zh") for s in samples)
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
cat_cnt = Counter(
s.get("c_primary") for s in samples if s.get("c_primary") not in (None, "None")
)
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
print(f"\n{''*52}")
print(f"{name:<50}")
print(f"{''*52}")
print(f"│ 总数 : {total} (有风险={risky}, 安全={total-risky})")
print(f"│ 语言 : zh={lang_cnt.get('zh',0)} en={lang_cnt.get('en',0)}")
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
print(f"│ 干预动作 : {dict(act_cnt)}")
print(f"│ 细粒度(Top8): {dict(fine_cnt.most_common(8))}")
print(f"{''*52}")
def coverage_check(samples: List[Dict]):
all_cats = {f"R{i}" for i in range(1, 11)}
all_fines = VALID_C_FINE
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
print("\n覆盖率检查(合并后全集):")
print(" 一级类别≥50条")
for cat in sorted(all_cats):
n = cat_cnt.get(cat, 0)
ok = "" if n >= 50 else ""
print(f" {cat}: {n:5d} {ok}")
print(" 细粒度标签≥30条")
for tag in sorted(all_fines):
n = fine_cnt.get(tag, 0)
ok = "" if n >= 30 else ""
print(f" {tag}: {n:5d} {ok}")
# ── 主入口 ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--v4", default="data/processed/CompanionRisk-Bench-v4-backup/all.jsonl")
parser.add_argument("--en-core", default="data/raw/generated_english_core.jsonl")
parser.add_argument("--en-targeted", default="data/raw/generated_english_targeted.jsonl")
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench-v5")
args = parser.parse_args()
out_dir = Path(args.out_dir)
print(f"\n{'='*56}")
print(f" CompanionRisk-Bench v5 构建")
print(f"{'='*56}")
# 1. 加载
print("\n[1/5] 加载数据...")
zh_v4 = load_jsonl(Path(args.v4))
en_core = load_jsonl(Path(args.en_core))
en_tgt = load_jsonl(Path(args.en_targeted))
print(f" 中文 v4 (已质检) : {len(zh_v4):6d}")
print(f" 英文 core : {len(en_core):6d}")
print(f" 英文 targeted : {len(en_tgt):6d}")
print(f" 合计(过滤前) : {len(zh_v4)+len(en_core)+len(en_tgt):6d}")
# 2. 标记语言字段(确保一致)
for s in zh_v4:
s.setdefault("lang", "zh")
for s in en_core + en_tgt:
s.setdefault("lang", "en")
# 3. 质量过滤(仅对新英文数据)
print("\n[2/5] 质量过滤(英文数据)...")
en_all = en_core + en_tgt
en_filtered, reasons = quality_filter(en_all)
dropped = len(en_all) - len(en_filtered)
print(f" 英文过滤前: {len(en_all)} → 过滤后: {len(en_filtered)} (丢弃 {dropped} 条)")
if reasons:
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" {k}: {v}")
# 4. 合并 + 全局去重
print("\n[3/5] 合并 + 全局去重...")
merged = zh_v4 + en_filtered
unique, dups = deduplicate(merged)
print(f" 合并后: {len(merged)} → 去重后: {len(unique)} (去除 {dups} 条重复)")
# 5. 分层划分(按 y_risk × lang
print("\n[4/5] 分层划分 (train:dev:test ≈ 70:15:15)...")
train, dev, test = stratified_split(unique)
print(f" train: {len(train)}")
print(f" dev : {len(dev)}")
print(f" test : {len(test)}")
# 6. 保存
print(f"\n[5/5] 保存到 {out_dir}/...")
save_jsonl(train, out_dir / "train.jsonl")
save_jsonl(dev, out_dir / "dev.jsonl")
save_jsonl(test, out_dir / "test.jsonl")
all_samples = train + dev + test
for i, s in enumerate(all_samples):
s["final_id"] = f"crb-v5-{i:05d}"
save_jsonl(all_samples, out_dir / "all.jsonl")
print(f" 保存完成train / dev / test / all")
# 7. 统计报告
print(f"\n{'='*56}")
print(f" 数据集统计报告")
print(f"{'='*56}")
print_stats("ALL (v5)", all_samples)
print_stats("TRAIN", train)
print_stats("DEV", dev)
print_stats("TEST", test)
coverage_check(all_samples)
# 8. 语言 × 分割矩阵
print("\n 语言 × 分割分布:")
print(f" {'':12} {'train':>8} {'dev':>8} {'test':>8} {'total':>8}")
for lang in ("zh", "en"):
row = [sum(1 for s in split if s.get("lang") == lang)
for split in (train, dev, test)]
print(f" {lang:12} {row[0]:>8} {row[1]:>8} {row[2]:>8} {sum(row):>8}")
print(f"\n{'='*56}")
print(f" 构建完成!总样本数: {len(all_samples)}")
print(f" 输出目录: {out_dir.resolve()}")
print(f"{'='*56}\n")
if __name__ == "__main__":
main()