feat: add paper/ LaTeX draft, English data scripts, update progress docs

- paper/: 22-page LaTeX framework (7/10 sections complete, compiles cleanly)
  main.tex + 10 section files + refs.bib + compiled PDF (329KB)
- code/scripts/: three English dataset generation & merging scripts
  generate_english.py / generate_english_targeted.py / merge_v5.py
- CLAUDE.md: update paper writing status, add paper/ file map entry
- state.md: add section 8 paper writing progress (2026-05-15)
- .gitignore: add LaTeX build artifact exclusion rules

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-18 11:19:39 +08:00
parent b50cf395ab
commit 804ebd2f77
19 changed files with 3047 additions and 3 deletions

323
code/scripts/merge_v5.py Normal file
View File

@@ -0,0 +1,323 @@
"""
合并 CompanionRisk-Bench v5
来源:
- 中文 v4 (data/processed/CompanionRisk-Bench-v4-backup/all.jsonl) 9,896 条(已质检)
- 英文 core (data/raw/generated_english_core.jsonl) 12,000 条
- 英文 targeted (data/raw/generated_english_targeted.jsonl) 3,000 条
输出data/processed/CompanionRisk-Bench-v5/{train,dev,test,all}.jsonl
用法:
cd code/
python scripts/merge_v5.py
"""
import argparse
import hashlib
import json
import random
from collections import Counter
from pathlib import Path
from typing import List, Dict, Tuple
RANDOM_SEED = 42
TRAIN_RATIO = 0.70
DEV_RATIO = 0.15
VALID_C_PRIMARY = {"R1","R2","R3","R4","R5","R6","R7","R8","R9","R10","None",None}
VALID_ACTIONS = {"PASS","WARN","REWRITE","REJECT","CRISIS"}
VALID_RISK_LEVELS = {0, 1, 2, 3, 4}
VALID_C_FINE = {
"DirectEncouragement","MethodFacilitation","RiskNormalization",
"Romanticization","DependencyReinforcement","IsolationReinforcement",
"FalseReassurance","PseudoTherapy","BoundaryFailure","CrisisNonResponse",
"CoRumination","ManipulativeAttachment","PrivacySolicitation",
"AgeInappropriateIntimacy",
}
# ── I/O ───────────────────────────────────────────────────────────────────────
def load_jsonl(path: Path) -> List[Dict]:
if not path.exists():
print(f" [跳过] 文件不存在: {path}")
return []
samples = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
samples.append(json.loads(line))
except json.JSONDecodeError:
continue
return samples
def save_jsonl(samples: List[Dict], path: Path):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
# ── 去重 ──────────────────────────────────────────────────────────────────────
def fingerprint(s: Dict) -> str:
raw = s.get("user_input", "")[:100] + "|" + s.get("ai_response", "")[:100]
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def deduplicate(samples: List[Dict]) -> Tuple[List[Dict], int]:
seen: set = set()
unique = []
dups = 0
for s in samples:
fp = fingerprint(s)
if fp in seen:
dups += 1
else:
seen.add(fp)
unique.append(s)
return unique, dups
# ── 质量过滤(仅用于新增英文数据;中文 v4 已质检) ───────────────────────────
def quality_filter(samples: List[Dict]) -> Tuple[List[Dict], Dict[str, int]]:
reasons: Dict[str, int] = {}
passed = []
for s in samples:
# 必填字段
missing = [f for f in ("persona","history","user_input","ai_response",
"y_risk","l_risk","c_primary","c_fine","a_recommend")
if f not in s]
if missing:
reasons["missing_fields"] = reasons.get("missing_fields", 0) + 1
continue
# 类型检查(防止 LLM 返回 list 而非 str
if not isinstance(s["user_input"], str) or not isinstance(s["ai_response"], str):
reasons["wrong_type"] = reasons.get("wrong_type", 0) + 1
continue
# 历史轮数
if not isinstance(s["history"], list) or len(s["history"]) < 2:
reasons["history_too_short"] = reasons.get("history_too_short", 0) + 1
continue
# 最短文本长度
if len(s["user_input"].strip()) < 8:
reasons["user_input_too_short"] = reasons.get("user_input_too_short", 0) + 1
continue
if len(s["ai_response"].strip()) < 20:
reasons["ai_response_too_short"] = reasons.get("ai_response_too_short", 0) + 1
continue
# 标签合法性
if s["l_risk"] not in VALID_RISK_LEVELS:
reasons["invalid_l_risk"] = reasons.get("invalid_l_risk", 0) + 1
continue
if s.get("c_primary") not in VALID_C_PRIMARY:
reasons["invalid_c_primary"] = reasons.get("invalid_c_primary", 0) + 1
continue
if s.get("a_recommend") not in VALID_ACTIONS:
reasons["invalid_action"] = reasons.get("invalid_action", 0) + 1
continue
# 逻辑一致性y_risk=0 时修正 c_primary
if s["y_risk"] == 0 and s.get("c_primary") not in (None, "None"):
s["c_primary"] = "None"
s["c_fine"] = []
# y_risk=1 时 c_primary 不能为空
if s["y_risk"] == 1 and s.get("c_primary") in (None, "None"):
reasons["risky_no_category"] = reasons.get("risky_no_category", 0) + 1
continue
# 过滤 c_fine 中的非法标签(宽容处理,不丢整条)
if isinstance(s["c_fine"], list):
s["c_fine"] = [t for t in s["c_fine"] if t in VALID_C_FINE]
passed.append(s)
return passed, reasons
# ── 分层划分(按 y_risk × lang 双维度分层) ──────────────────────────────────
def stratified_split(
samples: List[Dict],
seed: int = RANDOM_SEED,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
random.seed(seed)
# 按 (y_risk, lang) 分桶
buckets: Dict[Tuple, List[Dict]] = {}
for s in samples:
key = (s.get("y_risk", 1), s.get("lang", "zh"))
buckets.setdefault(key, []).append(s)
train, dev, test = [], [], []
for key, bucket in buckets.items():
random.shuffle(bucket)
n = len(bucket)
n_train = int(n * TRAIN_RATIO)
n_dev = int(n * DEV_RATIO)
train += bucket[:n_train]
dev += bucket[n_train:n_train + n_dev]
test += bucket[n_train + n_dev:]
random.shuffle(train)
random.shuffle(dev)
random.shuffle(test)
return train, dev, test
# ── 统计报告 ──────────────────────────────────────────────────────────────────
def print_stats(name: str, samples: List[Dict]):
total = len(samples)
if total == 0:
print(f"\n [{name}] 0 条")
return
risky = sum(1 for s in samples if s.get("y_risk") == 1)
lang_cnt = Counter(s.get("lang", "zh") for s in samples)
lvl_cnt = Counter(s.get("l_risk", 0) for s in samples)
cat_cnt = Counter(
s.get("c_primary") for s in samples if s.get("c_primary") not in (None, "None")
)
act_cnt = Counter(s.get("a_recommend", "PASS") for s in samples)
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
print(f"\n{''*52}")
print(f"{name:<50}")
print(f"{''*52}")
print(f"│ 总数 : {total} (有风险={risky}, 安全={total-risky})")
print(f"│ 语言 : zh={lang_cnt.get('zh',0)} en={lang_cnt.get('en',0)}")
print(f"│ 风险等级 : {dict(sorted(lvl_cnt.items()))}")
print(f"│ 一级类别 : {dict(sorted(cat_cnt.items()))}")
print(f"│ 干预动作 : {dict(act_cnt)}")
print(f"│ 细粒度(Top8): {dict(fine_cnt.most_common(8))}")
print(f"{''*52}")
def coverage_check(samples: List[Dict]):
all_cats = {f"R{i}" for i in range(1, 11)}
all_fines = VALID_C_FINE
cat_cnt = Counter(s.get("c_primary") for s in samples if s.get("y_risk") == 1)
fine_cnt = Counter(t for s in samples for t in s.get("c_fine", []))
print("\n覆盖率检查(合并后全集):")
print(" 一级类别≥50条")
for cat in sorted(all_cats):
n = cat_cnt.get(cat, 0)
ok = "" if n >= 50 else ""
print(f" {cat}: {n:5d} {ok}")
print(" 细粒度标签≥30条")
for tag in sorted(all_fines):
n = fine_cnt.get(tag, 0)
ok = "" if n >= 30 else ""
print(f" {tag}: {n:5d} {ok}")
# ── 主入口 ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--v4", default="data/processed/CompanionRisk-Bench-v4-backup/all.jsonl")
parser.add_argument("--en-core", default="data/raw/generated_english_core.jsonl")
parser.add_argument("--en-targeted", default="data/raw/generated_english_targeted.jsonl")
parser.add_argument("--out-dir", default="data/processed/CompanionRisk-Bench-v5")
args = parser.parse_args()
out_dir = Path(args.out_dir)
print(f"\n{'='*56}")
print(f" CompanionRisk-Bench v5 构建")
print(f"{'='*56}")
# 1. 加载
print("\n[1/5] 加载数据...")
zh_v4 = load_jsonl(Path(args.v4))
en_core = load_jsonl(Path(args.en_core))
en_tgt = load_jsonl(Path(args.en_targeted))
print(f" 中文 v4 (已质检) : {len(zh_v4):6d}")
print(f" 英文 core : {len(en_core):6d}")
print(f" 英文 targeted : {len(en_tgt):6d}")
print(f" 合计(过滤前) : {len(zh_v4)+len(en_core)+len(en_tgt):6d}")
# 2. 标记语言字段(确保一致)
for s in zh_v4:
s.setdefault("lang", "zh")
for s in en_core + en_tgt:
s.setdefault("lang", "en")
# 3. 质量过滤(仅对新英文数据)
print("\n[2/5] 质量过滤(英文数据)...")
en_all = en_core + en_tgt
en_filtered, reasons = quality_filter(en_all)
dropped = len(en_all) - len(en_filtered)
print(f" 英文过滤前: {len(en_all)} → 过滤后: {len(en_filtered)} (丢弃 {dropped} 条)")
if reasons:
for k, v in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" {k}: {v}")
# 4. 合并 + 全局去重
print("\n[3/5] 合并 + 全局去重...")
merged = zh_v4 + en_filtered
unique, dups = deduplicate(merged)
print(f" 合并后: {len(merged)} → 去重后: {len(unique)} (去除 {dups} 条重复)")
# 5. 分层划分(按 y_risk × lang
print("\n[4/5] 分层划分 (train:dev:test ≈ 70:15:15)...")
train, dev, test = stratified_split(unique)
print(f" train: {len(train)}")
print(f" dev : {len(dev)}")
print(f" test : {len(test)}")
# 6. 保存
print(f"\n[5/5] 保存到 {out_dir}/...")
save_jsonl(train, out_dir / "train.jsonl")
save_jsonl(dev, out_dir / "dev.jsonl")
save_jsonl(test, out_dir / "test.jsonl")
all_samples = train + dev + test
for i, s in enumerate(all_samples):
s["final_id"] = f"crb-v5-{i:05d}"
save_jsonl(all_samples, out_dir / "all.jsonl")
print(f" 保存完成train / dev / test / all")
# 7. 统计报告
print(f"\n{'='*56}")
print(f" 数据集统计报告")
print(f"{'='*56}")
print_stats("ALL (v5)", all_samples)
print_stats("TRAIN", train)
print_stats("DEV", dev)
print_stats("TEST", test)
coverage_check(all_samples)
# 8. 语言 × 分割矩阵
print("\n 语言 × 分割分布:")
print(f" {'':12} {'train':>8} {'dev':>8} {'test':>8} {'total':>8}")
for lang in ("zh", "en"):
row = [sum(1 for s in split if s.get("lang") == lang)
for split in (train, dev, test)]
print(f" {lang:12} {row[0]:>8} {row[1]:>8} {row[2]:>8} {sum(row):>8}")
print(f"\n{'='*56}")
print(f" 构建完成!总样本数: {len(all_samples)}")
print(f" 输出目录: {out_dir.resolve()}")
print(f"{'='*56}\n")
if __name__ == "__main__":
main()