chore: initial commit — unified project repo

Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-14 11:28:42 +08:00
commit bd1f51c496
85 changed files with 20568 additions and 0 deletions

View File

@@ -0,0 +1,578 @@
"""Upload revised train_d1.py using DataParallel (4-GPU, no DDP/NCCL needed)."""
import paramiko, warnings
warnings.filterwarnings("ignore")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
sftp = client.open_sftp()
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
PROJ = ZSY + '/multimodal_affect'
TRAIN_D1 = '''\
#!/usr/bin/env python3
# Phase 1 Direction 1 Training Script (DataParallel edition)
# Stage A: Supervised pretraining with noise-aware confidence estimation
# Stage B: PPO-based adaptive fusion weight learning
#
# Uses nn.DataParallel (4 GPUs, single process, no NCCL needed)
#
# Launch:
# python scripts/train/train_d1.py \\
# --stage supervised --dataset IEMOCAP \\
# --config configs/d1/stage_a.yaml \\
# --output outputs/checkpoints/d1_stageA
import os, sys, argparse, yaml, time, logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import f1_score, accuracy_score
import wandb
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
PROJ = os.path.join(ZSY, "multimodal_affect")
sys.path.insert(0, PROJ)
from src.data.dataset import MultimodalDataset, get_dataloader
from src.models.encoders import MultimodalEncoder
from src.models.classifier import EmotionClassifier
from src.rl.fusion_agent import ModalFusionAgent
from src.rl.reward import compute_reward
def save_ckpt(state, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(state, path)
def _noisy_batch(dataset, variant, indices, device):
text = variant.get("text", dataset.text)
audio = variant.get("audio", dataset.audio)
vision = variant.get("vision", dataset.vision)
return (
torch.from_numpy(text[indices]).to(device),
torch.from_numpy(audio[indices]).to(device),
torch.from_numpy(vision[indices]).to(device),
torch.from_numpy(dataset.labels[indices]).to(device),
)
def _confidence_targets(variant_name, batch_size, device):
target = torch.full((batch_size, 3), 0.9, device=device)
noisy_map = {
"gaussian_light": (0, 1, 2),
"gaussian_heavy": (0, 1, 2),
"missing_audio": (1,),
"missing_visual": (2,),
"text_word_drop_30": (0,),
"audio_masking_50": (1,),
"realistic_mixed": (0, 1, 2),
"audio_time_mask": (1,),
}
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
target[:, idx] = 0.1
return target
# ── Evaluation ────────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate(encoder, classifier, loader, device, agent=None):
encoder.eval()
classifier.eval()
if agent is not None:
agent.eval()
all_preds, all_labels = [], []
for batch in loader:
text = batch["text"].to(device)
audio = batch["audio"].to(device)
vision = batch["vision"].to(device)
labels = batch["labels"].to(device)
# DataParallel: call module directly for eval (avoids scatter overhead)
enc = encoder.module if hasattr(encoder, "module") else encoder
cls = classifier.module if hasattr(classifier, "module") else classifier
agt = (agent.module if hasattr(agent, "module") else agent) if agent else None
tf, af, vf, confs = enc(text, audio, vision)
if agt is not None:
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, *_ = agt.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
else:
fused = (tf + af + vf) / 3.0
logits = cls(fused)
all_preds.append(logits.argmax(-1).cpu())
all_labels.append(labels.cpu())
preds = torch.cat(all_preds).numpy()
labels = torch.cat(all_labels).numpy()
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
acc = float(accuracy_score(labels, preds))
return wf1, acc
# ── Stage A: Supervised pretraining ──────────────────────────────────────
def train_stage_a(args, cfg, device, gpu_ids):
rng = np.random.default_rng(42)
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
noise_root=noise_root)
val_ds = MultimodalDataset(data_dir, "val")
# Increase batch_size proportional to # GPUs for DataParallel
eff_bs = cfg["batch_size"] * len(gpu_ids)
train_loader = get_dataloader(train_ds, eff_bs, distributed=False)
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
distributed=False, drop_last=False)
text_dim = train_ds.text.shape[1]
audio_dim = train_ds.audio.shape[1]
vision_dim = train_ds.vision.shape[1]
num_classes = int(train_ds.labels.max()) + 1
proj_dim = cfg.get("proj_dim", 1024)
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
classifier = EmotionClassifier(proj_dim, num_classes,
hidden=cfg.get("cls_hidden", 512)).to(device)
if len(gpu_ids) > 1:
encoder = nn.DataParallel(encoder, device_ids=gpu_ids)
classifier = nn.DataParallel(classifier, device_ids=gpu_ids)
params = list(encoder.parameters()) + list(classifier.parameters())
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=cfg["epochs"], eta_min=1e-5)
scaler = GradScaler()
conf_weight = cfg.get("conf_weight", 0.2)
noise_prob = cfg.get("noise_prob", 0.4)
best_wf1 = 0.0
for epoch in range(cfg["epochs"]):
encoder.train()
classifier.train()
ep_loss = ep_ce = ep_conf = 0.0
for batch in train_loader:
text = batch["text"].to(device)
audio = batch["audio"].to(device)
vision = batch["vision"].to(device)
labels = batch["labels"].to(device)
B = text.size(0)
# Noise injection
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
if use_noise:
vname = rng.choice(train_ds.variant_names)
v = train_ds.noisy_variants[vname]
ni = rng.integers(0, len(train_ds), size=B)
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
with autocast():
tf, af, vf, confs = encoder(text, audio, vision)
fused = (tf + af + vf) / 3.0
logits = classifier(fused)
ce_loss = F.cross_entropy(logits, labels)
if use_noise:
c_tgt = _confidence_targets(vname, B, device)
else:
c_tgt = torch.full((B, 3), 0.9, device=device)
conf_loss = F.binary_cross_entropy(confs, c_tgt)
loss = ce_loss + conf_weight * conf_loss
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(params, 1.0)
scaler.step(opt)
scaler.update()
ep_loss += loss.item()
ep_ce += ce_loss.item()
ep_conf += conf_loss.item()
sched.step()
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device)
n = len(train_loader)
logging.info(
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
)
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
"A/val_acc": val_acc, "epoch": epoch + 1})
enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict()
cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict()
if val_wf1 > best_wf1:
best_wf1 = val_wf1
save_ckpt({
"epoch": epoch + 1,
"encoder": enc_state,
"classifier": cls_state,
"val_wf1": val_wf1,
"text_dim": text_dim, "audio_dim": audio_dim,
"vision_dim": vision_dim, "num_classes": num_classes,
"proj_dim": proj_dim, "cfg": cfg,
}, os.path.join(args.output, "best.ckpt"))
logging.info(f" -> New best WF1: {val_wf1:.4f}")
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
save_ckpt({
"epoch": cfg["epochs"],
"encoder": enc_state,
"classifier": cls_state,
"text_dim": text_dim, "audio_dim": audio_dim,
"vision_dim": vision_dim, "num_classes": num_classes,
"proj_dim": proj_dim, "cfg": cfg,
}, os.path.join(args.output, "last.ckpt"))
enc_m = encoder.module if hasattr(encoder, "module") else encoder
cls_m = classifier.module if hasattr(classifier, "module") else classifier
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
return enc_m, cls_m, dims
# ── Stage B: PPO training ─────────────────────────────────────────────────
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
encoder.eval()
classifier.eval()
agent.eval()
bs = cfg.get("batch_size", 128)
nprob = cfg.get("noise_prob", 0.5)
rng = np.random.default_rng()
states, actions, log_probs, values, rewards = [], [], [], [], []
collected = 0
with torch.no_grad():
while collected < rollout_size:
bsz = min(bs, rollout_size - collected)
idx = rng.integers(0, len(dataset), size=bsz)
text = torch.from_numpy(dataset.text[idx]).to(device)
audio = torch.from_numpy(dataset.audio[idx]).to(device)
vision = torch.from_numpy(dataset.vision[idx]).to(device)
labels = torch.from_numpy(dataset.labels[idx]).to(device)
if rng.random() < nprob and dataset.variant_names:
vname = rng.choice(dataset.variant_names)
v = dataset.noisy_variants[vname]
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
tf, af, vf, confs = encoder(text, audio, vision)
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, log_p, value, _ = agent.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
logits = classifier(fused)
rew, _ = compute_reward(
logits, labels, confs, weights, prev_weights,
alpha=cfg.get("reward_alpha", 1.0),
beta =cfg.get("reward_beta", 0.3),
gamma=cfg.get("reward_gamma", 0.1),
)
states.append(state)
actions.append(weights)
log_probs.append(log_p)
values.append(value.squeeze(-1))
rewards.append(rew)
collected += bsz
states = torch.cat(states)
actions = torch.cat(actions)
log_probs = torch.cat(log_probs)
values = torch.cat(values)
rewards = torch.cat(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
advantages = rewards - values.detach().cpu()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return dict(states=states, actions=actions, log_probs=log_probs,
values=values, rewards=rewards, advantages=advantages,
mean_weights=actions.mean(0))
def ppo_update(agent, opt, rollout, cfg, device, scaler):
eps = cfg.get("ppo_clip", 0.2)
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
mb_size = cfg.get("ppo_mini_batch", 256)
v_coef = cfg.get("value_coef", 0.5)
ent_coef = cfg.get("entropy_coef", 0.01)
states = rollout["states"].to(device)
actions = rollout["actions"].to(device)
old_lp = rollout["log_probs"].to(device)
adv = rollout["advantages"].to(device)
ret = rollout["rewards"].to(device)
n = states.size(0)
total_pl = total_vl = total_ent = cnt = 0.0
agent.train()
for _ in range(ppo_ep):
perm = torch.randperm(n, device=device)
for start in range(0, n, mb_size):
idx = perm[start:start + mb_size]
s = states[idx]; a = actions[idx]
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
with autocast():
new_lp, val, ent = agent.evaluate(s, a)
val = val.squeeze(-1)
ratio = (new_lp - olp).exp()
p_loss = -torch.min(ratio*ad,
torch.clamp(ratio, 1-eps, 1+eps)*ad).mean()
v_loss = F.mse_loss(val, r)
e_loss = -ent.mean()
loss = p_loss + v_coef*v_loss + ent_coef*e_loss
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
scaler.step(opt)
scaler.update()
total_pl += p_loss.item()
total_vl += v_loss.item()
total_ent += ent.mean().item()
cnt += 1
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids):
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
noise_root=noise_root)
val_ds = MultimodalDataset(data_dir, "val")
eff_bs = cfg.get("batch_size", 128) * len(gpu_ids)
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
distributed=False, drop_last=False)
# Freeze encoder
for p in encoder.parameters():
p.requires_grad_(False)
encoder.to(device).eval()
# Classifier: keep trainable
classifier.to(device)
opt_cls = torch.optim.AdamW(classifier.parameters(),
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
# RL agent (small, DataParallel not needed for 4-dim input)
agent = ModalFusionAgent(state_dim=4,
hidden=cfg.get("agent_hidden", 128)).to(device)
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
scaler = GradScaler()
rollout_size = cfg.get("rollout_steps", 512)
n_updates = cfg.get("n_ppo_updates", 500)
eval_every = cfg.get("eval_every", 10)
best_wf1 = 0.0
prev_weights = None
for upd in range(n_updates):
rollout = collect_rollout(
encoder, classifier, agent,
train_ds, device, rollout_size, cfg, prev_weights,
)
prev_weights = rollout["mean_weights"].to(device)
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
# Classifier supervised refresh
if upd % 2 == 0:
idx = np.random.randint(0, len(train_ds), eff_bs)
text = torch.from_numpy(train_ds.text[idx]).to(device)
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
with torch.no_grad():
tf, af, vf, confs = encoder(text, audio, vision)
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, *_ = agent.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
with autocast():
logits = classifier(fused)
loss = F.cross_entropy(logits, labels)
opt_cls.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.step(opt_cls)
scaler.update()
if upd % eval_every == 0:
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
agent=agent)
mean_rew = rollout["rewards"].mean().item()
logging.info(
f"[StageB] PPO {upd:4d}/{n_updates} | "
f"rew={mean_rew:.4f} p={ppo_info['p_loss']:.4f} "
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
f"val_wf1={val_wf1:.4f}"
)
wandb.log({
"B/reward": mean_rew,
"B/p_loss": ppo_info["p_loss"],
"B/v_loss": ppo_info["v_loss"],
"B/entropy": ppo_info["entropy"],
"B/val_wf1": val_wf1,
"ppo_update": upd,
})
if val_wf1 > best_wf1:
best_wf1 = val_wf1
save_ckpt({
"update": upd,
"encoder": encoder.state_dict(),
"classifier": classifier.state_dict(),
"agent": agent.state_dict(),
"val_wf1": val_wf1,
**dims,
}, os.path.join(args.output, "best.ckpt"))
logging.info(f" -> New best WF1: {val_wf1:.4f}")
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
# ── Main ──────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
p.add_argument("--dataset", default="IEMOCAP")
p.add_argument("--config", required=True)
p.add_argument("--output", required=True)
p.add_argument("--checkpoint", default=None)
p.add_argument("--gpus", default="0,1,2,3",
help="Comma-separated GPU ids to use")
return p.parse_args()
def main():
args = parse_args()
gpu_ids = [int(g) for g in args.gpus.split(",")]
device = torch.device(f"cuda:{gpu_ids[0]}")
os.makedirs(args.output, exist_ok=True)
log_dir = os.path.join(PROJ, "outputs", "logs")
os.makedirs(log_dir, exist_ok=True)
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"),
],
)
logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})")
with open(os.path.join(PROJ, args.config)) as f:
cfg = yaml.safe_load(f)
os.environ.setdefault("WANDB_MODE", "offline")
wandb.init(
project="multimodal_affect",
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
config={**cfg, "stage": args.stage, "dataset": args.dataset,
"gpus": gpu_ids},
dir=os.path.join(PROJ, "outputs"),
)
if args.stage == "supervised":
train_stage_a(args, cfg, device, gpu_ids)
elif args.stage == "rl":
if not args.checkpoint:
raise ValueError("--checkpoint required for --stage rl")
ckpt = torch.load(args.checkpoint, map_location=device)
text_dim = ckpt["text_dim"]
audio_dim = ckpt["audio_dim"]
vision_dim = ckpt["vision_dim"]
num_classes = ckpt["num_classes"]
proj_dim = ckpt.get("proj_dim", 1024)
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
vision_dim=vision_dim, num_classes=num_classes,
proj_dim=proj_dim)
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
classifier = EmotionClassifier(proj_dim, num_classes)
encoder.load_state_dict(ckpt["encoder"])
classifier.load_state_dict(ckpt["classifier"])
logging.info(
f"Loaded Stage A ckpt: {args.checkpoint} "
f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}"
)
train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids)
wandb.finish()
if __name__ == "__main__":
main()
'''
with sftp.open(f'{PROJ}/scripts/train/train_d1.py', 'w') as f:
f.write(TRAIN_D1)
print("Uploaded train_d1.py (DataParallel edition)")
# Also update the launch script
LAUNCH = f'''#!/bin/bash
set -e
export ZSY={ZSY}
export WANDB_MODE=offline
export PYTHONPATH={PROJ}
cd {PROJ}
mkdir -p outputs/checkpoints/d1_stageA
mkdir -p outputs/checkpoints/d1_stageB
mkdir -p outputs/logs
echo "[$(date)] Starting Stage A (4-GPU DataParallel, 50 epochs)"
{ZSY}/envs/multimodal_affect/bin/python3 scripts/train/train_d1.py \\
--stage supervised \\
--dataset IEMOCAP \\
--config configs/d1/stage_a.yaml \\
--output outputs/checkpoints/d1_stageA \\
--gpus 0,1,2,3 \\
2>&1 | tee outputs/logs/stage_a_stdout.log
echo "[$(date)] Stage A done. Starting Stage B (PPO, 500 updates)"
{ZSY}/envs/multimodal_affect/bin/python3 scripts/train/train_d1.py \\
--stage rl \\
--dataset IEMOCAP \\
--checkpoint outputs/checkpoints/d1_stageA/best.ckpt \\
--config configs/d1/stage_b.yaml \\
--output outputs/checkpoints/d1_stageB \\
--gpus 0,1,2,3 \\
2>&1 | tee outputs/logs/stage_b_stdout.log
echo "[$(date)] All training complete!"
'''
with sftp.open(f'{PROJ}/run_d1.sh', 'w') as f:
f.write(LAUNCH)
print("Updated run_d1.sh")
sftp.close()
client.close()
print("Done.")