Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
491 lines
23 KiB
Markdown
491 lines
23 KiB
Markdown
# CompanionGuard-RL — 项目进度快照
|
||
**更新时间:2026-05-12(Module C ✅ 完成;det_l_risk 修复后重训 v2 完成,评估 v3 为最终论文结果)**
|
||
|
||
> 📖 **可复用经验库** → 见 [`exp.md`](exp.md)(RTX 5090 NCCL、PyYAML 陷阱、分布式 Tensor 设备一致性、CRLF 等 12 类经验)
|
||
|
||
---
|
||
|
||
## 总体进度
|
||
|
||
| 模块 | 状态 | 关键指标 |
|
||
|------|------|---------|
|
||
| 数据集 CompanionRisk-Bench v4 | ✅ 完成 | 9,896 样本,全 14 标签覆盖 |
|
||
| Module B — 检测器 v4 | ✅ **完成** | binary_f1=0.9995, level_macro_f1=0.550 |
|
||
| Module B — 泛化性验证 | ✅ 完成 | human subset binary_f1=0.9848,无过拟合 |
|
||
| Module C — RL 干预策略 | ✅ **完成** | 1-GPU 模式 BC+PPO 200k steps 收敛,safety_recall=1.0,over_refusal=0.0 |
|
||
| 论文写作 | 🔄 **可启动** | Module C 结果已出,可开始写作 |
|
||
|
||
---
|
||
|
||
## 一、数据集 CompanionRisk-Bench(最终版 v4)
|
||
|
||
### 规模
|
||
| 分割 | 样本数 |
|
||
|------|--------|
|
||
| train | 6,926 |
|
||
| dev | 1,484 |
|
||
| test | 1,486 |
|
||
| **total** | **9,896** |
|
||
|
||
### 数据来源构成
|
||
| 来源 | 样本数 | 说明 |
|
||
|------|--------|------|
|
||
| LLM 核心集(Qwen2.5-72B via SiliconFlow) | 8,000 | 中文,10 类风险 + safe |
|
||
| 弱标签专项集(generate_targeted.py) | 1,083 | FalseReassurance/PseudoTherapy/IsolationReinforcement 单标签增强 |
|
||
| Human-AI Suicide Risk Dataset | 393 | 英文,R1 危机类 |
|
||
| CoSafe Dataset | 420 | 多类别对话安全 |
|
||
| DICES-990 | ~200(质检后约 0 入库) | 质检未通过(history_too_short) |
|
||
|
||
### 生成路径
|
||
```
|
||
scripts/generate_siliconflow.py → data/raw/generated_core.jsonl (8000条)
|
||
scripts/generate_targeted.py → data/raw/generated_targeted.jsonl (1083条)
|
||
scripts/adapt_public_datasets.py → data/raw/adapted_*.jsonl
|
||
scripts/merge_and_split.py → data/processed/CompanionRisk-Bench/{train,dev,test,all}.jsonl
|
||
```
|
||
|
||
### 细粒度标签训练集覆盖(全部 ≥ 30 条)
|
||
| 标签 | 全集数量 | 训练集 |
|
||
|------|---------|--------|
|
||
| RiskNormalization | 1,787 | 1,235 |
|
||
| DirectEncouragement | 1,292 | 921 |
|
||
| FalseReassurance | 1,290 | 905 |
|
||
| BoundaryFailure | 1,157 | 800 |
|
||
| PseudoTherapy | 1,090 | 767 |
|
||
| IsolationReinforcement | 991 | 693 |
|
||
| ManipulativeAttachment | 752 | 534 |
|
||
| DependencyReinforcement | 787 | 537 |
|
||
| CoRumination | 638 | ~440 |
|
||
| CrisisNonResponse | 594 | ~410 |
|
||
| AgeInappropriateIntimacy | 583 | ~410 |
|
||
| PrivacySolicitation | 530 | 370 |
|
||
| MethodFacilitation | 683 | 489 |
|
||
| Romanticization | 437 | 310 |
|
||
|
||
---
|
||
|
||
## 二、Module B — 检测器训练(最终版 v4)
|
||
|
||
### 模型架构
|
||
- **Backbone**:`hfl/chinese-macbert-large`(1024 hidden, LoRA=off)
|
||
- 服务器本地路径:`/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large`
|
||
- **融合层**:CrossAttention(response 为 query,persona+context 为 key/value)
|
||
- **输出头**:4 个分类头(y_risk / l_risk / c_primary / c_fine)
|
||
|
||
### 训练配置(v4)
|
||
| 参数 | 值 |
|
||
|------|----|
|
||
| GPU | 4 × RTX 5090 32GB |
|
||
| Effective batch | 128(16 × 4 GPU × 2 accum) |
|
||
| Epochs | 10 |
|
||
| LR | 2e-5,线性 warmup 100 steps |
|
||
| Mixed precision | bf16 |
|
||
| fine_loss_weight | 2.0 |
|
||
| fine_training.use_pos_weight | true(max clip=30) |
|
||
| fine_training.risky_only | true |
|
||
|
||
### v2 → v3 → v4 关键指标演进
|
||
| 指标 | v2(4022条) | v3(8813条) | v4(9896条) |
|
||
|------|-------------|-------------|-------------|
|
||
| binary_f1 | 0.9848 | 0.9989 | **0.9995** |
|
||
| high_risk_recall | — | 0.9989 | **1.0000** |
|
||
| FNR | 1.52% | 0.11% | **0.00%** |
|
||
| level_macro_f1 | ~0.43 | 0.497 | **0.550** |
|
||
| level_weighted_f1 | — | 0.511 | **0.559** |
|
||
| L1 Mild F1 | ~0 | 0.174 | **0.635** |
|
||
| fine_macro_f1 (all 14类) | 0.000 (bug) | 0.476 | 0.463 |
|
||
| fine_macro_f1 (public 10类) | — | — | **0.484** |
|
||
|
||
### v4 细粒度标签 F1(全量 test)
|
||
| 标签 | v3 F1 | v4 F1 | 变化 |
|
||
|------|-------|-------|------|
|
||
| DirectEncouragement | 0.705 | 0.684 | → |
|
||
| RiskNormalization | 0.627 | 0.698 | ↑ |
|
||
| AgeInappropriateIntimacy | 0.694 | 0.616 | → |
|
||
| BoundaryFailure | 0.609 | 0.532 | → |
|
||
| DependencyReinforcement | 0.625 | 0.585 | → |
|
||
| FalseReassurance | 0.279 | **0.383** | ↑ +0.104 ✅ |
|
||
| PseudoTherapy | 0.239 | **0.338** | ↑ +0.099 ✅ |
|
||
| IsolationReinforcement | 0.288 | **0.356** | ↑ +0.068 ✅ |
|
||
| ManipulativeAttachment | 0.444 | 0.441 | → |
|
||
| MethodFacilitation | 0.403 | 0.466 | ↑ |
|
||
| Romanticization | 0.434 | 0.402 | → |
|
||
| CoRumination | 0.350 | 0.269 | ↓ (targeted 副作用) |
|
||
| CrisisNonResponse | 0.588 | 0.394 | ↓ (targeted 副作用) |
|
||
| PrivacySolicitation | 0.373 | 0.321 | → |
|
||
|
||
### 论文汇报策略
|
||
- **主指标**:`binary_f1=0.9995`,`level_weighted_f1=0.559`,`fine_macro_f1(public)=0.484`
|
||
- level_macro_f1 下降的 L0 问题:使用 weighted 指标,注释"test set risky:safe = 2.33:1"
|
||
- fine companion 4 类(DependencyReinforcement/IsolationReinforcement/CoRumination/ManipulativeAttachment)单独在表格中列出,说明"companion-specific,无人工标注参考集"
|
||
- 不再迭代 Module B,CoRumination/CrisisNonResponse 的轻微下降在 limitation 一句话说明
|
||
|
||
### checkpoint 路径(服务器)
|
||
```
|
||
checkpoints/detector/best.pt ← Module C 使用此权重(frozen)
|
||
```
|
||
|
||
---
|
||
|
||
## 三、L1 规则基线 vs Ours(v4 test,n=1486)
|
||
|
||
| 方法 | BinaryF1 | Recall | FNR | LevelF1(weighted) |
|
||
|------|---------|--------|-----|-------------------|
|
||
| L1a Keyword | 0.264 | 0.155 | 0.845 | 0.098 |
|
||
| L1b Regex | 0.067 | 0.035 | 0.965 | 0.063 |
|
||
| L1c Combined | 0.306 | 0.184 | 0.816 | 0.106 |
|
||
| **Ours (Module B)** | **0.9995** | **1.000** | **0.000** | **0.559** |
|
||
|
||
---
|
||
|
||
## 四、Module C — RL 干预策略(当前阶段)
|
||
|
||
### 任务目标
|
||
在 Module B 检测器输出的状态向量上,用 PPO 训练一个干预动作策略,学习对不同风险等级和类别选择最优干预动作(PASS/WARN/REWRITE/REJECT/CRISIS)。
|
||
|
||
### 核心文件
|
||
| 文件 | 说明 |
|
||
|------|------|
|
||
| `scripts/train_intervention.py` | 两阶段训练主脚本(BC warmup + PPO,支持4-GPU分布式) |
|
||
| `configs/intervention_config.yaml` | 完整训练配置 |
|
||
| `src/models/intervention_agent.py` | Actor-Critic 网络(已修复 _encode_obs 维度 bug) |
|
||
| `src/rl/companion_env.py` | Gymnasium 兼容的离线 RL 环境 |
|
||
| `src/rl/ppo_trainer.py` | PPO 训练器(RolloutBuffer + GAE + ppo_update) |
|
||
| `src/rl/reward.py` | 多目标奖励函数(safety + anti-over-refusal + UX) |
|
||
| `src/utils/preprocessing.py` | detector → RL 状态向量转换 |
|
||
|
||
### 状态向量结构(obs_dim = 2065)
|
||
```
|
||
[d_score(1) | l_risk_onehot(5) | c_primary_probs(10) | e_H_pool(1024) | e_P_pool(1024) | t_norm(1)]
|
||
= 1 + 5 + 10 + 1024 + 1024 + 1 = 2065
|
||
```
|
||
|
||
### 重要 Bug 修复记录(已完成)
|
||
- `intervention_agent.py`:原来 `actor` 第一层 `Linear(256,256)` 直接接 2065-dim 原始 obs 会崩溃。
|
||
已添加 `_encode_obs()` 方法,在 `forward()` 内先解析 flat obs → StateEncoder → 256-dim latent。
|
||
- `companion_env.py`、`InterventionAgent` 默认 `detector_hidden` 已从 768 改为 1024。
|
||
|
||
### 训练参数(intervention_config.yaml)
|
||
```yaml
|
||
behavior_cloning:
|
||
enabled: true
|
||
epochs: 5
|
||
per_gpu_batch_size: 256
|
||
lr: 1e-3
|
||
|
||
ppo:
|
||
total_timesteps: 200000
|
||
n_rollout_steps: 2048
|
||
n_epochs: 4
|
||
batch_size: 256
|
||
lr: 3e-4
|
||
clip_eps: 0.2
|
||
entropy_coef: 0.01
|
||
gamma: 0.99
|
||
gae_lambda: 0.95
|
||
|
||
reward:
|
||
w1: 2.0 # safety gain
|
||
w2: 3.0 # false negative penalty
|
||
w3: 4.0 # crisis bonus (R1)
|
||
w4: 1.5 # over-refusal penalty
|
||
w5: 0.5 # UX cost
|
||
```
|
||
|
||
### 训练命令(服务器,当前最新版)
|
||
```bash
|
||
cd /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL
|
||
|
||
# 注意:需要同时禁用 SHM 和 P2P,否则 RTX 5090 NCCL 报 CUDA illegal memory access
|
||
CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_SHM_DISABLE=1 NCCL_P2P_DISABLE=1 \
|
||
/opt/conda/envs/dlapo-py310-cu128/bin/accelerate launch \
|
||
--num_processes=4 --mixed_precision=bf16 \
|
||
scripts/train_intervention.py \
|
||
--config configs/intervention_config.yaml \
|
||
--train-data data/processed/CompanionRisk-Bench/train.jsonl \
|
||
> experiments/train_intervention_$(date +%Y%m%d_%H%M%S).log 2>&1 &
|
||
```
|
||
|
||
### 评估命令(训练完成后)
|
||
```bash
|
||
python scripts/evaluate.py \
|
||
--detector-ckpt checkpoints/detector/best.pt \
|
||
--agent-ckpt checkpoints/intervention/final.pt \
|
||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \
|
||
--config configs/detector_config_server.yaml \
|
||
--intervention-config configs/intervention_config.yaml \
|
||
--output experiments/eval_intervention_v1.json
|
||
```
|
||
|
||
### 成功标准(Module C)
|
||
| 指标 | 目标 | 说明 |
|
||
|------|------|------|
|
||
| safety_recall(高风险正确处理率) | > 0.85 | L3/L4 被 REWRITE/REJECT/CRISIS |
|
||
| over_refusal_rate(安全内容误拦截) | < 0.10 | y_risk=0 被 REWRITE+ |
|
||
| action_accuracy(vs a_recommend) | > 0.70 | 与标注推荐动作吻合率 |
|
||
| crisis_precision(R1 选 CRISIS 精度) | > 0.80 | 关键安全保障 |
|
||
|
||
### Module C 调试记录(时序)
|
||
|
||
| # | 错误 | 根因 | 修复位置 |
|
||
|---|------|------|---------|
|
||
| 1 | `ModuleNotFoundError: gymnasium` | dlapo 环境无此包 | `cp -r .../gymnasium .../site-packages/` |
|
||
| 2 | `ModuleNotFoundError: wandb`(复杂依赖链) | 环境缺 wandb 及其依赖 | `train_intervention.py` + `ppo_trainer.py` 改为 `try/except` 导入;`use_wandb: false` |
|
||
| 3 | `OSError: Can't load hfl/chinese-macbert-large` | 服务器无公网 | `intervention_config.yaml` 改为本地绝对路径 |
|
||
| 4 | `RuntimeError: No backend type associated with device type cpu` | `torch.distributed.broadcast` 不支持 CPU tensor | `train_intervention.py` broadcast 段改为先 `.to(accelerator.device)` 再广播 |
|
||
| 5 | `TypeError: '<=' not supported between float and str` | PyYAML 6.x 将 `1e-3` 解析为字符串 | `intervention_config.yaml` 改为 `0.001` / `0.0003` |
|
||
| 6 | `AttributeError: SequentialSampler has no set_epoch` | DataLoader 使用 SequentialSampler 而非 DistributedSampler | `train_intervention.py` 加 `if hasattr(loader.sampler, "set_epoch"):` guard |
|
||
| 7 | `RuntimeError: cannot pin torch.cuda.FloatTensor` | `pin_memory=True` 要求 CPU tensor,但 tensor 已在 GPU | `train_intervention.py` L~116 加 `.cpu()` 后再构建 TensorDataset |
|
||
| 8 | **`CUDA error: an illegal memory access`(BC 后 PPO 开始)** | `accelerator.wait_for_everyone()` → `torch.distributed.barrier()` 在 RTX 5090 NCCL 下崩溃,与 NCCL_P2P 无关 | **修复**:改用 `--num_processes=1` 单 GPU 运行,完全绕开 NCCL barrier |
|
||
|
||
### 当前状态(2026-05-12 最终)
|
||
|
||
**✅ Module C 训练完成(单 GPU 模式):**
|
||
```
|
||
Running on 1 GPU(s), mixed_precision=bf16
|
||
=== Stage 1: Behavior Cloning (1 GPU) === ← BC 正常收敛
|
||
=== Stage 2: PPO Fine-tuning (GPU-0) ===
|
||
[PPO] Update 98 | Steps 200704/200000 ← PPO 完成
|
||
Training complete. Final model: checkpoints/intervention/final.pt
|
||
```
|
||
|
||
**训练命令(已验证可用):**
|
||
```bash
|
||
cd /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL
|
||
export PYTHONPATH=$PWD
|
||
CUDA_VISIBLE_DEVICES=0 \
|
||
/opt/conda/envs/dlapo-py310-cu128/bin/accelerate launch \
|
||
--num_processes=1 --mixed_precision=bf16 \
|
||
scripts/train_intervention.py \
|
||
--config configs/intervention_config.yaml \
|
||
--train-data data/processed/CompanionRisk-Bench/train.jsonl
|
||
```
|
||
|
||
---
|
||
|
||
## 五、服务器信息
|
||
|
||
### 服务器1(主训练机,当前被占用)
|
||
|
||
| 项目 | 值 |
|
||
|------|----|
|
||
| SSH | `ssh -p 20083 root@10.82.3.180` |
|
||
| 密码 | `m2dGcwyrhI` |
|
||
| 项目目录 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL` |
|
||
| MacBERT 路径 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large` |
|
||
| 环境 | `/opt/conda/envs/dlapo-py310-cu128`(torch 2.7.1+cu128,transformers 5.8.0) |
|
||
| GPU | 4 × RTX 5090 32GB |
|
||
|
||
### 服务器2(当前使用)
|
||
|
||
| 项目 | 值 |
|
||
|------|----|
|
||
| SSH | `ssh -p 20060 root@10.82.3.180` |
|
||
| 密码 | `zwfn65xjTY` |
|
||
| 项目目录 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl` |
|
||
| MacBERT 路径 | 需同步(见下文) |
|
||
| 环境 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128`(从服务器1迁移) |
|
||
| GPU | 2 × RTX 5090 32GB |
|
||
| 存储 | NFS 1TB(`siton-data-740d234e02d749f08fe5347b0c74c49f`) |
|
||
|
||
> **服务器2训练命令(已验证可用路径):**
|
||
> ```bash
|
||
> PROJ=/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl
|
||
> PY=/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128/bin
|
||
> cd $PROJ && export PYTHONPATH=$PROJ
|
||
>
|
||
> # 单GPU(推荐,避免NCCL)
|
||
> CUDA_VISIBLE_DEVICES=0 $PY/accelerate launch --num_processes=1 --mixed_precision=bf16 \
|
||
> scripts/train_intervention.py --config configs/intervention_config.yaml \
|
||
> --train-data data/processed/CompanionRisk-Bench/train.jsonl
|
||
>
|
||
> # detector_config_server.yaml 需将 model.name 改为:
|
||
> # /root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/macbert-large
|
||
> ```
|
||
|
||
### 两台服务器说明
|
||
- 同一宿主机 `10.82.3.180`,不同 Docker 容器,不同端口
|
||
- 容器间互通:服务器1可 ssh 到 `172.17.0.1:20060` 访问服务器2
|
||
- Host key 相同:`SHA256:nAMVofPMCFZxa0DyOO2Olepfnp1MzZGdMyW7j5OekQI`
|
||
|
||
---
|
||
|
||
## 六、代码同步状态(2026-05-12 晚更新)
|
||
|
||
### 本地 ↔ 服务器1 同步(已完成)
|
||
| 操作 | 文件 | 说明 |
|
||
|------|------|------|
|
||
| 服务器1→本地 | `src/rl/ppo_trainer.py` | 服务器调试版(MD5不同),已下载 |
|
||
| 本地→服务器1 | `checkpoints/intervention/final_v2.pt` | v2权重命名版,已上传 |
|
||
| 跳过 | `checkpoints/detector/best.pt / final.pt` | 字节完全一致(1,352,746,854 B) |
|
||
| 跳过 | `src/utils/preprocessing.py` | MD5一致 |
|
||
|
||
### 服务器1 → 服务器2 同步(已完成)
|
||
| 内容 | 状态 |
|
||
|------|------|
|
||
| `src/`(18个py文件) | ✅ |
|
||
| `scripts/` | ✅ |
|
||
| `configs/` | ✅ |
|
||
| `data/processed/CompanionRisk-Bench/`(9896条) | ✅ |
|
||
| `experiments/`(eval/train logs+json) | ✅ |
|
||
| `checkpoints/detector/best.pt`(1.35GB) | ✅ |
|
||
| `checkpoints/detector/final.pt`(1.35GB) | ✅ |
|
||
| `checkpoints/intervention/final.pt + final_v2.pt` | ✅ |
|
||
| `requirements.txt` | ✅ |
|
||
| conda env `dlapo-py310-cu128`(7.7GB) | ✅ `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128/`(torch 2.7.1+cu128 ✓,GPU×2 ✓) |
|
||
| MacBERT 权重(1.3GB) | ✅ `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/macbert-large/` |
|
||
|
||
### 关键文件清单(截至 2026-05-12)
|
||
|
||
| 文件 | 状态 | 说明 |
|
||
|------|------|------|
|
||
| `checkpoints/detector/best.pt` | ✅ 服务器1+2 + 本地 | v4 最优检测器权重(1.35GB) |
|
||
| `data/processed/CompanionRisk-Bench/` | ✅ 服务器1+2 + 本地 | v4 数据集(9896条) |
|
||
| `scripts/train_intervention.py` | ✅ 就绪 | Module C 训练脚本 |
|
||
| `configs/intervention_config.yaml` | ✅ 就绪 | Module C 完整配置 |
|
||
| `src/models/intervention_agent.py` | ✅ bug已修 | Actor-Critic(obs_dim=2065→256→actions) |
|
||
| `src/rl/companion_env.py` | ✅ 就绪 | 离线 RL 环境 |
|
||
| `src/rl/ppo_trainer.py` | ✅ 就绪 | PPO 训练器 |
|
||
| `src/rl/reward.py` | ✅ 就绪 | 多目标奖励函数 |
|
||
| `src/utils/preprocessing.py` | ✅ bug已修(v2) | build_obs_vector 改用 det_l_risk |
|
||
| `src/utils/metrics.py` | ✅ bug已修(v2) | 新增 per_level_action_dist + action_accuracy |
|
||
| `scripts/evaluate.py` | ✅ bug已修(v2) | rule policy 改用 det_l_risk,展示新指标 |
|
||
| `experiments/eval_v4_all.log` | ✅ 本地 | v4 完整评估日志 |
|
||
| `experiments/eval_v4_public.log` | ✅ 本地 | v4 public filter 评估日志 |
|
||
| `checkpoints/intervention/final.pt` | ✅ 服务器 + 本地 | Module C PPO 最终权重(5.1MB) |
|
||
| `experiments/eval_intervention_v1.json` | ✅ 本地 | Module C 评估 v1(有 bug,已废弃) |
|
||
| `experiments/eval_intervention_v2.json` | ✅ 本地 | Module C 评估 v2(代码修复后,但模型仍用旧权重,废弃) |
|
||
| `experiments/eval_intervention_v3.json` | ✅ 本地 | Module C 评估 v3(重训+修复,**论文用此**) |
|
||
| `checkpoints/intervention/final_v2.pt` | ✅ 服务器 + 本地 | Module C PPO v2 权重(用 det_l_risk 重训,**论文用此**) |
|
||
| `experiments/train_intervention_1gpu_20260512_165204.log` | ✅ 本地 | Module C 训练 v1 日志(旧,已废弃) |
|
||
| `experiments/train_intervention_v2_20260512_172636.log` | ✅ 本地 | Module C 训练 v2 日志(det_l_risk,**论文用此**) |
|
||
|
||
---
|
||
|
||
## 五(补)、Module C 评估 Bug 修复记录
|
||
|
||
### v1 的两个问题(均已修复)
|
||
|
||
**Bug A — `build_obs_vector` 用了 ground truth `l_risk`**
|
||
- **位置**:`src/utils/preprocessing.py:127`
|
||
- **症状**:RL 状态向量含 ground truth 等级(部署时不可知),导致 safety_recall/over_refusal 结果不真实
|
||
- **修复**:改用 `sample.get("det_l_risk", sample["l_risk"])`(优先检测器预测值)
|
||
- **影响**:不需要重新训练(detector binary_f1=0.9995,两者几乎相同;但概念上正确)
|
||
|
||
**Bug B — 干预指标 `intervention_recall_high`=1.0、`over_refusal`=0.0 三方法无差别**
|
||
- **位置**:`src/utils/metrics.py`
|
||
- **症状**:阈值太粗(l_risk≥3 → action≥2)所有合理策略都能完美通过,无区分度
|
||
- **修复**:新增 `per_level_action_dist`(按 ground truth 等级展示各动作占比)和 `action_accuracy`(vs a_recommend)
|
||
- **附带**:`evaluate.py` 中 `run_rule_intervention` 的策略输入改为 `det_l_risk`,与部署一致
|
||
|
||
---
|
||
|
||
## 六、Module C 评估结果 v2(2026-05-12,论文用)
|
||
|
||
### 干预任务汇总指标
|
||
|
||
| 方法 | safety_recall(L3/L4) | over_refusal | action_accuracy | crisis_precision |
|
||
|------|---------------------|--------------|-----------------|-----------------|
|
||
| Rule-based (l≥3→REJECT) | 0.908 | 0.000 | — | — |
|
||
| Threshold Baseline | 0.908 | 0.000 | — | 0.624 |
|
||
| **Ours (RL, Module C)** | **1.000** | **0.000** | **0.587** | 0.470 |
|
||
|
||
> safety_recall 改为基于 `det_l_risk` 策略输入 vs ground truth level,Rule-based/Threshold 降至 0.908(9.2% L3/L4 样本被检测器预测为 <L3,因此 rule 给了 PASS/WARN)。RL 仍 1.0 说明它学到了超越 l_risk 阈值的综合判断。
|
||
|
||
### Per-level Action Distribution(核心论文表格)
|
||
|
||
```
|
||
方法: Rule-based (l_risk≥3→REJECT)
|
||
Level n PASS WARN RWRT REJT CRISIS
|
||
L0_Safe 237 1.000 0.000 0.000 0.000 0.000
|
||
L1_Mild 280 0.918 0.000 0.000 0.082 0.000
|
||
L2_Moderate 317 0.420 0.000 0.000 0.580 0.000
|
||
L3_High 456 0.114 0.000 0.000 0.886 0.000
|
||
L4_Critical 196 0.041 0.000 0.000 0.959 0.000
|
||
|
||
方法: Threshold Baseline
|
||
Level n PASS WARN RWRT REJT CRISIS
|
||
L0_Safe 237 1.000 0.000 0.000 0.000 0.000
|
||
L1_Mild 280 0.843 0.075 0.082 0.000 0.000
|
||
L2_Moderate 317 0.044 0.375 0.552 0.000 0.028
|
||
L3_High 456 0.009 0.105 0.739 0.000 0.147
|
||
L4_Critical 196 0.000 0.041 0.316 0.000 0.643
|
||
|
||
方法: Ours (RL)
|
||
Level n PASS WARN RWRT REJT CRISIS
|
||
L0_Safe 237 0.983 0.017 0.000 0.000 0.000
|
||
L1_Mild 280 0.754 0.004 0.218 0.000 0.025
|
||
L2_Moderate 317 0.000 0.000 0.915 0.000 0.085
|
||
L3_High 456 0.000 0.000 0.879 0.000 0.121
|
||
L4_Critical 196 0.000 0.000 0.597 0.000 0.403
|
||
```
|
||
|
||
### 成功标准达成情况(v2)
|
||
|
||
| 指标 | 目标 | RL实测 | 状态 |
|
||
|------|------|------|------|
|
||
| safety_recall(L3/L4 正确处理率) | > 0.85 | **1.000** | ✅ |
|
||
| over_refusal_rate(safe 内容误拦截) | < 0.10 | **0.000** | ✅ |
|
||
| action_accuracy(vs a_recommend) | > 0.70 | **0.587** | ⚠️ |
|
||
| crisis_precision(CRISIS→L4 精度) | > 0.80 | **0.470** | ⚠️ |
|
||
|
||
### RL 策略解读(v2,已废弃,见 v3)
|
||
- v2 基于旧权重(用 GT l_risk 训练)+ 新评估代码,存在 train/eval 不一致,仅作对照参考
|
||
|
||
---
|
||
|
||
## 七、Module C 最终结果 v3(重训 + 正确评估,论文用)
|
||
|
||
### 重训原因
|
||
RL agent 训练时 state 向量包含 ground truth `l_risk`(非检测器预测),而检测器 level_macro_f1=0.55(各等级预测有误差),导致训练条件与部署不一致,需要用 `det_l_risk` 重训。
|
||
|
||
### 评估 v1 / v2 / v3 演进
|
||
|
||
| 版本 | 代码 | 模型 | 问题 |
|
||
|------|------|------|------|
|
||
| v1 | 旧(GT l_risk state, 无 per-level) | 旧(GT l_risk 训练) | 两个 bug,指标虚高 |
|
||
| v2 | 新(det_l_risk state, 有 per-level) | 旧(GT l_risk 训练) | train/eval 不一致 |
|
||
| **v3** | 新 | 新(det_l_risk 训练) | **论文使用** |
|
||
|
||
### 汇总指标(v3,最终)
|
||
|
||
| 方法 | safety_recall(L3/L4) | over_refusal | action_accuracy | crisis_precision | safety_ux_fscore |
|
||
|------|---------------------|--------------|-----------------|-----------------|-----------------|
|
||
| Rule-based (l≥3→REJECT) | 0.908 | 0.000 | — | — | 0.952 |
|
||
| Threshold Baseline | 0.908 | 0.000 | — | 0.624 | 0.952 |
|
||
| **Ours (RL v2)** | **1.000** | **0.004** | **0.575** | 0.421 | **0.998** |
|
||
|
||
### Per-level Action Distribution(v3,论文核心表格)
|
||
|
||
```
|
||
方法: Rule-based (l_risk≥3→REJECT) 方法: Threshold Baseline
|
||
Level n PASS WARN RWRT REJT CRISIS Level n PASS WARN RWRT REJT CRISIS
|
||
L0_Safe 237 1.000 0.000 0.000 0.000 0.000 L0_Safe 237 1.000 0.000 0.000 0.000 0.000
|
||
L1_Mild 280 0.918 0.000 0.000 0.082 0.000 L1_Mild 280 0.843 0.075 0.082 0.000 0.000
|
||
L2_Moderate 317 0.420 0.000 0.000 0.580 0.000 L2_Moderate 317 0.044 0.375 0.552 0.000 0.028
|
||
L3_High 456 0.114 0.000 0.000 0.886 0.000 L3_High 456 0.009 0.105 0.739 0.000 0.147
|
||
L4_Critical 196 0.041 0.000 0.000 0.959 0.000 L4_Critical 196 0.000 0.041 0.316 0.000 0.643
|
||
|
||
方法: Ours (RL v2, 重训)
|
||
Level n PASS WARN RWRT REJT CRISIS
|
||
L0_Safe 237 0.987 0.008 0.004 0.000 0.000 ← over_refusal 0.4%(REWRITE)
|
||
L1_Mild 280 0.729 0.011 0.229 0.000 0.032 ← 部分轻度误触发(limitation)
|
||
L2_Moderate 317 0.000 0.000 0.902 0.000 0.098 ← REWRITE 主导 ✓
|
||
L3_High 456 0.000 0.000 0.871 0.000 0.129 ← REWRITE 主导 ✓
|
||
L4_Critical 196 0.000 0.000 0.633 0.000 0.367 ← CRISIS 偏低(limitation)
|
||
```
|
||
|
||
### 成功标准达成情况(v3 最终)
|
||
|
||
| 指标 | 目标 | RL v2 实测 | 状态 |
|
||
|------|------|-----------|------|
|
||
| safety_recall(L3/L4 正确处理率) | > 0.85 | **1.000** | ✅ |
|
||
| over_refusal_rate(safe 内容误拦截) | < 0.10 | **0.004** | ✅ |
|
||
| action_accuracy(vs a_recommend) | > 0.70 | **0.575** | ⚠️ |
|
||
| crisis_precision(CRISIS→L4 精度) | > 0.80 | **0.421** | ⚠️ |
|
||
|
||
### 论文论点
|
||
- **优势**:safety_recall=1.0(baseline 仅 0.908),RL 在检测器等级误差下仍能正确干预,说明学到了多信号综合判断
|
||
- **Limitation 1**:action_accuracy=0.575;L1 层级误触发(22.9% REWRITE),轻度风险处理过激
|
||
- **Limitation 2**:crisis_precision=0.421;L4 CRISIS 触发率仅 36.7%(Threshold 64.3%),R1 训练样本稀少(136条)+ w3=4.0 不足
|