CompanionGuard-RL — 项目进度快照
更新时间:2026-05-12(Module C ✅ 完成;det_l_risk 修复后重训 v2 完成,评估 v3 为最终论文结果)
📖 可复用经验库 → 见 exp.md(RTX 5090 NCCL、PyYAML 陷阱、分布式 Tensor 设备一致性、CRLF 等 12 类经验)
总体进度
| 模块 |
状态 |
关键指标 |
| 数据集 CompanionRisk-Bench v4 |
✅ 完成 |
9,896 样本,全 14 标签覆盖 |
| Module B — 检测器 v4 |
✅ 完成 |
binary_f1=0.9995, level_macro_f1=0.550 |
| Module B — 泛化性验证 |
✅ 完成 |
human subset binary_f1=0.9848,无过拟合 |
| Module C — RL 干预策略 |
✅ 完成 |
1-GPU 模式 BC+PPO 200k steps 收敛,safety_recall=1.0,over_refusal=0.0 |
| 论文写作 |
🔄 可启动 |
Module C 结果已出,可开始写作 |
一、数据集 CompanionRisk-Bench(最终版 v4)
规模
| 分割 |
样本数 |
| train |
6,926 |
| dev |
1,484 |
| test |
1,486 |
| total |
9,896 |
数据来源构成
| 来源 |
样本数 |
说明 |
| LLM 核心集(Qwen2.5-72B via SiliconFlow) |
8,000 |
中文,10 类风险 + safe |
| 弱标签专项集(generate_targeted.py) |
1,083 |
FalseReassurance/PseudoTherapy/IsolationReinforcement 单标签增强 |
| Human-AI Suicide Risk Dataset |
393 |
英文,R1 危机类 |
| CoSafe Dataset |
420 |
多类别对话安全 |
| DICES-990 |
~200(质检后约 0 入库) |
质检未通过(history_too_short) |
生成路径
细粒度标签训练集覆盖(全部 ≥ 30 条)
| 标签 |
全集数量 |
训练集 |
| RiskNormalization |
1,787 |
1,235 |
| DirectEncouragement |
1,292 |
921 |
| FalseReassurance |
1,290 |
905 |
| BoundaryFailure |
1,157 |
800 |
| PseudoTherapy |
1,090 |
767 |
| IsolationReinforcement |
991 |
693 |
| ManipulativeAttachment |
752 |
534 |
| DependencyReinforcement |
787 |
537 |
| CoRumination |
638 |
~440 |
| CrisisNonResponse |
594 |
~410 |
| AgeInappropriateIntimacy |
583 |
~410 |
| PrivacySolicitation |
530 |
370 |
| MethodFacilitation |
683 |
489 |
| Romanticization |
437 |
310 |
二、Module B — 检测器训练(最终版 v4)
模型架构
- Backbone:
hfl/chinese-macbert-large(1024 hidden, LoRA=off)
- 服务器本地路径:
/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large
- 融合层:CrossAttention(response 为 query,persona+context 为 key/value)
- 输出头:4 个分类头(y_risk / l_risk / c_primary / c_fine)
训练配置(v4)
| 参数 |
值 |
| GPU |
4 × RTX 5090 32GB |
| Effective batch |
128(16 × 4 GPU × 2 accum) |
| Epochs |
10 |
| LR |
2e-5,线性 warmup 100 steps |
| Mixed precision |
bf16 |
| fine_loss_weight |
2.0 |
| fine_training.use_pos_weight |
true(max clip=30) |
| fine_training.risky_only |
true |
v2 → v3 → v4 关键指标演进
| 指标 |
v2(4022条) |
v3(8813条) |
v4(9896条) |
| binary_f1 |
0.9848 |
0.9989 |
0.9995 |
| high_risk_recall |
— |
0.9989 |
1.0000 |
| FNR |
1.52% |
0.11% |
0.00% |
| level_macro_f1 |
~0.43 |
0.497 |
0.550 |
| level_weighted_f1 |
— |
0.511 |
0.559 |
| L1 Mild F1 |
~0 |
0.174 |
0.635 |
| fine_macro_f1 (all 14类) |
0.000 (bug) |
0.476 |
0.463 |
| fine_macro_f1 (public 10类) |
— |
— |
0.484 |
v4 细粒度标签 F1(全量 test)
| 标签 |
v3 F1 |
v4 F1 |
变化 |
| DirectEncouragement |
0.705 |
0.684 |
→ |
| RiskNormalization |
0.627 |
0.698 |
↑ |
| AgeInappropriateIntimacy |
0.694 |
0.616 |
→ |
| BoundaryFailure |
0.609 |
0.532 |
→ |
| DependencyReinforcement |
0.625 |
0.585 |
→ |
| FalseReassurance |
0.279 |
0.383 |
↑ +0.104 ✅ |
| PseudoTherapy |
0.239 |
0.338 |
↑ +0.099 ✅ |
| IsolationReinforcement |
0.288 |
0.356 |
↑ +0.068 ✅ |
| ManipulativeAttachment |
0.444 |
0.441 |
→ |
| MethodFacilitation |
0.403 |
0.466 |
↑ |
| Romanticization |
0.434 |
0.402 |
→ |
| CoRumination |
0.350 |
0.269 |
↓ (targeted 副作用) |
| CrisisNonResponse |
0.588 |
0.394 |
↓ (targeted 副作用) |
| PrivacySolicitation |
0.373 |
0.321 |
→ |
论文汇报策略
- 主指标:
binary_f1=0.9995,level_weighted_f1=0.559,fine_macro_f1(public)=0.484
- level_macro_f1 下降的 L0 问题:使用 weighted 指标,注释"test set risky:safe = 2.33:1"
- fine companion 4 类(DependencyReinforcement/IsolationReinforcement/CoRumination/ManipulativeAttachment)单独在表格中列出,说明"companion-specific,无人工标注参考集"
- 不再迭代 Module B,CoRumination/CrisisNonResponse 的轻微下降在 limitation 一句话说明
checkpoint 路径(服务器)
三、L1 规则基线 vs Ours(v4 test,n=1486)
| 方法 |
BinaryF1 |
Recall |
FNR |
LevelF1(weighted) |
| L1a Keyword |
0.264 |
0.155 |
0.845 |
0.098 |
| L1b Regex |
0.067 |
0.035 |
0.965 |
0.063 |
| L1c Combined |
0.306 |
0.184 |
0.816 |
0.106 |
| Ours (Module B) |
0.9995 |
1.000 |
0.000 |
0.559 |
四、Module C — RL 干预策略(当前阶段)
任务目标
在 Module B 检测器输出的状态向量上,用 PPO 训练一个干预动作策略,学习对不同风险等级和类别选择最优干预动作(PASS/WARN/REWRITE/REJECT/CRISIS)。
核心文件
| 文件 |
说明 |
scripts/train_intervention.py |
两阶段训练主脚本(BC warmup + PPO,支持4-GPU分布式) |
configs/intervention_config.yaml |
完整训练配置 |
src/models/intervention_agent.py |
Actor-Critic 网络(已修复 _encode_obs 维度 bug) |
src/rl/companion_env.py |
Gymnasium 兼容的离线 RL 环境 |
src/rl/ppo_trainer.py |
PPO 训练器(RolloutBuffer + GAE + ppo_update) |
src/rl/reward.py |
多目标奖励函数(safety + anti-over-refusal + UX) |
src/utils/preprocessing.py |
detector → RL 状态向量转换 |
状态向量结构(obs_dim = 2065)
重要 Bug 修复记录(已完成)
intervention_agent.py:原来 actor 第一层 Linear(256,256) 直接接 2065-dim 原始 obs 会崩溃。
已添加 _encode_obs() 方法,在 forward() 内先解析 flat obs → StateEncoder → 256-dim latent。
companion_env.py、InterventionAgent 默认 detector_hidden 已从 768 改为 1024。
训练参数(intervention_config.yaml)
训练命令(服务器,当前最新版)
评估命令(训练完成后)
成功标准(Module C)
| 指标 |
目标 |
说明 |
| safety_recall(高风险正确处理率) |
> 0.85 |
L3/L4 被 REWRITE/REJECT/CRISIS |
| over_refusal_rate(安全内容误拦截) |
< 0.10 |
y_risk=0 被 REWRITE+ |
| action_accuracy(vs a_recommend) |
> 0.70 |
与标注推荐动作吻合率 |
| crisis_precision(R1 选 CRISIS 精度) |
> 0.80 |
关键安全保障 |
Module C 调试记录(时序)
| # |
错误 |
根因 |
修复位置 |
| 1 |
ModuleNotFoundError: gymnasium |
dlapo 环境无此包 |
cp -r .../gymnasium .../site-packages/ |
| 2 |
ModuleNotFoundError: wandb(复杂依赖链) |
环境缺 wandb 及其依赖 |
train_intervention.py + ppo_trainer.py 改为 try/except 导入;use_wandb: false |
| 3 |
OSError: Can't load hfl/chinese-macbert-large |
服务器无公网 |
intervention_config.yaml 改为本地绝对路径 |
| 4 |
RuntimeError: No backend type associated with device type cpu |
torch.distributed.broadcast 不支持 CPU tensor |
train_intervention.py broadcast 段改为先 .to(accelerator.device) 再广播 |
| 5 |
TypeError: '<=' not supported between float and str |
PyYAML 6.x 将 1e-3 解析为字符串 |
intervention_config.yaml 改为 0.001 / 0.0003 |
| 6 |
AttributeError: SequentialSampler has no set_epoch |
DataLoader 使用 SequentialSampler 而非 DistributedSampler |
train_intervention.py 加 if hasattr(loader.sampler, "set_epoch"): guard |
| 7 |
RuntimeError: cannot pin torch.cuda.FloatTensor |
pin_memory=True 要求 CPU tensor,但 tensor 已在 GPU |
train_intervention.py L~116 加 .cpu() 后再构建 TensorDataset |
| 8 |
CUDA error: an illegal memory access(BC 后 PPO 开始) |
accelerator.wait_for_everyone() → torch.distributed.barrier() 在 RTX 5090 NCCL 下崩溃,与 NCCL_P2P 无关 |
修复:改用 --num_processes=1 单 GPU 运行,完全绕开 NCCL barrier |
当前状态(2026-05-12 最终)
✅ Module C 训练完成(单 GPU 模式):
训练命令(已验证可用):
五、服务器信息
服务器1(主训练机,当前被占用)
| 项目 |
值 |
| SSH |
ssh -p 20083 root@10.82.3.180 |
| 密码 |
m2dGcwyrhI |
| 项目目录 |
/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL |
| MacBERT 路径 |
/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large |
| 环境 |
/opt/conda/envs/dlapo-py310-cu128(torch 2.7.1+cu128,transformers 5.8.0) |
| GPU |
4 × RTX 5090 32GB |
服务器2(当前使用)
| 项目 |
值 |
| SSH |
ssh -p 20060 root@10.82.3.180 |
| 密码 |
zwfn65xjTY |
| 项目目录 |
/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl |
| MacBERT 路径 |
需同步(见下文) |
| 环境 |
/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128(从服务器1迁移) |
| GPU |
2 × RTX 5090 32GB |
| 存储 |
NFS 1TB(siton-data-740d234e02d749f08fe5347b0c74c49f) |
服务器2训练命令(已验证可用路径):
两台服务器说明
- 同一宿主机
10.82.3.180,不同 Docker 容器,不同端口
- 容器间互通:服务器1可 ssh 到
172.17.0.1:20060 访问服务器2
- Host key 相同:
SHA256:nAMVofPMCFZxa0DyOO2Olepfnp1MzZGdMyW7j5OekQI
六、代码同步状态(2026-05-12 晚更新)
本地 ↔ 服务器1 同步(已完成)
| 操作 |
文件 |
说明 |
| 服务器1→本地 |
src/rl/ppo_trainer.py |
服务器调试版(MD5不同),已下载 |
| 本地→服务器1 |
checkpoints/intervention/final_v2.pt |
v2权重命名版,已上传 |
| 跳过 |
checkpoints/detector/best.pt / final.pt |
字节完全一致(1,352,746,854 B) |
| 跳过 |
src/utils/preprocessing.py |
MD5一致 |
服务器1 → 服务器2 同步(已完成)
| 内容 |
状态 |
src/(18个py文件) |
✅ |
scripts/ |
✅ |
configs/ |
✅ |
data/processed/CompanionRisk-Bench/(9896条) |
✅ |
experiments/(eval/train logs+json) |
✅ |
checkpoints/detector/best.pt(1.35GB) |
✅ |
checkpoints/detector/final.pt(1.35GB) |
✅ |
checkpoints/intervention/final.pt + final_v2.pt |
✅ |
requirements.txt |
✅ |
conda env dlapo-py310-cu128(7.7GB) |
✅ /root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128/(torch 2.7.1+cu128 ✓,GPU×2 ✓) |
| MacBERT 权重(1.3GB) |
✅ /root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/macbert-large/ |
关键文件清单(截至 2026-05-12)
| 文件 |
状态 |
说明 |
checkpoints/detector/best.pt |
✅ 服务器1+2 + 本地 |
v4 最优检测器权重(1.35GB) |
data/processed/CompanionRisk-Bench/ |
✅ 服务器1+2 + 本地 |
v4 数据集(9896条) |
scripts/train_intervention.py |
✅ 就绪 |
Module C 训练脚本 |
configs/intervention_config.yaml |
✅ 就绪 |
Module C 完整配置 |
src/models/intervention_agent.py |
✅ bug已修 |
Actor-Critic(obs_dim=2065→256→actions) |
src/rl/companion_env.py |
✅ 就绪 |
离线 RL 环境 |
src/rl/ppo_trainer.py |
✅ 就绪 |
PPO 训练器 |
src/rl/reward.py |
✅ 就绪 |
多目标奖励函数 |
src/utils/preprocessing.py |
✅ bug已修(v2) |
build_obs_vector 改用 det_l_risk |
src/utils/metrics.py |
✅ bug已修(v2) |
新增 per_level_action_dist + action_accuracy |
scripts/evaluate.py |
✅ bug已修(v2) |
rule policy 改用 det_l_risk,展示新指标 |
experiments/eval_v4_all.log |
✅ 本地 |
v4 完整评估日志 |
experiments/eval_v4_public.log |
✅ 本地 |
v4 public filter 评估日志 |
checkpoints/intervention/final.pt |
✅ 服务器 + 本地 |
Module C PPO 最终权重(5.1MB) |
experiments/eval_intervention_v1.json |
✅ 本地 |
Module C 评估 v1(有 bug,已废弃) |
experiments/eval_intervention_v2.json |
✅ 本地 |
Module C 评估 v2(代码修复后,但模型仍用旧权重,废弃) |
experiments/eval_intervention_v3.json |
✅ 本地 |
Module C 评估 v3(重训+修复,论文用此) |
checkpoints/intervention/final_v2.pt |
✅ 服务器 + 本地 |
Module C PPO v2 权重(用 det_l_risk 重训,论文用此) |
experiments/train_intervention_1gpu_20260512_165204.log |
✅ 本地 |
Module C 训练 v1 日志(旧,已废弃) |
experiments/train_intervention_v2_20260512_172636.log |
✅ 本地 |
Module C 训练 v2 日志(det_l_risk,论文用此) |
五(补)、Module C 评估 Bug 修复记录
v1 的两个问题(均已修复)
Bug A — build_obs_vector 用了 ground truth l_risk
- 位置:
src/utils/preprocessing.py:127
- 症状:RL 状态向量含 ground truth 等级(部署时不可知),导致 safety_recall/over_refusal 结果不真实
- 修复:改用
sample.get("det_l_risk", sample["l_risk"])(优先检测器预测值)
- 影响:不需要重新训练(detector binary_f1=0.9995,两者几乎相同;但概念上正确)
Bug B — 干预指标 intervention_recall_high=1.0、over_refusal=0.0 三方法无差别
- 位置:
src/utils/metrics.py
- 症状:阈值太粗(l_risk≥3 → action≥2)所有合理策略都能完美通过,无区分度
- 修复:新增
per_level_action_dist(按 ground truth 等级展示各动作占比)和 action_accuracy(vs a_recommend)
- 附带:
evaluate.py 中 run_rule_intervention 的策略输入改为 det_l_risk,与部署一致
六、Module C 评估结果 v2(2026-05-12,论文用)
干预任务汇总指标
| 方法 |
safety_recall(L3/L4) |
over_refusal |
action_accuracy |
crisis_precision |
| Rule-based (l≥3→REJECT) |
0.908 |
0.000 |
— |
— |
| Threshold Baseline |
0.908 |
0.000 |
— |
0.624 |
| Ours (RL, Module C) |
1.000 |
0.000 |
0.587 |
0.470 |
safety_recall 改为基于 det_l_risk 策略输入 vs ground truth level,Rule-based/Threshold 降至 0.908(9.2% L3/L4 样本被检测器预测为 <L3,因此 rule 给了 PASS/WARN)。RL 仍 1.0 说明它学到了超越 l_risk 阈值的综合判断。
Per-level Action Distribution(核心论文表格)
成功标准达成情况(v2)
| 指标 |
目标 |
RL实测 |
状态 |
| safety_recall(L3/L4 正确处理率) |
> 0.85 |
1.000 |
✅ |
| over_refusal_rate(safe 内容误拦截) |
< 0.10 |
0.000 |
✅ |
| action_accuracy(vs a_recommend) |
> 0.70 |
0.587 |
⚠️ |
| crisis_precision(CRISIS→L4 精度) |
> 0.80 |
0.470 |
⚠️ |
RL 策略解读(v2,已废弃,见 v3)
- v2 基于旧权重(用 GT l_risk 训练)+ 新评估代码,存在 train/eval 不一致,仅作对照参考
七、Module C 最终结果 v3(重训 + 正确评估,论文用)
重训原因
RL agent 训练时 state 向量包含 ground truth l_risk(非检测器预测),而检测器 level_macro_f1=0.55(各等级预测有误差),导致训练条件与部署不一致,需要用 det_l_risk 重训。
评估 v1 / v2 / v3 演进
| 版本 |
代码 |
模型 |
问题 |
| v1 |
旧(GT l_risk state, 无 per-level) |
旧(GT l_risk 训练) |
两个 bug,指标虚高 |
| v2 |
新(det_l_risk state, 有 per-level) |
旧(GT l_risk 训练) |
train/eval 不一致 |
| v3 |
新 |
新(det_l_risk 训练) |
论文使用 |
汇总指标(v3,最终)
| 方法 |
safety_recall(L3/L4) |
over_refusal |
action_accuracy |
crisis_precision |
safety_ux_fscore |
| Rule-based (l≥3→REJECT) |
0.908 |
0.000 |
— |
— |
0.952 |
| Threshold Baseline |
0.908 |
0.000 |
— |
0.624 |
0.952 |
| Ours (RL v2) |
1.000 |
0.004 |
0.575 |
0.421 |
0.998 |
Per-level Action Distribution(v3,论文核心表格)
成功标准达成情况(v3 最终)
| 指标 |
目标 |
RL v2 实测 |
状态 |
| safety_recall(L3/L4 正确处理率) |
> 0.85 |
1.000 |
✅ |
| over_refusal_rate(safe 内容误拦截) |
< 0.10 |
0.004 |
✅ |
| action_accuracy(vs a_recommend) |
> 0.70 |
0.575 |
⚠️ |
| crisis_precision(CRISIS→L4 精度) |
> 0.80 |
0.421 |
⚠️ |
论文论点
- 优势:safety_recall=1.0(baseline 仅 0.908),RL 在检测器等级误差下仍能正确干预,说明学到了多信号综合判断
- Limitation 1:action_accuracy=0.575;L1 层级误触发(22.9% REWRITE),轻度风险处理过激
- Limitation 2:crisis_precision=0.421;L4 CRISIS 触发率仅 36.7%(Threshold 64.3%),R1 训练样本稀少(136条)+ w3=4.0 不足