#!/usr/bin/env python3 # Phase 1 Direction 1 Training Script (DataParallel edition) # Stage A: Supervised pretraining with noise-aware confidence estimation # Stage B: PPO-based adaptive fusion weight learning [v2: reward display + entropy fix] # # Launch: # python scripts/train/train_d1.py \ # --stage supervised --dataset IEMOCAP \ # --config configs/d1/stage_a.yaml \ # --output outputs/checkpoints/d1_stageA import os, sys, argparse, yaml, time, logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import GradScaler, autocast from sklearn.metrics import f1_score, accuracy_score import wandb ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy") PROJ = os.path.join(ZSY, "multimodal_affect") sys.path.insert(0, PROJ) from src.data.dataset import MultimodalDataset, get_dataloader from src.models.encoders import MultimodalEncoder from src.models.classifier import EmotionClassifier from src.rl.fusion_agent import ModalFusionAgent from src.rl.reward import compute_reward def save_ckpt(state, path): os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(state, path) def _noisy_batch(dataset, variant, indices, device): """Return a same-index multimodal batch; fall back to clean missing files.""" text = variant.get("text", dataset.text) audio = variant.get("audio", dataset.audio) vision = variant.get("vision", dataset.vision) return ( torch.from_numpy(text[indices]).to(device), torch.from_numpy(audio[indices]).to(device), torch.from_numpy(vision[indices]).to(device), torch.from_numpy(dataset.labels[indices]).to(device), ) def _confidence_targets(variant_name, batch_size, device): """Low confidence for modalities actually corrupted by the named variant.""" target = torch.full((batch_size, 3), 0.9, device=device) noisy_map = { "gaussian_light": (0, 1, 2), "gaussian_heavy": (0, 1, 2), "missing_audio": (1,), "missing_visual": (2,), "text_word_drop_30": (0,), "audio_masking_50": (1,), "realistic_mixed": (0, 1, 2), "audio_time_mask": (1,), } for idx in noisy_map.get(str(variant_name), (0, 1, 2)): target[:, idx] = 0.1 return target # ── Evaluation ──────────────────────────────────────────────────────────── @torch.no_grad() def evaluate(encoder, classifier, loader, device, agent=None): encoder.eval() classifier.eval() if agent is not None: agent.eval() all_preds, all_labels = [], [] for batch in loader: text = batch["text"].to(device) audio = batch["audio"].to(device) vision = batch["vision"].to(device) labels = batch["labels"].to(device) enc = encoder.module if hasattr(encoder, "module") else encoder cls = classifier.module if hasattr(classifier, "module") else classifier agt = (agent.module if hasattr(agent, "module") else agent) if agent else None tf, af, vf, confs = enc(text, audio, vision) if agt is not None: noise_est = audio.std(dim=-1, keepdim=True).sigmoid() state = torch.cat([confs, noise_est], dim=-1) weights, *_ = agt.get_action_and_value(state) fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf else: fused = (tf + af + vf) / 3.0 logits = cls(fused) all_preds.append(logits.argmax(-1).cpu()) all_labels.append(labels.cpu()) preds = torch.cat(all_preds).numpy() labels = torch.cat(all_labels).numpy() wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0)) acc = float(accuracy_score(labels, preds)) return wf1, acc # ── Stage A: Supervised pretraining ────────────────────────────────────── def train_stage_a(args, cfg, device, gpu_ids): rng = np.random.default_rng(42) data_dir = os.path.join(PROJ, "data", args.dataset.lower()) noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy") train_ds = MultimodalDataset(data_dir, "train", load_noisy=True, noise_root=noise_root) val_ds = MultimodalDataset(data_dir, "val") eff_bs = cfg["batch_size"] * len(gpu_ids) train_loader = get_dataloader(train_ds, eff_bs, distributed=False) val_loader = get_dataloader(val_ds, eff_bs, shuffle=False, distributed=False, drop_last=False) text_dim = train_ds.text.shape[1] audio_dim = train_ds.audio.shape[1] vision_dim = train_ds.vision.shape[1] num_classes = int(train_ds.labels.max()) + 1 proj_dim = cfg.get("proj_dim", 1024) encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device) classifier = EmotionClassifier(proj_dim, num_classes, hidden=cfg.get("cls_hidden", 512)).to(device) params = list(encoder.parameters()) + list(classifier.parameters()) opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4)) sched = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=cfg["epochs"], eta_min=1e-5) scaler = GradScaler('cuda') conf_weight = cfg.get("conf_weight", 0.2) noise_prob = cfg.get("noise_prob", 0.4) best_wf1 = 0.0 for epoch in range(cfg["epochs"]): encoder.train() classifier.train() ep_loss = ep_ce = ep_conf = 0.0 for batch in train_loader: text = batch["text"].to(device) audio = batch["audio"].to(device) vision = batch["vision"].to(device) labels = batch["labels"].to(device) B = text.size(0) use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names) if use_noise: vname = rng.choice(train_ds.variant_names) v = train_ds.noisy_variants[vname] ni = rng.integers(0, len(train_ds), size=B) text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device) with autocast('cuda'): tf, af, vf, confs = encoder(text, audio, vision) fused = (tf + af + vf) / 3.0 logits = classifier(fused) ce_loss = F.cross_entropy(logits, labels) if use_noise: c_tgt = _confidence_targets(vname, B, device) else: c_tgt = torch.full((B, 3), 0.9, device=device) conf_loss = F.binary_cross_entropy(confs.float(), c_tgt.float()) with autocast('cuda'): loss = ce_loss + conf_weight * conf_loss opt.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(opt) nn.utils.clip_grad_norm_(params, 1.0) scaler.step(opt) scaler.update() ep_loss += loss.item() ep_ce += ce_loss.item() ep_conf += conf_loss.item() sched.step() val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device) n = len(train_loader) logging.info( f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | " f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | " f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}" ) wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n, "A/conf": ep_conf/n, "A/val_wf1": val_wf1, "A/val_acc": val_acc, "epoch": epoch + 1}) enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict() cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict() if val_wf1 > best_wf1: best_wf1 = val_wf1 save_ckpt({ "epoch": epoch + 1, "encoder": enc_state, "classifier": cls_state, "val_wf1": val_wf1, "text_dim": text_dim, "audio_dim": audio_dim, "vision_dim": vision_dim, "num_classes": num_classes, "proj_dim": proj_dim, "cfg": cfg, }, os.path.join(args.output, "best.ckpt")) logging.info(f" -> New best WF1: {val_wf1:.4f}") logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}") save_ckpt({ "epoch": cfg["epochs"], "encoder": enc_state, "classifier": cls_state, "text_dim": text_dim, "audio_dim": audio_dim, "vision_dim": vision_dim, "num_classes": num_classes, "proj_dim": proj_dim, "cfg": cfg, }, os.path.join(args.output, "last.ckpt")) enc_m = encoder.module if hasattr(encoder, "module") else encoder cls_m = classifier.module if hasattr(classifier, "module") else classifier dims = dict(text_dim=text_dim, audio_dim=audio_dim, vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim) return enc_m, cls_m, dims # ── Stage B: PPO training ───────────────────────────────────────────────── def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights): encoder.eval() classifier.eval() agent.eval() bs = cfg.get("batch_size", 128) nprob = cfg.get("noise_prob", 0.5) rng = np.random.default_rng() states, actions, log_probs, values, rewards = [], [], [], [], [] collected = 0 with torch.no_grad(): while collected < rollout_size: bsz = min(bs, rollout_size - collected) idx = rng.integers(0, len(dataset), size=bsz) text = torch.from_numpy(dataset.text[idx]).to(device) audio = torch.from_numpy(dataset.audio[idx]).to(device) vision = torch.from_numpy(dataset.vision[idx]).to(device) labels = torch.from_numpy(dataset.labels[idx]).to(device) if rng.random() < nprob and dataset.variant_names: vname = rng.choice(dataset.variant_names) v = dataset.noisy_variants[vname] text, audio, vision, labels = _noisy_batch(dataset, v, idx, device) tf, af, vf, confs = encoder(text, audio, vision) noise_est = audio.std(dim=-1, keepdim=True).sigmoid() state = torch.cat([confs, noise_est], dim=-1) weights, log_p, value, _ = agent.get_action_and_value(state) fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf logits = classifier(fused) rew, _ = compute_reward( logits, labels, confs, weights, prev_weights, alpha=cfg.get("reward_alpha", 1.0), beta =cfg.get("reward_beta", 0.3), gamma=cfg.get("reward_gamma", 0.1), ) states.append(state) actions.append(weights) log_probs.append(log_p) values.append(value.squeeze(-1)) rewards.append(rew) collected += bsz states = torch.cat(states) actions = torch.cat(actions) log_probs = torch.cat(log_probs) values = torch.cat(values).cpu() rewards = torch.cat(rewards).cpu() # FIX: save raw stats before normalization # The normalized mean is always 0 by construction — useless for logging raw_rew_mean = rewards.mean().item() raw_rew_std = rewards.std().item() rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8) advantages = rewards - values.detach() advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return dict(states=states, actions=actions, log_probs=log_probs, values=values, rewards=rewards, advantages=advantages, mean_weights=actions.mean(0), raw_rew_mean=raw_rew_mean, raw_rew_std=raw_rew_std) def ppo_update(agent, opt, rollout, cfg, device, scaler): eps = cfg.get("ppo_clip", 0.2) ppo_ep = cfg.get("ppo_epochs_per_update", 4) mb_size = cfg.get("ppo_mini_batch", 256) v_coef = cfg.get("value_coef", 0.5) ent_coef = cfg.get("entropy_coef", 0.01) states = rollout["states"].to(device) actions = rollout["actions"].to(device) old_lp = rollout["log_probs"].to(device) adv = rollout["advantages"].to(device) ret = rollout["rewards"].to(device) n = states.size(0) total_pl = total_vl = total_ent = cnt = 0.0 agent.train() for _ in range(ppo_ep): perm = torch.randperm(n, device=device) for start in range(0, n, mb_size): idx = perm[start:start + mb_size] s = states[idx]; a = actions[idx] olp = old_lp[idx]; ad = adv[idx]; r = ret[idx] with autocast('cuda'): new_lp, val, ent = agent.evaluate(s, a) val = val.squeeze(-1) ratio = (new_lp - olp).exp() p_loss = -torch.min(ratio*ad, torch.clamp(ratio, 1-eps, 1+eps)*ad).mean() v_loss = F.mse_loss(val, r) e_loss = -ent.mean() loss = p_loss + v_coef*v_loss + ent_coef*e_loss opt.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(opt) nn.utils.clip_grad_norm_(agent.parameters(), 0.5) scaler.step(opt) scaler.update() total_pl += p_loss.item() total_vl += v_loss.item() total_ent += ent.mean().item() cnt += 1 return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt) def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids): data_dir = os.path.join(PROJ, "data", args.dataset.lower()) noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy") train_ds = MultimodalDataset(data_dir, "train", load_noisy=True, noise_root=noise_root) val_ds = MultimodalDataset(data_dir, "val") eff_bs = cfg.get("batch_size", 128) * len(gpu_ids) val_loader = get_dataloader(val_ds, eff_bs, shuffle=False, distributed=False, drop_last=False) for p in encoder.parameters(): p.requires_grad_(False) encoder.to(device).eval() classifier.to(device) opt_cls = torch.optim.AdamW(classifier.parameters(), lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4) agent = ModalFusionAgent(state_dim=4, hidden=cfg.get("agent_hidden", 128)).to(device) opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4)) scaler = GradScaler() rollout_size = cfg.get("rollout_steps", 512) n_updates = cfg.get("n_ppo_updates", 500) eval_every = cfg.get("eval_every", 10) best_wf1 = 0.0 prev_weights = None for upd in range(n_updates): rollout = collect_rollout( encoder, classifier, agent, train_ds, device, rollout_size, cfg, prev_weights, ) prev_weights = rollout["mean_weights"].to(device) ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler) if upd % 2 == 0: idx = np.random.randint(0, len(train_ds), eff_bs) text = torch.from_numpy(train_ds.text[idx]).to(device) audio = torch.from_numpy(train_ds.audio[idx]).to(device) vision = torch.from_numpy(train_ds.vision[idx]).to(device) labels = torch.from_numpy(train_ds.labels[idx]).to(device) with torch.no_grad(): tf, af, vf, confs = encoder(text, audio, vision) noise_est = audio.std(dim=-1, keepdim=True).sigmoid() state = torch.cat([confs, noise_est], dim=-1) weights, *_ = agent.get_action_and_value(state) fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf with autocast('cuda'): logits = classifier(fused) loss = F.cross_entropy(logits, labels) opt_cls.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.step(opt_cls) scaler.update() if upd % eval_every == 0: val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device, agent=agent) # FIX: use raw (pre-normalization) reward for meaningful logging raw_rew = rollout["raw_rew_mean"] raw_std = rollout["raw_rew_std"] mw = rollout["mean_weights"] # mean fusion weights [text, audio, visual] logging.info( f"[StageB] PPO {upd:4d}/{n_updates} | " f"rew={raw_rew:.4f}(+/-{raw_std:.3f}) p={ppo_info['p_loss']:.4f} " f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | " f"val_wf1={val_wf1:.4f} | " f"w=[t:{mw[0]:.3f} a:{mw[1]:.3f} v:{mw[2]:.3f}]" ) wandb.log({ "B/reward": raw_rew, "B/rew_std": raw_std, "B/w_text": mw[0].item(), "B/w_audio": mw[1].item(), "B/w_visual": mw[2].item(), "B/p_loss": ppo_info["p_loss"], "B/v_loss": ppo_info["v_loss"], "B/entropy": ppo_info["entropy"], "B/val_wf1": val_wf1, "ppo_update": upd, }) if val_wf1 > best_wf1: best_wf1 = val_wf1 save_ckpt({ "update": upd, "encoder": encoder.state_dict(), "classifier": classifier.state_dict(), "agent": agent.state_dict(), "val_wf1": val_wf1, **dims, }, os.path.join(args.output, "best.ckpt")) logging.info(f" -> New best WF1: {val_wf1:.4f}") logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}") # ── Main ────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser() p.add_argument("--stage", required=True, choices=["supervised", "rl"]) p.add_argument("--dataset", default="IEMOCAP") p.add_argument("--config", required=True) p.add_argument("--output", required=True) p.add_argument("--checkpoint", default=None) p.add_argument("--gpus", default="0,1,2,3", help="Comma-separated GPU ids to use") return p.parse_args() def main(): args = parse_args() gpu_ids = [int(g) for g in args.gpus.split(",")] device = torch.device(f"cuda:{gpu_ids[0]}") os.makedirs(args.output, exist_ok=True) log_dir = os.path.join(PROJ, "outputs", "logs") os.makedirs(log_dir, exist_ok=True) stage_tag = "stageA" if args.stage == "supervised" else "stageB" logging.basicConfig( level=logging.INFO, format="%(asctime)s %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler( os.path.join(log_dir, f"{stage_tag}.log"), mode="a"), ], ) logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})") with open(os.path.join(PROJ, args.config)) as f: cfg = yaml.safe_load(f) os.environ.setdefault("WANDB_MODE", "offline") wandb.init( project="multimodal_affect", name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}", config={**cfg, "stage": args.stage, "dataset": args.dataset, "gpus": gpu_ids}, dir=os.path.join(PROJ, "outputs"), ) if args.stage == "supervised": train_stage_a(args, cfg, device, gpu_ids) elif args.stage == "rl": if not args.checkpoint: raise ValueError("--checkpoint required for --stage rl") ckpt = torch.load(args.checkpoint, map_location=device) text_dim = ckpt["text_dim"] audio_dim = ckpt["audio_dim"] vision_dim = ckpt["vision_dim"] num_classes = ckpt["num_classes"] proj_dim = ckpt.get("proj_dim", 1024) dims = dict(text_dim=text_dim, audio_dim=audio_dim, vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim) encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim) classifier = EmotionClassifier(proj_dim, num_classes) encoder.load_state_dict(ckpt["encoder"]) classifier.load_state_dict(ckpt["classifier"]) logging.info( f"Loaded Stage A ckpt: {args.checkpoint} " f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}" ) train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids) wandb.finish() if __name__ == "__main__": main()