Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
631 lines
24 KiB
Python
631 lines
24 KiB
Python
"""Upload train_d1.py and config files to server."""
|
|
import paramiko, warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
|
sftp = client.open_sftp()
|
|
|
|
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
|
PROJ = ZSY + '/multimodal_affect'
|
|
|
|
# ─── scripts/train/train_d1.py ────────────────────────────────────────────
|
|
TRAIN_D1 = '''\
|
|
#!/usr/bin/env python3
|
|
# Phase 1 Direction 1 Training Script
|
|
# Stage A: Supervised pretraining with noise-aware confidence estimation
|
|
# Stage B: PPO-based adaptive fusion weight learning
|
|
#
|
|
# Launch:
|
|
# torchrun --nproc_per_node=4 scripts/train/train_d1.py \\
|
|
# --stage supervised --dataset IEMOCAP \\
|
|
# --config configs/d1/stage_a.yaml \\
|
|
# --output outputs/checkpoints/d1_stageA
|
|
|
|
import os, sys, argparse, yaml, time, logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
from sklearn.metrics import f1_score, accuracy_score
|
|
import wandb
|
|
|
|
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
|
PROJ = os.path.join(ZSY, "multimodal_affect")
|
|
sys.path.insert(0, PROJ)
|
|
|
|
from src.data.dataset import MultimodalDataset, get_dataloader
|
|
from src.models.encoders import MultimodalEncoder
|
|
from src.models.classifier import EmotionClassifier
|
|
from src.rl.fusion_agent import ModalFusionAgent
|
|
from src.rl.reward import compute_reward
|
|
|
|
|
|
# ── Distributed helpers ───────────────────────────────────────────────────
|
|
|
|
def setup_ddp():
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
dist.init_process_group("nccl")
|
|
torch.cuda.set_device(local_rank)
|
|
return local_rank, dist.get_rank(), dist.get_world_size()
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
def is_main(rank):
|
|
return rank == 0
|
|
|
|
def all_reduce_mean(val, device):
|
|
t = torch.tensor(float(val), device=device)
|
|
dist.all_reduce(t, op=dist.ReduceOp.SUM)
|
|
return (t / dist.get_world_size()).item()
|
|
|
|
def save_ckpt(state, path):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
torch.save(state, path)
|
|
|
|
|
|
def _noisy_batch(dataset, variant, indices, device):
|
|
text = variant.get("text", dataset.text)
|
|
audio = variant.get("audio", dataset.audio)
|
|
vision = variant.get("vision", dataset.vision)
|
|
return (
|
|
torch.from_numpy(text[indices]).to(device),
|
|
torch.from_numpy(audio[indices]).to(device),
|
|
torch.from_numpy(vision[indices]).to(device),
|
|
torch.from_numpy(dataset.labels[indices]).to(device),
|
|
)
|
|
|
|
|
|
def _confidence_targets(variant_name, batch_size, device):
|
|
target = torch.full((batch_size, 3), 0.9, device=device)
|
|
noisy_map = {
|
|
"gaussian_light": (0, 1, 2),
|
|
"gaussian_heavy": (0, 1, 2),
|
|
"missing_audio": (1,),
|
|
"missing_visual": (2,),
|
|
"text_word_drop_30": (0,),
|
|
"audio_masking_50": (1,),
|
|
"realistic_mixed": (0, 1, 2),
|
|
"audio_time_mask": (1,),
|
|
}
|
|
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
|
|
target[:, idx] = 0.1
|
|
return target
|
|
|
|
|
|
# ── Evaluation ────────────────────────────────────────────────────────────
|
|
|
|
@torch.no_grad()
|
|
def evaluate(encoder, classifier, loader, device, agent=None):
|
|
encoder.eval()
|
|
classifier.eval()
|
|
if agent is not None:
|
|
agent.eval()
|
|
all_preds, all_labels = [], []
|
|
for batch in loader:
|
|
text = batch["text"].to(device)
|
|
audio = batch["audio"].to(device)
|
|
vision = batch["vision"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
if agent is not None:
|
|
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
|
state = torch.cat([confs, noise_est], dim=-1)
|
|
weights, *_ = agent.get_action_and_value(state)
|
|
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
|
else:
|
|
fused = (tf + af + vf) / 3.0
|
|
logits = classifier(fused)
|
|
all_preds.append(logits.argmax(-1).cpu())
|
|
all_labels.append(labels.cpu())
|
|
preds = torch.cat(all_preds).numpy()
|
|
labels = torch.cat(all_labels).numpy()
|
|
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
|
acc = float(accuracy_score(labels, preds))
|
|
return wf1, acc
|
|
|
|
|
|
# ── Stage A: Supervised pretraining ──────────────────────────────────────
|
|
|
|
def train_stage_a(args, cfg, local_rank, rank, world_size):
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
rng = np.random.default_rng(42 + rank)
|
|
|
|
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
|
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
|
|
|
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
|
noise_root=noise_root)
|
|
val_ds = MultimodalDataset(data_dir, "val")
|
|
|
|
train_loader = get_dataloader(train_ds, cfg["batch_size"], distributed=True)
|
|
val_loader = get_dataloader(val_ds, cfg["batch_size"], shuffle=False,
|
|
distributed=True, drop_last=False)
|
|
|
|
text_dim = train_ds.text.shape[1]
|
|
audio_dim = train_ds.audio.shape[1]
|
|
vision_dim = train_ds.vision.shape[1]
|
|
num_classes = int(train_ds.labels.max()) + 1
|
|
proj_dim = cfg.get("proj_dim", 1024)
|
|
|
|
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
|
|
classifier = EmotionClassifier(proj_dim, num_classes,
|
|
hidden=cfg.get("cls_hidden", 512)).to(device)
|
|
encoder = DDP(encoder, device_ids=[local_rank])
|
|
classifier = DDP(classifier, device_ids=[local_rank])
|
|
|
|
params = list(encoder.parameters()) + list(classifier.parameters())
|
|
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
|
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
opt, T_max=cfg["epochs"], eta_min=1e-5)
|
|
scaler = GradScaler()
|
|
|
|
conf_weight = cfg.get("conf_weight", 0.2)
|
|
noise_prob = cfg.get("noise_prob", 0.4)
|
|
best_wf1 = 0.0
|
|
|
|
for epoch in range(cfg["epochs"]):
|
|
train_loader.sampler.set_epoch(epoch)
|
|
encoder.train()
|
|
classifier.train()
|
|
ep_loss = ep_ce = ep_conf = 0.0
|
|
|
|
for batch in train_loader:
|
|
text = batch["text"].to(device)
|
|
audio = batch["audio"].to(device)
|
|
vision = batch["vision"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
B = text.size(0)
|
|
|
|
# Noise injection: randomly replace with noisy variant
|
|
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
|
|
if use_noise:
|
|
vname = rng.choice(train_ds.variant_names)
|
|
v = train_ds.noisy_variants[vname]
|
|
ni = rng.integers(0, len(train_ds), size=B)
|
|
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
|
|
|
|
with autocast():
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
fused = (tf + af + vf) / 3.0
|
|
logits = classifier(fused)
|
|
ce_loss = F.cross_entropy(logits, labels)
|
|
|
|
# Confidence target: noisy modalities -> 0.1, clean -> 0.9
|
|
if use_noise:
|
|
c_tgt = _confidence_targets(vname, B, device)
|
|
else:
|
|
c_tgt = torch.full((B, 3), 0.9, device=device)
|
|
conf_loss = F.binary_cross_entropy(confs, c_tgt)
|
|
loss = ce_loss + conf_weight * conf_loss
|
|
|
|
opt.zero_grad(set_to_none=True)
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(opt)
|
|
nn.utils.clip_grad_norm_(params, 1.0)
|
|
scaler.step(opt)
|
|
scaler.update()
|
|
|
|
ep_loss += loss.item()
|
|
ep_ce += ce_loss.item()
|
|
ep_conf += conf_loss.item()
|
|
|
|
sched.step()
|
|
val_wf1, val_acc = evaluate(encoder.module, classifier.module,
|
|
val_loader, device)
|
|
val_wf1 = all_reduce_mean(val_wf1, device)
|
|
|
|
if is_main(rank):
|
|
n = len(train_loader)
|
|
logging.info(
|
|
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
|
|
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
|
|
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
|
|
)
|
|
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
|
|
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
|
|
"A/val_acc": val_acc, "epoch": epoch + 1})
|
|
|
|
if val_wf1 > best_wf1:
|
|
best_wf1 = val_wf1
|
|
save_ckpt({
|
|
"epoch": epoch + 1,
|
|
"encoder": encoder.module.state_dict(),
|
|
"classifier": classifier.module.state_dict(),
|
|
"val_wf1": val_wf1,
|
|
"text_dim": text_dim, "audio_dim": audio_dim,
|
|
"vision_dim": vision_dim, "num_classes": num_classes,
|
|
"proj_dim": proj_dim, "cfg": cfg,
|
|
}, os.path.join(args.output, "best.ckpt"))
|
|
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
|
|
|
if is_main(rank):
|
|
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
|
|
save_ckpt({
|
|
"epoch": cfg["epochs"],
|
|
"encoder": encoder.module.state_dict(),
|
|
"classifier": classifier.module.state_dict(),
|
|
"text_dim": text_dim, "audio_dim": audio_dim,
|
|
"vision_dim": vision_dim, "num_classes": num_classes,
|
|
"proj_dim": proj_dim, "cfg": cfg,
|
|
}, os.path.join(args.output, "last.ckpt"))
|
|
|
|
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
|
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
|
|
return encoder.module, classifier.module, dims
|
|
|
|
|
|
# ── Stage B: PPO training ─────────────────────────────────────────────────
|
|
|
|
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
|
|
encoder.eval()
|
|
classifier.eval()
|
|
agent.eval()
|
|
bs = cfg.get("batch_size", 128)
|
|
nprob = cfg.get("noise_prob", 0.5)
|
|
rng = np.random.default_rng()
|
|
states, actions, log_probs, values, rewards = [], [], [], [], []
|
|
collected = 0
|
|
|
|
with torch.no_grad():
|
|
while collected < rollout_size:
|
|
bsz = min(bs, rollout_size - collected)
|
|
idx = rng.integers(0, len(dataset), size=bsz)
|
|
|
|
text = torch.from_numpy(dataset.text[idx]).to(device)
|
|
audio = torch.from_numpy(dataset.audio[idx]).to(device)
|
|
vision = torch.from_numpy(dataset.vision[idx]).to(device)
|
|
labels = torch.from_numpy(dataset.labels[idx]).to(device)
|
|
|
|
if rng.random() < nprob and dataset.variant_names:
|
|
vname = rng.choice(dataset.variant_names)
|
|
v = dataset.noisy_variants[vname]
|
|
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
|
|
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
|
state = torch.cat([confs, noise_est], dim=-1) # (B, 4)
|
|
|
|
weights, log_p, value, _ = agent.get_action_and_value(state)
|
|
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
|
logits = classifier(fused)
|
|
|
|
rew, _ = compute_reward(
|
|
logits, labels, confs, weights, prev_weights,
|
|
alpha=cfg.get("reward_alpha", 1.0),
|
|
beta =cfg.get("reward_beta", 0.3),
|
|
gamma=cfg.get("reward_gamma", 0.1),
|
|
)
|
|
|
|
states.append(state)
|
|
actions.append(weights)
|
|
log_probs.append(log_p)
|
|
values.append(value.squeeze(-1))
|
|
rewards.append(rew)
|
|
collected += bsz
|
|
|
|
states = torch.cat(states)
|
|
actions = torch.cat(actions)
|
|
log_probs = torch.cat(log_probs)
|
|
values = torch.cat(values)
|
|
rewards = torch.cat(rewards)
|
|
|
|
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
|
advantages = rewards - values.detach().cpu()
|
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
|
|
return dict(states=states, actions=actions, log_probs=log_probs,
|
|
values=values, rewards=rewards, advantages=advantages,
|
|
mean_weights=actions.mean(0))
|
|
|
|
|
|
def ppo_update(agent, opt, rollout, cfg, device, scaler):
|
|
eps = cfg.get("ppo_clip", 0.2)
|
|
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
|
|
mb_size = cfg.get("ppo_mini_batch", 256)
|
|
v_coef = cfg.get("value_coef", 0.5)
|
|
ent_coef = cfg.get("entropy_coef", 0.01)
|
|
|
|
states = rollout["states"].to(device)
|
|
actions = rollout["actions"].to(device)
|
|
old_lp = rollout["log_probs"].to(device)
|
|
adv = rollout["advantages"].to(device)
|
|
ret = rollout["rewards"].to(device)
|
|
n = states.size(0)
|
|
|
|
total_pl = total_vl = total_ent = cnt = 0.0
|
|
agent.train()
|
|
|
|
for _ in range(ppo_ep):
|
|
perm = torch.randperm(n, device=device)
|
|
for start in range(0, n, mb_size):
|
|
idx = perm[start:start + mb_size]
|
|
s = states[idx]; a = actions[idx]
|
|
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
|
|
|
|
with autocast():
|
|
new_lp, val, ent = agent.evaluate(s, a)
|
|
val = val.squeeze(-1)
|
|
ratio = (new_lp - olp).exp()
|
|
p_loss = -torch.min(ratio * ad,
|
|
torch.clamp(ratio, 1-eps, 1+eps) * ad).mean()
|
|
v_loss = F.mse_loss(val, r)
|
|
e_loss = -ent.mean()
|
|
loss = p_loss + v_coef * v_loss + ent_coef * e_loss
|
|
|
|
opt.zero_grad(set_to_none=True)
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(opt)
|
|
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
|
|
scaler.step(opt)
|
|
scaler.update()
|
|
|
|
total_pl += p_loss.item()
|
|
total_vl += v_loss.item()
|
|
total_ent += ent.mean().item()
|
|
cnt += 1
|
|
|
|
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
|
|
|
|
|
|
def train_stage_b(args, cfg, encoder, classifier, dims, local_rank, rank, world_size):
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
|
|
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
|
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
|
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
|
noise_root=noise_root)
|
|
val_ds = MultimodalDataset(data_dir, "val")
|
|
val_loader = get_dataloader(val_ds, cfg.get("batch_size", 128),
|
|
shuffle=False, distributed=True, drop_last=False)
|
|
|
|
# Freeze encoder (projectors + confidence estimators)
|
|
for p in encoder.parameters():
|
|
p.requires_grad_(False)
|
|
encoder.to(device).eval()
|
|
|
|
# Classifier: keep trainable (supervised component)
|
|
classifier.to(device)
|
|
cls_ddp = DDP(classifier, device_ids=[local_rank])
|
|
opt_cls = torch.optim.AdamW(classifier.parameters(),
|
|
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
|
|
|
|
# RL agent
|
|
agent = ModalFusionAgent(state_dim=4,
|
|
hidden=cfg.get("agent_hidden", 128)).to(device)
|
|
agent = DDP(agent, device_ids=[local_rank])
|
|
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
|
|
|
|
scaler = GradScaler()
|
|
|
|
rollout_size = cfg.get("rollout_steps", 512)
|
|
n_updates = cfg.get("n_ppo_updates", 500)
|
|
eval_every = cfg.get("eval_every", 10)
|
|
best_wf1 = 0.0
|
|
prev_weights = None
|
|
|
|
for upd in range(n_updates):
|
|
# Rollout collection
|
|
rollout = collect_rollout(
|
|
encoder, classifier, agent.module,
|
|
train_ds, device, rollout_size, cfg, prev_weights,
|
|
)
|
|
prev_weights = rollout["mean_weights"].to(device)
|
|
|
|
# PPO update
|
|
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
|
|
|
|
# Lightweight supervised classifier refresh (one mini-batch every 2 updates)
|
|
if upd % 2 == 0:
|
|
idx = np.random.randint(0, len(train_ds), cfg.get("batch_size", 128))
|
|
text = torch.from_numpy(train_ds.text[idx]).to(device)
|
|
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
|
|
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
|
|
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
|
|
with torch.no_grad():
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
|
state = torch.cat([confs, noise_est], dim=-1)
|
|
weights, *_ = agent.module.get_action_and_value(state)
|
|
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
|
with autocast():
|
|
logits = cls_ddp(fused)
|
|
loss = F.cross_entropy(logits, labels)
|
|
opt_cls.zero_grad(set_to_none=True)
|
|
scaler.scale(loss).backward()
|
|
scaler.step(opt_cls)
|
|
scaler.update()
|
|
|
|
# Evaluate
|
|
if upd % eval_every == 0:
|
|
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
|
|
agent=agent.module)
|
|
val_wf1 = all_reduce_mean(val_wf1, device)
|
|
|
|
if is_main(rank):
|
|
mean_rew = rollout["rewards"].mean().item()
|
|
logging.info(
|
|
f"[StageB] PPO {upd:4d}/{n_updates} | "
|
|
f"rew={mean_rew:.4f} p={ppo_info['p_loss']:.4f} "
|
|
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
|
|
f"val_wf1={val_wf1:.4f}"
|
|
)
|
|
wandb.log({
|
|
"B/reward": mean_rew,
|
|
"B/p_loss": ppo_info["p_loss"],
|
|
"B/v_loss": ppo_info["v_loss"],
|
|
"B/entropy": ppo_info["entropy"],
|
|
"B/val_wf1": val_wf1,
|
|
"ppo_update": upd,
|
|
})
|
|
|
|
if val_wf1 > best_wf1:
|
|
best_wf1 = val_wf1
|
|
save_ckpt({
|
|
"update": upd,
|
|
"encoder": encoder.state_dict(),
|
|
"classifier": classifier.state_dict(),
|
|
"agent": agent.module.state_dict(),
|
|
"val_wf1": val_wf1,
|
|
**dims,
|
|
}, os.path.join(args.output, "best.ckpt"))
|
|
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
|
|
|
if is_main(rank):
|
|
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
|
|
|
|
|
|
# ── Main ──────────────────────────────────────────────────────────────────
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
|
|
p.add_argument("--dataset", default="IEMOCAP")
|
|
p.add_argument("--config", required=True)
|
|
p.add_argument("--output", required=True)
|
|
p.add_argument("--checkpoint", default=None)
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
local_rank, rank, world_size = setup_ddp()
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
log_dir = os.path.join(PROJ, "outputs", "logs")
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
|
|
handlers = [logging.StreamHandler()]
|
|
if is_main(rank):
|
|
handlers.append(logging.FileHandler(
|
|
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"))
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format=f"[rank{rank}] %(asctime)s %(message)s",
|
|
handlers=handlers,
|
|
)
|
|
|
|
with open(os.path.join(PROJ, args.config)) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
if is_main(rank):
|
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
wandb.init(
|
|
project="multimodal_affect",
|
|
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
|
|
config={**cfg, "stage": args.stage, "dataset": args.dataset},
|
|
dir=os.path.join(PROJ, "outputs"),
|
|
)
|
|
|
|
if args.stage == "supervised":
|
|
train_stage_a(args, cfg, local_rank, rank, world_size)
|
|
|
|
elif args.stage == "rl":
|
|
if not args.checkpoint:
|
|
raise ValueError("--checkpoint required for --stage rl")
|
|
ckpt = torch.load(args.checkpoint, map_location=device)
|
|
|
|
text_dim = ckpt["text_dim"]
|
|
audio_dim = ckpt["audio_dim"]
|
|
vision_dim = ckpt["vision_dim"]
|
|
num_classes = ckpt["num_classes"]
|
|
proj_dim = ckpt.get("proj_dim", 1024)
|
|
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
|
vision_dim=vision_dim, num_classes=num_classes,
|
|
proj_dim=proj_dim)
|
|
|
|
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
|
|
classifier = EmotionClassifier(proj_dim, num_classes)
|
|
encoder.load_state_dict(ckpt["encoder"])
|
|
classifier.load_state_dict(ckpt["classifier"])
|
|
|
|
if is_main(rank):
|
|
logging.info(
|
|
f"Loaded Stage A ckpt from {args.checkpoint} "
|
|
f"(val_wf1={ckpt.get('val_wf1', 0.0):.4f})"
|
|
)
|
|
train_stage_b(args, cfg, encoder, classifier, dims,
|
|
local_rank, rank, world_size)
|
|
|
|
if is_main(rank):
|
|
wandb.finish()
|
|
cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
'''
|
|
|
|
# ─── configs/d1/stage_a.yaml ──────────────────────────────────────────────
|
|
STAGE_A_YAML = '''\
|
|
# Stage A: Supervised pretraining
|
|
# Trains projection MLPs + confidence estimators + classifier
|
|
# with noise injection to teach confidence estimation
|
|
|
|
epochs: 50
|
|
batch_size: 128
|
|
lr: 2.0e-4
|
|
wd: 1.0e-4
|
|
proj_dim: 1024
|
|
cls_hidden: 512
|
|
conf_weight: 0.2 # BCE loss weight for confidence estimators
|
|
noise_prob: 0.4 # probability of injecting noisy batch
|
|
'''
|
|
|
|
# ─── configs/d1/stage_b.yaml ──────────────────────────────────────────────
|
|
STAGE_B_YAML = '''\
|
|
# Stage B: PPO-based adaptive fusion weight learning
|
|
# Encoder (projectors + confidence estimators) frozen from Stage A
|
|
# RL agent learns noise-adaptive fusion weights via PPO
|
|
|
|
batch_size: 128
|
|
proj_dim: 1024
|
|
|
|
# PPO
|
|
rollout_steps: 512 # experiences collected per PPO update
|
|
n_ppo_updates: 500 # total PPO update iterations
|
|
ppo_clip: 0.2
|
|
ppo_epochs_per_update: 4
|
|
ppo_mini_batch: 256
|
|
value_coef: 0.5
|
|
entropy_coef: 0.01
|
|
rl_lr: 3.0e-4
|
|
cls_lr: 5.0e-5
|
|
|
|
# Reward coefficients (R = alpha*(-CE) + beta*Consistency - gamma*Instability)
|
|
reward_alpha: 1.0
|
|
reward_beta: 0.3
|
|
reward_gamma: 0.1
|
|
|
|
# RL agent architecture
|
|
agent_hidden: 128
|
|
|
|
# Noise injection during rollout collection
|
|
noise_prob: 0.5
|
|
|
|
eval_every: 10 # evaluate every N PPO updates
|
|
'''
|
|
|
|
# Upload
|
|
uploads = {
|
|
f"{PROJ}/scripts/train/train_d1.py": TRAIN_D1,
|
|
f"{PROJ}/configs/d1/stage_a.yaml": STAGE_A_YAML,
|
|
f"{PROJ}/configs/d1/stage_b.yaml": STAGE_B_YAML,
|
|
}
|
|
|
|
for path, content in uploads.items():
|
|
with sftp.open(path, 'w') as f:
|
|
f.write(content)
|
|
print(f" uploaded: {path.split('multimodal_affect/')[-1]}")
|
|
|
|
sftp.close()
|
|
client.close()
|
|
print("\nAll training files uploaded.")
|