chore: update CLAUDE.md paths + gitignore 旧方向信息/
- CLAUDE.md: rewrite as project reference (training done); fix all local paths (remove CompanionGuard-RL nesting in code/) - .gitignore: add 旧方向信息/ and untrack it from index Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
443
code/CLAUDE.md
443
code/CLAUDE.md
@@ -1,385 +1,154 @@
|
||||
# CompanionGuard-RL — 远程 4-GPU 训练任务指南
|
||||
# CompanionGuard-RL — 项目参考文档
|
||||
|
||||
> 本文件由 Claude Code 自动读取。请**严格按阶段顺序执行**,每阶段完成后打印一行 `=== Phase N done ===` 再继续。
|
||||
> 本文件由 Claude Code 自动读取。训练已全部完成,当前阶段:**论文写作**。
|
||||
|
||||
---
|
||||
|
||||
## 任务目标
|
||||
## 项目状态(2026-05-12)
|
||||
|
||||
在远程 GPU 服务器上完成 **Module B — Context-aware Risk Detector** 的 4-GPU 分布式训练,产出 `checkpoints/detector/best.pt`。
|
||||
| 模块 | 状态 | 关键指标 |
|
||||
|------|------|---------|
|
||||
| 数据集 CompanionRisk-Bench v4 | ✅ 完成 | 9,896 样本,全 14 标签覆盖 |
|
||||
| Module B — 检测器(MacBERT-large) | ✅ 完成 | binary_f1=0.9995, level_weighted_f1=0.559 |
|
||||
| Module C — RL 干预策略(PPO) | ✅ 完成 | safety_recall=1.0, over_refusal=0.004 |
|
||||
| 论文写作 | 🔄 进行中 | — |
|
||||
|
||||
详细结果见项目根目录 `../state.md`,踩坑经验见 `exp.md`,变更记录见 `change.md`。
|
||||
|
||||
---
|
||||
|
||||
## 本地目录结构
|
||||
|
||||
```
|
||||
D:\Myresearch\CompanionGuard-RL\
|
||||
├── code/ ← 本目录(源代码)
|
||||
│ ├── src/ ← 18 个核心 .py(models/ rl/ utils/)
|
||||
│ ├── scripts/ ← 训练/评估/数据生成脚本
|
||||
│ ├── configs/ ← 4 个 yaml 配置
|
||||
│ ├── checkpoints/ ← 模型权重(gitignored)
|
||||
│ │ ├── detector/best.pt ← Module B 论文权重(1.35GB)
|
||||
│ │ └── intervention/final_v2.pt ← Module C 论文权重
|
||||
│ ├── experiments/ ← 评估结果 JSON
|
||||
│ │ ├── eval_intervention_v3.json ← Module C 论文用
|
||||
│ │ └── eval_intervention_v4.json ← v3 重跑确认(数字相同)
|
||||
│ └── data/ ← 处理后数据(gitignored)
|
||||
├── data/ ← 原始数据集(gitignored)
|
||||
├── docs/ ← 研究文档
|
||||
├── state.md ← 项目进度快照(最新)
|
||||
└── experiments/ ← 根目录评估结果备份
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器信息
|
||||
|
||||
### 服务器 1(主训练机)
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|-----|
|
||||
| SSH 命令 | `ssh -p 20083 root@10.82.3.180` |
|
||||
|------|----|
|
||||
| SSH | `ssh -p 20083 root@10.82.3.180` |
|
||||
| 密码 | `m2dGcwyrhI` |
|
||||
| GPU | 4 × RTX 5090 32 GB |
|
||||
| 远程工作根目录 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/` |
|
||||
| 远程项目目录(以下简称 `$PROJ`) | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL` |
|
||||
| 项目目录 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL` |
|
||||
| MacBERT | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large` |
|
||||
| 环境 | `/opt/conda/envs/dlapo-py310-cu128`(torch 2.7.1+cu128) |
|
||||
| GPU | 4 × RTX 5090 32GB |
|
||||
|
||||
> **重要约束**:服务器 Docker 网络受限,**部分包无法直接 pip install / wget**。
|
||||
> 优先尝试国内镜像;若失败,改用本地下载 → scp 传输的离线方式。
|
||||
### 服务器 2(当前使用)
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|----|
|
||||
| SSH | `ssh -p 20060 root@10.82.3.180` |
|
||||
| 密码 | `zwfn65xjTY` |
|
||||
| 项目目录 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl` |
|
||||
| MacBERT | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/macbert-large` |
|
||||
| 环境 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128` |
|
||||
| GPU | 2 × RTX 5090 32GB |
|
||||
|
||||
> 两台服务器在同一宿主机 `10.82.3.180`,不同 Docker 容器。
|
||||
|
||||
---
|
||||
|
||||
## Phase 0 — 连接与环境探查
|
||||
|
||||
```bash
|
||||
# 探查可用资源(ssh 进入后逐条运行)
|
||||
nvidia-smi # 确认 4 块 GPU 都可见
|
||||
python3 --version || python --version
|
||||
which conda && conda --version || echo "no conda"
|
||||
pip3 --version || pip --version
|
||||
python3 -c "import torch; print(torch.__version__, torch.cuda.device_count())"
|
||||
python3 -c "import transformers; print(transformers.__version__)"
|
||||
python3 -c "import accelerate; print(accelerate.__version__)"
|
||||
```
|
||||
|
||||
记录以下信息用于后续决策:
|
||||
- `python` 命令是 `python3` 还是 `python`
|
||||
- torch 是否已安装,版本是否 ≥ 2.0
|
||||
- transformers / accelerate / peft 是否已安装
|
||||
- 是否有 conda
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — 项目文件传输
|
||||
|
||||
**在本地(Windows PowerShell / cmd)执行 scp,将代码与数据传到服务器。**
|
||||
## SCP 同步命令(本地 ↔ 服务器)
|
||||
|
||||
```powershell
|
||||
# 1-A 创建远程目录
|
||||
ssh -p 20083 root@10.82.3.180 "mkdir -p /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
# ===== 本地 → 服务器1(上传代码)=====
|
||||
$S1="root@10.82.3.180"
|
||||
$PROJ1="/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
|
||||
# 1-B 传输源码目录(排除缓存与已有checkpoint)
|
||||
scp -P 20083 -r `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\src `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\scripts `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\configs `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\requirements.txt `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/
|
||||
D:\Myresearch\CompanionGuard-RL\code\src `
|
||||
D:\Myresearch\CompanionGuard-RL\code\scripts `
|
||||
D:\Myresearch\CompanionGuard-RL\code\configs `
|
||||
D:\Myresearch\CompanionGuard-RL\code\requirements.txt `
|
||||
${S1}:${PROJ1}/
|
||||
|
||||
# 1-C 传输数据集(约 30-50 MB)
|
||||
# 上传已处理数据
|
||||
scp -P 20083 -r `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\data `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/
|
||||
```
|
||||
D:\Myresearch\CompanionGuard-RL\code\data `
|
||||
${S1}:${PROJ1}/
|
||||
|
||||
**验证**(在服务器上):
|
||||
```bash
|
||||
cd $PROJ
|
||||
ls src/ scripts/ configs/ data/processed/CompanionRisk-Bench/
|
||||
wc -l data/processed/CompanionRisk-Bench/train.jsonl # 应为 2815
|
||||
wc -l data/processed/CompanionRisk-Bench/test.jsonl # 应为 605
|
||||
# ===== 服务器1 → 本地(取回结果)=====
|
||||
scp -P 20083 -r `
|
||||
${S1}:${PROJ1}/checkpoints `
|
||||
D:\Myresearch\CompanionGuard-RL\code\
|
||||
|
||||
scp -P 20083 -r `
|
||||
${S1}:${PROJ1}/experiments `
|
||||
D:\Myresearch\CompanionGuard-RL\code\
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Python 依赖安装
|
||||
|
||||
### 2-A 先尝试国内镜像直接安装
|
||||
## 核心脚本用法
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
torch transformers accelerate peft datasets tokenizers \
|
||||
scikit-learn tqdm pyyaml omegaconf jsonlines rich \
|
||||
openai anthropic wandb
|
||||
```
|
||||
|
||||
若上述命令报网络错误,转 **2-B(离线方式)**。
|
||||
|
||||
### 2-B 离线方式(若 2-A 失败)
|
||||
|
||||
**在本地 Windows 执行**(需要本地能访问 PyPI):
|
||||
|
||||
```powershell
|
||||
# 下载所有 wheel 到本地文件夹
|
||||
pip download -d D:\Myresearch\wheels --platform linux_x86_64 `
|
||||
--python-version 310 --only-binary=:all: `
|
||||
torch transformers accelerate peft scikit-learn tqdm `
|
||||
pyyaml omegaconf jsonlines rich
|
||||
|
||||
# 传输 wheels 到服务器
|
||||
scp -P 20083 -r D:\Myresearch\wheels `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/
|
||||
```
|
||||
|
||||
**在服务器上安装**:
|
||||
```bash
|
||||
pip3 install --no-index --find-links=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/wheels \
|
||||
torch transformers accelerate peft scikit-learn tqdm pyyaml omegaconf jsonlines rich
|
||||
```
|
||||
|
||||
### 2-C 验证
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
import torch, transformers, accelerate, peft, sklearn
|
||||
print('torch:', torch.__version__, '| cuda gpus:', torch.cuda.device_count())
|
||||
print('transformers:', transformers.__version__)
|
||||
print('accelerate:', accelerate.__version__)
|
||||
print('peft:', peft.__version__)
|
||||
"
|
||||
```
|
||||
|
||||
期望:`cuda gpus: 4`。
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — MacBERT 模型获取
|
||||
|
||||
模型名称:`hfl/chinese-macbert-large`(约 500 MB)。
|
||||
|
||||
### 3-A 优先:使用 HuggingFace 国内镜像
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
HF_ENDPOINT=https://hf-mirror.com python3 -c "
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
AutoTokenizer.from_pretrained('hfl/chinese-macbert-large')
|
||||
AutoModel.from_pretrained('hfl/chinese-macbert-large')
|
||||
print('MacBERT download OK')
|
||||
"
|
||||
```
|
||||
|
||||
若成功,跳过 3-B / 3-C。
|
||||
|
||||
### 3-B 备选:ModelScope 下载
|
||||
|
||||
```bash
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple modelscope
|
||||
python3 -c "
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('hfl/chinese-macbert-large', cache_dir='$PROJ/model_cache')
|
||||
"
|
||||
```
|
||||
|
||||
若成功,修改 `configs/detector_config.yaml`:
|
||||
```
|
||||
model:
|
||||
name: "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/model_cache/hfl/chinese-macbert-large"
|
||||
```
|
||||
|
||||
### 3-C 最终备选:本地下载 → scp
|
||||
|
||||
**在本地 Windows 执行**:
|
||||
```powershell
|
||||
# 需要本地能访问 HuggingFace 或 hf-mirror
|
||||
pip install huggingface_hub
|
||||
python -c "
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download('hfl/chinese-macbert-large', local_dir='D:/Myresearch/macbert-large')
|
||||
"
|
||||
|
||||
# 传输到服务器
|
||||
scp -P 20083 -r D:\Myresearch\macbert-large `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large
|
||||
```
|
||||
|
||||
**在服务器上更新配置**(见下方 Phase 4)。
|
||||
|
||||
---
|
||||
|
||||
## Phase 4 — 配置确认(4-GPU Linux 专用)
|
||||
|
||||
服务器专用配置已预生成:`configs/detector_config_server.yaml`
|
||||
(`num_workers: 4`,`effective batch = 16 × 4 GPUs × 2 accum = 128`,`bf16`)。
|
||||
|
||||
**仅当 Phase 3-C(本地 scp 传输模型)时**,需要更新 model.name:
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
|
||||
# 仅在 Phase 3-C 时执行:将 model.name 改为本地路径
|
||||
sed -i 's|name: "hfl/chinese-macbert-large"|name: "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large"|' configs/detector_config_server.yaml
|
||||
|
||||
# 确认关键参数
|
||||
grep -E "num_workers|per_gpu_batch|gradient_accum|mixed_precision|name:" configs/detector_config_server.yaml
|
||||
```
|
||||
|
||||
Phase 3-A / 3-B 成功时无需修改,直接进入 Phase 5。
|
||||
|
||||
---
|
||||
|
||||
## Phase 5 — 启动 4-GPU 训练
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
mkdir -p experiments checkpoints/detector
|
||||
|
||||
# 推荐:accelerate launch(使用服务器专用配置)
|
||||
accelerate launch \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--multi_gpu \
|
||||
scripts/train_detector.py \
|
||||
--config configs/detector_config_server.yaml \
|
||||
2>&1 | tee experiments/train_$(date +%Y%m%d_%H%M%S).log &
|
||||
|
||||
echo "Training PID: $!"
|
||||
```
|
||||
|
||||
若 `accelerate launch` 不可用,改用 torchrun:
|
||||
```bash
|
||||
torchrun --nproc_per_node=4 \
|
||||
scripts/train_detector.py \
|
||||
--config configs/detector_config_server.yaml \
|
||||
2>&1 | tee experiments/train_$(date +%Y%m%d_%H%M%S).log &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 6 — 监控与验证
|
||||
|
||||
训练启动后持续执行以下检查:
|
||||
|
||||
```bash
|
||||
# 6-A 查看实时日志(关键:前100步 loss 应在 1.0~3.0 之间下降)
|
||||
tail -f experiments/train_*.log
|
||||
|
||||
# 6-B GPU 利用率(4 块 GPU 利用率均应 >80%)
|
||||
watch -n 5 nvidia-smi
|
||||
|
||||
# 6-C 检查第一次验证输出(~100 global steps 后出现)
|
||||
# 期望 Val binary F1 > 0.40(超过 L1c 基线 0.410 是最低目标,目标 >0.80)
|
||||
|
||||
# 6-D 检查 checkpoint 保存
|
||||
ls -lh checkpoints/detector/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 7 — 模型评估(验证 F1=0.9978 是否真实)
|
||||
|
||||
> **背景**:训练报告 Val Binary F1=0.9978,但该分数基于验证集(dev.jsonl),
|
||||
> 且验证集与训练集同为 LLM 生成,存在"同源过拟合"风险。
|
||||
> 本 Phase 用三组实验定位真实泛化能力。
|
||||
|
||||
### 7-A 全量 test 集评估
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
|
||||
# 重新评估检测器(Module B)
|
||||
python scripts/evaluate.py \
|
||||
--detector-ckpt checkpoints/detector/best.pt \
|
||||
--config configs/detector_config_server.yaml \
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \
|
||||
--source-filter all \
|
||||
--output experiments/eval_all.json
|
||||
```
|
||||
|
||||
重点观察:
|
||||
- `binary_f1` 是否仍接近 0.9978(若是,说明 test 集也被"污染")
|
||||
- `level_macro_f1`(l_risk 0-4 等级 F1)——这比 binary 难得多,若也完美则有问题
|
||||
- `fine_macro_f1`(14 类细粒度标签 F1)——最难任务,正常应在 0.5-0.8
|
||||
|
||||
### 7-B 仅人工标注子集(关键实验)
|
||||
|
||||
```bash
|
||||
# 重新评估干预策略(Module C)
|
||||
python scripts/evaluate.py \
|
||||
--detector-ckpt checkpoints/detector/best.pt \
|
||||
--config configs/detector_config_server.yaml \
|
||||
--agent-ckpt checkpoints/intervention/final_v2.pt \
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \
|
||||
--source-filter human \
|
||||
--output experiments/eval_human_only.json
|
||||
```
|
||||
|
||||
> 仅评估来自 DICES / CoSafe / Human-AI Suicide Risk 三个人工标注数据集的样本。
|
||||
> 这些样本来源不同于 LLM 生成,能真实反映泛化性。
|
||||
> **若此处 binary_f1 明显下降(<0.80),说明模型依赖 LLM 文体特征而非风险语义。**
|
||||
|
||||
### 7-C 查看 source 字段分布(调试用)
|
||||
|
||||
```bash
|
||||
# 确认 test.jsonl 中 source 字段的实际取值
|
||||
python3 -c "
|
||||
import json
|
||||
from collections import Counter
|
||||
samples = [json.loads(l) for l in open('data/processed/CompanionRisk-Bench/test.jsonl') if l.strip()]
|
||||
src_counter = Counter(s.get('source', s.get('id','?')[:10]) for s in samples)
|
||||
for k, v in sorted(src_counter.items(), key=lambda x: -x[1]):
|
||||
print(f' {k}: {v}')
|
||||
print(f'Total: {len(samples)}')
|
||||
"
|
||||
```
|
||||
|
||||
> 若输出发现所有样本都没有 source 字段,则 source-filter 用 id 前缀判断(evaluate.py 已处理)。
|
||||
> 把输出贴回来,若所有样本都是 LLM 生成(无人工标注),说明 test 集设计有问题。
|
||||
|
||||
### 7-D 结果判读标准
|
||||
|
||||
| 实验 | binary_f1 | 解释 |
|
||||
|------|-----------|------|
|
||||
| 7-A 全量 test | ~0.99 | test/dev 同源,无参考价值 |
|
||||
| 7-A 全量 test | ~0.80-0.90 | 合理,模型有真实泛化能力 |
|
||||
| 7-B 人工标注 | ~0.99 | **可信**,真实泛化优秀 |
|
||||
| 7-B 人工标注 | 0.60-0.75 | **同源过拟合确认**,需处理 |
|
||||
| 7-B 人工标注 | <0.60 | 严重过拟合,训练方案需调整 |
|
||||
|
||||
## Phase 9 — 取回结果
|
||||
|
||||
训练和评估完成后,将 checkpoint、日志和评估 JSON 传回本地:
|
||||
|
||||
```powershell
|
||||
# 在本地 Windows PowerShell 执行
|
||||
scp -P 20083 -r `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/checkpoints `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\
|
||||
|
||||
scp -P 20083 -r `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/experiments `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\
|
||||
|
||||
# 同时取回更新后的 evaluate.py(已修复 bug,含 source-filter 功能)
|
||||
scp -P 20083 `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/scripts/evaluate.py `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\scripts\
|
||||
--config configs/detector_config_server.yaml \
|
||||
--intervention-config configs/intervention_config.yaml \
|
||||
--output experiments/eval_intervention_v3.json
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 关键指标参考(训练目标)
|
||||
## 关键结果(论文用)
|
||||
|
||||
| 指标 | L1c 规则基线(下界) | MacBERT 目标 |
|
||||
|------|---------------------|--------------|
|
||||
| Binary F1 | 0.410 | **> 0.80** |
|
||||
| R1 recall(危机类) | 0.097 | **> 0.75** |
|
||||
| R9 recall | 0.091 | **> 0.70** |
|
||||
| FNR(漏检率) | 0.740 | **< 0.20** |
|
||||
### Module B — 检测器 v4
|
||||
|
||||
| 指标 | 值 |
|
||||
|------|----|
|
||||
| binary_f1 | **0.9995** |
|
||||
| high_risk_recall | **1.0000** |
|
||||
| FNR | **0.00%** |
|
||||
| level_weighted_f1 | **0.559** |
|
||||
| fine_macro_f1(public 10类) | **0.484** |
|
||||
|
||||
### Module C — RL 干预策略 v3(论文用,`eval_intervention_v3.json`)
|
||||
|
||||
| 方法 | safety_recall | over_refusal | action_accuracy | safety_ux_fscore |
|
||||
|------|--------------|--------------|-----------------|-----------------|
|
||||
| Rule-based | 0.908 | 0.000 | — | 0.952 |
|
||||
| Threshold | 0.908 | 0.000 | — | 0.952 |
|
||||
| **Ours (RL)** | **1.000** | **0.004** | **0.575** | **0.998** |
|
||||
|
||||
**使用权重**:`checkpoints/intervention/final_v2.pt`(用 `det_l_risk` 重训)
|
||||
|
||||
---
|
||||
|
||||
## 常见问题处理
|
||||
## 重要注意事项
|
||||
|
||||
### NCCL 通信报错
|
||||
```bash
|
||||
export NCCL_P2P_DISABLE=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
# 再重新启动 accelerate launch
|
||||
```
|
||||
|
||||
### OOM(显存不足,不太可能:5090 32GB)
|
||||
在 `configs/detector_config.yaml` 中将 `per_gpu_batch_size: 16` 改为 `8`,`gradient_accumulation_steps: 4`。
|
||||
|
||||
### MacBERT 路径找不到
|
||||
检查 `~/.cache/huggingface/hub/` 或 `model_cache/` 目录,找到实际下载路径后更新 config 的 `model.name`。
|
||||
|
||||
### accelerate 找不到
|
||||
```bash
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple accelerate
|
||||
# 或用 torchrun 替代(见 Phase 5)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 文件清单(训练产出)
|
||||
|
||||
| 文件 | 描述 |
|
||||
|------|------|
|
||||
| `checkpoints/detector/best.pt` | 验证集 F1 最高的模型权重 |
|
||||
| `checkpoints/detector/final.pt` | 最后一个 epoch 的权重 |
|
||||
| `experiments/train_YYYYMMDD_HHMMSS.log` | 完整训练日志 |
|
||||
- **PyYAML 6.x 陷阱**:lr 值必须写 `0.001` 而非 `1e-3`(后者被解析为字符串)
|
||||
- **RTX 5090 NCCL**:多卡训练需 `NCCL_SHM_DISABLE=1 NCCL_P2P_DISABLE=1`;PPO 阶段用单卡绕开 barrier 问题
|
||||
- **det_l_risk vs l_risk**:评估和训练均须用检测器预测的 `det_l_risk`,不能用 ground truth `l_risk`
|
||||
- **obs_dim = 2065**:state 向量结构 `[d_score(1)|l_risk_onehot(5)|c_primary_probs(10)|e_H_pool(1024)|e_P_pool(1024)|t_norm(1)]`
|
||||
|
||||
Reference in New Issue
Block a user