Files
CompanionGuard-RL/旧方向信息/scripts/train_d1_fixed.py

523 lines
21 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# Phase 1 Direction 1 Training Script (DataParallel edition)
# Stage A: Supervised pretraining with noise-aware confidence estimation
# Stage B: PPO-based adaptive fusion weight learning [v2: reward display + entropy fix]
#
# Launch:
# python scripts/train/train_d1.py \
# --stage supervised --dataset IEMOCAP \
# --config configs/d1/stage_a.yaml \
# --output outputs/checkpoints/d1_stageA
import os, sys, argparse, yaml, time, logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from sklearn.metrics import f1_score, accuracy_score
import wandb
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
PROJ = os.path.join(ZSY, "multimodal_affect")
sys.path.insert(0, PROJ)
from src.data.dataset import MultimodalDataset, get_dataloader
from src.models.encoders import MultimodalEncoder
from src.models.classifier import EmotionClassifier
from src.rl.fusion_agent import ModalFusionAgent
from src.rl.reward import compute_reward
def save_ckpt(state, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(state, path)
def _noisy_batch(dataset, variant, indices, device):
"""Return a same-index multimodal batch; fall back to clean missing files."""
text = variant.get("text", dataset.text)
audio = variant.get("audio", dataset.audio)
vision = variant.get("vision", dataset.vision)
return (
torch.from_numpy(text[indices]).to(device),
torch.from_numpy(audio[indices]).to(device),
torch.from_numpy(vision[indices]).to(device),
torch.from_numpy(dataset.labels[indices]).to(device),
)
def _confidence_targets(variant_name, batch_size, device):
"""Low confidence for modalities actually corrupted by the named variant."""
target = torch.full((batch_size, 3), 0.9, device=device)
noisy_map = {
"gaussian_light": (0, 1, 2),
"gaussian_heavy": (0, 1, 2),
"missing_audio": (1,),
"missing_visual": (2,),
"text_word_drop_30": (0,),
"audio_masking_50": (1,),
"realistic_mixed": (0, 1, 2),
"audio_time_mask": (1,),
}
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
target[:, idx] = 0.1
return target
# ── Evaluation ────────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate(encoder, classifier, loader, device, agent=None):
encoder.eval()
classifier.eval()
if agent is not None:
agent.eval()
all_preds, all_labels = [], []
for batch in loader:
text = batch["text"].to(device)
audio = batch["audio"].to(device)
vision = batch["vision"].to(device)
labels = batch["labels"].to(device)
enc = encoder.module if hasattr(encoder, "module") else encoder
cls = classifier.module if hasattr(classifier, "module") else classifier
agt = (agent.module if hasattr(agent, "module") else agent) if agent else None
tf, af, vf, confs = enc(text, audio, vision)
if agt is not None:
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, *_ = agt.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
else:
fused = (tf + af + vf) / 3.0
logits = cls(fused)
all_preds.append(logits.argmax(-1).cpu())
all_labels.append(labels.cpu())
preds = torch.cat(all_preds).numpy()
labels = torch.cat(all_labels).numpy()
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
acc = float(accuracy_score(labels, preds))
return wf1, acc
# ── Stage A: Supervised pretraining ──────────────────────────────────────
def train_stage_a(args, cfg, device, gpu_ids):
rng = np.random.default_rng(42)
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
noise_root=noise_root)
val_ds = MultimodalDataset(data_dir, "val")
eff_bs = cfg["batch_size"] * len(gpu_ids)
train_loader = get_dataloader(train_ds, eff_bs, distributed=False)
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
distributed=False, drop_last=False)
text_dim = train_ds.text.shape[1]
audio_dim = train_ds.audio.shape[1]
vision_dim = train_ds.vision.shape[1]
num_classes = int(train_ds.labels.max()) + 1
proj_dim = cfg.get("proj_dim", 1024)
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
classifier = EmotionClassifier(proj_dim, num_classes,
hidden=cfg.get("cls_hidden", 512)).to(device)
params = list(encoder.parameters()) + list(classifier.parameters())
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=cfg["epochs"], eta_min=1e-5)
scaler = GradScaler('cuda')
conf_weight = cfg.get("conf_weight", 0.2)
noise_prob = cfg.get("noise_prob", 0.4)
best_wf1 = 0.0
for epoch in range(cfg["epochs"]):
encoder.train()
classifier.train()
ep_loss = ep_ce = ep_conf = 0.0
for batch in train_loader:
text = batch["text"].to(device)
audio = batch["audio"].to(device)
vision = batch["vision"].to(device)
labels = batch["labels"].to(device)
B = text.size(0)
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
if use_noise:
vname = rng.choice(train_ds.variant_names)
v = train_ds.noisy_variants[vname]
ni = rng.integers(0, len(train_ds), size=B)
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
with autocast('cuda'):
tf, af, vf, confs = encoder(text, audio, vision)
fused = (tf + af + vf) / 3.0
logits = classifier(fused)
ce_loss = F.cross_entropy(logits, labels)
if use_noise:
c_tgt = _confidence_targets(vname, B, device)
else:
c_tgt = torch.full((B, 3), 0.9, device=device)
conf_loss = F.binary_cross_entropy(confs.float(), c_tgt.float())
with autocast('cuda'):
loss = ce_loss + conf_weight * conf_loss
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(params, 1.0)
scaler.step(opt)
scaler.update()
ep_loss += loss.item()
ep_ce += ce_loss.item()
ep_conf += conf_loss.item()
sched.step()
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device)
n = len(train_loader)
logging.info(
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
)
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
"A/val_acc": val_acc, "epoch": epoch + 1})
enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict()
cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict()
if val_wf1 > best_wf1:
best_wf1 = val_wf1
save_ckpt({
"epoch": epoch + 1,
"encoder": enc_state,
"classifier": cls_state,
"val_wf1": val_wf1,
"text_dim": text_dim, "audio_dim": audio_dim,
"vision_dim": vision_dim, "num_classes": num_classes,
"proj_dim": proj_dim, "cfg": cfg,
}, os.path.join(args.output, "best.ckpt"))
logging.info(f" -> New best WF1: {val_wf1:.4f}")
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
save_ckpt({
"epoch": cfg["epochs"],
"encoder": enc_state,
"classifier": cls_state,
"text_dim": text_dim, "audio_dim": audio_dim,
"vision_dim": vision_dim, "num_classes": num_classes,
"proj_dim": proj_dim, "cfg": cfg,
}, os.path.join(args.output, "last.ckpt"))
enc_m = encoder.module if hasattr(encoder, "module") else encoder
cls_m = classifier.module if hasattr(classifier, "module") else classifier
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
return enc_m, cls_m, dims
# ── Stage B: PPO training ─────────────────────────────────────────────────
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
encoder.eval()
classifier.eval()
agent.eval()
bs = cfg.get("batch_size", 128)
nprob = cfg.get("noise_prob", 0.5)
rng = np.random.default_rng()
states, actions, log_probs, values, rewards = [], [], [], [], []
collected = 0
with torch.no_grad():
while collected < rollout_size:
bsz = min(bs, rollout_size - collected)
idx = rng.integers(0, len(dataset), size=bsz)
text = torch.from_numpy(dataset.text[idx]).to(device)
audio = torch.from_numpy(dataset.audio[idx]).to(device)
vision = torch.from_numpy(dataset.vision[idx]).to(device)
labels = torch.from_numpy(dataset.labels[idx]).to(device)
if rng.random() < nprob and dataset.variant_names:
vname = rng.choice(dataset.variant_names)
v = dataset.noisy_variants[vname]
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
tf, af, vf, confs = encoder(text, audio, vision)
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, log_p, value, _ = agent.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
logits = classifier(fused)
rew, _ = compute_reward(
logits, labels, confs, weights, prev_weights,
alpha=cfg.get("reward_alpha", 1.0),
beta =cfg.get("reward_beta", 0.3),
gamma=cfg.get("reward_gamma", 0.1),
)
states.append(state)
actions.append(weights)
log_probs.append(log_p)
values.append(value.squeeze(-1))
rewards.append(rew)
collected += bsz
states = torch.cat(states)
actions = torch.cat(actions)
log_probs = torch.cat(log_probs)
values = torch.cat(values).cpu()
rewards = torch.cat(rewards).cpu()
# FIX: save raw stats before normalization
# The normalized mean is always 0 by construction — useless for logging
raw_rew_mean = rewards.mean().item()
raw_rew_std = rewards.std().item()
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
advantages = rewards - values.detach()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return dict(states=states, actions=actions, log_probs=log_probs,
values=values, rewards=rewards, advantages=advantages,
mean_weights=actions.mean(0),
raw_rew_mean=raw_rew_mean, raw_rew_std=raw_rew_std)
def ppo_update(agent, opt, rollout, cfg, device, scaler):
eps = cfg.get("ppo_clip", 0.2)
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
mb_size = cfg.get("ppo_mini_batch", 256)
v_coef = cfg.get("value_coef", 0.5)
ent_coef = cfg.get("entropy_coef", 0.01)
states = rollout["states"].to(device)
actions = rollout["actions"].to(device)
old_lp = rollout["log_probs"].to(device)
adv = rollout["advantages"].to(device)
ret = rollout["rewards"].to(device)
n = states.size(0)
total_pl = total_vl = total_ent = cnt = 0.0
agent.train()
for _ in range(ppo_ep):
perm = torch.randperm(n, device=device)
for start in range(0, n, mb_size):
idx = perm[start:start + mb_size]
s = states[idx]; a = actions[idx]
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
with autocast('cuda'):
new_lp, val, ent = agent.evaluate(s, a)
val = val.squeeze(-1)
ratio = (new_lp - olp).exp()
p_loss = -torch.min(ratio*ad,
torch.clamp(ratio, 1-eps, 1+eps)*ad).mean()
v_loss = F.mse_loss(val, r)
e_loss = -ent.mean()
loss = p_loss + v_coef*v_loss + ent_coef*e_loss
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
scaler.step(opt)
scaler.update()
total_pl += p_loss.item()
total_vl += v_loss.item()
total_ent += ent.mean().item()
cnt += 1
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids):
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
noise_root=noise_root)
val_ds = MultimodalDataset(data_dir, "val")
eff_bs = cfg.get("batch_size", 128) * len(gpu_ids)
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
distributed=False, drop_last=False)
for p in encoder.parameters():
p.requires_grad_(False)
encoder.to(device).eval()
classifier.to(device)
opt_cls = torch.optim.AdamW(classifier.parameters(),
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
agent = ModalFusionAgent(state_dim=4,
hidden=cfg.get("agent_hidden", 128)).to(device)
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
scaler = GradScaler()
rollout_size = cfg.get("rollout_steps", 512)
n_updates = cfg.get("n_ppo_updates", 500)
eval_every = cfg.get("eval_every", 10)
best_wf1 = 0.0
prev_weights = None
for upd in range(n_updates):
rollout = collect_rollout(
encoder, classifier, agent,
train_ds, device, rollout_size, cfg, prev_weights,
)
prev_weights = rollout["mean_weights"].to(device)
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
if upd % 2 == 0:
idx = np.random.randint(0, len(train_ds), eff_bs)
text = torch.from_numpy(train_ds.text[idx]).to(device)
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
with torch.no_grad():
tf, af, vf, confs = encoder(text, audio, vision)
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
state = torch.cat([confs, noise_est], dim=-1)
weights, *_ = agent.get_action_and_value(state)
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
with autocast('cuda'):
logits = classifier(fused)
loss = F.cross_entropy(logits, labels)
opt_cls.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.step(opt_cls)
scaler.update()
if upd % eval_every == 0:
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
agent=agent)
# FIX: use raw (pre-normalization) reward for meaningful logging
raw_rew = rollout["raw_rew_mean"]
raw_std = rollout["raw_rew_std"]
mw = rollout["mean_weights"] # mean fusion weights [text, audio, visual]
logging.info(
f"[StageB] PPO {upd:4d}/{n_updates} | "
f"rew={raw_rew:.4f}(+/-{raw_std:.3f}) p={ppo_info['p_loss']:.4f} "
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
f"val_wf1={val_wf1:.4f} | "
f"w=[t:{mw[0]:.3f} a:{mw[1]:.3f} v:{mw[2]:.3f}]"
)
wandb.log({
"B/reward": raw_rew,
"B/rew_std": raw_std,
"B/w_text": mw[0].item(),
"B/w_audio": mw[1].item(),
"B/w_visual": mw[2].item(),
"B/p_loss": ppo_info["p_loss"],
"B/v_loss": ppo_info["v_loss"],
"B/entropy": ppo_info["entropy"],
"B/val_wf1": val_wf1,
"ppo_update": upd,
})
if val_wf1 > best_wf1:
best_wf1 = val_wf1
save_ckpt({
"update": upd,
"encoder": encoder.state_dict(),
"classifier": classifier.state_dict(),
"agent": agent.state_dict(),
"val_wf1": val_wf1,
**dims,
}, os.path.join(args.output, "best.ckpt"))
logging.info(f" -> New best WF1: {val_wf1:.4f}")
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
# ── Main ──────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
p.add_argument("--dataset", default="IEMOCAP")
p.add_argument("--config", required=True)
p.add_argument("--output", required=True)
p.add_argument("--checkpoint", default=None)
p.add_argument("--gpus", default="0,1,2,3",
help="Comma-separated GPU ids to use")
return p.parse_args()
def main():
args = parse_args()
gpu_ids = [int(g) for g in args.gpus.split(",")]
device = torch.device(f"cuda:{gpu_ids[0]}")
os.makedirs(args.output, exist_ok=True)
log_dir = os.path.join(PROJ, "outputs", "logs")
os.makedirs(log_dir, exist_ok=True)
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"),
],
)
logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})")
with open(os.path.join(PROJ, args.config)) as f:
cfg = yaml.safe_load(f)
os.environ.setdefault("WANDB_MODE", "offline")
wandb.init(
project="multimodal_affect",
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
config={**cfg, "stage": args.stage, "dataset": args.dataset,
"gpus": gpu_ids},
dir=os.path.join(PROJ, "outputs"),
)
if args.stage == "supervised":
train_stage_a(args, cfg, device, gpu_ids)
elif args.stage == "rl":
if not args.checkpoint:
raise ValueError("--checkpoint required for --stage rl")
ckpt = torch.load(args.checkpoint, map_location=device)
text_dim = ckpt["text_dim"]
audio_dim = ckpt["audio_dim"]
vision_dim = ckpt["vision_dim"]
num_classes = ckpt["num_classes"]
proj_dim = ckpt.get("proj_dim", 1024)
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
vision_dim=vision_dim, num_classes=num_classes,
proj_dim=proj_dim)
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
classifier = EmotionClassifier(proj_dim, num_classes)
encoder.load_state_dict(ckpt["encoder"])
classifier.load_state_dict(ckpt["classifier"])
logging.info(
f"Loaded Stage A ckpt: {args.checkpoint} "
f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}"
)
train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids)
wandb.finish()
if __name__ == "__main__":
main()