chore: update CLAUDE.md paths + gitignore 旧方向信息/
- CLAUDE.md: rewrite as project reference (training done); fix all local paths (remove CompanionGuard-RL nesting in code/) - .gitignore: add 旧方向信息/ and untrack it from index Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -28,6 +28,9 @@ sync_v*.zip
|
||||
# === 大型实验日志 ===
|
||||
code/experiments/*.log
|
||||
|
||||
# === 旧方向归档 ===
|
||||
旧方向信息/
|
||||
|
||||
# === 工具配置(用户本地)===
|
||||
.claude/
|
||||
.playwright-mcp/
|
||||
|
||||
443
code/CLAUDE.md
443
code/CLAUDE.md
@@ -1,385 +1,154 @@
|
||||
# CompanionGuard-RL — 远程 4-GPU 训练任务指南
|
||||
# CompanionGuard-RL — 项目参考文档
|
||||
|
||||
> 本文件由 Claude Code 自动读取。请**严格按阶段顺序执行**,每阶段完成后打印一行 `=== Phase N done ===` 再继续。
|
||||
> 本文件由 Claude Code 自动读取。训练已全部完成,当前阶段:**论文写作**。
|
||||
|
||||
---
|
||||
|
||||
## 任务目标
|
||||
## 项目状态(2026-05-12)
|
||||
|
||||
在远程 GPU 服务器上完成 **Module B — Context-aware Risk Detector** 的 4-GPU 分布式训练,产出 `checkpoints/detector/best.pt`。
|
||||
| 模块 | 状态 | 关键指标 |
|
||||
|------|------|---------|
|
||||
| 数据集 CompanionRisk-Bench v4 | ✅ 完成 | 9,896 样本,全 14 标签覆盖 |
|
||||
| Module B — 检测器(MacBERT-large) | ✅ 完成 | binary_f1=0.9995, level_weighted_f1=0.559 |
|
||||
| Module C — RL 干预策略(PPO) | ✅ 完成 | safety_recall=1.0, over_refusal=0.004 |
|
||||
| 论文写作 | 🔄 进行中 | — |
|
||||
|
||||
详细结果见项目根目录 `../state.md`,踩坑经验见 `exp.md`,变更记录见 `change.md`。
|
||||
|
||||
---
|
||||
|
||||
## 本地目录结构
|
||||
|
||||
```
|
||||
D:\Myresearch\CompanionGuard-RL\
|
||||
├── code/ ← 本目录(源代码)
|
||||
│ ├── src/ ← 18 个核心 .py(models/ rl/ utils/)
|
||||
│ ├── scripts/ ← 训练/评估/数据生成脚本
|
||||
│ ├── configs/ ← 4 个 yaml 配置
|
||||
│ ├── checkpoints/ ← 模型权重(gitignored)
|
||||
│ │ ├── detector/best.pt ← Module B 论文权重(1.35GB)
|
||||
│ │ └── intervention/final_v2.pt ← Module C 论文权重
|
||||
│ ├── experiments/ ← 评估结果 JSON
|
||||
│ │ ├── eval_intervention_v3.json ← Module C 论文用
|
||||
│ │ └── eval_intervention_v4.json ← v3 重跑确认(数字相同)
|
||||
│ └── data/ ← 处理后数据(gitignored)
|
||||
├── data/ ← 原始数据集(gitignored)
|
||||
├── docs/ ← 研究文档
|
||||
├── state.md ← 项目进度快照(最新)
|
||||
└── experiments/ ← 根目录评估结果备份
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 服务器信息
|
||||
|
||||
### 服务器 1(主训练机)
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|-----|
|
||||
| SSH 命令 | `ssh -p 20083 root@10.82.3.180` |
|
||||
|------|----|
|
||||
| SSH | `ssh -p 20083 root@10.82.3.180` |
|
||||
| 密码 | `m2dGcwyrhI` |
|
||||
| GPU | 4 × RTX 5090 32 GB |
|
||||
| 远程工作根目录 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/` |
|
||||
| 远程项目目录(以下简称 `$PROJ`) | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL` |
|
||||
| 项目目录 | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL` |
|
||||
| MacBERT | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large` |
|
||||
| 环境 | `/opt/conda/envs/dlapo-py310-cu128`(torch 2.7.1+cu128) |
|
||||
| GPU | 4 × RTX 5090 32GB |
|
||||
|
||||
> **重要约束**:服务器 Docker 网络受限,**部分包无法直接 pip install / wget**。
|
||||
> 优先尝试国内镜像;若失败,改用本地下载 → scp 传输的离线方式。
|
||||
### 服务器 2(当前使用)
|
||||
|
||||
| 项目 | 值 |
|
||||
|------|----|
|
||||
| SSH | `ssh -p 20060 root@10.82.3.180` |
|
||||
| 密码 | `zwfn65xjTY` |
|
||||
| 项目目录 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl` |
|
||||
| MacBERT | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/macbert-large` |
|
||||
| 环境 | `/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128` |
|
||||
| GPU | 2 × RTX 5090 32GB |
|
||||
|
||||
> 两台服务器在同一宿主机 `10.82.3.180`,不同 Docker 容器。
|
||||
|
||||
---
|
||||
|
||||
## Phase 0 — 连接与环境探查
|
||||
|
||||
```bash
|
||||
# 探查可用资源(ssh 进入后逐条运行)
|
||||
nvidia-smi # 确认 4 块 GPU 都可见
|
||||
python3 --version || python --version
|
||||
which conda && conda --version || echo "no conda"
|
||||
pip3 --version || pip --version
|
||||
python3 -c "import torch; print(torch.__version__, torch.cuda.device_count())"
|
||||
python3 -c "import transformers; print(transformers.__version__)"
|
||||
python3 -c "import accelerate; print(accelerate.__version__)"
|
||||
```
|
||||
|
||||
记录以下信息用于后续决策:
|
||||
- `python` 命令是 `python3` 还是 `python`
|
||||
- torch 是否已安装,版本是否 ≥ 2.0
|
||||
- transformers / accelerate / peft 是否已安装
|
||||
- 是否有 conda
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — 项目文件传输
|
||||
|
||||
**在本地(Windows PowerShell / cmd)执行 scp,将代码与数据传到服务器。**
|
||||
## SCP 同步命令(本地 ↔ 服务器)
|
||||
|
||||
```powershell
|
||||
# 1-A 创建远程目录
|
||||
ssh -p 20083 root@10.82.3.180 "mkdir -p /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
# ===== 本地 → 服务器1(上传代码)=====
|
||||
$S1="root@10.82.3.180"
|
||||
$PROJ1="/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
|
||||
# 1-B 传输源码目录(排除缓存与已有checkpoint)
|
||||
scp -P 20083 -r `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\src `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\scripts `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\configs `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\requirements.txt `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/
|
||||
D:\Myresearch\CompanionGuard-RL\code\src `
|
||||
D:\Myresearch\CompanionGuard-RL\code\scripts `
|
||||
D:\Myresearch\CompanionGuard-RL\code\configs `
|
||||
D:\Myresearch\CompanionGuard-RL\code\requirements.txt `
|
||||
${S1}:${PROJ1}/
|
||||
|
||||
# 1-C 传输数据集(约 30-50 MB)
|
||||
# 上传已处理数据
|
||||
scp -P 20083 -r `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\data `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/
|
||||
```
|
||||
D:\Myresearch\CompanionGuard-RL\code\data `
|
||||
${S1}:${PROJ1}/
|
||||
|
||||
**验证**(在服务器上):
|
||||
```bash
|
||||
cd $PROJ
|
||||
ls src/ scripts/ configs/ data/processed/CompanionRisk-Bench/
|
||||
wc -l data/processed/CompanionRisk-Bench/train.jsonl # 应为 2815
|
||||
wc -l data/processed/CompanionRisk-Bench/test.jsonl # 应为 605
|
||||
# ===== 服务器1 → 本地(取回结果)=====
|
||||
scp -P 20083 -r `
|
||||
${S1}:${PROJ1}/checkpoints `
|
||||
D:\Myresearch\CompanionGuard-RL\code\
|
||||
|
||||
scp -P 20083 -r `
|
||||
${S1}:${PROJ1}/experiments `
|
||||
D:\Myresearch\CompanionGuard-RL\code\
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Python 依赖安装
|
||||
|
||||
### 2-A 先尝试国内镜像直接安装
|
||||
## 核心脚本用法
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
torch transformers accelerate peft datasets tokenizers \
|
||||
scikit-learn tqdm pyyaml omegaconf jsonlines rich \
|
||||
openai anthropic wandb
|
||||
```
|
||||
|
||||
若上述命令报网络错误,转 **2-B(离线方式)**。
|
||||
|
||||
### 2-B 离线方式(若 2-A 失败)
|
||||
|
||||
**在本地 Windows 执行**(需要本地能访问 PyPI):
|
||||
|
||||
```powershell
|
||||
# 下载所有 wheel 到本地文件夹
|
||||
pip download -d D:\Myresearch\wheels --platform linux_x86_64 `
|
||||
--python-version 310 --only-binary=:all: `
|
||||
torch transformers accelerate peft scikit-learn tqdm `
|
||||
pyyaml omegaconf jsonlines rich
|
||||
|
||||
# 传输 wheels 到服务器
|
||||
scp -P 20083 -r D:\Myresearch\wheels `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/
|
||||
```
|
||||
|
||||
**在服务器上安装**:
|
||||
```bash
|
||||
pip3 install --no-index --find-links=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/wheels \
|
||||
torch transformers accelerate peft scikit-learn tqdm pyyaml omegaconf jsonlines rich
|
||||
```
|
||||
|
||||
### 2-C 验证
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
import torch, transformers, accelerate, peft, sklearn
|
||||
print('torch:', torch.__version__, '| cuda gpus:', torch.cuda.device_count())
|
||||
print('transformers:', transformers.__version__)
|
||||
print('accelerate:', accelerate.__version__)
|
||||
print('peft:', peft.__version__)
|
||||
"
|
||||
```
|
||||
|
||||
期望:`cuda gpus: 4`。
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — MacBERT 模型获取
|
||||
|
||||
模型名称:`hfl/chinese-macbert-large`(约 500 MB)。
|
||||
|
||||
### 3-A 优先:使用 HuggingFace 国内镜像
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
HF_ENDPOINT=https://hf-mirror.com python3 -c "
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
AutoTokenizer.from_pretrained('hfl/chinese-macbert-large')
|
||||
AutoModel.from_pretrained('hfl/chinese-macbert-large')
|
||||
print('MacBERT download OK')
|
||||
"
|
||||
```
|
||||
|
||||
若成功,跳过 3-B / 3-C。
|
||||
|
||||
### 3-B 备选:ModelScope 下载
|
||||
|
||||
```bash
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple modelscope
|
||||
python3 -c "
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('hfl/chinese-macbert-large', cache_dir='$PROJ/model_cache')
|
||||
"
|
||||
```
|
||||
|
||||
若成功,修改 `configs/detector_config.yaml`:
|
||||
```
|
||||
model:
|
||||
name: "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/model_cache/hfl/chinese-macbert-large"
|
||||
```
|
||||
|
||||
### 3-C 最终备选:本地下载 → scp
|
||||
|
||||
**在本地 Windows 执行**:
|
||||
```powershell
|
||||
# 需要本地能访问 HuggingFace 或 hf-mirror
|
||||
pip install huggingface_hub
|
||||
python -c "
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download('hfl/chinese-macbert-large', local_dir='D:/Myresearch/macbert-large')
|
||||
"
|
||||
|
||||
# 传输到服务器
|
||||
scp -P 20083 -r D:\Myresearch\macbert-large `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large
|
||||
```
|
||||
|
||||
**在服务器上更新配置**(见下方 Phase 4)。
|
||||
|
||||
---
|
||||
|
||||
## Phase 4 — 配置确认(4-GPU Linux 专用)
|
||||
|
||||
服务器专用配置已预生成:`configs/detector_config_server.yaml`
|
||||
(`num_workers: 4`,`effective batch = 16 × 4 GPUs × 2 accum = 128`,`bf16`)。
|
||||
|
||||
**仅当 Phase 3-C(本地 scp 传输模型)时**,需要更新 model.name:
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
|
||||
# 仅在 Phase 3-C 时执行:将 model.name 改为本地路径
|
||||
sed -i 's|name: "hfl/chinese-macbert-large"|name: "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/macbert-large"|' configs/detector_config_server.yaml
|
||||
|
||||
# 确认关键参数
|
||||
grep -E "num_workers|per_gpu_batch|gradient_accum|mixed_precision|name:" configs/detector_config_server.yaml
|
||||
```
|
||||
|
||||
Phase 3-A / 3-B 成功时无需修改,直接进入 Phase 5。
|
||||
|
||||
---
|
||||
|
||||
## Phase 5 — 启动 4-GPU 训练
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
mkdir -p experiments checkpoints/detector
|
||||
|
||||
# 推荐:accelerate launch(使用服务器专用配置)
|
||||
accelerate launch \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=bf16 \
|
||||
--multi_gpu \
|
||||
scripts/train_detector.py \
|
||||
--config configs/detector_config_server.yaml \
|
||||
2>&1 | tee experiments/train_$(date +%Y%m%d_%H%M%S).log &
|
||||
|
||||
echo "Training PID: $!"
|
||||
```
|
||||
|
||||
若 `accelerate launch` 不可用,改用 torchrun:
|
||||
```bash
|
||||
torchrun --nproc_per_node=4 \
|
||||
scripts/train_detector.py \
|
||||
--config configs/detector_config_server.yaml \
|
||||
2>&1 | tee experiments/train_$(date +%Y%m%d_%H%M%S).log &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 6 — 监控与验证
|
||||
|
||||
训练启动后持续执行以下检查:
|
||||
|
||||
```bash
|
||||
# 6-A 查看实时日志(关键:前100步 loss 应在 1.0~3.0 之间下降)
|
||||
tail -f experiments/train_*.log
|
||||
|
||||
# 6-B GPU 利用率(4 块 GPU 利用率均应 >80%)
|
||||
watch -n 5 nvidia-smi
|
||||
|
||||
# 6-C 检查第一次验证输出(~100 global steps 后出现)
|
||||
# 期望 Val binary F1 > 0.40(超过 L1c 基线 0.410 是最低目标,目标 >0.80)
|
||||
|
||||
# 6-D 检查 checkpoint 保存
|
||||
ls -lh checkpoints/detector/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 7 — 模型评估(验证 F1=0.9978 是否真实)
|
||||
|
||||
> **背景**:训练报告 Val Binary F1=0.9978,但该分数基于验证集(dev.jsonl),
|
||||
> 且验证集与训练集同为 LLM 生成,存在"同源过拟合"风险。
|
||||
> 本 Phase 用三组实验定位真实泛化能力。
|
||||
|
||||
### 7-A 全量 test 集评估
|
||||
|
||||
```bash
|
||||
cd $PROJ
|
||||
|
||||
# 重新评估检测器(Module B)
|
||||
python scripts/evaluate.py \
|
||||
--detector-ckpt checkpoints/detector/best.pt \
|
||||
--config configs/detector_config_server.yaml \
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \
|
||||
--source-filter all \
|
||||
--output experiments/eval_all.json
|
||||
```
|
||||
|
||||
重点观察:
|
||||
- `binary_f1` 是否仍接近 0.9978(若是,说明 test 集也被"污染")
|
||||
- `level_macro_f1`(l_risk 0-4 等级 F1)——这比 binary 难得多,若也完美则有问题
|
||||
- `fine_macro_f1`(14 类细粒度标签 F1)——最难任务,正常应在 0.5-0.8
|
||||
|
||||
### 7-B 仅人工标注子集(关键实验)
|
||||
|
||||
```bash
|
||||
# 重新评估干预策略(Module C)
|
||||
python scripts/evaluate.py \
|
||||
--detector-ckpt checkpoints/detector/best.pt \
|
||||
--config configs/detector_config_server.yaml \
|
||||
--agent-ckpt checkpoints/intervention/final_v2.pt \
|
||||
--test-data data/processed/CompanionRisk-Bench/test.jsonl \
|
||||
--source-filter human \
|
||||
--output experiments/eval_human_only.json
|
||||
```
|
||||
|
||||
> 仅评估来自 DICES / CoSafe / Human-AI Suicide Risk 三个人工标注数据集的样本。
|
||||
> 这些样本来源不同于 LLM 生成,能真实反映泛化性。
|
||||
> **若此处 binary_f1 明显下降(<0.80),说明模型依赖 LLM 文体特征而非风险语义。**
|
||||
|
||||
### 7-C 查看 source 字段分布(调试用)
|
||||
|
||||
```bash
|
||||
# 确认 test.jsonl 中 source 字段的实际取值
|
||||
python3 -c "
|
||||
import json
|
||||
from collections import Counter
|
||||
samples = [json.loads(l) for l in open('data/processed/CompanionRisk-Bench/test.jsonl') if l.strip()]
|
||||
src_counter = Counter(s.get('source', s.get('id','?')[:10]) for s in samples)
|
||||
for k, v in sorted(src_counter.items(), key=lambda x: -x[1]):
|
||||
print(f' {k}: {v}')
|
||||
print(f'Total: {len(samples)}')
|
||||
"
|
||||
```
|
||||
|
||||
> 若输出发现所有样本都没有 source 字段,则 source-filter 用 id 前缀判断(evaluate.py 已处理)。
|
||||
> 把输出贴回来,若所有样本都是 LLM 生成(无人工标注),说明 test 集设计有问题。
|
||||
|
||||
### 7-D 结果判读标准
|
||||
|
||||
| 实验 | binary_f1 | 解释 |
|
||||
|------|-----------|------|
|
||||
| 7-A 全量 test | ~0.99 | test/dev 同源,无参考价值 |
|
||||
| 7-A 全量 test | ~0.80-0.90 | 合理,模型有真实泛化能力 |
|
||||
| 7-B 人工标注 | ~0.99 | **可信**,真实泛化优秀 |
|
||||
| 7-B 人工标注 | 0.60-0.75 | **同源过拟合确认**,需处理 |
|
||||
| 7-B 人工标注 | <0.60 | 严重过拟合,训练方案需调整 |
|
||||
|
||||
## Phase 9 — 取回结果
|
||||
|
||||
训练和评估完成后,将 checkpoint、日志和评估 JSON 传回本地:
|
||||
|
||||
```powershell
|
||||
# 在本地 Windows PowerShell 执行
|
||||
scp -P 20083 -r `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/checkpoints `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\
|
||||
|
||||
scp -P 20083 -r `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/experiments `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\
|
||||
|
||||
# 同时取回更新后的 evaluate.py(已修复 bug,含 source-filter 功能)
|
||||
scp -P 20083 `
|
||||
root@10.82.3.180:/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL/scripts/evaluate.py `
|
||||
D:\Myresearch\CompanionGuard-RL\code\CompanionGuard-RL\scripts\
|
||||
--config configs/detector_config_server.yaml \
|
||||
--intervention-config configs/intervention_config.yaml \
|
||||
--output experiments/eval_intervention_v3.json
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 关键指标参考(训练目标)
|
||||
## 关键结果(论文用)
|
||||
|
||||
| 指标 | L1c 规则基线(下界) | MacBERT 目标 |
|
||||
|------|---------------------|--------------|
|
||||
| Binary F1 | 0.410 | **> 0.80** |
|
||||
| R1 recall(危机类) | 0.097 | **> 0.75** |
|
||||
| R9 recall | 0.091 | **> 0.70** |
|
||||
| FNR(漏检率) | 0.740 | **< 0.20** |
|
||||
### Module B — 检测器 v4
|
||||
|
||||
| 指标 | 值 |
|
||||
|------|----|
|
||||
| binary_f1 | **0.9995** |
|
||||
| high_risk_recall | **1.0000** |
|
||||
| FNR | **0.00%** |
|
||||
| level_weighted_f1 | **0.559** |
|
||||
| fine_macro_f1(public 10类) | **0.484** |
|
||||
|
||||
### Module C — RL 干预策略 v3(论文用,`eval_intervention_v3.json`)
|
||||
|
||||
| 方法 | safety_recall | over_refusal | action_accuracy | safety_ux_fscore |
|
||||
|------|--------------|--------------|-----------------|-----------------|
|
||||
| Rule-based | 0.908 | 0.000 | — | 0.952 |
|
||||
| Threshold | 0.908 | 0.000 | — | 0.952 |
|
||||
| **Ours (RL)** | **1.000** | **0.004** | **0.575** | **0.998** |
|
||||
|
||||
**使用权重**:`checkpoints/intervention/final_v2.pt`(用 `det_l_risk` 重训)
|
||||
|
||||
---
|
||||
|
||||
## 常见问题处理
|
||||
## 重要注意事项
|
||||
|
||||
### NCCL 通信报错
|
||||
```bash
|
||||
export NCCL_P2P_DISABLE=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
# 再重新启动 accelerate launch
|
||||
```
|
||||
|
||||
### OOM(显存不足,不太可能:5090 32GB)
|
||||
在 `configs/detector_config.yaml` 中将 `per_gpu_batch_size: 16` 改为 `8`,`gradient_accumulation_steps: 4`。
|
||||
|
||||
### MacBERT 路径找不到
|
||||
检查 `~/.cache/huggingface/hub/` 或 `model_cache/` 目录,找到实际下载路径后更新 config 的 `model.name`。
|
||||
|
||||
### accelerate 找不到
|
||||
```bash
|
||||
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple accelerate
|
||||
# 或用 torchrun 替代(见 Phase 5)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 文件清单(训练产出)
|
||||
|
||||
| 文件 | 描述 |
|
||||
|------|------|
|
||||
| `checkpoints/detector/best.pt` | 验证集 F1 最高的模型权重 |
|
||||
| `checkpoints/detector/final.pt` | 最后一个 epoch 的权重 |
|
||||
| `experiments/train_YYYYMMDD_HHMMSS.log` | 完整训练日志 |
|
||||
- **PyYAML 6.x 陷阱**:lr 值必须写 `0.001` 而非 `1e-3`(后者被解析为字符串)
|
||||
- **RTX 5090 NCCL**:多卡训练需 `NCCL_SHM_DISABLE=1 NCCL_P2P_DISABLE=1`;PPO 阶段用单卡绕开 barrier 问题
|
||||
- **det_l_risk vs l_risk**:评估和训练均须用检测器预测的 `det_l_risk`,不能用 ground truth `l_risk`
|
||||
- **obs_dim = 2065**:state 向量结构 `[d_score(1)|l_risk_onehot(5)|c_primary_probs(10)|e_H_pool(1024)|e_P_pool(1024)|t_norm(1)]`
|
||||
|
||||
@@ -1,617 +0,0 @@
|
||||
# 多模态情感模型优化研究执行方案
|
||||
|
||||
> 文档作用:本文是本课题唯一主工作文档,记录“当前有效方案”。它应当能直接指导后续实验执行、代码开发、基线选择、结果验收与论文写作。历史判断、废弃方案和变更原因只写入变更日志,不在本文反复展开。
|
||||
>
|
||||
> 当前版本:v3.1
|
||||
> 更新日期:2026-04-24
|
||||
> 当前主线:D2 RL 对话图结构优化为主要贡献;D1 RL 自适应模态融合作为辅助实验、鲁棒性分析和负结果讨论保留。
|
||||
> 当前限制:本轮只修文档和代码口径,不启动 GPU 训练。
|
||||
|
||||
## 1. 研究目标
|
||||
|
||||
本课题研究多模态情感识别中的动态决策问题。核心问题不是简单把文本、音频、视觉拼接起来,而是在不同样本、不同噪声、不同对话上下文下,让模型学会“当前应该信任哪种模态、应该参考哪些历史发言”。
|
||||
|
||||
最终论文建议围绕一个统一叙事组织:
|
||||
|
||||
> Reinforcement Learning for Adaptive Multimodal Emotion Recognition: From Modality Fusion to Conversation Graph Topology
|
||||
|
||||
中文表述:
|
||||
|
||||
> 面向多模态情感识别的强化学习动态决策框架:从模态融合权重到对话图拓扑优化。
|
||||
|
||||
两个技术层次:
|
||||
|
||||
1. D1 话语级动态融合:RL agent 根据每路模态的可靠性动态分配 text/audio/vision 权重。
|
||||
2. D2 对话级动态图结构:RL agent 根据当前发言状态动态选择上下文窗口、历史发言边权重和说话人关系。
|
||||
|
||||
当前优先级:
|
||||
|
||||
| 方向 | 定位 | 当前状态 | 优先级 |
|
||||
|---|---|---|---|
|
||||
| D1 RL 自适应模态融合 | 辅助实验、鲁棒性分析、负结果讨论 | bug 已修,待 GPU 空闲后重跑 | 中 |
|
||||
| D2 RL 对话图结构优化 | 主贡献、主实验、论文核心 | 等待 COGMEN 依赖与公平基线 | 高 |
|
||||
|
||||
## 2. 当前资源与目录
|
||||
|
||||
服务器:
|
||||
|
||||
| 项目 | 内容 |
|
||||
|---|---|
|
||||
| SSH | `ssh -p 20083 root@10.82.3.180` |
|
||||
| `$ZSY` | `/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy` |
|
||||
| 项目目录 | `$ZSY/multimodal_affect` |
|
||||
| 环境 | `$ZSY/envs/multimodal_affect` |
|
||||
| 数据盘 | 1TB,当前约 788GB 可用 |
|
||||
| GPU | 4 x RTX 5090,当前暂不占用 |
|
||||
|
||||
本地目录:
|
||||
|
||||
| 项目 | 路径 |
|
||||
|---|---|
|
||||
| 本地研究目录 | `D:\Myresearch\多模态情感模型优化` |
|
||||
| 主方案 | `2026-04-22-研究执行方案.md` |
|
||||
| 变更日志 | `2026-04-24-变更日志.md` |
|
||||
|
||||
服务器数据状态:
|
||||
|
||||
| 数据 | 当前状态 | 用途 | 注意事项 |
|
||||
|---|---|---|---|
|
||||
| IEMOCAP 预提取特征 | `data/iemocap`,GloVe 300 + COVAREP 74 + FACET 35 | D1 修复后重跑 | 扁平特征,不适合 D2 图结构 |
|
||||
| IEMOCAP 原始数据 | `data/raw/IEMOCAP_full_release`,已解压 Session1-5 | D2 重建对话级样本、必要时提取新特征 | 不覆盖现有 `data/iemocap` |
|
||||
| COGMEN IEMOCAP 4 类数据 | `baselines/COGMEN/data/iemocap_4` | D2 公平基线首选入口 | 需要 PyG 等依赖 |
|
||||
| MELD CSV | `data/meld`,当前主要是文本/标签 | 后续 D2 泛化实验 | 已解压 MELD.Raw,可补音频/视频 |
|
||||
| MELD.Raw | `data/raw/MELD/MELD.Raw`,已解压 | 后续多模态特征 | 暂不运行耗时提取 |
|
||||
| MOSI | `data/mosi` | D1 或补充实验 | audio 中有少量 `-inf`,使用前必须清洗 |
|
||||
| IEMOCAP noisy | `data/iemocap_noisy` | D1 鲁棒性实验 | 已重新生成三模态噪声文件 |
|
||||
|
||||
## 3. 当前代码状态
|
||||
|
||||
已修复内容:
|
||||
|
||||
| 模块 | 文件 | 修复内容 | 状态 |
|
||||
|---|---|---|---|
|
||||
| 噪声生成 | `scripts/preprocess/generate_noise.py` | 统一输出 `*_vision.npy`,兼容 `visual` 配置别名 | 已同步本地与服务器 |
|
||||
| 数据加载 | `src/data/dataset.py` | noisy variant 加载 text/audio/vision,缺失文件回退 clean 同索引模态 | 已同步服务器 |
|
||||
| D1 训练 | `scripts/train_d1_fixed.py`、服务器 `scripts/train/train_d1.py` | 噪声 batch 使用同一索引加载 text/audio/vision/labels | 已同步 |
|
||||
| D1 评测 | `scripts/run_eval_ablation.py`、服务器 `scripts/eval/eval_d1.py` | 噪声评测替换三模态,修复最后 batch 索引 | 已同步 |
|
||||
| 噪声数据 | 服务器 `data/iemocap_noisy` | 8 个变体均含 train/val/test 的 text/audio/vision/labels | 已验证 |
|
||||
|
||||
已完成检查:
|
||||
|
||||
- 本地相关 Python 文件 `py_compile` 通过。
|
||||
- 服务器相关 Python 文件 `py_compile` 通过。
|
||||
- 服务器噪声数据形状检查通过。
|
||||
- 未启动任何训练。
|
||||
|
||||
旧结果使用规则:
|
||||
|
||||
| 结果 | 是否可用于论文正式表格 | 用途 |
|
||||
|---|---|---|
|
||||
| 修复前 D1 Stage A/B 数值 | 否 | 仅用于说明 bug 发现前的诊断过程 |
|
||||
| 修复前 D1 噪声鲁棒性数值 | 否 | 视觉噪声相关结果无效 |
|
||||
| 修复后重新训练结果 | 待产生 | 可作为正式 D1 结果 |
|
||||
| COGMEN 论文引用数值 | 不能作为主表唯一基线 | 可放 Related Work 或参考表 |
|
||||
| 本地复现 COGMEN-Ours | 待产生 | D2 主基线 |
|
||||
|
||||
## 4. 总体技术路线
|
||||
|
||||
整体路线分为两个层次:
|
||||
|
||||
```text
|
||||
输入:text / audio / vision 多模态情感数据
|
||||
|
||||
D1:话语级动态融合
|
||||
预提取特征 -> 模态 projector -> 置信度/不确定性状态 -> RL fusion agent -> 分类
|
||||
目标:验证 RL 是否能在噪声/缺失模态下优于固定融合
|
||||
|
||||
D2:对话级动态图结构
|
||||
对话序列 -> COGMEN 基线图 -> RL window/topology agent -> 动态 GNN -> ERC 分类
|
||||
目标:学习“该看哪些历史发言”和“边权重应该多强”
|
||||
```
|
||||
|
||||
研究策略:
|
||||
|
||||
1. 先保证 D1 训练和评测口径正确,重跑后决定其论文位置。
|
||||
2. D2 必须先跑通 COGMEN-Ours,建立公平基线。
|
||||
3. D2 先做小动作空间的 `RL-Window`,再做完整 `RL-Topo-Soft`。
|
||||
4. 最终论文以 D2 主结果为核心,D1 作为辅助实验或 analysis 章节。
|
||||
|
||||
## 5. D1 方案:RL 自适应模态融合
|
||||
|
||||
### 5.1 任务定义
|
||||
|
||||
输入为单条 utterance 的三模态特征:
|
||||
|
||||
| 模态 | 当前特征 | 维度 |
|
||||
|---|---|---|
|
||||
| Text | GloVe 平均/预提取文本特征 | 300 |
|
||||
| Audio | COVAREP | 74 |
|
||||
| Vision | FACET | 35 |
|
||||
|
||||
输出为 4 类 IEMOCAP 情感分类。
|
||||
|
||||
当前 D1 不声称使用 BERT/Wav2Vec2/ResNet。若未来要写“大模型特征提取”,必须另做特征提取与重跑。
|
||||
|
||||
### 5.2 当前模型结构
|
||||
|
||||
```text
|
||||
text/audio/vision feature
|
||||
-> modality projector: low-dim feature -> 1024-d shared space
|
||||
-> confidence estimator: 每个模态输出一个置信度
|
||||
-> RL fusion agent: 输出三路融合权重
|
||||
-> weighted fusion
|
||||
-> MLP classifier
|
||||
```
|
||||
|
||||
Stage A:监督预训练。
|
||||
|
||||
- 使用均匀融合训练 projector、confidence estimator、classifier。
|
||||
- 训练时注入噪声。
|
||||
- 置信度目标按实际噪声模态设置:被污染模态为 0.1,干净模态为 0.9。
|
||||
|
||||
Stage B:PPO 融合权重学习。
|
||||
|
||||
- encoder 冻结。
|
||||
- agent 根据状态输出三路权重。
|
||||
- classifier 可轻量刷新。
|
||||
- 该版本作为 `RL-Decoupled`,用于验证两阶段 RL 的局限。
|
||||
|
||||
### 5.3 D1 状态、动作、奖励
|
||||
|
||||
当前状态:
|
||||
|
||||
```text
|
||||
s = [conf_text, conf_audio, conf_vision, noise_est]
|
||||
```
|
||||
|
||||
动作:
|
||||
|
||||
```text
|
||||
a = [w_text, w_audio, w_vision], sum(a)=1
|
||||
```
|
||||
|
||||
当前奖励:
|
||||
|
||||
```text
|
||||
R = alpha * (-CE) + beta * consistency(weights, confidence) - gamma * instability
|
||||
```
|
||||
|
||||
后续增强状态候选:
|
||||
|
||||
| 状态项 | 说明 |
|
||||
|---|---|
|
||||
| 三路 confidence | 现有 |
|
||||
| 三路单模态预测熵 | 衡量每个模态自身不确定性 |
|
||||
| 三对跨模态 cosine similarity | 衡量模态一致性 |
|
||||
| 三路 feature norm / variance | 检测异常特征 |
|
||||
| 分模态 noise estimate | 替代单一 `audio.std` |
|
||||
|
||||
目标是从 4 维状态扩展到 12-16 维状态,前提是 `RL-Decoupled` 重跑后仍有继续价值。
|
||||
|
||||
### 5.4 D1 噪声实验
|
||||
|
||||
当前 IEMOCAP 噪声变体:
|
||||
|
||||
| 变体 | 噪声设计 | 应影响模态 |
|
||||
|---|---|---|
|
||||
| `gaussian_light` | 轻度高斯噪声 | text/audio/vision |
|
||||
| `gaussian_heavy` | 重度高斯噪声 | text/audio/vision |
|
||||
| `missing_audio` | 音频全置零 | audio |
|
||||
| `missing_visual` | 视觉全置零 | vision |
|
||||
| `text_word_drop_30` | 文本 30% dropout 或特征置零 | text |
|
||||
| `audio_masking_50` | 音频 50% 维度遮蔽 | audio |
|
||||
| `realistic_mixed` | 文本轻度损坏 + 音频噪声 + 视觉遮挡 | text/audio/vision |
|
||||
| `audio_time_mask` | 音频样本级时间遮蔽 | audio |
|
||||
|
||||
验收要求:
|
||||
|
||||
- 每个变体必须有 `train/val/test` 的 `text/audio/vision/labels`。
|
||||
- 评测时 noisy 文件存在就替换对应模态;不存在才回退 clean 同索引模态。
|
||||
- 不允许 batch 内跨样本替换模态。
|
||||
|
||||
### 5.5 D1 对比基线
|
||||
|
||||
D1 不以 COGMEN 为主基线。建议基线:
|
||||
|
||||
| 类别 | 方法 |
|
||||
|---|---|
|
||||
| 基础融合 | concat + MLP、Fixed-Equal、Fixed-Learned-Weight |
|
||||
| 动态融合 | Attention Fusion、Gated Fusion |
|
||||
| 多模态情感经典方法 | TFN、LMF、MFN、MulT、MISA、Self-MM |
|
||||
| 鲁棒/缺失模态 | GCNet、MMIN、M2R2、GSDNet 或同类方法 |
|
||||
| 本方法 | RL-Decoupled、RL-Joint、RL-Joint-16dim |
|
||||
|
||||
最小可执行矩阵:
|
||||
|
||||
| 变体 | 必做 | 说明 |
|
||||
|---|---|---|
|
||||
| Stage-A-Only | 是 | 修复后监督基线 |
|
||||
| Fixed-Equal | 是 | 固定融合 |
|
||||
| Gated-Fusion | 是 | 非 RL 动态融合 |
|
||||
| RL-Decoupled | 是 | 当前两阶段 RL |
|
||||
| RL-Joint | 视结果 | 如果 decoupled 仍弱,做联合训练 |
|
||||
| RL-Joint-16dim | 视结果 | 状态增强版本 |
|
||||
|
||||
### 5.6 D1 后续执行顺序
|
||||
|
||||
触发条件:GPU 空闲。
|
||||
|
||||
1. 确认 `data/iemocap_noisy` 三模态文件完整。
|
||||
2. 重跑 Stage A。
|
||||
3. 重跑 Stage B `RL-Decoupled`。
|
||||
4. 重新运行修复后的 `eval_d1.py`。
|
||||
5. 汇总 clean 与 8 种噪声下的 WF1/Acc。
|
||||
6. 判断是否进入 `RL-Joint`。
|
||||
|
||||
D1 成功标准:
|
||||
|
||||
- clean test 不显著低于 Stage-A-Only。
|
||||
- 至少在多数噪声场景中优于 Fixed-Equal 或 Gated-Fusion。
|
||||
- 若无法达成,D1 写为负结果与设计反思,不作为主贡献。
|
||||
|
||||
## 6. D2 方案:RL 对话图结构优化
|
||||
|
||||
### 6.1 任务定义
|
||||
|
||||
D2 是论文主线。目标是在多模态对话情感识别中,让 RL agent 动态决定当前发言应参考哪些历史发言。
|
||||
|
||||
输入为一个 dialogue:
|
||||
|
||||
```text
|
||||
u1, u2, ..., ut
|
||||
每个 ui 有 text/audio/vision feature、speaker id、label
|
||||
```
|
||||
|
||||
输出为每个 utterance 的情感类别。
|
||||
|
||||
核心问题:
|
||||
|
||||
```text
|
||||
固定图结构:每个当前发言使用固定窗口或固定全局连接
|
||||
改进目标:根据当前语义、说话人、历史情绪变化,动态选择上下文边
|
||||
```
|
||||
|
||||
### 6.2 D2 与 COGMEN 的关系
|
||||
|
||||
COGMEN 是 D2 的直接前身,不跳过。
|
||||
|
||||
当前策略:
|
||||
|
||||
1. 先跑官方 COGMEN `iemocap_4`,得到 `COGMEN-Ours`。
|
||||
2. 在同一数据、同一 split、同一指标下实现固定窗口变体。
|
||||
3. 再加入 RL window selector。
|
||||
4. 最后实现 soft topology agent。
|
||||
|
||||
不能只引用 COGMEN 论文表格,因为:
|
||||
|
||||
- 论文数值和当前本地环境、特征、依赖版本可能不同。
|
||||
- D2 的改进必须建立在可复现同口径基线上。
|
||||
- 后续固定窗口和 RL 图结构需要直接改 COGMEN 图构建逻辑。
|
||||
|
||||
### 6.3 D2 环境依赖
|
||||
|
||||
当前阻塞:
|
||||
|
||||
| 依赖 | 状态 |
|
||||
|---|---|
|
||||
| `torch_geometric` | 未安装 |
|
||||
| `torch_scatter` | 未安装 |
|
||||
| `torch_sparse` | 未安装 |
|
||||
| `sentence_transformers` | 未安装 |
|
||||
| `comet_ml` | 未安装 |
|
||||
|
||||
解决方案:
|
||||
|
||||
1. 优先离线安装与 `torch 2.7.1+cu128` 兼容的 PyG wheels。
|
||||
2. 如果 PyG wheel 不好配,先用 COGMEN 官方预处理数据做评测脚本适配,减少训练依赖。
|
||||
3. `comet_ml` 只作为日志依赖,必要时 stub 或禁用。
|
||||
4. `sentence_transformers` 若只评测官方 pkl,可避免重新提取 SBERT 特征;若训练流程强依赖则离线安装。
|
||||
|
||||
### 6.4 D2 阶段一:COGMEN-Ours
|
||||
|
||||
目标:跑通本地 COGMEN 4 类 IEMOCAP。
|
||||
|
||||
任务:
|
||||
|
||||
1. 阅读并定位 COGMEN 图构建逻辑:
|
||||
- `cogmen/model/COGMEN.py`
|
||||
- `cogmen/model/GNN.py`
|
||||
- `cogmen/Dataset.py`
|
||||
- `preprocess.py`
|
||||
2. 确认官方 `data_iemocap_4.pkl` 与 `IEMOCAP_features_4.pkl` 的结构。
|
||||
3. 安装或绕过缺失依赖。
|
||||
4. 跑原始评测,记录 WF1/Acc。
|
||||
5. 保存基线结果到 `outputs/results/d2_cogmen_ours.json`。
|
||||
|
||||
验收:
|
||||
|
||||
- 可以在服务器上复现实验。
|
||||
- 有固定 random seed。
|
||||
- 有完整命令、日志和结果文件。
|
||||
|
||||
### 6.5 D2 阶段二:固定窗口基线
|
||||
|
||||
目的:确认 COGMEN 固定图结构中上下文窗口的影响。
|
||||
|
||||
实验:
|
||||
|
||||
| 变体 | 描述 |
|
||||
|---|---|
|
||||
| `COGMEN-Ours` | 原始 COGMEN |
|
||||
| `FixedWin-3` | 只连接历史 3 句 |
|
||||
| `FixedWin-5` | 只连接历史 5 句 |
|
||||
| `FixedWin-7` | 只连接历史 7 句 |
|
||||
| `FixedWin-All` | 全历史连接 |
|
||||
|
||||
输出:
|
||||
|
||||
- 每个窗口的 WF1/Acc。
|
||||
- 不同对话长度下的性能分组。
|
||||
- 后续 RL action space 的依据。
|
||||
|
||||
### 6.6 D2 阶段三:RL-Window
|
||||
|
||||
这是最小可行创新版本。
|
||||
|
||||
状态设计:
|
||||
|
||||
```text
|
||||
s_t = concat(
|
||||
h_t, 当前发言表示
|
||||
mean(h_1...h_{t-1}), 历史均值
|
||||
speaker_embedding, 说话人嵌入
|
||||
delta(h_t, h_{t-1}) 情绪变化估计
|
||||
)
|
||||
```
|
||||
|
||||
动作空间:
|
||||
|
||||
```text
|
||||
a_t in {3, 5, 7, all}
|
||||
```
|
||||
|
||||
奖励:
|
||||
|
||||
```text
|
||||
R = task_reward + lambda_sparse * sparsity_bonus + lambda_stable * stability_bonus
|
||||
```
|
||||
|
||||
可实现版本:
|
||||
|
||||
| 奖励项 | 实现 |
|
||||
|---|---|
|
||||
| task_reward | batch-level `-CE` 或 validation WF1 改善 |
|
||||
| sparsity_bonus | 窗口越短越稀疏,但不能牺牲任务性能 |
|
||||
| stability_bonus | 相邻 utterance 的窗口选择不要剧烈抖动 |
|
||||
|
||||
验收:
|
||||
|
||||
- `RL-Window` 超过 `COGMEN-Ours` 或超过最优固定窗口。
|
||||
- 至少有 3 个对话案例显示 RL 选择了不同窗口。
|
||||
- 若没有提升,分析是否状态信息不足或奖励过稀疏。
|
||||
|
||||
### 6.7 D2 阶段四:RL-Topo-Soft
|
||||
|
||||
完整版动态图结构。
|
||||
|
||||
动作:
|
||||
|
||||
```text
|
||||
对当前 utterance u_t,输出每条历史边 e_{i,t} 的连续权重 w_{i,t} in [0,1]
|
||||
```
|
||||
|
||||
为什么用 soft topology:
|
||||
|
||||
- 硬 0/1 边不可导,PPO 信号稀疏。
|
||||
- soft edge 可以让分类 loss 通过消息传递影响边权。
|
||||
- 可解释性仍保留:边权可视化即可。
|
||||
|
||||
模型候选:
|
||||
|
||||
```text
|
||||
history/current sequence
|
||||
-> transformer/state encoder
|
||||
-> edge scorer
|
||||
-> weighted graph message passing
|
||||
-> classifier
|
||||
```
|
||||
|
||||
奖励:
|
||||
|
||||
```text
|
||||
R = alpha * task_reward
|
||||
+ beta * sparsity
|
||||
+ gamma * emotion_coherence
|
||||
- eta * graph_instability
|
||||
```
|
||||
|
||||
解释性输出:
|
||||
|
||||
- 每个 utterance 选择的 top-k 历史边。
|
||||
- 同说话人/不同说话人边权统计。
|
||||
- 情绪转折点前后的图结构变化。
|
||||
- 噪声 utterance 是否被降低边权。
|
||||
|
||||
### 6.8 D2 主基线矩阵
|
||||
|
||||
论文主表建议至少覆盖:
|
||||
|
||||
| 类别 | 方法 | 必要性 |
|
||||
|---|---|---|
|
||||
| 经典 ERC | DialogueRNN | 基础对话上下文基线 |
|
||||
| 图 ERC | DialogueGCN | 图结构基础基线 |
|
||||
| 多模态图 | MMGCN | 多模态 ERC 图基线 |
|
||||
| 直接前身 | COGMEN | 必做 |
|
||||
| 近年图方法 | M3Net | 建议 |
|
||||
| 近年图方法 | AdaGIN | 建议 |
|
||||
| 近年图方法 | DER-GCN | 建议 |
|
||||
| 直接竞争 | DGODE | 必做 |
|
||||
| 补充竞争 | GASMER | 可做 |
|
||||
| 本文 | FixedWin / RL-Window / RL-Topo-Soft | 主结果 |
|
||||
|
||||
R1-Omni、AffectGPT-R1、EMO-RL、AffectAgent 放 Related Work,不强行放主表,除非任务设置、数据集、指标完全对齐。
|
||||
|
||||
## 7. 数据集与指标
|
||||
|
||||
### 7.1 数据集
|
||||
|
||||
| 数据集 | 用途 | 当前策略 |
|
||||
|---|---|---|
|
||||
| IEMOCAP 4-class | 主数据集 | D1 用扁平特征;D2 用 COGMEN 对话级数据 |
|
||||
| MELD | 泛化验证 | 先补齐音频/视频或使用已有对话结构 |
|
||||
| CMU-MOSI | D1 可选补充 | 使用前清洗 audio 的 `-inf` |
|
||||
|
||||
### 7.2 指标
|
||||
|
||||
主指标:
|
||||
|
||||
- Weighted F1
|
||||
- Accuracy
|
||||
|
||||
辅助指标:
|
||||
|
||||
- Macro F1
|
||||
- per-class precision/recall/F1
|
||||
- 噪声场景下相对下降率
|
||||
- 多 seed 均值与标准差
|
||||
- paired t-test 或 bootstrap confidence interval
|
||||
|
||||
实验规范:
|
||||
|
||||
- 主结果至少 3 个 seed。
|
||||
- IEMOCAP 必须明确 4-class 或 6-class。
|
||||
- 所有表格必须标明特征来源和 split。
|
||||
- COGMEN 论文引用数值和本地复现数值分开写。
|
||||
|
||||
## 8. 当前时间线
|
||||
|
||||
### 已完成
|
||||
|
||||
| 阶段 | 内容 | 状态 |
|
||||
|---|---|---|
|
||||
| P0 | 环境搭建、数据上传、初步特征整理 | 完成 |
|
||||
| P0-data | IEMOCAP/MELD 原始数据解压 | 完成 |
|
||||
| P1-fix | D1 噪声生成、训练采样、评测 bug 修复 | 完成 |
|
||||
| P1-check | 本地/服务器语法检查、噪声数据形状检查 | 完成 |
|
||||
|
||||
### 下一阶段
|
||||
|
||||
| 阶段 | 内容 | 触发条件 | 预计耗时 |
|
||||
|---|---|---|---|
|
||||
| P1-rerun | D1 修复后 Stage A/B 重跑 | GPU 空闲 | 0.5-1 天 |
|
||||
| P2-env | 安装/适配 COGMEN 依赖 | CPU/环境窗口 | 0.5-1 天 |
|
||||
| P2-base | 跑 COGMEN-Ours | 依赖就绪 | 0.5 天 |
|
||||
| P2-win | 固定窗口与 RL-Window | COGMEN-Ours 完成 | 2-4 天 |
|
||||
| P2-topo | RL-Topo-Soft | RL-Window 有效 | 1-3 周 |
|
||||
| P3 | 多 seed、消融、可视化、写作 | 主结果稳定 | 3-4 周 |
|
||||
|
||||
## 9. 后续命令草案
|
||||
|
||||
仅在 GPU 空闲后执行。
|
||||
|
||||
D1 重跑:
|
||||
|
||||
```bash
|
||||
cd $ZSY/multimodal_affect
|
||||
export ZSY=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy
|
||||
export PYTHONPATH=$ZSY/multimodal_affect
|
||||
export WANDB_MODE=offline
|
||||
|
||||
$ZSY/envs/multimodal_affect/bin/python scripts/preprocess/generate_noise.py \
|
||||
--config configs/noise_configs.yaml \
|
||||
--data_dir data/iemocap \
|
||||
--out_dir data/iemocap_noisy
|
||||
|
||||
$ZSY/envs/multimodal_affect/bin/python scripts/train/train_d1.py \
|
||||
--stage supervised \
|
||||
--dataset IEMOCAP \
|
||||
--config configs/d1/stage_a.yaml \
|
||||
--output outputs/checkpoints/d1_stageA_fixed \
|
||||
--gpus 0
|
||||
|
||||
$ZSY/envs/multimodal_affect/bin/python scripts/train/train_d1.py \
|
||||
--stage rl \
|
||||
--dataset IEMOCAP \
|
||||
--checkpoint outputs/checkpoints/d1_stageA_fixed/best.ckpt \
|
||||
--config configs/d1/stage_b.yaml \
|
||||
--output outputs/checkpoints/d1_stageB_fixed \
|
||||
--gpus 0
|
||||
```
|
||||
|
||||
D2 COGMEN-Ours:
|
||||
|
||||
```bash
|
||||
cd $ZSY/multimodal_affect/baselines/COGMEN
|
||||
|
||||
# 依赖就绪后:
|
||||
$ZSY/envs/multimodal_affect/bin/python eval.py \
|
||||
--dataset iemocap_4 \
|
||||
--modalities atv
|
||||
```
|
||||
|
||||
注意:以上命令是草案,执行前先检查 GPU 占用和依赖状态。
|
||||
|
||||
## 10. 风险与应对
|
||||
|
||||
| 风险 | 概率 | 影响 | 应对 |
|
||||
|---|---|---|---|
|
||||
| D1 修复后仍无收益 | 中 | D1 不能作为贡献 | 写作中定位为负结果和设计反思 |
|
||||
| COGMEN 依赖安装困难 | 中 | D2 延迟 | 离线 wheels、禁用 comet、优先评测路径 |
|
||||
| RL-Window 不超过固定窗口 | 中 | 创新不足 | 调整状态与奖励;分析对话长度分组 |
|
||||
| RL-Topo-Soft 训练不稳定 | 中 | 主结果风险 | 从 RL-Window 热启动,先做 soft edge supervised pretrain |
|
||||
| MELD 多模态特征提取耗时 | 中 | 泛化实验延迟 | 先 IEMOCAP 完整,MELD 作为补充 |
|
||||
| 基线过多导致时间失控 | 高 | 写作延期 | 先确保 COGMEN、DGODE、DialogueGCN、MMGCN |
|
||||
|
||||
## 11. 论文结构草案
|
||||
|
||||
```text
|
||||
1. Introduction
|
||||
- 多模态情感识别中的动态选择问题
|
||||
- 模态噪声与对话上下文冗余
|
||||
- RL 作为动态决策机制
|
||||
|
||||
2. Related Work
|
||||
- Multimodal sentiment/emotion recognition
|
||||
- Emotion recognition in conversation
|
||||
- Graph-based ERC
|
||||
- RL for affective computing
|
||||
|
||||
3. Method
|
||||
- Overview
|
||||
- D1: RL adaptive fusion
|
||||
- D2: RL graph topology optimization
|
||||
- Training objective and reward design
|
||||
|
||||
4. Experiments
|
||||
- Datasets and metrics
|
||||
- Baselines
|
||||
- Main results on IEMOCAP
|
||||
- Generalization on MELD
|
||||
- Ablation study
|
||||
- Robustness under noise/missing modality
|
||||
- Visualization and case study
|
||||
|
||||
5. Analysis
|
||||
- Why two-stage D1 is limited
|
||||
- How RL graph policy changes with dialogue context
|
||||
- Error analysis by emotion class and speaker dynamics
|
||||
|
||||
6. Conclusion
|
||||
```
|
||||
|
||||
## 12. 当前执行原则
|
||||
|
||||
必须遵守:
|
||||
|
||||
- 主方案只维护本文;变更原因只进日志。
|
||||
- 旧 D1 结果不进入论文正式结果表。
|
||||
- 不覆盖 `data/iemocap` 预提取特征。
|
||||
- 不用扁平 `data/iemocap` 训练 D2 图模型。
|
||||
- 不只拿 COGMEN 论文数值做主对比。
|
||||
- GPU 忙时只做文档、代码、CPU 检查、依赖准备。
|
||||
|
||||
下一步最优先事项:
|
||||
|
||||
1. 等 GPU 空闲后重跑 D1 修复版,确认辅助实验结论。
|
||||
2. 准备 COGMEN 依赖,跑通 `COGMEN-Ours`。
|
||||
3. 实现 `FixedWin` 与 `RL-Window`,尽快拿到 D2 第一组可写结果。
|
||||
@@ -1,17 +0,0 @@
|
||||
# 研究变更日志
|
||||
|
||||
> 文档作用:本文只记录主方案的关键变更,不写实验细节,不替代执行方案。当前有效方案以 `2026-04-22-研究执行方案.md` 为准。
|
||||
|
||||
| 日期 | 类型 | 变更 | 影响 | 状态 |
|
||||
|---|---|---|---|---|
|
||||
| 2026-04-24 | 文档 | 主工作文档重构为“当前有效方案”,保留研究目标、D1/D2路线、基线矩阵、阶段计划和验收标准 | 避免旧方案与新方案混用 | 完成 |
|
||||
| 2026-04-24 | 文档 | 移除 `2026-04-23-现状分析与推进方案.md`,Markdown 工作文档仅保留主方案与本日志 | 保持文档源唯一 | 完成 |
|
||||
| 2026-04-24 | 策略 | D2 RL 对话图结构优化升为主线,D1 RL 融合作为辅助实验与鲁棒性分析 | 论文重心转向对话级动态图结构 | 完成 |
|
||||
| 2026-04-24 | 基线 | COGMEN 保留为 D2 必跑基线,但不再只引用论文数值,必须建立 `COGMEN-Ours` | 提高对比公平性 | 完成 |
|
||||
| 2026-04-24 | 基线 | D1 不再以 COGMEN 为主基线,改与固定融合、门控融合、注意力融合和缺失模态鲁棒方法比较 | 避免 utterance-level 与 dialogue-level 不公平对比 | 完成 |
|
||||
| 2026-04-24 | 代码 | 修复噪声生成 `visual/vision` 命名错位,统一输出 `*_vision.npy` | 视觉噪声实验可真实生效 | 完成 |
|
||||
| 2026-04-24 | 代码 | 修复 D1 训练噪声 batch 三模态样本错配 | 修复后 D1 需要重跑 | 完成 |
|
||||
| 2026-04-24 | 代码 | 修复 D1 噪声评测只替换 text/audio 的问题,并修正最后 batch 索引 | 修复后噪声鲁棒性结果才可采信 | 完成 |
|
||||
| 2026-04-24 | 数据 | 服务器重新生成 IEMOCAP 8 个三模态噪声变体 | `data/iemocap_noisy` 已可用于重跑 | 完成 |
|
||||
| 2026-04-24 | 数据 | 服务器完成 IEMOCAP 原始 Session 与 MELD.Raw 解压 | 后续可重建对话级样本和补充 MELD 特征 | 完成 |
|
||||
| 2026-04-24 | 验证 | 本地与服务器相关脚本通过语法检查;服务器噪声数据形状检查通过 | 未启动 GPU 训练 | 完成 |
|
||||
@@ -1,121 +0,0 @@
|
||||
# Multimodal noise configuration for robustness experiments (P0-4)
|
||||
# Each variant defines noise applied per modality.
|
||||
# Run: python scripts/preprocess/generate_noise.py \
|
||||
# --config configs/noise_configs.yaml \
|
||||
# --data_dir $ZSY/multimodal_affect/data/iemocap
|
||||
|
||||
variants:
|
||||
|
||||
# ── Variant 1: Light Gaussian across all modalities ───────────────────
|
||||
- name: gaussian_light
|
||||
description: "Mild Gaussian noise on all modalities (σ=0.05)"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.05
|
||||
audio:
|
||||
type: gaussian
|
||||
intensity: 0.05
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.05
|
||||
|
||||
# ── Variant 2: Heavy Gaussian ─────────────────────────────────────────
|
||||
- name: gaussian_heavy
|
||||
description: "Strong Gaussian noise on all modalities (σ=0.20)"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.20
|
||||
audio:
|
||||
type: gaussian
|
||||
intensity: 0.20
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.20
|
||||
|
||||
# ── Variant 3: Missing modality – audio dropped ───────────────────────
|
||||
- name: missing_audio
|
||||
description: "Audio features zeroed out (100% drop)"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
audio:
|
||||
type: masking
|
||||
intensity: 1.0 # mask ALL dims
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
|
||||
# ── Variant 4: Missing modality – visual dropped ──────────────────────
|
||||
- name: missing_visual
|
||||
description: "Visual features zeroed out"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
audio:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
visual:
|
||||
type: missing_modality
|
||||
intensity: 1.0
|
||||
|
||||
# ── Variant 5: Text word-drop ─────────────────────────────────────────
|
||||
- name: text_word_drop_30
|
||||
description: "30% token dropout in text modality"
|
||||
noise:
|
||||
text:
|
||||
type: word_drop
|
||||
intensity: 0.30
|
||||
audio:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
|
||||
# ── Variant 6: Audio feature masking ─────────────────────────────────
|
||||
- name: audio_masking_50
|
||||
description: "50% of audio feature dimensions masked"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
audio:
|
||||
type: masking
|
||||
intensity: 0.50
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
|
||||
# ── Variant 7: Realistic mixed noise ─────────────────────────────────
|
||||
- name: realistic_mixed
|
||||
description: >
|
||||
Realistic scenario: moderate text corruption, noisy audio,
|
||||
partial visual occlusion
|
||||
noise:
|
||||
text:
|
||||
type: word_drop
|
||||
intensity: 0.15
|
||||
audio:
|
||||
type: gaussian
|
||||
intensity: 0.10
|
||||
visual:
|
||||
type: occlusion
|
||||
intensity: 0.25
|
||||
|
||||
# ── Variant 8: Asynchronous noise (time-shifted audio) ───────────────
|
||||
- name: audio_time_mask
|
||||
description: "Audio time-step masking (30%) simulates dropped frames"
|
||||
noise:
|
||||
text:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
audio:
|
||||
type: time_mask
|
||||
intensity: 0.30
|
||||
visual:
|
||||
type: gaussian
|
||||
intensity: 0.0
|
||||
@@ -1,67 +0,0 @@
|
||||
import torch, numpy as np, os, subprocess
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
||||
DATA = Path(ZSY + '/multimodal_affect/data')
|
||||
ok = []; fail = []
|
||||
|
||||
def chk(name, cond, detail=''):
|
||||
mark = 'OK ' if cond else 'FAIL'
|
||||
print(f' [{mark}] {name}: {detail}')
|
||||
(ok if cond else fail).append(name)
|
||||
|
||||
print('=== Phase 0 Acceptance Check ===')
|
||||
|
||||
# GPU
|
||||
g = torch.cuda.device_count()
|
||||
chk('4x GPU', g == 4, str(g) + ' GPUs')
|
||||
|
||||
# packages
|
||||
import transformers, librosa, timm, sklearn, wandb, stable_baselines3, gymnasium
|
||||
for name, mod in [
|
||||
('transformers', transformers), ('librosa', librosa), ('timm', timm),
|
||||
('sklearn', sklearn), ('wandb', wandb), ('sb3', stable_baselines3), ('gymnasium', gymnasium)
|
||||
]:
|
||||
chk(name, True, getattr(mod, '__version__', 'ok'))
|
||||
|
||||
# datasets
|
||||
for ds, mods in [
|
||||
('iemocap', ['text', 'audio', 'vision', 'labels']),
|
||||
('mosi', ['text', 'audio', 'vision', 'labels']),
|
||||
('meld', ['text', 'labels']),
|
||||
]:
|
||||
all_ok = all((DATA / ds / (s + '_' + m + '.npy')).exists()
|
||||
for s in ['train', 'val', 'test'] for m in mods)
|
||||
if all_ok:
|
||||
sample = np.load(str(DATA / ds / 'train_labels.npy'))
|
||||
chk(ds + ' features', True, 'train N=' + str(sample.shape[0]))
|
||||
else:
|
||||
missing = [s + '_' + m for s in ['train', 'val', 'test'] for m in mods
|
||||
if not (DATA / ds / (s + '_' + m + '.npy')).exists()]
|
||||
chk(ds + ' features', False, 'missing: ' + str(missing[:3]))
|
||||
|
||||
# noise variants
|
||||
noisy = DATA / 'iemocap_noisy'
|
||||
n_var = len([x for x in noisy.iterdir() if x.is_dir()]) if noisy.exists() else 0
|
||||
chk('noise 8 variants', n_var == 8, str(n_var) + '/8')
|
||||
|
||||
# git
|
||||
r = subprocess.run(['git', '-C', ZSY + '/multimodal_affect', 'log', '--oneline', '-3'],
|
||||
capture_output=True, text=True)
|
||||
log_oneline = r.stdout.strip().replace('\n', ' | ')
|
||||
chk('git history', r.returncode == 0, log_oneline)
|
||||
|
||||
# disk
|
||||
r2 = subprocess.run(['df', '-h', ZSY], capture_output=True, text=True)
|
||||
for line in r2.stdout.splitlines()[1:]:
|
||||
parts = line.split()
|
||||
if len(parts) >= 4:
|
||||
chk('disk free', True, 'avail=' + parts[3] + ' use=' + parts[4])
|
||||
|
||||
print()
|
||||
print('Result: ' + str(len(ok)) + ' OK / ' + str(len(fail)) + ' FAIL')
|
||||
if not fail:
|
||||
print('Phase 0 PASS - ready for Phase 1')
|
||||
else:
|
||||
print('FAIL items: ' + str(fail))
|
||||
@@ -1,287 +0,0 @@
|
||||
"""
|
||||
IEMOCAP feature extraction script.
|
||||
|
||||
Expected dataset structure:
|
||||
$DATA_ROOT/IEMOCAP_full_release/
|
||||
Session1/ ... Session5/
|
||||
dialog/
|
||||
EmoEvaluation/ -> label files (.txt)
|
||||
transcriptions/ -> transcript files (.txt)
|
||||
wav/ -> utterance wav files (Session1_F_improvised_001_F000.wav, ...)
|
||||
|
||||
Output: $ZSY/multimodal_affect/data/iemocap/
|
||||
{train,val,test}_text.npy shape: (N, seq_len) token ids (or (N, 768) if model available)
|
||||
{train,val,test}_audio.npy shape: (N, 40) MFCC means
|
||||
{train,val,test}_labels.npy shape: (N,) int labels
|
||||
label_map.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import wave
|
||||
import struct
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# ── constants ──────────────────────────────────────────────────────────────
|
||||
EMOTION_MAP = {"ang": 0, "hap": 1, "exc": 1, "sad": 2, "neu": 3} # exc merged into hap
|
||||
SESSIONS = ["Session1", "Session2", "Session3", "Session4", "Session5"]
|
||||
LABEL_NAMES = ["angry", "happy", "sad", "neutral"]
|
||||
SAMPLE_RATE = 16000
|
||||
N_MFCC = 40
|
||||
SEED = 42
|
||||
|
||||
|
||||
# ── audio utilities (no libsndfile needed) ─────────────────────────────────
|
||||
def _load_wav_stdlib(path: str):
|
||||
"""Load WAV with stdlib wave module → float32 mono array."""
|
||||
with wave.open(path, "rb") as f:
|
||||
n_channels = f.getnchannels()
|
||||
sampwidth = f.getsampwidth()
|
||||
n_frames = f.getnframes()
|
||||
raw = f.readframes(n_frames)
|
||||
|
||||
if sampwidth == 2:
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sampwidth == 4:
|
||||
samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sample width: {sampwidth}")
|
||||
|
||||
if n_channels > 1:
|
||||
samples = samples.reshape(-1, n_channels).mean(axis=1)
|
||||
return samples
|
||||
|
||||
|
||||
def _load_audio(path: str):
|
||||
"""Try av → stdlib wave, return float32 mono array."""
|
||||
try:
|
||||
import av
|
||||
container = av.open(path)
|
||||
stream = next(s for s in container.streams if s.type == "audio")
|
||||
chunks = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
arr = frame.to_ndarray()
|
||||
if arr.ndim == 2:
|
||||
arr = arr.mean(axis=0)
|
||||
chunks.append(arr.astype(np.float32))
|
||||
container.close()
|
||||
if chunks:
|
||||
return np.concatenate(chunks)
|
||||
except Exception:
|
||||
pass
|
||||
return _load_wav_stdlib(path)
|
||||
|
||||
|
||||
# ── MFCC via DCT (no librosa fallback if soundfile missing) ───────────────
|
||||
def _framing(signal, frame_len, hop_len):
|
||||
n_frames = 1 + (len(signal) - frame_len) // hop_len
|
||||
idx = np.arange(frame_len)[None, :] + hop_len * np.arange(n_frames)[:, None]
|
||||
return signal[idx]
|
||||
|
||||
|
||||
def compute_mfcc(signal: np.ndarray, sr: int = SAMPLE_RATE,
|
||||
n_mfcc: int = N_MFCC, n_fft: int = 512,
|
||||
hop_length: int = 160, n_mels: int = 40) -> np.ndarray:
|
||||
"""Minimal MFCC without librosa/soundfile dependency."""
|
||||
try:
|
||||
import librosa
|
||||
# librosa may use audioread / av backend
|
||||
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc,
|
||||
n_fft=n_fft, hop_length=hop_length,
|
||||
n_mels=n_mels)
|
||||
return mfcc.T # (T, n_mfcc)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# pure-numpy fallback
|
||||
frame_len = n_fft
|
||||
frames = _framing(signal, frame_len, hop_len=hop_length)
|
||||
window = np.hanning(frame_len)
|
||||
frames = frames * window[None, :]
|
||||
|
||||
mag = np.abs(np.fft.rfft(frames, n=n_fft))
|
||||
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
|
||||
|
||||
# mel filterbank
|
||||
def hz2mel(f): return 2595 * np.log10(1 + f / 700)
|
||||
def mel2hz(m): return 700 * (10 ** (m / 2595) - 1)
|
||||
mel_low, mel_high = hz2mel(80), hz2mel(sr / 2)
|
||||
mel_pts = np.linspace(mel_low, mel_high, n_mels + 2)
|
||||
hz_pts = mel2hz(mel_pts)
|
||||
bins = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
|
||||
|
||||
fbank = np.zeros((n_mels, n_fft // 2 + 1))
|
||||
for m in range(1, n_mels + 1):
|
||||
lo, ctr, hi = bins[m - 1], bins[m], bins[m + 1]
|
||||
fbank[m - 1, lo:ctr] = (np.arange(lo, ctr) - lo) / (ctr - lo + 1e-8)
|
||||
fbank[m - 1, ctr:hi] = (hi - np.arange(ctr, hi)) / (hi - ctr + 1e-8)
|
||||
|
||||
mel_energy = np.dot(mag ** 2, fbank.T)
|
||||
log_mel = np.log(np.maximum(mel_energy, 1e-9))
|
||||
|
||||
# DCT-II
|
||||
n = np.arange(n_mfcc)[:, None]
|
||||
k = np.arange(n_mels)[None, :]
|
||||
dct = np.cos(np.pi * n * (2 * k + 1) / (2 * n_mels))
|
||||
mfcc = np.dot(log_mel, dct.T)
|
||||
return mfcc # (T, n_mfcc)
|
||||
|
||||
|
||||
def mfcc_features(wav_path: str) -> np.ndarray:
|
||||
"""Return mean MFCC over time → shape (n_mfcc,)."""
|
||||
sig = _load_audio(wav_path)
|
||||
mfcc = compute_mfcc(sig)
|
||||
return mfcc.mean(axis=0)
|
||||
|
||||
|
||||
# ── text tokenisation ──────────────────────────────────────────────────────
|
||||
def get_text_features(text: str, tokenizer=None, model=None,
|
||||
max_len: int = 64) -> np.ndarray:
|
||||
"""Return [CLS] embedding (768-d) or BoW int vector (max_len,)."""
|
||||
if tokenizer is not None and model is not None:
|
||||
import torch
|
||||
enc = tokenizer(text, return_tensors="pt", truncation=True,
|
||||
max_length=max_len, padding="max_length")
|
||||
with torch.no_grad():
|
||||
out = model(**enc)
|
||||
return out.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()
|
||||
|
||||
# simple token-id fallback (word hash)
|
||||
tokens = text.lower().split()[:max_len]
|
||||
ids = [hash(t) % 30522 for t in tokens]
|
||||
ids += [0] * (max_len - len(ids))
|
||||
return np.array(ids, dtype=np.int32)
|
||||
|
||||
|
||||
# ── label parsing ──────────────────────────────────────────────────────────
|
||||
def parse_label_file(label_path: str) -> dict:
|
||||
"""Return dict: utterance_id → emotion string."""
|
||||
labels = {}
|
||||
with open(label_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("["):
|
||||
parts = line.strip().split("\t")
|
||||
if len(parts) >= 2:
|
||||
uid = parts[1].strip()
|
||||
emo = parts[2].strip().lower() if len(parts) > 2 else "xxx"
|
||||
labels[uid] = emo
|
||||
return labels
|
||||
|
||||
|
||||
def parse_transcription_file(trans_path: str) -> dict:
|
||||
"""Return dict: utterance_id → text."""
|
||||
texts = {}
|
||||
with open(trans_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
m = re.match(r"^(\w+)\s*\[.*?\]\s*:\s*(.+)$", line.strip())
|
||||
if m:
|
||||
texts[m.group(1)] = m.group(2).strip()
|
||||
return texts
|
||||
|
||||
|
||||
# ── main extraction ────────────────────────────────────────────────────────
|
||||
def extract_iemocap(data_root: str, out_dir: str,
|
||||
use_transformer: bool = False,
|
||||
model_name: str = "roberta-base",
|
||||
val_sessions: list = None,
|
||||
test_sessions: list = None):
|
||||
data_root = Path(data_root)
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if val_sessions is None:
|
||||
val_sessions = ["Session4"]
|
||||
if test_sessions is None:
|
||||
test_sessions = ["Session5"]
|
||||
|
||||
tokenizer, model = None, None
|
||||
if use_transformer:
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
print(f"Loading {model_name} …")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name)
|
||||
model.eval()
|
||||
|
||||
splits = {"train": [], "val": [], "test": []}
|
||||
|
||||
for session in SESSIONS:
|
||||
sess_dir = data_root / "IEMOCAP_full_release" / session
|
||||
if not sess_dir.exists():
|
||||
print(f" [skip] {sess_dir} not found")
|
||||
continue
|
||||
|
||||
emo_dir = sess_dir / "dialog" / "EmoEvaluation"
|
||||
trans_dir = sess_dir / "dialog" / "transcriptions"
|
||||
wav_dir = sess_dir / "sentences" / "wav"
|
||||
|
||||
if session in test_sessions:
|
||||
split = "test"
|
||||
elif session in val_sessions:
|
||||
split = "val"
|
||||
else:
|
||||
split = "train"
|
||||
|
||||
for label_file in sorted(emo_dir.glob("*.txt")):
|
||||
labels = parse_label_file(str(label_file))
|
||||
dialog_id = label_file.stem
|
||||
|
||||
trans_file = trans_dir / (dialog_id + ".txt")
|
||||
texts = parse_transcription_file(str(trans_file)) if trans_file.exists() else {}
|
||||
|
||||
for uid, emo in labels.items():
|
||||
if emo not in EMOTION_MAP:
|
||||
continue
|
||||
label = EMOTION_MAP[emo]
|
||||
text = texts.get(uid, "")
|
||||
|
||||
wav_path = wav_dir / dialog_id / (uid + ".wav")
|
||||
if not wav_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
audio_feat = mfcc_features(str(wav_path))
|
||||
text_feat = get_text_features(text, tokenizer, model)
|
||||
splits[split].append((text_feat, audio_feat, label))
|
||||
except Exception as e:
|
||||
print(f" [warn] {uid}: {e}")
|
||||
|
||||
print(f" {session} → {split}: {len(splits[split])} so far")
|
||||
|
||||
label_map = {i: name for i, name in enumerate(LABEL_NAMES)}
|
||||
with open(out_dir / "label_map.json", "w") as f:
|
||||
json.dump(label_map, f, indent=2)
|
||||
|
||||
for split, items in splits.items():
|
||||
if not items:
|
||||
print(f" [warn] {split} is empty")
|
||||
continue
|
||||
text_arr = np.stack([x[0] for x in items])
|
||||
audio_arr = np.stack([x[1] for x in items])
|
||||
label_arr = np.array([x[2] for x in items], dtype=np.int64)
|
||||
np.save(out_dir / f"{split}_text.npy", text_arr)
|
||||
np.save(out_dir / f"{split}_audio.npy", audio_arr)
|
||||
np.save(out_dir / f"{split}_labels.npy", label_arr)
|
||||
print(f" Saved {split}: text {text_arr.shape}, audio {audio_arr.shape}, labels {label_arr.shape}")
|
||||
|
||||
print("Done →", out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", required=True,
|
||||
help="Parent dir containing IEMOCAP_full_release/")
|
||||
parser.add_argument("--out_dir", default=None)
|
||||
parser.add_argument("--use_transformer", action="store_true")
|
||||
parser.add_argument("--model_name", default="roberta-base")
|
||||
args = parser.parse_args()
|
||||
|
||||
zsy = os.environ.get("ZSY", "/root")
|
||||
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/iemocap"
|
||||
extract_iemocap(args.data_root, out_dir,
|
||||
use_transformer=args.use_transformer,
|
||||
model_name=args.model_name)
|
||||
@@ -1,200 +0,0 @@
|
||||
"""
|
||||
MELD (Multimodal EmotionLines Dataset) feature extraction.
|
||||
|
||||
Dataset structure:
|
||||
$DATA_ROOT/MELD.Raw/
|
||||
train_sent_emo.csv
|
||||
dev_sent_emo.csv
|
||||
test_sent_emo.csv
|
||||
train/ dev/ test/ → subdirs with mp4 clips
|
||||
dia{N}_utt{M}.mp4
|
||||
|
||||
CSV columns:
|
||||
Sr No., Utterance, Speaker, Emotion, Sentiment,
|
||||
Dialogue_ID, Utterance_ID, Season, Episode, StartTime, EndTime
|
||||
|
||||
Emotions: neutral, surprise, fear, sadness, joy, disgust, anger
|
||||
|
||||
Output: $ZSY/multimodal_affect/data/meld/
|
||||
{train,val,test}_{text,audio,labels}.npy
|
||||
label_map.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
EMOTION_MAP = {
|
||||
"neutral": 0, "surprise": 1, "fear": 2,
|
||||
"sadness": 3, "joy": 4, "disgust": 5, "anger": 6,
|
||||
}
|
||||
LABEL_NAMES = ["neutral", "surprise", "fear", "sadness", "joy", "disgust", "anger"]
|
||||
N_MFCC = 40
|
||||
|
||||
|
||||
# ── audio loading ──────────────────────────────────────────────────────────
|
||||
def _load_audio_bytes(path: str) -> np.ndarray:
|
||||
"""Load audio from WAV or MP4 via av; fall back to wave stdlib."""
|
||||
path = str(path)
|
||||
if path.endswith(".mp4") or path.endswith(".mp3"):
|
||||
try:
|
||||
import av
|
||||
container = av.open(path)
|
||||
stream = next((s for s in container.streams if s.type == "audio"), None)
|
||||
if stream is None:
|
||||
return np.zeros(16000, dtype=np.float32)
|
||||
chunks = []
|
||||
for pkt in container.demux(stream):
|
||||
for frame in pkt.decode():
|
||||
arr = frame.to_ndarray()
|
||||
if arr.ndim == 2:
|
||||
arr = arr.mean(axis=0)
|
||||
chunks.append(arr.astype(np.float32))
|
||||
container.close()
|
||||
if chunks:
|
||||
return np.concatenate(chunks)
|
||||
except Exception as e:
|
||||
print(f" av failed for {path}: {e}")
|
||||
return np.zeros(16000, dtype=np.float32)
|
||||
|
||||
# WAV via stdlib
|
||||
with wave.open(path, "rb") as f:
|
||||
n_ch = f.getnchannels()
|
||||
sw = f.getsampwidth()
|
||||
raw = f.readframes(f.getnframes())
|
||||
if sw == 2:
|
||||
sig = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
|
||||
elif sw == 4:
|
||||
sig = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2**31
|
||||
else:
|
||||
sig = np.frombuffer(raw, dtype=np.float32)
|
||||
return sig.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else sig
|
||||
|
||||
|
||||
def _compute_mfcc_mean(signal: np.ndarray, sr: int = 16000) -> np.ndarray:
|
||||
try:
|
||||
import librosa
|
||||
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=N_MFCC)
|
||||
return mfcc.mean(axis=1)
|
||||
except Exception:
|
||||
pass
|
||||
# energy-based fallback
|
||||
rms = float(np.sqrt(np.mean(signal ** 2) + 1e-9))
|
||||
feat = np.zeros(N_MFCC, dtype=np.float32)
|
||||
feat[0] = rms
|
||||
return feat
|
||||
|
||||
|
||||
# ── text features ──────────────────────────────────────────────────────────
|
||||
def _text_features(text: str, max_len: int = 64) -> np.ndarray:
|
||||
tokens = text.lower().split()[:max_len]
|
||||
ids = [hash(t) % 30522 for t in tokens]
|
||||
ids += [0] * (max_len - len(ids))
|
||||
return np.array(ids, dtype=np.int32)
|
||||
|
||||
|
||||
# ── csv parsing ────────────────────────────────────────────────────────────
|
||||
def read_csv(csv_path: str):
|
||||
records = []
|
||||
with open(csv_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
records.append(row)
|
||||
return records
|
||||
|
||||
|
||||
def extract_split(csv_path: str, clip_dir: Path, out_prefix: Path,
|
||||
has_video: bool = True):
|
||||
records = read_csv(csv_path)
|
||||
texts, audios, labels_list = [], [], []
|
||||
|
||||
for rec in records:
|
||||
emo = rec.get("Emotion", "").strip().lower()
|
||||
if emo not in EMOTION_MAP:
|
||||
continue
|
||||
label = EMOTION_MAP[emo]
|
||||
|
||||
utterance = rec.get("Utterance", "").strip()
|
||||
dia_id = rec.get("Dialogue_ID", "").strip()
|
||||
utt_id = rec.get("Utterance_ID", "").strip()
|
||||
|
||||
# find audio
|
||||
audio_feat = np.zeros(N_MFCC, dtype=np.float32)
|
||||
if has_video and clip_dir.exists():
|
||||
clip_name = f"dia{dia_id}_utt{utt_id}.mp4"
|
||||
clip_path = clip_dir / clip_name
|
||||
if clip_path.exists():
|
||||
try:
|
||||
sig = _load_audio_bytes(str(clip_path))
|
||||
audio_feat = _compute_mfcc_mean(sig)
|
||||
except Exception as e:
|
||||
print(f" [warn] {clip_name}: {e}")
|
||||
|
||||
text_feat = _text_features(utterance)
|
||||
texts.append(text_feat)
|
||||
audios.append(audio_feat)
|
||||
labels_list.append(label)
|
||||
|
||||
if not labels_list:
|
||||
print(f" [warn] no valid records in {csv_path}")
|
||||
return
|
||||
|
||||
split = out_prefix.name
|
||||
base = out_prefix.parent
|
||||
np.save(base / f"{split}_text.npy", np.stack(texts))
|
||||
np.save(base / f"{split}_audio.npy", np.stack(audios))
|
||||
np.save(base / f"{split}_labels.npy", np.array(labels_list, dtype=np.int64))
|
||||
print(f" {split}: {len(labels_list)} samples, "
|
||||
f"text {np.stack(texts).shape}, audio {np.stack(audios).shape}")
|
||||
|
||||
|
||||
def extract_meld(data_root: str, out_dir: str):
|
||||
data_root = Path(data_root)
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meld_root = data_root / "MELD.Raw"
|
||||
if not meld_root.exists():
|
||||
meld_root = data_root # maybe already inside MELD.Raw
|
||||
|
||||
csv_map = {
|
||||
"train": "train_sent_emo.csv",
|
||||
"val": "dev_sent_emo.csv",
|
||||
"test": "test_sent_emo.csv",
|
||||
}
|
||||
dir_map = {
|
||||
"train": "train",
|
||||
"val": "dev",
|
||||
"test": "test",
|
||||
}
|
||||
|
||||
for split, csv_name in csv_map.items():
|
||||
csv_path = meld_root / csv_name
|
||||
if not csv_path.exists():
|
||||
print(f" [skip] {csv_path} not found")
|
||||
continue
|
||||
clip_dir = meld_root / dir_map[split]
|
||||
extract_split(str(csv_path), clip_dir, out_dir / split, has_video=clip_dir.exists())
|
||||
|
||||
label_map = {i: n for i, n in enumerate(LABEL_NAMES)}
|
||||
with open(out_dir / "label_map.json", "w") as f:
|
||||
json.dump(label_map, f, indent=2)
|
||||
|
||||
print("Done →", out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", required=True,
|
||||
help="Dir containing MELD.Raw/ (or already inside it)")
|
||||
parser.add_argument("--out_dir", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
zsy = os.environ.get("ZSY", "/root")
|
||||
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/meld"
|
||||
extract_meld(args.data_root, out_dir)
|
||||
@@ -1,241 +0,0 @@
|
||||
"""
|
||||
CMU-MOSI feature extraction script.
|
||||
|
||||
Supports two pickle formats:
|
||||
|
||||
Format A – CMU Multimodal SDK (aligned_50.pkl):
|
||||
data[split][modality][sample_id] = np.ndarray
|
||||
modalities: 'text', 'audio', 'vision', 'labels'
|
||||
splits: 'train', 'valid', 'test'
|
||||
|
||||
Format B – declare-lab flat array (mosi.pkl):
|
||||
data[split][modality] = np.ndarray shape (N, dim)
|
||||
modalities: 'glove'(text), 'covarep'(audio), 'facet'(visual), 'label'
|
||||
splits: 'train', 'valid', 'test'
|
||||
|
||||
Output: $ZSY/multimodal_affect/data/mosi/
|
||||
{train,val,test}_{text,audio,vision,labels}.npy
|
||||
meta.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SENTIMENT_BINS = [(-np.inf, -1, 0), (-1, 1, 1), (1, np.inf, 2)]
|
||||
LABEL_NAMES = ["negative", "neutral", "positive"]
|
||||
|
||||
|
||||
def sentiment_to_class(score: float) -> int:
|
||||
"""Continuous sentiment [-3,3] → 3-class label."""
|
||||
if score < -1:
|
||||
return 0
|
||||
if score <= 1:
|
||||
return 1
|
||||
return 2
|
||||
|
||||
|
||||
def load_sdk_pickle(pkl_path: str):
|
||||
"""Load CMU-SDK aligned pickle."""
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f, encoding="latin1")
|
||||
return data
|
||||
|
||||
|
||||
def extract_from_sdk(pkl_path: str, out_dir: Path):
|
||||
"""Extract from pre-aligned CMU-SDK pickle."""
|
||||
data = load_sdk_pickle(pkl_path)
|
||||
|
||||
split_map = {"train": "train", "valid": "val", "test": "test"}
|
||||
|
||||
for sdk_split, out_split in split_map.items():
|
||||
if sdk_split not in data:
|
||||
print(f" [skip] split '{sdk_split}' not in pickle")
|
||||
continue
|
||||
|
||||
split_data = data[sdk_split]
|
||||
ids = list(split_data.get("text", split_data.get("labels", {})).keys())
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
texts, audios, visions, labels = [], [], [], []
|
||||
for sid in ids:
|
||||
lbl_raw = split_data["labels"].get(sid)
|
||||
if lbl_raw is None:
|
||||
continue
|
||||
score = float(np.array(lbl_raw).flatten()[0])
|
||||
label = sentiment_to_class(score)
|
||||
|
||||
text = np.array(split_data["text"][sid], dtype=np.float32) if "text" in split_data else np.zeros((1, 300), dtype=np.float32)
|
||||
audio = np.array(split_data["audio"][sid], dtype=np.float32) if "audio" in split_data else np.zeros((1, 74), dtype=np.float32)
|
||||
vision = np.array(split_data["vision"][sid], dtype=np.float32) if "vision" in split_data else np.zeros((1, 35), dtype=np.float32)
|
||||
|
||||
# temporal mean pooling
|
||||
texts.append(text.mean(axis=0) if text.ndim == 2 else text.flatten())
|
||||
audios.append(audio.mean(axis=0) if audio.ndim == 2 else audio.flatten())
|
||||
visions.append(vision.mean(axis=0) if vision.ndim == 2 else vision.flatten())
|
||||
labels.append(label)
|
||||
|
||||
if not labels:
|
||||
continue
|
||||
|
||||
np.save(out_dir / f"{out_split}_text.npy", np.stack(texts))
|
||||
np.save(out_dir / f"{out_split}_audio.npy", np.stack(audios))
|
||||
np.save(out_dir / f"{out_split}_vision.npy", np.stack(visions))
|
||||
np.save(out_dir / f"{out_split}_labels.npy", np.array(labels, dtype=np.int64))
|
||||
print(f" {out_split}: {len(labels)} samples")
|
||||
|
||||
|
||||
def is_flat_format(data: dict) -> bool:
|
||||
"""Detect declare-lab flat array format: data[split][modality] = np.ndarray."""
|
||||
for split in ("train", "valid", "test"):
|
||||
if split in data:
|
||||
v = list(data[split].values())[0]
|
||||
return isinstance(v, np.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def extract_from_flat(pkl_path: str, out_dir: Path):
|
||||
"""Extract from declare-lab flat pickle (mosi.pkl).
|
||||
|
||||
Format: data[split]['glove'|'covarep'|'facet'|'label'] = np.ndarray (N, dim)
|
||||
Labels are continuous scores in [-3, 3]; binarised to 3 classes.
|
||||
"""
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f, encoding="latin1")
|
||||
|
||||
split_map = {"train": "train", "valid": "val", "test": "test"}
|
||||
# modality name aliases
|
||||
text_key = next((k for k in ("glove", "text", "bert") if k in list(data.get("train", {}).keys())), None)
|
||||
audio_key = next((k for k in ("covarep", "audio", "opensmile") if k in list(data.get("train", {}).keys())), None)
|
||||
vision_key = next((k for k in ("facet", "vision", "visual") if k in list(data.get("train", {}).keys())), None)
|
||||
label_key = next((k for k in ("label", "labels", "Opinion Segment Labels") if k in list(data.get("train", {}).keys())), None)
|
||||
|
||||
print(f" Detected keys — text:{text_key} audio:{audio_key} vision:{vision_key} label:{label_key}")
|
||||
|
||||
for sdk_split, out_split in split_map.items():
|
||||
if sdk_split not in data:
|
||||
print(f" [skip] '{sdk_split}' not found")
|
||||
continue
|
||||
sd = data[sdk_split]
|
||||
|
||||
labels_raw = sd[label_key].flatten() if label_key else np.zeros(len(sd[text_key or audio_key]))
|
||||
labels = np.array([sentiment_to_class(float(s)) for s in labels_raw], dtype=np.int64)
|
||||
n = len(labels)
|
||||
|
||||
text = sd[text_key].astype(np.float32) if text_key else np.zeros((n, 300), dtype=np.float32)
|
||||
audio = sd[audio_key].astype(np.float32) if audio_key else np.zeros((n, 74), dtype=np.float32)
|
||||
vision = sd[vision_key].astype(np.float32) if vision_key else np.zeros((n, 46), dtype=np.float32)
|
||||
|
||||
# mean-pool time dimension if present: (N, T, dim) → (N, dim)
|
||||
if text.ndim == 3:
|
||||
text = text.mean(axis=1)
|
||||
if audio.ndim == 3:
|
||||
audio = audio.mean(axis=1)
|
||||
if vision.ndim == 3:
|
||||
vision = vision.mean(axis=1)
|
||||
|
||||
np.save(out_dir / f"{out_split}_text.npy", text)
|
||||
np.save(out_dir / f"{out_split}_audio.npy", audio)
|
||||
np.save(out_dir / f"{out_split}_vision.npy", vision)
|
||||
np.save(out_dir / f"{out_split}_labels.npy", labels)
|
||||
print(f" {out_split}: {n} samples text{text.shape} audio{audio.shape} vision{vision.shape}")
|
||||
|
||||
|
||||
def extract_from_raw(raw_dir: Path, out_dir: Path):
|
||||
"""Fallback: extract from raw files using local MFCC + hashed text."""
|
||||
import wave
|
||||
import struct
|
||||
|
||||
def load_wav_stdlib(path):
|
||||
with wave.open(str(path), "rb") as f:
|
||||
n_ch = f.getnchannels()
|
||||
sw = f.getsampwidth()
|
||||
raw = f.readframes(f.getnframes())
|
||||
if sw == 2:
|
||||
s = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
|
||||
else:
|
||||
s = np.frombuffer(raw, dtype=np.float32)
|
||||
return s.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else s
|
||||
|
||||
print("[raw mode] scanning", raw_dir)
|
||||
wav_files = sorted(raw_dir.rglob("*.wav"))
|
||||
if not wav_files:
|
||||
print(" No WAV files found under", raw_dir)
|
||||
return
|
||||
|
||||
data = []
|
||||
for wf in wav_files:
|
||||
try:
|
||||
sig = load_wav_stdlib(str(wf))
|
||||
feat = sig.mean(), sig.std(), sig.max(), sig.min()
|
||||
text_feat = np.array([hash(wf.stem) % 30522], dtype=np.float32)
|
||||
data.append((text_feat, np.array(feat, dtype=np.float32), 1)) # neutral default
|
||||
except Exception as e:
|
||||
print(f" [warn] {wf.name}: {e}")
|
||||
|
||||
if data:
|
||||
np.save(out_dir / "train_audio.npy", np.stack([x[1] for x in data]))
|
||||
np.save(out_dir / "train_labels.npy", np.array([x[2] for x in data]))
|
||||
print(f" Saved {len(data)} raw samples")
|
||||
|
||||
|
||||
def extract_mosi(data_root: str, out_dir: str):
|
||||
data_root = Path(data_root)
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meta = {"label_names": LABEL_NAMES, "task": "sentiment-3class"}
|
||||
|
||||
# try pickle candidates (both SDK and declare-lab flat formats)
|
||||
pkl_candidates = [
|
||||
data_root / "mosi.pkl", # declare-lab flat
|
||||
data_root / "aligned_mosi.pkl", # mmsdk aligned
|
||||
data_root / "CMU_MOSI" / "Processed" / "aligned_50.pkl", # SDK standard
|
||||
data_root / "CMU_MOSI" / "Processed" / "unaligned_50.pkl",
|
||||
data_root / "mosi_data.pkl",
|
||||
data_root / "aligned_50.pkl",
|
||||
]
|
||||
for pkl in pkl_candidates:
|
||||
if pkl.exists():
|
||||
print(f"Found pickle: {pkl}")
|
||||
with open(pkl, "rb") as f:
|
||||
probe = pickle.load(f, encoding="latin1")
|
||||
if is_flat_format(probe):
|
||||
print(" Format: declare-lab flat array")
|
||||
extract_from_flat(str(pkl), out_dir)
|
||||
else:
|
||||
print(" Format: CMU-SDK dict-of-dicts")
|
||||
extract_from_sdk(str(pkl), out_dir)
|
||||
meta["source"] = str(pkl)
|
||||
meta["format"] = "flat" if is_flat_format(probe) else "sdk"
|
||||
break
|
||||
else:
|
||||
raw_dir = data_root / "CMU_MOSI" / "Raw"
|
||||
if raw_dir.exists():
|
||||
extract_from_raw(raw_dir, out_dir)
|
||||
meta["source"] = str(raw_dir)
|
||||
else:
|
||||
print(f"[error] No CMU-MOSI data found under {data_root}")
|
||||
print(" Tried:", [str(p) for p in pkl_candidates])
|
||||
return
|
||||
|
||||
with open(out_dir / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
print("Done →", out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", required=True,
|
||||
help="Dir containing CMU_MOSI/ subdirectory")
|
||||
parser.add_argument("--out_dir", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
zsy = os.environ.get("ZSY", "/root")
|
||||
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/mosi"
|
||||
extract_mosi(args.data_root, out_dir)
|
||||
@@ -1,242 +0,0 @@
|
||||
"""
|
||||
P0-4: Multimodal noise generation for robustness experiments.
|
||||
|
||||
Supports three modalities: text, audio, visual.
|
||||
Each modality has configurable noise types and intensity levels.
|
||||
|
||||
Usage:
|
||||
python generate_noise.py --config configs/noise_configs.yaml \
|
||||
--data_dir $ZSY/multimodal_affect/data/iemocap \
|
||||
--out_dir $ZSY/multimodal_affect/data/iemocap_noisy
|
||||
|
||||
Config schema → see configs/noise_configs.yaml
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import yaml
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
RNG = np.random.default_rng(42)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════
|
||||
# TEXT NOISE
|
||||
# ═══════════════════════════════════════════════════════
|
||||
|
||||
def _word_drop(ids: np.ndarray, drop_rate: float) -> np.ndarray:
|
||||
"""Randomly zero-out token ids (simulates word deletion)."""
|
||||
mask = RNG.random(ids.shape) < drop_rate
|
||||
return np.where(mask, 0, ids)
|
||||
|
||||
|
||||
def _word_swap(ids: np.ndarray, swap_rate: float) -> np.ndarray:
|
||||
"""Randomly shuffle adjacent tokens."""
|
||||
ids = ids.copy()
|
||||
n = len(ids)
|
||||
for i in range(n - 1):
|
||||
if RNG.random() < swap_rate:
|
||||
ids[i], ids[i + 1] = ids[i + 1], ids[i]
|
||||
return ids
|
||||
|
||||
|
||||
def _random_replace(ids: np.ndarray, replace_rate: float, vocab_size: int = 30522) -> np.ndarray:
|
||||
"""Replace tokens with random vocab ids."""
|
||||
ids = ids.copy()
|
||||
mask = RNG.random(ids.shape) < replace_rate
|
||||
rand_ids = RNG.integers(1, vocab_size, size=ids.shape)
|
||||
return np.where(mask & (ids != 0), rand_ids, ids)
|
||||
|
||||
|
||||
def add_text_noise(features: np.ndarray, cfg: Dict) -> np.ndarray:
|
||||
"""Apply text noise to an array of token-id features (N, seq_len)."""
|
||||
noise_type = cfg.get("type", "word_drop")
|
||||
intensity = float(cfg.get("intensity", 0.1))
|
||||
|
||||
if noise_type == "word_drop":
|
||||
return np.stack([_word_drop(row, intensity) for row in features])
|
||||
if noise_type == "word_swap":
|
||||
return np.stack([_word_swap(row, intensity) for row in features])
|
||||
if noise_type == "random_replace":
|
||||
return np.stack([_random_replace(row, intensity) for row in features])
|
||||
if noise_type == "gaussian":
|
||||
# for embedding features (N, dim) not token ids
|
||||
noise = RNG.standard_normal(features.shape).astype(np.float32)
|
||||
return features + intensity * noise
|
||||
raise ValueError(f"Unknown text noise type: {noise_type}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════
|
||||
# AUDIO NOISE
|
||||
# ═══════════════════════════════════════════════════════
|
||||
|
||||
def add_audio_noise(features: np.ndarray, cfg: Dict) -> np.ndarray:
|
||||
"""Apply noise to audio feature matrix (N, n_mfcc)."""
|
||||
noise_type = cfg.get("type", "gaussian")
|
||||
intensity = float(cfg.get("intensity", 0.05))
|
||||
|
||||
if noise_type == "gaussian":
|
||||
noise = RNG.standard_normal(features.shape).astype(np.float32)
|
||||
return features + intensity * noise * features.std(axis=0, keepdims=True)
|
||||
|
||||
if noise_type == "masking":
|
||||
# mask entire feature dimensions (simulates missing mic)
|
||||
features = features.copy()
|
||||
n_mask = max(1, int(features.shape[1] * intensity))
|
||||
dims = RNG.choice(features.shape[1], n_mask, replace=False)
|
||||
features[:, dims] = 0.0
|
||||
return features
|
||||
|
||||
if noise_type == "time_mask":
|
||||
# mask random samples (simulates packet loss for temporal features)
|
||||
features = features.copy()
|
||||
n_mask = max(1, int(features.shape[0] * intensity))
|
||||
rows = RNG.choice(features.shape[0], n_mask, replace=False)
|
||||
features[rows, :] = 0.0
|
||||
return features
|
||||
|
||||
if noise_type == "scale":
|
||||
# random amplitude scaling
|
||||
scale = 1.0 + intensity * (RNG.random(features.shape[0]) - 0.5) * 2
|
||||
return features * scale[:, None]
|
||||
|
||||
raise ValueError(f"Unknown audio noise type: {noise_type}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════
|
||||
# VISUAL NOISE (operates on feature vectors, not pixels)
|
||||
# ═══════════════════════════════════════════════════════
|
||||
|
||||
def add_visual_noise(features: np.ndarray, cfg: Dict) -> np.ndarray:
|
||||
"""Apply noise to visual feature matrix (N, feat_dim)."""
|
||||
noise_type = cfg.get("type", "gaussian")
|
||||
intensity = float(cfg.get("intensity", 0.1))
|
||||
|
||||
if noise_type == "gaussian":
|
||||
noise = RNG.standard_normal(features.shape).astype(np.float32)
|
||||
return features + intensity * noise
|
||||
|
||||
if noise_type == "dropout":
|
||||
mask = (RNG.random(features.shape) > intensity).astype(np.float32)
|
||||
return features * mask
|
||||
|
||||
if noise_type == "occlusion":
|
||||
# zero out a contiguous block of feature dims
|
||||
features = features.copy()
|
||||
start = RNG.integers(0, max(1, features.shape[1] - 1))
|
||||
length = max(1, int(features.shape[1] * intensity))
|
||||
features[:, start:start + length] = 0.0
|
||||
return features
|
||||
|
||||
if noise_type == "missing_modality":
|
||||
# simulate completely missing video frames
|
||||
features = features.copy()
|
||||
n_missing = max(1, int(len(features) * intensity))
|
||||
idx = RNG.choice(len(features), n_missing, replace=False)
|
||||
features[idx, :] = 0.0
|
||||
return features
|
||||
|
||||
raise ValueError(f"Unknown visual noise type: {noise_type}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════
|
||||
# COMBINED MULTIMODAL NOISE
|
||||
# ═══════════════════════════════════════════════════════
|
||||
|
||||
MODALITY_SPECS = [
|
||||
("text", ("text",), add_text_noise),
|
||||
("audio", ("audio",), add_audio_noise),
|
||||
# Dataset files use *_vision.npy. Older configs used "visual", so keep it
|
||||
# as an input alias but always write the canonical "vision" filename.
|
||||
("vision", ("vision", "visual"), add_visual_noise),
|
||||
]
|
||||
|
||||
|
||||
def _get_modality_cfg(noise_cfg: Dict, aliases: tuple) -> Dict:
|
||||
for name in aliases:
|
||||
if name in noise_cfg:
|
||||
return noise_cfg[name]
|
||||
return noise_cfg.get("default", {})
|
||||
|
||||
|
||||
def apply_noise_config(data_dir: Path, out_dir: Path, noise_cfg: Dict,
|
||||
splits: list = None):
|
||||
"""Apply noise config to all splits and modalities found in data_dir."""
|
||||
if splits is None:
|
||||
splits = ["train", "val", "test"]
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for split in splits:
|
||||
for modality, aliases, fn in MODALITY_SPECS:
|
||||
src = data_dir / f"{split}_{modality}.npy"
|
||||
if not src.exists():
|
||||
continue
|
||||
|
||||
features = np.load(str(src))
|
||||
mod_cfg = _get_modality_cfg(noise_cfg, aliases)
|
||||
|
||||
if mod_cfg:
|
||||
noisy = fn(features.astype(np.float32), mod_cfg)
|
||||
else:
|
||||
noisy = features.astype(np.float32).copy()
|
||||
dst = out_dir / f"{split}_{modality}.npy"
|
||||
np.save(str(dst), noisy)
|
||||
print(f" {split}/{modality}: {features.shape} → {dst.name}")
|
||||
|
||||
# copy labels unchanged
|
||||
label_src = data_dir / f"{split}_labels.npy"
|
||||
if label_src.exists():
|
||||
import shutil
|
||||
shutil.copy2(str(label_src), str(out_dir / f"{split}_labels.npy"))
|
||||
|
||||
# copy metadata
|
||||
for meta_file in ["label_map.json", "meta.json"]:
|
||||
src = data_dir / meta_file
|
||||
if src.exists():
|
||||
import shutil
|
||||
shutil.copy2(str(src), str(out_dir / meta_file))
|
||||
|
||||
|
||||
def generate_noise_variants(data_dir: str, out_base: str, config: Dict):
|
||||
"""Generate multiple noise variants as defined in config."""
|
||||
data_dir = Path(data_dir)
|
||||
out_base = Path(out_base)
|
||||
|
||||
variants = config.get("variants", [])
|
||||
if not variants:
|
||||
# single-variant mode: apply config directly
|
||||
apply_noise_config(data_dir, out_base, config.get("noise", {}))
|
||||
return
|
||||
|
||||
for variant in variants:
|
||||
name = variant["name"]
|
||||
noise_cfg = variant["noise"]
|
||||
out_dir = out_base / name
|
||||
print(f"\n[Variant: {name}]")
|
||||
apply_noise_config(data_dir, out_dir, noise_cfg)
|
||||
with open(out_dir / "noise_config.json", "w") as f:
|
||||
json.dump(variant, f, indent=2)
|
||||
|
||||
print(f"\nAll variants saved under {out_base}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", required=True,
|
||||
help="Path to noise_configs.yaml")
|
||||
parser.add_argument("--data_dir", required=True,
|
||||
help="Dir with {split}_{modality}.npy files")
|
||||
parser.add_argument("--out_dir", default=None,
|
||||
help="Output base dir (default: data_dir + '_noisy')")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
zsy = os.environ.get("ZSY", "/root")
|
||||
out_dir = args.out_dir or args.data_dir.rstrip("/") + "_noisy"
|
||||
generate_noise_variants(args.data_dir, out_dir, config)
|
||||
@@ -1,115 +0,0 @@
|
||||
#!/bin/bash
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# server_unpack_and_extract.sh
|
||||
# 服务器端:解压 + 特征提取一键脚本
|
||||
# 前提:数据已上传到 $ZSY/multimodal_affect/data/raw/
|
||||
#
|
||||
# 目录约定:
|
||||
# IEMOCAP zip: $ZSY/multimodal_affect/data/raw/IEMOCAP/*.zip
|
||||
# MELD tar.gz: $ZSY/multimodal_affect/data/raw/MELD/MELD.Raw.tar.gz
|
||||
# MOSI pkl: $ZSY/multimodal_affect/data/raw/MOSI/aligned_mosi.pkl
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
set -e
|
||||
source /root/.bashrc_zsy 2>/dev/null || true
|
||||
|
||||
ZSY=${ZSY:-/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy}
|
||||
PROJ=$ZSY/multimodal_affect
|
||||
RAW=$PROJ/data/raw
|
||||
PY=$ZSY/envs/multimodal_affect/bin/python
|
||||
|
||||
echo "=========================================="
|
||||
echo " Unpack & Extract — $(date)"
|
||||
echo " PROJ=$PROJ"
|
||||
echo "=========================================="
|
||||
|
||||
# ── IEMOCAP: 解压 zip ────────────────────────────────────────────────────────
|
||||
IEMOCAP_RAW=$RAW/IEMOCAP
|
||||
IEMOCAP_DEST=$RAW/IEMOCAP_full_release
|
||||
|
||||
if [ -d "$IEMOCAP_DEST/Session1" ]; then
|
||||
echo "[skip] IEMOCAP already unpacked at $IEMOCAP_DEST"
|
||||
elif ls "$IEMOCAP_RAW"/*.zip 1>/dev/null 2>&1; then
|
||||
echo "[IEMOCAP] Unzipping..."
|
||||
mkdir -p "$IEMOCAP_DEST"
|
||||
for zf in "$IEMOCAP_RAW"/*.zip; do
|
||||
echo " unzip $zf"
|
||||
unzip -q "$zf" -d "$IEMOCAP_DEST"
|
||||
done
|
||||
echo "[IEMOCAP] Unzip done. Sessions:"
|
||||
ls "$IEMOCAP_DEST/"
|
||||
else
|
||||
echo "[IEMOCAP] WARNING: no zip files found in $IEMOCAP_RAW"
|
||||
fi
|
||||
|
||||
# ── MELD: 解压 tar.gz ─────────────────────────────────────────────────────────
|
||||
MELD_RAW=$RAW/MELD
|
||||
MELD_DEST=$MELD_RAW/MELD.Raw
|
||||
|
||||
if [ -d "$MELD_DEST" ]; then
|
||||
echo "[skip] MELD already unpacked at $MELD_DEST"
|
||||
elif [ -f "$MELD_RAW/MELD.Raw.tar.gz" ]; then
|
||||
echo "[MELD] Extracting tar.gz (~10.8GB, takes a few minutes)..."
|
||||
tar -xzf "$MELD_RAW/MELD.Raw.tar.gz" -C "$MELD_RAW"
|
||||
echo "[MELD] Extract done."
|
||||
ls "$MELD_RAW/"
|
||||
else
|
||||
echo "[MELD] WARNING: MELD.Raw.tar.gz not found in $MELD_RAW"
|
||||
echo " CSV-only mode will be used (no audio features)"
|
||||
fi
|
||||
|
||||
# ── 特征提取 ──────────────────────────────────────────────────────────────────
|
||||
cd "$PROJ"
|
||||
|
||||
echo ""
|
||||
echo "=== Feature Extraction ==="
|
||||
|
||||
# IEMOCAP
|
||||
if [ -d "$IEMOCAP_DEST/Session1" ]; then
|
||||
echo "[extract] IEMOCAP..."
|
||||
$PY scripts/preprocess/extract_iemocap.py \
|
||||
--data_root "$RAW" \
|
||||
--out_dir "$PROJ/data/iemocap"
|
||||
echo "[done] IEMOCAP features → $PROJ/data/iemocap"
|
||||
else
|
||||
echo "[skip] IEMOCAP not ready"
|
||||
fi
|
||||
|
||||
# MOSI
|
||||
MOSI_PKL=$RAW/MOSI/aligned_mosi.pkl
|
||||
if [ -f "$MOSI_PKL" ]; then
|
||||
echo "[extract] CMU-MOSI..."
|
||||
$PY scripts/preprocess/extract_mosi.py \
|
||||
--data_root "$RAW/MOSI" \
|
||||
--out_dir "$PROJ/data/mosi"
|
||||
echo "[done] MOSI features → $PROJ/data/mosi"
|
||||
else
|
||||
echo "[skip] MOSI aligned_mosi.pkl not found at $MOSI_PKL"
|
||||
fi
|
||||
|
||||
# MELD
|
||||
if [ -d "$MELD_DEST" ] || ls "$MELD_RAW"/*.csv 1>/dev/null 2>&1; then
|
||||
echo "[extract] MELD..."
|
||||
$PY scripts/preprocess/extract_meld.py \
|
||||
--data_root "$MELD_RAW" \
|
||||
--out_dir "$PROJ/data/meld"
|
||||
echo "[done] MELD features → $PROJ/data/meld"
|
||||
else
|
||||
echo "[skip] MELD data not ready"
|
||||
fi
|
||||
|
||||
# ── 噪声生成(IEMOCAP 特征就位后运行)──────────────────────────────────────────
|
||||
if [ -f "$PROJ/data/iemocap/train_labels.npy" ]; then
|
||||
echo ""
|
||||
echo "=== Noise Generation (8 variants) ==="
|
||||
$PY scripts/preprocess/generate_noise.py \
|
||||
--config configs/noise_configs.yaml \
|
||||
--data_dir "$PROJ/data/iemocap" \
|
||||
--out_dir "$PROJ/data/iemocap_noisy"
|
||||
echo "[done] Noisy variants → $PROJ/data/iemocap_noisy"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " ALL DONE — $(date)"
|
||||
echo "=========================================="
|
||||
@@ -1,296 +0,0 @@
|
||||
"""
|
||||
Upload and launch test evaluation + D1-4 ablation experiments on server.
|
||||
Uses Stage B v1 checkpoint (best val WF1=0.7291).
|
||||
"""
|
||||
import paramiko, warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
||||
PROJ = ZSY + '/multimodal_affect'
|
||||
ENV = ZSY + '/envs/multimodal_affect/bin/python'
|
||||
|
||||
# ── eval_d1.py ────────────────────────────────────────────────────────────
|
||||
EVAL_SCRIPT = r'''#!/usr/bin/env python3
|
||||
"""
|
||||
Evaluate Direction-1 checkpoint on test set.
|
||||
Also runs ablation variants: fixed-equal, rl-nonoise, rl-noc (beta=0), rl-nostab (gamma=0).
|
||||
|
||||
Usage:
|
||||
python scripts/eval/eval_d1.py \
|
||||
--checkpoint outputs/checkpoints/d1_stageB/best_v1.ckpt \
|
||||
--dataset IEMOCAP \
|
||||
--gpu 0
|
||||
"""
|
||||
import os, sys, argparse, json, csv, logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import f1_score, accuracy_score, classification_report
|
||||
|
||||
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
||||
PROJ = os.path.join(ZSY, "multimodal_affect")
|
||||
sys.path.insert(0, PROJ)
|
||||
|
||||
from src.data.dataset import MultimodalDataset, get_dataloader
|
||||
from src.models.encoders import MultimodalEncoder
|
||||
from src.models.classifier import EmotionClassifier
|
||||
from src.rl.fusion_agent import ModalFusionAgent
|
||||
from src.rl.reward import compute_reward
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(encoder, classifier, loader, device, agent=None, fixed_weights=None):
|
||||
encoder.eval(); classifier.eval()
|
||||
if agent: agent.eval()
|
||||
preds, labels_all = [], []
|
||||
for batch in loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
if agent is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
elif fixed_weights is not None:
|
||||
w = torch.tensor(fixed_weights, device=device).view(1, 3)
|
||||
fused = w[:, 0:1]*tf + w[:, 1:2]*af + w[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
preds.append(logits.argmax(-1).cpu())
|
||||
labels_all.append(labels.cpu())
|
||||
p = torch.cat(preds).numpy()
|
||||
l = torch.cat(labels_all).numpy()
|
||||
return p, l
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_noisy(encoder, classifier, loader, device, variant_data, agent=None, fixed_weights=None):
|
||||
"""Run inference with a noisy variant, replacing any modalities it provides."""
|
||||
encoder.eval(); classifier.eval()
|
||||
if agent: agent.eval()
|
||||
preds, labels_all = [], []
|
||||
arrays = {k: torch.from_numpy(v).float() for k, v in variant_data.items()}
|
||||
cursor = 0
|
||||
for batch in loader:
|
||||
bsz = batch["text"].size(0)
|
||||
text = (arrays["text"][cursor:cursor+bsz] if "text" in arrays else batch["text"]).to(device)
|
||||
audio = (arrays["audio"][cursor:cursor+bsz] if "audio" in arrays else batch["audio"]).to(device)
|
||||
vision = (arrays["vision"][cursor:cursor+bsz] if "vision" in arrays else batch["vision"]).to(device)
|
||||
cursor += bsz
|
||||
labels = batch["labels"].to(device)
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
if agent is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
elif fixed_weights is not None:
|
||||
w = torch.tensor(fixed_weights, device=device).view(1, 3)
|
||||
fused = w[:, 0:1]*tf + w[:, 1:2]*af + w[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
preds.append(logits.argmax(-1).cpu())
|
||||
labels_all.append(labels.cpu())
|
||||
p = torch.cat(preds).numpy()
|
||||
l = torch.cat(labels_all).numpy()
|
||||
return p, l
|
||||
|
||||
|
||||
def metrics(preds, labels, split="test"):
|
||||
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
||||
acc = float(accuracy_score(labels, preds))
|
||||
return {"split": split, "wf1": round(wf1, 4), "acc": round(acc, 4)}
|
||||
|
||||
|
||||
def load_model(ckpt_path, device):
|
||||
ckpt = torch.load(ckpt_path, map_location=device)
|
||||
td, ad, vd = ckpt["text_dim"], ckpt["audio_dim"], ckpt["vision_dim"]
|
||||
nc = ckpt["num_classes"]
|
||||
pd = ckpt.get("proj_dim", 1024)
|
||||
enc = MultimodalEncoder(td, ad, vd, pd)
|
||||
cls = EmotionClassifier(pd, nc)
|
||||
enc.load_state_dict(ckpt["encoder"])
|
||||
cls.load_state_dict(ckpt["classifier"])
|
||||
enc.to(device).eval(); cls.to(device).eval()
|
||||
agent = None
|
||||
if "agent" in ckpt:
|
||||
agent = ModalFusionAgent(state_dim=4, hidden=128)
|
||||
agent.load_state_dict(ckpt["agent"])
|
||||
agent.to(device).eval()
|
||||
return enc, cls, agent, ckpt
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--checkpoint", required=True)
|
||||
p.add_argument("--stage_a_ckpt", default=None,
|
||||
help="Stage A ckpt for ablations that need encoder+classifier only")
|
||||
p.add_argument("--dataset", default="IEMOCAP")
|
||||
p.add_argument("--gpu", default="0")
|
||||
p.add_argument("--out_json", default=None)
|
||||
p.add_argument("--out_csv", default=None)
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device(f"cuda:{args.gpu}")
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
NOISE_VARIANTS = [
|
||||
"gaussian_light", "gaussian_heavy", "missing_audio",
|
||||
"missing_visual", "text_word_drop_30", "audio_masking_50",
|
||||
"realistic_mixed", "audio_time_mask",
|
||||
]
|
||||
|
||||
# Datasets
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
test_ds = MultimodalDataset(data_dir, "test")
|
||||
val_loader = get_dataloader(val_ds, 128, shuffle=False, drop_last=False)
|
||||
test_loader = get_dataloader(test_ds, 128, shuffle=False, drop_last=False)
|
||||
|
||||
# Load Stage B v1 checkpoint (encoder + classifier + agent)
|
||||
enc, cls, agent, ckpt = load_model(args.checkpoint, device)
|
||||
logging.info(f"Loaded: {args.checkpoint} val_wf1={ckpt.get('val_wf1',0):.4f}")
|
||||
|
||||
results = {}
|
||||
|
||||
# ── 1. Main evaluation: val + test ────────────────────────────────────
|
||||
logging.info("=== Main Evaluation (Stage B RL-Full) ===")
|
||||
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
||||
ds = val_ds if split == "val" else test_ds
|
||||
preds, labels = predict(enc, cls, loader, device, agent=agent)
|
||||
m = metrics(preds, labels, split)
|
||||
results[f"RL-Full_{split}"] = m
|
||||
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
||||
if split == "test":
|
||||
rpt = classification_report(labels, preds,
|
||||
target_names=[str(i) for i in range(ckpt["num_classes"])],
|
||||
zero_division=0)
|
||||
logging.info(f"\n{rpt}")
|
||||
|
||||
# ── 2. Ablation A: Fixed-Equal (uniform weights, Stage B classifier) ──
|
||||
logging.info("=== Ablation: Fixed-Equal ===")
|
||||
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
||||
preds, labels = predict(enc, cls, loader, device,
|
||||
fixed_weights=[1/3, 1/3, 1/3])
|
||||
m = metrics(preds, labels, split)
|
||||
results[f"Fixed-Equal_{split}"] = m
|
||||
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
||||
|
||||
# ── 3. Ablation B: Stage A only (no RL, trained classifier w/ uniform fusion) ─
|
||||
if args.stage_a_ckpt:
|
||||
logging.info("=== Ablation: Stage-A-Only ===")
|
||||
enc_a, cls_a, _, ckpt_a = load_model(args.stage_a_ckpt, device)
|
||||
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
||||
preds, labels = predict(enc_a, cls_a, loader, device)
|
||||
m = metrics(preds, labels, split)
|
||||
results[f"StageA-Only_{split}"] = m
|
||||
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
||||
else:
|
||||
# estimate from Stage A ckpt embedded in Stage B (same encoder/classifier)
|
||||
# just run with agent=None (uniform fusion) using Stage B encoder+classifier
|
||||
logging.info("=== Ablation: RL-Agent-Removed (Stage B enc+cls, uniform fusion) ===")
|
||||
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
||||
preds, labels = predict(enc, cls, loader, device, agent=None)
|
||||
m = metrics(preds, labels, split)
|
||||
results[f"NoRL-UniformFusion_{split}"] = m
|
||||
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
||||
|
||||
# ── 4. Noise robustness evaluation ────────────────────────────────────
|
||||
logging.info("=== Noise Robustness (test set) ===")
|
||||
for vname in NOISE_VARIANTS:
|
||||
vdir = os.path.join(noise_root, vname)
|
||||
paths = {
|
||||
"text": os.path.join(vdir, "test_text.npy"),
|
||||
"audio": os.path.join(vdir, "test_audio.npy"),
|
||||
"vision": os.path.join(vdir, "test_vision.npy"),
|
||||
}
|
||||
available = {m: p for m, p in paths.items() if os.path.exists(p)}
|
||||
if not available:
|
||||
logging.info(f" [{vname}] SKIP (no noisy modality files)")
|
||||
continue
|
||||
missing = sorted(set(paths) - set(available))
|
||||
if missing:
|
||||
logging.warning(f" [{vname}] missing noisy files for {missing}; clean same-index modality will be used")
|
||||
vdata = {m: np.load(p).astype(np.float32) for m, p in available.items()}
|
||||
# RL-Full under noise
|
||||
preds_rl, labels = predict_noisy(enc, cls, test_loader, device, vdata, agent=agent)
|
||||
wf1_rl = float(f1_score(labels, preds_rl, average="weighted", zero_division=0))
|
||||
# Fixed-Equal under noise
|
||||
preds_fx, _ = predict_noisy(enc, cls, test_loader, device, vdata,
|
||||
fixed_weights=[1/3, 1/3, 1/3])
|
||||
wf1_fx = float(f1_score(labels, preds_fx, average="weighted", zero_division=0))
|
||||
results[f"noise_{vname}_RL-Full"] = round(wf1_rl, 4)
|
||||
results[f"noise_{vname}_Fixed-Equal"] = round(wf1_fx, 4)
|
||||
pct = (1 - wf1_rl / max(wf1_fx, 1e-6)) * 100 # relative degradation vs fixed
|
||||
logging.info(f" [{vname}] RL={wf1_rl:.4f} Fixed={wf1_fx:.4f} "
|
||||
f"RL_degradation_vs_clean={pct:+.1f}%")
|
||||
|
||||
# ── 5. Save results ───────────────────────────────────────────────────
|
||||
os.makedirs(os.path.join(PROJ, "outputs", "results"), exist_ok=True)
|
||||
out_json = args.out_json or os.path.join(PROJ, "outputs", "results", "d1_eval.json")
|
||||
out_csv = args.out_csv or os.path.join(PROJ, "outputs", "results", "d1_ablation.csv")
|
||||
|
||||
with open(out_json, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
logging.info(f"Results saved to {out_json}")
|
||||
|
||||
# CSV for ablation table
|
||||
rows = []
|
||||
for variant in ["RL-Full", "Fixed-Equal", "NoRL-UniformFusion", "StageA-Only"]:
|
||||
row = {"variant": variant}
|
||||
for split in ["val", "test"]:
|
||||
k = f"{variant}_{split}"
|
||||
if k in results:
|
||||
row[f"{split}_wf1"] = results[k]["wf1"]
|
||||
row[f"{split}_acc"] = results[k]["acc"]
|
||||
if "val_wf1" in row:
|
||||
rows.append(row)
|
||||
if rows:
|
||||
with open(out_csv, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["variant","val_wf1","val_acc","test_wf1","test_acc"])
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
logging.info(f"Ablation CSV saved to {out_csv}")
|
||||
|
||||
# Noise robustness summary
|
||||
logging.info("\n=== Noise Robustness Summary ===")
|
||||
clean_rl = results.get("RL-Full_test", {}).get("wf1", 0)
|
||||
clean_fx = results.get("Fixed-Equal_test", {}).get("wf1", 0)
|
||||
for vname in NOISE_VARIANTS:
|
||||
rl_k = f"noise_{vname}_RL-Full"
|
||||
fx_k = f"noise_{vname}_Fixed-Equal"
|
||||
if rl_k in results and fx_k in results:
|
||||
rl = results[rl_k]; fx = results[fx_k]
|
||||
rl_drop = (clean_rl - rl) / max(clean_rl, 1e-6) * 100
|
||||
fx_drop = (clean_fx - fx) / max(clean_fx, 1e-6) * 100
|
||||
logging.info(f" {vname:22s} RL_drop={rl_drop:+5.1f}% Fixed_drop={fx_drop:+5.1f}%")
|
||||
|
||||
logging.info("Evaluation complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
||||
sftp = client.open_sftp()
|
||||
|
||||
# Make eval dir
|
||||
_, o, e = client.exec_command(f'mkdir -p {PROJ}/scripts/eval', timeout=10)
|
||||
o.read(); e.read()
|
||||
|
||||
sftp.putfo(__import__('io').BytesIO(EVAL_SCRIPT.encode()), PROJ + '/scripts/eval/eval_d1.py')
|
||||
print("uploaded: scripts/eval/eval_d1.py")
|
||||
|
||||
sftp.close()
|
||||
client.close()
|
||||
@@ -1,522 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Phase 1 Direction 1 Training Script (DataParallel edition)
|
||||
# Stage A: Supervised pretraining with noise-aware confidence estimation
|
||||
# Stage B: PPO-based adaptive fusion weight learning [v2: reward display + entropy fix]
|
||||
#
|
||||
# Launch:
|
||||
# python scripts/train/train_d1.py \
|
||||
# --stage supervised --dataset IEMOCAP \
|
||||
# --config configs/d1/stage_a.yaml \
|
||||
# --output outputs/checkpoints/d1_stageA
|
||||
|
||||
import os, sys, argparse, yaml, time, logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.amp import GradScaler, autocast
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
import wandb
|
||||
|
||||
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
||||
PROJ = os.path.join(ZSY, "multimodal_affect")
|
||||
sys.path.insert(0, PROJ)
|
||||
|
||||
from src.data.dataset import MultimodalDataset, get_dataloader
|
||||
from src.models.encoders import MultimodalEncoder
|
||||
from src.models.classifier import EmotionClassifier
|
||||
from src.rl.fusion_agent import ModalFusionAgent
|
||||
from src.rl.reward import compute_reward
|
||||
|
||||
|
||||
def save_ckpt(state, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(state, path)
|
||||
|
||||
|
||||
def _noisy_batch(dataset, variant, indices, device):
|
||||
"""Return a same-index multimodal batch; fall back to clean missing files."""
|
||||
text = variant.get("text", dataset.text)
|
||||
audio = variant.get("audio", dataset.audio)
|
||||
vision = variant.get("vision", dataset.vision)
|
||||
return (
|
||||
torch.from_numpy(text[indices]).to(device),
|
||||
torch.from_numpy(audio[indices]).to(device),
|
||||
torch.from_numpy(vision[indices]).to(device),
|
||||
torch.from_numpy(dataset.labels[indices]).to(device),
|
||||
)
|
||||
|
||||
|
||||
def _confidence_targets(variant_name, batch_size, device):
|
||||
"""Low confidence for modalities actually corrupted by the named variant."""
|
||||
target = torch.full((batch_size, 3), 0.9, device=device)
|
||||
noisy_map = {
|
||||
"gaussian_light": (0, 1, 2),
|
||||
"gaussian_heavy": (0, 1, 2),
|
||||
"missing_audio": (1,),
|
||||
"missing_visual": (2,),
|
||||
"text_word_drop_30": (0,),
|
||||
"audio_masking_50": (1,),
|
||||
"realistic_mixed": (0, 1, 2),
|
||||
"audio_time_mask": (1,),
|
||||
}
|
||||
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
|
||||
target[:, idx] = 0.1
|
||||
return target
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(encoder, classifier, loader, device, agent=None):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
if agent is not None:
|
||||
agent.eval()
|
||||
all_preds, all_labels = [], []
|
||||
for batch in loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
enc = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls = classifier.module if hasattr(classifier, "module") else classifier
|
||||
agt = (agent.module if hasattr(agent, "module") else agent) if agent else None
|
||||
tf, af, vf, confs = enc(text, audio, vision)
|
||||
if agt is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agt.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = cls(fused)
|
||||
all_preds.append(logits.argmax(-1).cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
preds = torch.cat(all_preds).numpy()
|
||||
labels = torch.cat(all_labels).numpy()
|
||||
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
||||
acc = float(accuracy_score(labels, preds))
|
||||
return wf1, acc
|
||||
|
||||
|
||||
# ── Stage A: Supervised pretraining ──────────────────────────────────────
|
||||
|
||||
def train_stage_a(args, cfg, device, gpu_ids):
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
|
||||
eff_bs = cfg["batch_size"] * len(gpu_ids)
|
||||
train_loader = get_dataloader(train_ds, eff_bs, distributed=False)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
text_dim = train_ds.text.shape[1]
|
||||
audio_dim = train_ds.audio.shape[1]
|
||||
vision_dim = train_ds.vision.shape[1]
|
||||
num_classes = int(train_ds.labels.max()) + 1
|
||||
proj_dim = cfg.get("proj_dim", 1024)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes,
|
||||
hidden=cfg.get("cls_hidden", 512)).to(device)
|
||||
|
||||
params = list(encoder.parameters()) + list(classifier.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
opt, T_max=cfg["epochs"], eta_min=1e-5)
|
||||
scaler = GradScaler('cuda')
|
||||
|
||||
conf_weight = cfg.get("conf_weight", 0.2)
|
||||
noise_prob = cfg.get("noise_prob", 0.4)
|
||||
best_wf1 = 0.0
|
||||
|
||||
for epoch in range(cfg["epochs"]):
|
||||
encoder.train()
|
||||
classifier.train()
|
||||
ep_loss = ep_ce = ep_conf = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
B = text.size(0)
|
||||
|
||||
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
|
||||
if use_noise:
|
||||
vname = rng.choice(train_ds.variant_names)
|
||||
v = train_ds.noisy_variants[vname]
|
||||
ni = rng.integers(0, len(train_ds), size=B)
|
||||
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
|
||||
|
||||
with autocast('cuda'):
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
ce_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if use_noise:
|
||||
c_tgt = _confidence_targets(vname, B, device)
|
||||
else:
|
||||
c_tgt = torch.full((B, 3), 0.9, device=device)
|
||||
conf_loss = F.binary_cross_entropy(confs.float(), c_tgt.float())
|
||||
with autocast('cuda'):
|
||||
loss = ce_loss + conf_weight * conf_loss
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(params, 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
|
||||
ep_loss += loss.item()
|
||||
ep_ce += ce_loss.item()
|
||||
ep_conf += conf_loss.item()
|
||||
|
||||
sched.step()
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device)
|
||||
|
||||
n = len(train_loader)
|
||||
logging.info(
|
||||
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
|
||||
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
|
||||
)
|
||||
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
|
||||
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
|
||||
"A/val_acc": val_acc, "epoch": epoch + 1})
|
||||
|
||||
enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict()
|
||||
cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict()
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"epoch": epoch + 1,
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"val_wf1": val_wf1,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
|
||||
save_ckpt({
|
||||
"epoch": cfg["epochs"],
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "last.ckpt"))
|
||||
|
||||
enc_m = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls_m = classifier.module if hasattr(classifier, "module") else classifier
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
|
||||
return enc_m, cls_m, dims
|
||||
|
||||
|
||||
# ── Stage B: PPO training ─────────────────────────────────────────────────
|
||||
|
||||
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
agent.eval()
|
||||
bs = cfg.get("batch_size", 128)
|
||||
nprob = cfg.get("noise_prob", 0.5)
|
||||
rng = np.random.default_rng()
|
||||
states, actions, log_probs, values, rewards = [], [], [], [], []
|
||||
collected = 0
|
||||
|
||||
with torch.no_grad():
|
||||
while collected < rollout_size:
|
||||
bsz = min(bs, rollout_size - collected)
|
||||
idx = rng.integers(0, len(dataset), size=bsz)
|
||||
|
||||
text = torch.from_numpy(dataset.text[idx]).to(device)
|
||||
audio = torch.from_numpy(dataset.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(dataset.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(dataset.labels[idx]).to(device)
|
||||
|
||||
if rng.random() < nprob and dataset.variant_names:
|
||||
vname = rng.choice(dataset.variant_names)
|
||||
v = dataset.noisy_variants[vname]
|
||||
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
|
||||
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
|
||||
weights, log_p, value, _ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
logits = classifier(fused)
|
||||
|
||||
rew, _ = compute_reward(
|
||||
logits, labels, confs, weights, prev_weights,
|
||||
alpha=cfg.get("reward_alpha", 1.0),
|
||||
beta =cfg.get("reward_beta", 0.3),
|
||||
gamma=cfg.get("reward_gamma", 0.1),
|
||||
)
|
||||
|
||||
states.append(state)
|
||||
actions.append(weights)
|
||||
log_probs.append(log_p)
|
||||
values.append(value.squeeze(-1))
|
||||
rewards.append(rew)
|
||||
collected += bsz
|
||||
|
||||
states = torch.cat(states)
|
||||
actions = torch.cat(actions)
|
||||
log_probs = torch.cat(log_probs)
|
||||
values = torch.cat(values).cpu()
|
||||
rewards = torch.cat(rewards).cpu()
|
||||
|
||||
# FIX: save raw stats before normalization
|
||||
# The normalized mean is always 0 by construction — useless for logging
|
||||
raw_rew_mean = rewards.mean().item()
|
||||
raw_rew_std = rewards.std().item()
|
||||
|
||||
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
advantages = rewards - values.detach()
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
return dict(states=states, actions=actions, log_probs=log_probs,
|
||||
values=values, rewards=rewards, advantages=advantages,
|
||||
mean_weights=actions.mean(0),
|
||||
raw_rew_mean=raw_rew_mean, raw_rew_std=raw_rew_std)
|
||||
|
||||
|
||||
def ppo_update(agent, opt, rollout, cfg, device, scaler):
|
||||
eps = cfg.get("ppo_clip", 0.2)
|
||||
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
|
||||
mb_size = cfg.get("ppo_mini_batch", 256)
|
||||
v_coef = cfg.get("value_coef", 0.5)
|
||||
ent_coef = cfg.get("entropy_coef", 0.01)
|
||||
|
||||
states = rollout["states"].to(device)
|
||||
actions = rollout["actions"].to(device)
|
||||
old_lp = rollout["log_probs"].to(device)
|
||||
adv = rollout["advantages"].to(device)
|
||||
ret = rollout["rewards"].to(device)
|
||||
n = states.size(0)
|
||||
total_pl = total_vl = total_ent = cnt = 0.0
|
||||
agent.train()
|
||||
|
||||
for _ in range(ppo_ep):
|
||||
perm = torch.randperm(n, device=device)
|
||||
for start in range(0, n, mb_size):
|
||||
idx = perm[start:start + mb_size]
|
||||
s = states[idx]; a = actions[idx]
|
||||
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
|
||||
with autocast('cuda'):
|
||||
new_lp, val, ent = agent.evaluate(s, a)
|
||||
val = val.squeeze(-1)
|
||||
ratio = (new_lp - olp).exp()
|
||||
p_loss = -torch.min(ratio*ad,
|
||||
torch.clamp(ratio, 1-eps, 1+eps)*ad).mean()
|
||||
v_loss = F.mse_loss(val, r)
|
||||
e_loss = -ent.mean()
|
||||
loss = p_loss + v_coef*v_loss + ent_coef*e_loss
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
total_pl += p_loss.item()
|
||||
total_vl += v_loss.item()
|
||||
total_ent += ent.mean().item()
|
||||
cnt += 1
|
||||
|
||||
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
|
||||
|
||||
|
||||
def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids):
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
eff_bs = cfg.get("batch_size", 128) * len(gpu_ids)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
for p in encoder.parameters():
|
||||
p.requires_grad_(False)
|
||||
encoder.to(device).eval()
|
||||
|
||||
classifier.to(device)
|
||||
opt_cls = torch.optim.AdamW(classifier.parameters(),
|
||||
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
|
||||
|
||||
agent = ModalFusionAgent(state_dim=4,
|
||||
hidden=cfg.get("agent_hidden", 128)).to(device)
|
||||
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
|
||||
scaler = GradScaler()
|
||||
|
||||
rollout_size = cfg.get("rollout_steps", 512)
|
||||
n_updates = cfg.get("n_ppo_updates", 500)
|
||||
eval_every = cfg.get("eval_every", 10)
|
||||
best_wf1 = 0.0
|
||||
prev_weights = None
|
||||
|
||||
for upd in range(n_updates):
|
||||
rollout = collect_rollout(
|
||||
encoder, classifier, agent,
|
||||
train_ds, device, rollout_size, cfg, prev_weights,
|
||||
)
|
||||
prev_weights = rollout["mean_weights"].to(device)
|
||||
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
|
||||
|
||||
if upd % 2 == 0:
|
||||
idx = np.random.randint(0, len(train_ds), eff_bs)
|
||||
text = torch.from_numpy(train_ds.text[idx]).to(device)
|
||||
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
|
||||
with torch.no_grad():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
with autocast('cuda'):
|
||||
logits = classifier(fused)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
opt_cls.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(opt_cls)
|
||||
scaler.update()
|
||||
|
||||
if upd % eval_every == 0:
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
|
||||
agent=agent)
|
||||
# FIX: use raw (pre-normalization) reward for meaningful logging
|
||||
raw_rew = rollout["raw_rew_mean"]
|
||||
raw_std = rollout["raw_rew_std"]
|
||||
mw = rollout["mean_weights"] # mean fusion weights [text, audio, visual]
|
||||
logging.info(
|
||||
f"[StageB] PPO {upd:4d}/{n_updates} | "
|
||||
f"rew={raw_rew:.4f}(+/-{raw_std:.3f}) p={ppo_info['p_loss']:.4f} "
|
||||
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} | "
|
||||
f"w=[t:{mw[0]:.3f} a:{mw[1]:.3f} v:{mw[2]:.3f}]"
|
||||
)
|
||||
wandb.log({
|
||||
"B/reward": raw_rew,
|
||||
"B/rew_std": raw_std,
|
||||
"B/w_text": mw[0].item(),
|
||||
"B/w_audio": mw[1].item(),
|
||||
"B/w_visual": mw[2].item(),
|
||||
"B/p_loss": ppo_info["p_loss"],
|
||||
"B/v_loss": ppo_info["v_loss"],
|
||||
"B/entropy": ppo_info["entropy"],
|
||||
"B/val_wf1": val_wf1,
|
||||
"ppo_update": upd,
|
||||
})
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"update": upd,
|
||||
"encoder": encoder.state_dict(),
|
||||
"classifier": classifier.state_dict(),
|
||||
"agent": agent.state_dict(),
|
||||
"val_wf1": val_wf1,
|
||||
**dims,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
|
||||
p.add_argument("--dataset", default="IEMOCAP")
|
||||
p.add_argument("--config", required=True)
|
||||
p.add_argument("--output", required=True)
|
||||
p.add_argument("--checkpoint", default=None)
|
||||
p.add_argument("--gpus", default="0,1,2,3",
|
||||
help="Comma-separated GPU ids to use")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
gpu_ids = [int(g) for g in args.gpus.split(",")]
|
||||
device = torch.device(f"cuda:{gpu_ids[0]}")
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
log_dir = os.path.join(PROJ, "outputs", "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(
|
||||
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"),
|
||||
],
|
||||
)
|
||||
logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})")
|
||||
|
||||
with open(os.path.join(PROJ, args.config)) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
os.environ.setdefault("WANDB_MODE", "offline")
|
||||
wandb.init(
|
||||
project="multimodal_affect",
|
||||
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
|
||||
config={**cfg, "stage": args.stage, "dataset": args.dataset,
|
||||
"gpus": gpu_ids},
|
||||
dir=os.path.join(PROJ, "outputs"),
|
||||
)
|
||||
|
||||
if args.stage == "supervised":
|
||||
train_stage_a(args, cfg, device, gpu_ids)
|
||||
|
||||
elif args.stage == "rl":
|
||||
if not args.checkpoint:
|
||||
raise ValueError("--checkpoint required for --stage rl")
|
||||
ckpt = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
text_dim = ckpt["text_dim"]
|
||||
audio_dim = ckpt["audio_dim"]
|
||||
vision_dim = ckpt["vision_dim"]
|
||||
num_classes = ckpt["num_classes"]
|
||||
proj_dim = ckpt.get("proj_dim", 1024)
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes,
|
||||
proj_dim=proj_dim)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes)
|
||||
encoder.load_state_dict(ckpt["encoder"])
|
||||
classifier.load_state_dict(ckpt["classifier"])
|
||||
logging.info(
|
||||
f"Loaded Stage A ckpt: {args.checkpoint} "
|
||||
f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}"
|
||||
)
|
||||
train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids)
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,266 +0,0 @@
|
||||
"""Upload all Phase 1 implementation files to the server."""
|
||||
import paramiko, warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
||||
sftp = client.open_sftp()
|
||||
|
||||
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
||||
PROJ = ZSY + '/multimodal_affect'
|
||||
|
||||
files = {}
|
||||
|
||||
# ─── src/data/dataset.py ──────────────────────────────────────────────────
|
||||
files['src/data/dataset.py'] = (
|
||||
'import os\n'
|
||||
'import numpy as np\n'
|
||||
'import torch\n'
|
||||
'from torch.utils.data import Dataset, DataLoader\n'
|
||||
'from torch.utils.data.distributed import DistributedSampler\n'
|
||||
'\n'
|
||||
'NOISE_VARIANTS = [\n'
|
||||
' "gaussian_light", "gaussian_heavy", "missing_audio",\n'
|
||||
' "missing_visual", "text_word_drop_30", "audio_masking_50",\n'
|
||||
' "realistic_mixed", "audio_time_mask",\n'
|
||||
']\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class MultimodalDataset(Dataset):\n'
|
||||
' def __init__(self, data_dir, split, load_noisy=False, noise_root=None):\n'
|
||||
' self.split = split\n'
|
||||
' self.text = np.load(f"{data_dir}/{split}_text.npy").astype(np.float32)\n'
|
||||
' self.audio = np.load(f"{data_dir}/{split}_audio.npy").astype(np.float32)\n'
|
||||
' self.vision = np.load(f"{data_dir}/{split}_vision.npy").astype(np.float32)\n'
|
||||
' self.labels = np.load(f"{data_dir}/{split}_labels.npy").astype(np.int64)\n'
|
||||
'\n'
|
||||
' self.noisy_variants = {}\n'
|
||||
' if load_noisy and noise_root:\n'
|
||||
' for v in NOISE_VARIANTS:\n'
|
||||
' vd = os.path.join(noise_root, v)\n'
|
||||
' tf = os.path.join(vd, f"{split}_text.npy")\n'
|
||||
' af = os.path.join(vd, f"{split}_audio.npy")\n'
|
||||
' vf = os.path.join(vd, f"{split}_vision.npy")\n'
|
||||
' if os.path.exists(tf) or os.path.exists(af) or os.path.exists(vf):\n'
|
||||
' self.noisy_variants[v] = {\n'
|
||||
' "text": np.load(tf).astype(np.float32) if os.path.exists(tf) else self.text,\n'
|
||||
' "audio": np.load(af).astype(np.float32) if os.path.exists(af) else self.audio,\n'
|
||||
' "vision": np.load(vf).astype(np.float32) if os.path.exists(vf) else self.vision,\n'
|
||||
' }\n'
|
||||
' self.variant_names = sorted(self.noisy_variants.keys())\n'
|
||||
'\n'
|
||||
' def __len__(self):\n'
|
||||
' return len(self.labels)\n'
|
||||
'\n'
|
||||
' def __getitem__(self, idx):\n'
|
||||
' return {\n'
|
||||
' "text": torch.from_numpy(self.text[idx].copy()),\n'
|
||||
' "audio": torch.from_numpy(self.audio[idx].copy()),\n'
|
||||
' "vision": torch.from_numpy(self.vision[idx].copy()),\n'
|
||||
' "labels": torch.tensor(int(self.labels[idx])),\n'
|
||||
' }\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'def get_dataloader(ds, batch_size, shuffle=True, distributed=False,\n'
|
||||
' num_workers=4, drop_last=True):\n'
|
||||
' sampler = DistributedSampler(ds, shuffle=shuffle) if distributed else None\n'
|
||||
' return DataLoader(\n'
|
||||
' ds, batch_size=batch_size,\n'
|
||||
' shuffle=(shuffle and sampler is None),\n'
|
||||
' sampler=sampler, num_workers=num_workers,\n'
|
||||
' pin_memory=True, drop_last=drop_last,\n'
|
||||
' )\n'
|
||||
)
|
||||
|
||||
# ─── src/models/encoders.py ───────────────────────────────────────────────
|
||||
files['src/models/encoders.py'] = (
|
||||
'import torch\n'
|
||||
'import torch.nn as nn\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class ModalityProjector(nn.Module):\n'
|
||||
' # Project low-dim pre-extracted features to shared proj_dim space\n'
|
||||
' def __init__(self, in_dim: int, proj_dim: int = 1024):\n'
|
||||
' super().__init__()\n'
|
||||
' mid = max(in_dim * 4, 256)\n'
|
||||
' self.net = nn.Sequential(\n'
|
||||
' nn.Linear(in_dim, mid),\n'
|
||||
' nn.LayerNorm(mid),\n'
|
||||
' nn.GELU(),\n'
|
||||
' nn.Dropout(0.1),\n'
|
||||
' nn.Linear(mid, proj_dim),\n'
|
||||
' nn.LayerNorm(proj_dim),\n'
|
||||
' )\n'
|
||||
'\n'
|
||||
' def forward(self, x: torch.Tensor) -> torch.Tensor:\n'
|
||||
' return self.net(x)\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class ConfidenceEstimator(nn.Module):\n'
|
||||
' # Lightweight MLP: proj_dim -> noise-quality confidence scalar in (0, 1)\n'
|
||||
' def __init__(self, proj_dim: int = 1024, hidden: int = 256):\n'
|
||||
' super().__init__()\n'
|
||||
' self.net = nn.Sequential(\n'
|
||||
' nn.Linear(proj_dim, hidden),\n'
|
||||
' nn.ReLU(),\n'
|
||||
' nn.Dropout(0.1),\n'
|
||||
' nn.Linear(hidden, 1),\n'
|
||||
' nn.Sigmoid(),\n'
|
||||
' )\n'
|
||||
'\n'
|
||||
' def forward(self, x: torch.Tensor) -> torch.Tensor:\n'
|
||||
' return self.net(x).squeeze(-1)\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class MultimodalEncoder(nn.Module):\n'
|
||||
' # Three-branch projector + three per-modality confidence estimators\n'
|
||||
' def __init__(self,\n'
|
||||
' text_dim: int = 300,\n'
|
||||
' audio_dim: int = 74,\n'
|
||||
' vision_dim: int = 35,\n'
|
||||
' proj_dim: int = 1024):\n'
|
||||
' super().__init__()\n'
|
||||
' self.text_proj = ModalityProjector(text_dim, proj_dim)\n'
|
||||
' self.audio_proj = ModalityProjector(audio_dim, proj_dim)\n'
|
||||
' self.vision_proj = ModalityProjector(vision_dim, proj_dim)\n'
|
||||
' self.text_conf = ConfidenceEstimator(proj_dim)\n'
|
||||
' self.audio_conf = ConfidenceEstimator(proj_dim)\n'
|
||||
' self.vision_conf = ConfidenceEstimator(proj_dim)\n'
|
||||
'\n'
|
||||
' def forward(self, text, audio, vision):\n'
|
||||
' tf = self.text_proj(text)\n'
|
||||
' af = self.audio_proj(audio)\n'
|
||||
' vf = self.vision_proj(vision)\n'
|
||||
' confs = torch.stack([\n'
|
||||
' self.text_conf(tf),\n'
|
||||
' self.audio_conf(af),\n'
|
||||
' self.vision_conf(vf),\n'
|
||||
' ], dim=1) # (B, 3)\n'
|
||||
' return tf, af, vf, confs\n'
|
||||
)
|
||||
|
||||
# ─── src/models/classifier.py ─────────────────────────────────────────────
|
||||
files['src/models/classifier.py'] = (
|
||||
'import torch.nn as nn\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class EmotionClassifier(nn.Module):\n'
|
||||
' def __init__(self, in_dim: int = 1024, num_classes: int = 4,\n'
|
||||
' hidden: int = 512, dropout: float = 0.3):\n'
|
||||
' super().__init__()\n'
|
||||
' self.net = nn.Sequential(\n'
|
||||
' nn.Linear(in_dim, hidden),\n'
|
||||
' nn.LayerNorm(hidden),\n'
|
||||
' nn.GELU(),\n'
|
||||
' nn.Dropout(dropout),\n'
|
||||
' nn.Linear(hidden, hidden // 2),\n'
|
||||
' nn.GELU(),\n'
|
||||
' nn.Dropout(dropout),\n'
|
||||
' nn.Linear(hidden // 2, num_classes),\n'
|
||||
' )\n'
|
||||
'\n'
|
||||
' def forward(self, x):\n'
|
||||
' return self.net(x)\n'
|
||||
)
|
||||
|
||||
# ─── src/rl/fusion_agent.py ───────────────────────────────────────────────
|
||||
files['src/rl/fusion_agent.py'] = (
|
||||
'import torch\n'
|
||||
'import torch.nn as nn\n'
|
||||
'import torch.nn.functional as F\n'
|
||||
'from torch.distributions import Dirichlet\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'class ModalFusionAgent(nn.Module):\n'
|
||||
' # PPO Actor-Critic for RL-adaptive modality fusion\n'
|
||||
' # State s = [conf_text, conf_audio, conf_visual, noise_est] (R^4)\n'
|
||||
' # Action a = fusion weights from Dirichlet distribution (simplex R^3)\n'
|
||||
'\n'
|
||||
' def __init__(self, state_dim: int = 4, hidden: int = 128):\n'
|
||||
' super().__init__()\n'
|
||||
' self.actor = nn.Sequential(\n'
|
||||
' nn.Linear(state_dim, hidden), nn.Tanh(),\n'
|
||||
' nn.Linear(hidden, hidden), nn.Tanh(),\n'
|
||||
' nn.Linear(hidden, 3),\n'
|
||||
' )\n'
|
||||
' self.critic = nn.Sequential(\n'
|
||||
' nn.Linear(state_dim, hidden), nn.Tanh(),\n'
|
||||
' nn.Linear(hidden, hidden), nn.Tanh(),\n'
|
||||
' nn.Linear(hidden, 1),\n'
|
||||
' )\n'
|
||||
'\n'
|
||||
' def _concentration(self, state: torch.Tensor) -> torch.Tensor:\n'
|
||||
' return F.softplus(self.actor(state)) + 1e-3\n'
|
||||
'\n'
|
||||
' def get_action_and_value(self, state: torch.Tensor):\n'
|
||||
' conc = self._concentration(state)\n'
|
||||
' dist = Dirichlet(conc)\n'
|
||||
' weights = dist.rsample()\n'
|
||||
' log_p = dist.log_prob(weights)\n'
|
||||
' value = self.critic(state)\n'
|
||||
' entropy = dist.entropy()\n'
|
||||
' return weights, log_p, value, entropy\n'
|
||||
'\n'
|
||||
' def evaluate(self, state: torch.Tensor, weights: torch.Tensor):\n'
|
||||
' # Recompute log-prob and value for stored actions (PPO update)\n'
|
||||
' conc = self._concentration(state)\n'
|
||||
' dist = Dirichlet(conc)\n'
|
||||
' log_p = dist.log_prob(weights.clamp(1e-6, 1 - 1e-6))\n'
|
||||
' value = self.critic(state)\n'
|
||||
' entropy = dist.entropy()\n'
|
||||
' return log_p, value, entropy\n'
|
||||
)
|
||||
|
||||
# ─── src/rl/reward.py ─────────────────────────────────────────────────────
|
||||
files['src/rl/reward.py'] = (
|
||||
'import torch\n'
|
||||
'import torch.nn.functional as F\n'
|
||||
'from sklearn.metrics import f1_score\n'
|
||||
'\n'
|
||||
'\n'
|
||||
'def compute_reward(logits, labels, confs, weights, prev_weights,\n'
|
||||
' alpha: float = 1.0,\n'
|
||||
' beta: float = 0.3,\n'
|
||||
' gamma: float = 0.1):\n'
|
||||
' # Per-sample reward: R = alpha*(-CE) + beta*Consistency - gamma*Instability\n'
|
||||
' neg_ce = -F.cross_entropy(logits, labels, reduction="none")\n'
|
||||
'\n'
|
||||
' w_norm = F.normalize(weights, p=1, dim=-1)\n'
|
||||
' c_norm = F.normalize(confs, p=1, dim=-1)\n'
|
||||
' consistency = (w_norm * c_norm).sum(dim=-1)\n'
|
||||
'\n'
|
||||
' if prev_weights is not None:\n'
|
||||
' delta = weights - prev_weights.unsqueeze(0).expand_as(weights)\n'
|
||||
' instability = torch.norm(delta, p=2, dim=-1)\n'
|
||||
' else:\n'
|
||||
' instability = torch.zeros_like(neg_ce)\n'
|
||||
'\n'
|
||||
' reward = alpha * neg_ce + beta * consistency - gamma * instability\n'
|
||||
'\n'
|
||||
' with torch.no_grad():\n'
|
||||
' wf1 = float(f1_score(\n'
|
||||
' labels.cpu().numpy(),\n'
|
||||
' logits.argmax(-1).cpu().numpy(),\n'
|
||||
' average="weighted", zero_division=0,\n'
|
||||
' ))\n'
|
||||
'\n'
|
||||
' info = {\n'
|
||||
' "wf1": wf1,\n'
|
||||
' "consistency": consistency.mean().item(),\n'
|
||||
' "instability": instability.mean().item(),\n'
|
||||
' "neg_ce": neg_ce.mean().item(),\n'
|
||||
' }\n'
|
||||
' return reward, info\n'
|
||||
)
|
||||
|
||||
# Upload
|
||||
for rel_path, content in files.items():
|
||||
remote_path = f"{PROJ}/{rel_path}"
|
||||
with sftp.open(remote_path, 'w') as f:
|
||||
f.write(content)
|
||||
print(f" uploaded: {rel_path}")
|
||||
|
||||
sftp.close()
|
||||
client.close()
|
||||
print("\nAll src files uploaded.")
|
||||
@@ -1,578 +0,0 @@
|
||||
"""Upload revised train_d1.py using DataParallel (4-GPU, no DDP/NCCL needed)."""
|
||||
import paramiko, warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
||||
sftp = client.open_sftp()
|
||||
|
||||
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
||||
PROJ = ZSY + '/multimodal_affect'
|
||||
|
||||
TRAIN_D1 = '''\
|
||||
#!/usr/bin/env python3
|
||||
# Phase 1 Direction 1 Training Script (DataParallel edition)
|
||||
# Stage A: Supervised pretraining with noise-aware confidence estimation
|
||||
# Stage B: PPO-based adaptive fusion weight learning
|
||||
#
|
||||
# Uses nn.DataParallel (4 GPUs, single process, no NCCL needed)
|
||||
#
|
||||
# Launch:
|
||||
# python scripts/train/train_d1.py \\
|
||||
# --stage supervised --dataset IEMOCAP \\
|
||||
# --config configs/d1/stage_a.yaml \\
|
||||
# --output outputs/checkpoints/d1_stageA
|
||||
|
||||
import os, sys, argparse, yaml, time, logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
import wandb
|
||||
|
||||
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
||||
PROJ = os.path.join(ZSY, "multimodal_affect")
|
||||
sys.path.insert(0, PROJ)
|
||||
|
||||
from src.data.dataset import MultimodalDataset, get_dataloader
|
||||
from src.models.encoders import MultimodalEncoder
|
||||
from src.models.classifier import EmotionClassifier
|
||||
from src.rl.fusion_agent import ModalFusionAgent
|
||||
from src.rl.reward import compute_reward
|
||||
|
||||
|
||||
def save_ckpt(state, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(state, path)
|
||||
|
||||
|
||||
def _noisy_batch(dataset, variant, indices, device):
|
||||
text = variant.get("text", dataset.text)
|
||||
audio = variant.get("audio", dataset.audio)
|
||||
vision = variant.get("vision", dataset.vision)
|
||||
return (
|
||||
torch.from_numpy(text[indices]).to(device),
|
||||
torch.from_numpy(audio[indices]).to(device),
|
||||
torch.from_numpy(vision[indices]).to(device),
|
||||
torch.from_numpy(dataset.labels[indices]).to(device),
|
||||
)
|
||||
|
||||
|
||||
def _confidence_targets(variant_name, batch_size, device):
|
||||
target = torch.full((batch_size, 3), 0.9, device=device)
|
||||
noisy_map = {
|
||||
"gaussian_light": (0, 1, 2),
|
||||
"gaussian_heavy": (0, 1, 2),
|
||||
"missing_audio": (1,),
|
||||
"missing_visual": (2,),
|
||||
"text_word_drop_30": (0,),
|
||||
"audio_masking_50": (1,),
|
||||
"realistic_mixed": (0, 1, 2),
|
||||
"audio_time_mask": (1,),
|
||||
}
|
||||
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
|
||||
target[:, idx] = 0.1
|
||||
return target
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(encoder, classifier, loader, device, agent=None):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
if agent is not None:
|
||||
agent.eval()
|
||||
all_preds, all_labels = [], []
|
||||
for batch in loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
# DataParallel: call module directly for eval (avoids scatter overhead)
|
||||
enc = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls = classifier.module if hasattr(classifier, "module") else classifier
|
||||
agt = (agent.module if hasattr(agent, "module") else agent) if agent else None
|
||||
tf, af, vf, confs = enc(text, audio, vision)
|
||||
if agt is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agt.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = cls(fused)
|
||||
all_preds.append(logits.argmax(-1).cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
preds = torch.cat(all_preds).numpy()
|
||||
labels = torch.cat(all_labels).numpy()
|
||||
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
||||
acc = float(accuracy_score(labels, preds))
|
||||
return wf1, acc
|
||||
|
||||
|
||||
# ── Stage A: Supervised pretraining ──────────────────────────────────────
|
||||
|
||||
def train_stage_a(args, cfg, device, gpu_ids):
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
|
||||
# Increase batch_size proportional to # GPUs for DataParallel
|
||||
eff_bs = cfg["batch_size"] * len(gpu_ids)
|
||||
train_loader = get_dataloader(train_ds, eff_bs, distributed=False)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
text_dim = train_ds.text.shape[1]
|
||||
audio_dim = train_ds.audio.shape[1]
|
||||
vision_dim = train_ds.vision.shape[1]
|
||||
num_classes = int(train_ds.labels.max()) + 1
|
||||
proj_dim = cfg.get("proj_dim", 1024)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes,
|
||||
hidden=cfg.get("cls_hidden", 512)).to(device)
|
||||
|
||||
if len(gpu_ids) > 1:
|
||||
encoder = nn.DataParallel(encoder, device_ids=gpu_ids)
|
||||
classifier = nn.DataParallel(classifier, device_ids=gpu_ids)
|
||||
|
||||
params = list(encoder.parameters()) + list(classifier.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
opt, T_max=cfg["epochs"], eta_min=1e-5)
|
||||
scaler = GradScaler()
|
||||
|
||||
conf_weight = cfg.get("conf_weight", 0.2)
|
||||
noise_prob = cfg.get("noise_prob", 0.4)
|
||||
best_wf1 = 0.0
|
||||
|
||||
for epoch in range(cfg["epochs"]):
|
||||
encoder.train()
|
||||
classifier.train()
|
||||
ep_loss = ep_ce = ep_conf = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
B = text.size(0)
|
||||
|
||||
# Noise injection
|
||||
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
|
||||
if use_noise:
|
||||
vname = rng.choice(train_ds.variant_names)
|
||||
v = train_ds.noisy_variants[vname]
|
||||
ni = rng.integers(0, len(train_ds), size=B)
|
||||
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
|
||||
|
||||
with autocast():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
ce_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if use_noise:
|
||||
c_tgt = _confidence_targets(vname, B, device)
|
||||
else:
|
||||
c_tgt = torch.full((B, 3), 0.9, device=device)
|
||||
conf_loss = F.binary_cross_entropy(confs, c_tgt)
|
||||
loss = ce_loss + conf_weight * conf_loss
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(params, 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
|
||||
ep_loss += loss.item()
|
||||
ep_ce += ce_loss.item()
|
||||
ep_conf += conf_loss.item()
|
||||
|
||||
sched.step()
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device)
|
||||
|
||||
n = len(train_loader)
|
||||
logging.info(
|
||||
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
|
||||
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
|
||||
)
|
||||
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
|
||||
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
|
||||
"A/val_acc": val_acc, "epoch": epoch + 1})
|
||||
|
||||
enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict()
|
||||
cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict()
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"epoch": epoch + 1,
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"val_wf1": val_wf1,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
|
||||
save_ckpt({
|
||||
"epoch": cfg["epochs"],
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "last.ckpt"))
|
||||
|
||||
enc_m = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls_m = classifier.module if hasattr(classifier, "module") else classifier
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
|
||||
return enc_m, cls_m, dims
|
||||
|
||||
|
||||
# ── Stage B: PPO training ─────────────────────────────────────────────────
|
||||
|
||||
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
agent.eval()
|
||||
bs = cfg.get("batch_size", 128)
|
||||
nprob = cfg.get("noise_prob", 0.5)
|
||||
rng = np.random.default_rng()
|
||||
states, actions, log_probs, values, rewards = [], [], [], [], []
|
||||
collected = 0
|
||||
|
||||
with torch.no_grad():
|
||||
while collected < rollout_size:
|
||||
bsz = min(bs, rollout_size - collected)
|
||||
idx = rng.integers(0, len(dataset), size=bsz)
|
||||
|
||||
text = torch.from_numpy(dataset.text[idx]).to(device)
|
||||
audio = torch.from_numpy(dataset.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(dataset.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(dataset.labels[idx]).to(device)
|
||||
|
||||
if rng.random() < nprob and dataset.variant_names:
|
||||
vname = rng.choice(dataset.variant_names)
|
||||
v = dataset.noisy_variants[vname]
|
||||
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
|
||||
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
|
||||
weights, log_p, value, _ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
logits = classifier(fused)
|
||||
|
||||
rew, _ = compute_reward(
|
||||
logits, labels, confs, weights, prev_weights,
|
||||
alpha=cfg.get("reward_alpha", 1.0),
|
||||
beta =cfg.get("reward_beta", 0.3),
|
||||
gamma=cfg.get("reward_gamma", 0.1),
|
||||
)
|
||||
|
||||
states.append(state)
|
||||
actions.append(weights)
|
||||
log_probs.append(log_p)
|
||||
values.append(value.squeeze(-1))
|
||||
rewards.append(rew)
|
||||
collected += bsz
|
||||
|
||||
states = torch.cat(states)
|
||||
actions = torch.cat(actions)
|
||||
log_probs = torch.cat(log_probs)
|
||||
values = torch.cat(values)
|
||||
rewards = torch.cat(rewards)
|
||||
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
advantages = rewards - values.detach().cpu()
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
return dict(states=states, actions=actions, log_probs=log_probs,
|
||||
values=values, rewards=rewards, advantages=advantages,
|
||||
mean_weights=actions.mean(0))
|
||||
|
||||
|
||||
def ppo_update(agent, opt, rollout, cfg, device, scaler):
|
||||
eps = cfg.get("ppo_clip", 0.2)
|
||||
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
|
||||
mb_size = cfg.get("ppo_mini_batch", 256)
|
||||
v_coef = cfg.get("value_coef", 0.5)
|
||||
ent_coef = cfg.get("entropy_coef", 0.01)
|
||||
|
||||
states = rollout["states"].to(device)
|
||||
actions = rollout["actions"].to(device)
|
||||
old_lp = rollout["log_probs"].to(device)
|
||||
adv = rollout["advantages"].to(device)
|
||||
ret = rollout["rewards"].to(device)
|
||||
n = states.size(0)
|
||||
total_pl = total_vl = total_ent = cnt = 0.0
|
||||
agent.train()
|
||||
|
||||
for _ in range(ppo_ep):
|
||||
perm = torch.randperm(n, device=device)
|
||||
for start in range(0, n, mb_size):
|
||||
idx = perm[start:start + mb_size]
|
||||
s = states[idx]; a = actions[idx]
|
||||
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
|
||||
with autocast():
|
||||
new_lp, val, ent = agent.evaluate(s, a)
|
||||
val = val.squeeze(-1)
|
||||
ratio = (new_lp - olp).exp()
|
||||
p_loss = -torch.min(ratio*ad,
|
||||
torch.clamp(ratio, 1-eps, 1+eps)*ad).mean()
|
||||
v_loss = F.mse_loss(val, r)
|
||||
e_loss = -ent.mean()
|
||||
loss = p_loss + v_coef*v_loss + ent_coef*e_loss
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
total_pl += p_loss.item()
|
||||
total_vl += v_loss.item()
|
||||
total_ent += ent.mean().item()
|
||||
cnt += 1
|
||||
|
||||
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
|
||||
|
||||
|
||||
def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids):
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
eff_bs = cfg.get("batch_size", 128) * len(gpu_ids)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
# Freeze encoder
|
||||
for p in encoder.parameters():
|
||||
p.requires_grad_(False)
|
||||
encoder.to(device).eval()
|
||||
|
||||
# Classifier: keep trainable
|
||||
classifier.to(device)
|
||||
opt_cls = torch.optim.AdamW(classifier.parameters(),
|
||||
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
|
||||
|
||||
# RL agent (small, DataParallel not needed for 4-dim input)
|
||||
agent = ModalFusionAgent(state_dim=4,
|
||||
hidden=cfg.get("agent_hidden", 128)).to(device)
|
||||
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
|
||||
scaler = GradScaler()
|
||||
|
||||
rollout_size = cfg.get("rollout_steps", 512)
|
||||
n_updates = cfg.get("n_ppo_updates", 500)
|
||||
eval_every = cfg.get("eval_every", 10)
|
||||
best_wf1 = 0.0
|
||||
prev_weights = None
|
||||
|
||||
for upd in range(n_updates):
|
||||
rollout = collect_rollout(
|
||||
encoder, classifier, agent,
|
||||
train_ds, device, rollout_size, cfg, prev_weights,
|
||||
)
|
||||
prev_weights = rollout["mean_weights"].to(device)
|
||||
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
|
||||
|
||||
# Classifier supervised refresh
|
||||
if upd % 2 == 0:
|
||||
idx = np.random.randint(0, len(train_ds), eff_bs)
|
||||
text = torch.from_numpy(train_ds.text[idx]).to(device)
|
||||
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
|
||||
with torch.no_grad():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
with autocast():
|
||||
logits = classifier(fused)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
opt_cls.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(opt_cls)
|
||||
scaler.update()
|
||||
|
||||
if upd % eval_every == 0:
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
|
||||
agent=agent)
|
||||
mean_rew = rollout["rewards"].mean().item()
|
||||
logging.info(
|
||||
f"[StageB] PPO {upd:4d}/{n_updates} | "
|
||||
f"rew={mean_rew:.4f} p={ppo_info['p_loss']:.4f} "
|
||||
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f}"
|
||||
)
|
||||
wandb.log({
|
||||
"B/reward": mean_rew,
|
||||
"B/p_loss": ppo_info["p_loss"],
|
||||
"B/v_loss": ppo_info["v_loss"],
|
||||
"B/entropy": ppo_info["entropy"],
|
||||
"B/val_wf1": val_wf1,
|
||||
"ppo_update": upd,
|
||||
})
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"update": upd,
|
||||
"encoder": encoder.state_dict(),
|
||||
"classifier": classifier.state_dict(),
|
||||
"agent": agent.state_dict(),
|
||||
"val_wf1": val_wf1,
|
||||
**dims,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
|
||||
p.add_argument("--dataset", default="IEMOCAP")
|
||||
p.add_argument("--config", required=True)
|
||||
p.add_argument("--output", required=True)
|
||||
p.add_argument("--checkpoint", default=None)
|
||||
p.add_argument("--gpus", default="0,1,2,3",
|
||||
help="Comma-separated GPU ids to use")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
gpu_ids = [int(g) for g in args.gpus.split(",")]
|
||||
device = torch.device(f"cuda:{gpu_ids[0]}")
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
log_dir = os.path.join(PROJ, "outputs", "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(
|
||||
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"),
|
||||
],
|
||||
)
|
||||
logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})")
|
||||
|
||||
with open(os.path.join(PROJ, args.config)) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
os.environ.setdefault("WANDB_MODE", "offline")
|
||||
wandb.init(
|
||||
project="multimodal_affect",
|
||||
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
|
||||
config={**cfg, "stage": args.stage, "dataset": args.dataset,
|
||||
"gpus": gpu_ids},
|
||||
dir=os.path.join(PROJ, "outputs"),
|
||||
)
|
||||
|
||||
if args.stage == "supervised":
|
||||
train_stage_a(args, cfg, device, gpu_ids)
|
||||
|
||||
elif args.stage == "rl":
|
||||
if not args.checkpoint:
|
||||
raise ValueError("--checkpoint required for --stage rl")
|
||||
ckpt = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
text_dim = ckpt["text_dim"]
|
||||
audio_dim = ckpt["audio_dim"]
|
||||
vision_dim = ckpt["vision_dim"]
|
||||
num_classes = ckpt["num_classes"]
|
||||
proj_dim = ckpt.get("proj_dim", 1024)
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes,
|
||||
proj_dim=proj_dim)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes)
|
||||
encoder.load_state_dict(ckpt["encoder"])
|
||||
classifier.load_state_dict(ckpt["classifier"])
|
||||
logging.info(
|
||||
f"Loaded Stage A ckpt: {args.checkpoint} "
|
||||
f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}"
|
||||
)
|
||||
train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids)
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
with sftp.open(f'{PROJ}/scripts/train/train_d1.py', 'w') as f:
|
||||
f.write(TRAIN_D1)
|
||||
print("Uploaded train_d1.py (DataParallel edition)")
|
||||
|
||||
# Also update the launch script
|
||||
LAUNCH = f'''#!/bin/bash
|
||||
set -e
|
||||
export ZSY={ZSY}
|
||||
export WANDB_MODE=offline
|
||||
export PYTHONPATH={PROJ}
|
||||
cd {PROJ}
|
||||
|
||||
mkdir -p outputs/checkpoints/d1_stageA
|
||||
mkdir -p outputs/checkpoints/d1_stageB
|
||||
mkdir -p outputs/logs
|
||||
|
||||
echo "[$(date)] Starting Stage A (4-GPU DataParallel, 50 epochs)"
|
||||
|
||||
{ZSY}/envs/multimodal_affect/bin/python3 scripts/train/train_d1.py \\
|
||||
--stage supervised \\
|
||||
--dataset IEMOCAP \\
|
||||
--config configs/d1/stage_a.yaml \\
|
||||
--output outputs/checkpoints/d1_stageA \\
|
||||
--gpus 0,1,2,3 \\
|
||||
2>&1 | tee outputs/logs/stage_a_stdout.log
|
||||
|
||||
echo "[$(date)] Stage A done. Starting Stage B (PPO, 500 updates)"
|
||||
|
||||
{ZSY}/envs/multimodal_affect/bin/python3 scripts/train/train_d1.py \\
|
||||
--stage rl \\
|
||||
--dataset IEMOCAP \\
|
||||
--checkpoint outputs/checkpoints/d1_stageA/best.ckpt \\
|
||||
--config configs/d1/stage_b.yaml \\
|
||||
--output outputs/checkpoints/d1_stageB \\
|
||||
--gpus 0,1,2,3 \\
|
||||
2>&1 | tee outputs/logs/stage_b_stdout.log
|
||||
|
||||
echo "[$(date)] All training complete!"
|
||||
'''
|
||||
with sftp.open(f'{PROJ}/run_d1.sh', 'w') as f:
|
||||
f.write(LAUNCH)
|
||||
print("Updated run_d1.sh")
|
||||
|
||||
sftp.close()
|
||||
client.close()
|
||||
print("Done.")
|
||||
@@ -1,630 +0,0 @@
|
||||
"""Upload train_d1.py and config files to server."""
|
||||
import paramiko, warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
||||
sftp = client.open_sftp()
|
||||
|
||||
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
||||
PROJ = ZSY + '/multimodal_affect'
|
||||
|
||||
# ─── scripts/train/train_d1.py ────────────────────────────────────────────
|
||||
TRAIN_D1 = '''\
|
||||
#!/usr/bin/env python3
|
||||
# Phase 1 Direction 1 Training Script
|
||||
# Stage A: Supervised pretraining with noise-aware confidence estimation
|
||||
# Stage B: PPO-based adaptive fusion weight learning
|
||||
#
|
||||
# Launch:
|
||||
# torchrun --nproc_per_node=4 scripts/train/train_d1.py \\
|
||||
# --stage supervised --dataset IEMOCAP \\
|
||||
# --config configs/d1/stage_a.yaml \\
|
||||
# --output outputs/checkpoints/d1_stageA
|
||||
|
||||
import os, sys, argparse, yaml, time, logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
import wandb
|
||||
|
||||
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
||||
PROJ = os.path.join(ZSY, "multimodal_affect")
|
||||
sys.path.insert(0, PROJ)
|
||||
|
||||
from src.data.dataset import MultimodalDataset, get_dataloader
|
||||
from src.models.encoders import MultimodalEncoder
|
||||
from src.models.classifier import EmotionClassifier
|
||||
from src.rl.fusion_agent import ModalFusionAgent
|
||||
from src.rl.reward import compute_reward
|
||||
|
||||
|
||||
# ── Distributed helpers ───────────────────────────────────────────────────
|
||||
|
||||
def setup_ddp():
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
dist.init_process_group("nccl")
|
||||
torch.cuda.set_device(local_rank)
|
||||
return local_rank, dist.get_rank(), dist.get_world_size()
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def is_main(rank):
|
||||
return rank == 0
|
||||
|
||||
def all_reduce_mean(val, device):
|
||||
t = torch.tensor(float(val), device=device)
|
||||
dist.all_reduce(t, op=dist.ReduceOp.SUM)
|
||||
return (t / dist.get_world_size()).item()
|
||||
|
||||
def save_ckpt(state, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(state, path)
|
||||
|
||||
|
||||
def _noisy_batch(dataset, variant, indices, device):
|
||||
text = variant.get("text", dataset.text)
|
||||
audio = variant.get("audio", dataset.audio)
|
||||
vision = variant.get("vision", dataset.vision)
|
||||
return (
|
||||
torch.from_numpy(text[indices]).to(device),
|
||||
torch.from_numpy(audio[indices]).to(device),
|
||||
torch.from_numpy(vision[indices]).to(device),
|
||||
torch.from_numpy(dataset.labels[indices]).to(device),
|
||||
)
|
||||
|
||||
|
||||
def _confidence_targets(variant_name, batch_size, device):
|
||||
target = torch.full((batch_size, 3), 0.9, device=device)
|
||||
noisy_map = {
|
||||
"gaussian_light": (0, 1, 2),
|
||||
"gaussian_heavy": (0, 1, 2),
|
||||
"missing_audio": (1,),
|
||||
"missing_visual": (2,),
|
||||
"text_word_drop_30": (0,),
|
||||
"audio_masking_50": (1,),
|
||||
"realistic_mixed": (0, 1, 2),
|
||||
"audio_time_mask": (1,),
|
||||
}
|
||||
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
|
||||
target[:, idx] = 0.1
|
||||
return target
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(encoder, classifier, loader, device, agent=None):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
if agent is not None:
|
||||
agent.eval()
|
||||
all_preds, all_labels = [], []
|
||||
for batch in loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
if agent is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
all_preds.append(logits.argmax(-1).cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
preds = torch.cat(all_preds).numpy()
|
||||
labels = torch.cat(all_labels).numpy()
|
||||
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
||||
acc = float(accuracy_score(labels, preds))
|
||||
return wf1, acc
|
||||
|
||||
|
||||
# ── Stage A: Supervised pretraining ──────────────────────────────────────
|
||||
|
||||
def train_stage_a(args, cfg, local_rank, rank, world_size):
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
rng = np.random.default_rng(42 + rank)
|
||||
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
|
||||
train_loader = get_dataloader(train_ds, cfg["batch_size"], distributed=True)
|
||||
val_loader = get_dataloader(val_ds, cfg["batch_size"], shuffle=False,
|
||||
distributed=True, drop_last=False)
|
||||
|
||||
text_dim = train_ds.text.shape[1]
|
||||
audio_dim = train_ds.audio.shape[1]
|
||||
vision_dim = train_ds.vision.shape[1]
|
||||
num_classes = int(train_ds.labels.max()) + 1
|
||||
proj_dim = cfg.get("proj_dim", 1024)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes,
|
||||
hidden=cfg.get("cls_hidden", 512)).to(device)
|
||||
encoder = DDP(encoder, device_ids=[local_rank])
|
||||
classifier = DDP(classifier, device_ids=[local_rank])
|
||||
|
||||
params = list(encoder.parameters()) + list(classifier.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
opt, T_max=cfg["epochs"], eta_min=1e-5)
|
||||
scaler = GradScaler()
|
||||
|
||||
conf_weight = cfg.get("conf_weight", 0.2)
|
||||
noise_prob = cfg.get("noise_prob", 0.4)
|
||||
best_wf1 = 0.0
|
||||
|
||||
for epoch in range(cfg["epochs"]):
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
encoder.train()
|
||||
classifier.train()
|
||||
ep_loss = ep_ce = ep_conf = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
B = text.size(0)
|
||||
|
||||
# Noise injection: randomly replace with noisy variant
|
||||
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
|
||||
if use_noise:
|
||||
vname = rng.choice(train_ds.variant_names)
|
||||
v = train_ds.noisy_variants[vname]
|
||||
ni = rng.integers(0, len(train_ds), size=B)
|
||||
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
|
||||
|
||||
with autocast():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
ce_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
# Confidence target: noisy modalities -> 0.1, clean -> 0.9
|
||||
if use_noise:
|
||||
c_tgt = _confidence_targets(vname, B, device)
|
||||
else:
|
||||
c_tgt = torch.full((B, 3), 0.9, device=device)
|
||||
conf_loss = F.binary_cross_entropy(confs, c_tgt)
|
||||
loss = ce_loss + conf_weight * conf_loss
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(params, 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
|
||||
ep_loss += loss.item()
|
||||
ep_ce += ce_loss.item()
|
||||
ep_conf += conf_loss.item()
|
||||
|
||||
sched.step()
|
||||
val_wf1, val_acc = evaluate(encoder.module, classifier.module,
|
||||
val_loader, device)
|
||||
val_wf1 = all_reduce_mean(val_wf1, device)
|
||||
|
||||
if is_main(rank):
|
||||
n = len(train_loader)
|
||||
logging.info(
|
||||
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
|
||||
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
|
||||
)
|
||||
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
|
||||
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
|
||||
"A/val_acc": val_acc, "epoch": epoch + 1})
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"epoch": epoch + 1,
|
||||
"encoder": encoder.module.state_dict(),
|
||||
"classifier": classifier.module.state_dict(),
|
||||
"val_wf1": val_wf1,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
if is_main(rank):
|
||||
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
|
||||
save_ckpt({
|
||||
"epoch": cfg["epochs"],
|
||||
"encoder": encoder.module.state_dict(),
|
||||
"classifier": classifier.module.state_dict(),
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "last.ckpt"))
|
||||
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
|
||||
return encoder.module, classifier.module, dims
|
||||
|
||||
|
||||
# ── Stage B: PPO training ─────────────────────────────────────────────────
|
||||
|
||||
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
agent.eval()
|
||||
bs = cfg.get("batch_size", 128)
|
||||
nprob = cfg.get("noise_prob", 0.5)
|
||||
rng = np.random.default_rng()
|
||||
states, actions, log_probs, values, rewards = [], [], [], [], []
|
||||
collected = 0
|
||||
|
||||
with torch.no_grad():
|
||||
while collected < rollout_size:
|
||||
bsz = min(bs, rollout_size - collected)
|
||||
idx = rng.integers(0, len(dataset), size=bsz)
|
||||
|
||||
text = torch.from_numpy(dataset.text[idx]).to(device)
|
||||
audio = torch.from_numpy(dataset.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(dataset.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(dataset.labels[idx]).to(device)
|
||||
|
||||
if rng.random() < nprob and dataset.variant_names:
|
||||
vname = rng.choice(dataset.variant_names)
|
||||
v = dataset.noisy_variants[vname]
|
||||
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
|
||||
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1) # (B, 4)
|
||||
|
||||
weights, log_p, value, _ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
logits = classifier(fused)
|
||||
|
||||
rew, _ = compute_reward(
|
||||
logits, labels, confs, weights, prev_weights,
|
||||
alpha=cfg.get("reward_alpha", 1.0),
|
||||
beta =cfg.get("reward_beta", 0.3),
|
||||
gamma=cfg.get("reward_gamma", 0.1),
|
||||
)
|
||||
|
||||
states.append(state)
|
||||
actions.append(weights)
|
||||
log_probs.append(log_p)
|
||||
values.append(value.squeeze(-1))
|
||||
rewards.append(rew)
|
||||
collected += bsz
|
||||
|
||||
states = torch.cat(states)
|
||||
actions = torch.cat(actions)
|
||||
log_probs = torch.cat(log_probs)
|
||||
values = torch.cat(values)
|
||||
rewards = torch.cat(rewards)
|
||||
|
||||
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
advantages = rewards - values.detach().cpu()
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
return dict(states=states, actions=actions, log_probs=log_probs,
|
||||
values=values, rewards=rewards, advantages=advantages,
|
||||
mean_weights=actions.mean(0))
|
||||
|
||||
|
||||
def ppo_update(agent, opt, rollout, cfg, device, scaler):
|
||||
eps = cfg.get("ppo_clip", 0.2)
|
||||
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
|
||||
mb_size = cfg.get("ppo_mini_batch", 256)
|
||||
v_coef = cfg.get("value_coef", 0.5)
|
||||
ent_coef = cfg.get("entropy_coef", 0.01)
|
||||
|
||||
states = rollout["states"].to(device)
|
||||
actions = rollout["actions"].to(device)
|
||||
old_lp = rollout["log_probs"].to(device)
|
||||
adv = rollout["advantages"].to(device)
|
||||
ret = rollout["rewards"].to(device)
|
||||
n = states.size(0)
|
||||
|
||||
total_pl = total_vl = total_ent = cnt = 0.0
|
||||
agent.train()
|
||||
|
||||
for _ in range(ppo_ep):
|
||||
perm = torch.randperm(n, device=device)
|
||||
for start in range(0, n, mb_size):
|
||||
idx = perm[start:start + mb_size]
|
||||
s = states[idx]; a = actions[idx]
|
||||
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
|
||||
|
||||
with autocast():
|
||||
new_lp, val, ent = agent.evaluate(s, a)
|
||||
val = val.squeeze(-1)
|
||||
ratio = (new_lp - olp).exp()
|
||||
p_loss = -torch.min(ratio * ad,
|
||||
torch.clamp(ratio, 1-eps, 1+eps) * ad).mean()
|
||||
v_loss = F.mse_loss(val, r)
|
||||
e_loss = -ent.mean()
|
||||
loss = p_loss + v_coef * v_loss + ent_coef * e_loss
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
|
||||
total_pl += p_loss.item()
|
||||
total_vl += v_loss.item()
|
||||
total_ent += ent.mean().item()
|
||||
cnt += 1
|
||||
|
||||
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
|
||||
|
||||
|
||||
def train_stage_b(args, cfg, encoder, classifier, dims, local_rank, rank, world_size):
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
val_loader = get_dataloader(val_ds, cfg.get("batch_size", 128),
|
||||
shuffle=False, distributed=True, drop_last=False)
|
||||
|
||||
# Freeze encoder (projectors + confidence estimators)
|
||||
for p in encoder.parameters():
|
||||
p.requires_grad_(False)
|
||||
encoder.to(device).eval()
|
||||
|
||||
# Classifier: keep trainable (supervised component)
|
||||
classifier.to(device)
|
||||
cls_ddp = DDP(classifier, device_ids=[local_rank])
|
||||
opt_cls = torch.optim.AdamW(classifier.parameters(),
|
||||
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
|
||||
|
||||
# RL agent
|
||||
agent = ModalFusionAgent(state_dim=4,
|
||||
hidden=cfg.get("agent_hidden", 128)).to(device)
|
||||
agent = DDP(agent, device_ids=[local_rank])
|
||||
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
|
||||
|
||||
scaler = GradScaler()
|
||||
|
||||
rollout_size = cfg.get("rollout_steps", 512)
|
||||
n_updates = cfg.get("n_ppo_updates", 500)
|
||||
eval_every = cfg.get("eval_every", 10)
|
||||
best_wf1 = 0.0
|
||||
prev_weights = None
|
||||
|
||||
for upd in range(n_updates):
|
||||
# Rollout collection
|
||||
rollout = collect_rollout(
|
||||
encoder, classifier, agent.module,
|
||||
train_ds, device, rollout_size, cfg, prev_weights,
|
||||
)
|
||||
prev_weights = rollout["mean_weights"].to(device)
|
||||
|
||||
# PPO update
|
||||
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
|
||||
|
||||
# Lightweight supervised classifier refresh (one mini-batch every 2 updates)
|
||||
if upd % 2 == 0:
|
||||
idx = np.random.randint(0, len(train_ds), cfg.get("batch_size", 128))
|
||||
text = torch.from_numpy(train_ds.text[idx]).to(device)
|
||||
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
|
||||
with torch.no_grad():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.module.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
with autocast():
|
||||
logits = cls_ddp(fused)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
opt_cls.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(opt_cls)
|
||||
scaler.update()
|
||||
|
||||
# Evaluate
|
||||
if upd % eval_every == 0:
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
|
||||
agent=agent.module)
|
||||
val_wf1 = all_reduce_mean(val_wf1, device)
|
||||
|
||||
if is_main(rank):
|
||||
mean_rew = rollout["rewards"].mean().item()
|
||||
logging.info(
|
||||
f"[StageB] PPO {upd:4d}/{n_updates} | "
|
||||
f"rew={mean_rew:.4f} p={ppo_info['p_loss']:.4f} "
|
||||
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f}"
|
||||
)
|
||||
wandb.log({
|
||||
"B/reward": mean_rew,
|
||||
"B/p_loss": ppo_info["p_loss"],
|
||||
"B/v_loss": ppo_info["v_loss"],
|
||||
"B/entropy": ppo_info["entropy"],
|
||||
"B/val_wf1": val_wf1,
|
||||
"ppo_update": upd,
|
||||
})
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"update": upd,
|
||||
"encoder": encoder.state_dict(),
|
||||
"classifier": classifier.state_dict(),
|
||||
"agent": agent.module.state_dict(),
|
||||
"val_wf1": val_wf1,
|
||||
**dims,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
if is_main(rank):
|
||||
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
|
||||
p.add_argument("--dataset", default="IEMOCAP")
|
||||
p.add_argument("--config", required=True)
|
||||
p.add_argument("--output", required=True)
|
||||
p.add_argument("--checkpoint", default=None)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
local_rank, rank, world_size = setup_ddp()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
log_dir = os.path.join(PROJ, "outputs", "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
|
||||
handlers = [logging.StreamHandler()]
|
||||
if is_main(rank):
|
||||
handlers.append(logging.FileHandler(
|
||||
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"))
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=f"[rank{rank}] %(asctime)s %(message)s",
|
||||
handlers=handlers,
|
||||
)
|
||||
|
||||
with open(os.path.join(PROJ, args.config)) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
if is_main(rank):
|
||||
os.environ.setdefault("WANDB_MODE", "offline")
|
||||
wandb.init(
|
||||
project="multimodal_affect",
|
||||
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
|
||||
config={**cfg, "stage": args.stage, "dataset": args.dataset},
|
||||
dir=os.path.join(PROJ, "outputs"),
|
||||
)
|
||||
|
||||
if args.stage == "supervised":
|
||||
train_stage_a(args, cfg, local_rank, rank, world_size)
|
||||
|
||||
elif args.stage == "rl":
|
||||
if not args.checkpoint:
|
||||
raise ValueError("--checkpoint required for --stage rl")
|
||||
ckpt = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
text_dim = ckpt["text_dim"]
|
||||
audio_dim = ckpt["audio_dim"]
|
||||
vision_dim = ckpt["vision_dim"]
|
||||
num_classes = ckpt["num_classes"]
|
||||
proj_dim = ckpt.get("proj_dim", 1024)
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes,
|
||||
proj_dim=proj_dim)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes)
|
||||
encoder.load_state_dict(ckpt["encoder"])
|
||||
classifier.load_state_dict(ckpt["classifier"])
|
||||
|
||||
if is_main(rank):
|
||||
logging.info(
|
||||
f"Loaded Stage A ckpt from {args.checkpoint} "
|
||||
f"(val_wf1={ckpt.get('val_wf1', 0.0):.4f})"
|
||||
)
|
||||
train_stage_b(args, cfg, encoder, classifier, dims,
|
||||
local_rank, rank, world_size)
|
||||
|
||||
if is_main(rank):
|
||||
wandb.finish()
|
||||
cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
# ─── configs/d1/stage_a.yaml ──────────────────────────────────────────────
|
||||
STAGE_A_YAML = '''\
|
||||
# Stage A: Supervised pretraining
|
||||
# Trains projection MLPs + confidence estimators + classifier
|
||||
# with noise injection to teach confidence estimation
|
||||
|
||||
epochs: 50
|
||||
batch_size: 128
|
||||
lr: 2.0e-4
|
||||
wd: 1.0e-4
|
||||
proj_dim: 1024
|
||||
cls_hidden: 512
|
||||
conf_weight: 0.2 # BCE loss weight for confidence estimators
|
||||
noise_prob: 0.4 # probability of injecting noisy batch
|
||||
'''
|
||||
|
||||
# ─── configs/d1/stage_b.yaml ──────────────────────────────────────────────
|
||||
STAGE_B_YAML = '''\
|
||||
# Stage B: PPO-based adaptive fusion weight learning
|
||||
# Encoder (projectors + confidence estimators) frozen from Stage A
|
||||
# RL agent learns noise-adaptive fusion weights via PPO
|
||||
|
||||
batch_size: 128
|
||||
proj_dim: 1024
|
||||
|
||||
# PPO
|
||||
rollout_steps: 512 # experiences collected per PPO update
|
||||
n_ppo_updates: 500 # total PPO update iterations
|
||||
ppo_clip: 0.2
|
||||
ppo_epochs_per_update: 4
|
||||
ppo_mini_batch: 256
|
||||
value_coef: 0.5
|
||||
entropy_coef: 0.01
|
||||
rl_lr: 3.0e-4
|
||||
cls_lr: 5.0e-5
|
||||
|
||||
# Reward coefficients (R = alpha*(-CE) + beta*Consistency - gamma*Instability)
|
||||
reward_alpha: 1.0
|
||||
reward_beta: 0.3
|
||||
reward_gamma: 0.1
|
||||
|
||||
# RL agent architecture
|
||||
agent_hidden: 128
|
||||
|
||||
# Noise injection during rollout collection
|
||||
noise_prob: 0.5
|
||||
|
||||
eval_every: 10 # evaluate every N PPO updates
|
||||
'''
|
||||
|
||||
# Upload
|
||||
uploads = {
|
||||
f"{PROJ}/scripts/train/train_d1.py": TRAIN_D1,
|
||||
f"{PROJ}/configs/d1/stage_a.yaml": STAGE_A_YAML,
|
||||
f"{PROJ}/configs/d1/stage_b.yaml": STAGE_B_YAML,
|
||||
}
|
||||
|
||||
for path, content in uploads.items():
|
||||
with sftp.open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print(f" uploaded: {path.split('multimodal_affect/')[-1]}")
|
||||
|
||||
sftp.close()
|
||||
client.close()
|
||||
print("\nAll training files uploaded.")
|
||||
Binary file not shown.
BIN
旧方向信息/执行摘要.pdf
BIN
旧方向信息/执行摘要.pdf
Binary file not shown.
Reference in New Issue
Block a user