chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
522
旧方向信息/scripts/train_d1_fixed.py
Normal file
522
旧方向信息/scripts/train_d1_fixed.py
Normal file
@@ -0,0 +1,522 @@
|
||||
#!/usr/bin/env python3
|
||||
# Phase 1 Direction 1 Training Script (DataParallel edition)
|
||||
# Stage A: Supervised pretraining with noise-aware confidence estimation
|
||||
# Stage B: PPO-based adaptive fusion weight learning [v2: reward display + entropy fix]
|
||||
#
|
||||
# Launch:
|
||||
# python scripts/train/train_d1.py \
|
||||
# --stage supervised --dataset IEMOCAP \
|
||||
# --config configs/d1/stage_a.yaml \
|
||||
# --output outputs/checkpoints/d1_stageA
|
||||
|
||||
import os, sys, argparse, yaml, time, logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.amp import GradScaler, autocast
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
import wandb
|
||||
|
||||
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
||||
PROJ = os.path.join(ZSY, "multimodal_affect")
|
||||
sys.path.insert(0, PROJ)
|
||||
|
||||
from src.data.dataset import MultimodalDataset, get_dataloader
|
||||
from src.models.encoders import MultimodalEncoder
|
||||
from src.models.classifier import EmotionClassifier
|
||||
from src.rl.fusion_agent import ModalFusionAgent
|
||||
from src.rl.reward import compute_reward
|
||||
|
||||
|
||||
def save_ckpt(state, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(state, path)
|
||||
|
||||
|
||||
def _noisy_batch(dataset, variant, indices, device):
|
||||
"""Return a same-index multimodal batch; fall back to clean missing files."""
|
||||
text = variant.get("text", dataset.text)
|
||||
audio = variant.get("audio", dataset.audio)
|
||||
vision = variant.get("vision", dataset.vision)
|
||||
return (
|
||||
torch.from_numpy(text[indices]).to(device),
|
||||
torch.from_numpy(audio[indices]).to(device),
|
||||
torch.from_numpy(vision[indices]).to(device),
|
||||
torch.from_numpy(dataset.labels[indices]).to(device),
|
||||
)
|
||||
|
||||
|
||||
def _confidence_targets(variant_name, batch_size, device):
|
||||
"""Low confidence for modalities actually corrupted by the named variant."""
|
||||
target = torch.full((batch_size, 3), 0.9, device=device)
|
||||
noisy_map = {
|
||||
"gaussian_light": (0, 1, 2),
|
||||
"gaussian_heavy": (0, 1, 2),
|
||||
"missing_audio": (1,),
|
||||
"missing_visual": (2,),
|
||||
"text_word_drop_30": (0,),
|
||||
"audio_masking_50": (1,),
|
||||
"realistic_mixed": (0, 1, 2),
|
||||
"audio_time_mask": (1,),
|
||||
}
|
||||
for idx in noisy_map.get(str(variant_name), (0, 1, 2)):
|
||||
target[:, idx] = 0.1
|
||||
return target
|
||||
|
||||
|
||||
# ── Evaluation ────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(encoder, classifier, loader, device, agent=None):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
if agent is not None:
|
||||
agent.eval()
|
||||
all_preds, all_labels = [], []
|
||||
for batch in loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
enc = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls = classifier.module if hasattr(classifier, "module") else classifier
|
||||
agt = (agent.module if hasattr(agent, "module") else agent) if agent else None
|
||||
tf, af, vf, confs = enc(text, audio, vision)
|
||||
if agt is not None:
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agt.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
else:
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = cls(fused)
|
||||
all_preds.append(logits.argmax(-1).cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
preds = torch.cat(all_preds).numpy()
|
||||
labels = torch.cat(all_labels).numpy()
|
||||
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
||||
acc = float(accuracy_score(labels, preds))
|
||||
return wf1, acc
|
||||
|
||||
|
||||
# ── Stage A: Supervised pretraining ──────────────────────────────────────
|
||||
|
||||
def train_stage_a(args, cfg, device, gpu_ids):
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
|
||||
eff_bs = cfg["batch_size"] * len(gpu_ids)
|
||||
train_loader = get_dataloader(train_ds, eff_bs, distributed=False)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
text_dim = train_ds.text.shape[1]
|
||||
audio_dim = train_ds.audio.shape[1]
|
||||
vision_dim = train_ds.vision.shape[1]
|
||||
num_classes = int(train_ds.labels.max()) + 1
|
||||
proj_dim = cfg.get("proj_dim", 1024)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim).to(device)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes,
|
||||
hidden=cfg.get("cls_hidden", 512)).to(device)
|
||||
|
||||
params = list(encoder.parameters()) + list(classifier.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=cfg["lr"], weight_decay=cfg.get("wd", 1e-4))
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
opt, T_max=cfg["epochs"], eta_min=1e-5)
|
||||
scaler = GradScaler('cuda')
|
||||
|
||||
conf_weight = cfg.get("conf_weight", 0.2)
|
||||
noise_prob = cfg.get("noise_prob", 0.4)
|
||||
best_wf1 = 0.0
|
||||
|
||||
for epoch in range(cfg["epochs"]):
|
||||
encoder.train()
|
||||
classifier.train()
|
||||
ep_loss = ep_ce = ep_conf = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
text = batch["text"].to(device)
|
||||
audio = batch["audio"].to(device)
|
||||
vision = batch["vision"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
B = text.size(0)
|
||||
|
||||
use_noise = (rng.random() < noise_prob) and bool(train_ds.variant_names)
|
||||
if use_noise:
|
||||
vname = rng.choice(train_ds.variant_names)
|
||||
v = train_ds.noisy_variants[vname]
|
||||
ni = rng.integers(0, len(train_ds), size=B)
|
||||
text, audio, vision, labels = _noisy_batch(train_ds, v, ni, device)
|
||||
|
||||
with autocast('cuda'):
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
fused = (tf + af + vf) / 3.0
|
||||
logits = classifier(fused)
|
||||
ce_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if use_noise:
|
||||
c_tgt = _confidence_targets(vname, B, device)
|
||||
else:
|
||||
c_tgt = torch.full((B, 3), 0.9, device=device)
|
||||
conf_loss = F.binary_cross_entropy(confs.float(), c_tgt.float())
|
||||
with autocast('cuda'):
|
||||
loss = ce_loss + conf_weight * conf_loss
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(params, 1.0)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
|
||||
ep_loss += loss.item()
|
||||
ep_ce += ce_loss.item()
|
||||
ep_conf += conf_loss.item()
|
||||
|
||||
sched.step()
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device)
|
||||
|
||||
n = len(train_loader)
|
||||
logging.info(
|
||||
f"[StageA] Epoch {epoch+1:3d}/{cfg['epochs']} | "
|
||||
f"loss={ep_loss/n:.4f} ce={ep_ce/n:.4f} conf={ep_conf/n:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} acc={val_acc:.4f}"
|
||||
)
|
||||
wandb.log({"A/loss": ep_loss/n, "A/ce": ep_ce/n,
|
||||
"A/conf": ep_conf/n, "A/val_wf1": val_wf1,
|
||||
"A/val_acc": val_acc, "epoch": epoch + 1})
|
||||
|
||||
enc_state = encoder.module.state_dict() if hasattr(encoder, "module") else encoder.state_dict()
|
||||
cls_state = classifier.module.state_dict() if hasattr(classifier, "module") else classifier.state_dict()
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"epoch": epoch + 1,
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"val_wf1": val_wf1,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage A done. Best val WF1: {best_wf1:.4f}")
|
||||
save_ckpt({
|
||||
"epoch": cfg["epochs"],
|
||||
"encoder": enc_state,
|
||||
"classifier": cls_state,
|
||||
"text_dim": text_dim, "audio_dim": audio_dim,
|
||||
"vision_dim": vision_dim, "num_classes": num_classes,
|
||||
"proj_dim": proj_dim, "cfg": cfg,
|
||||
}, os.path.join(args.output, "last.ckpt"))
|
||||
|
||||
enc_m = encoder.module if hasattr(encoder, "module") else encoder
|
||||
cls_m = classifier.module if hasattr(classifier, "module") else classifier
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes, proj_dim=proj_dim)
|
||||
return enc_m, cls_m, dims
|
||||
|
||||
|
||||
# ── Stage B: PPO training ─────────────────────────────────────────────────
|
||||
|
||||
def collect_rollout(encoder, classifier, agent, dataset, device, rollout_size, cfg, prev_weights):
|
||||
encoder.eval()
|
||||
classifier.eval()
|
||||
agent.eval()
|
||||
bs = cfg.get("batch_size", 128)
|
||||
nprob = cfg.get("noise_prob", 0.5)
|
||||
rng = np.random.default_rng()
|
||||
states, actions, log_probs, values, rewards = [], [], [], [], []
|
||||
collected = 0
|
||||
|
||||
with torch.no_grad():
|
||||
while collected < rollout_size:
|
||||
bsz = min(bs, rollout_size - collected)
|
||||
idx = rng.integers(0, len(dataset), size=bsz)
|
||||
|
||||
text = torch.from_numpy(dataset.text[idx]).to(device)
|
||||
audio = torch.from_numpy(dataset.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(dataset.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(dataset.labels[idx]).to(device)
|
||||
|
||||
if rng.random() < nprob and dataset.variant_names:
|
||||
vname = rng.choice(dataset.variant_names)
|
||||
v = dataset.noisy_variants[vname]
|
||||
text, audio, vision, labels = _noisy_batch(dataset, v, idx, device)
|
||||
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
|
||||
weights, log_p, value, _ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
logits = classifier(fused)
|
||||
|
||||
rew, _ = compute_reward(
|
||||
logits, labels, confs, weights, prev_weights,
|
||||
alpha=cfg.get("reward_alpha", 1.0),
|
||||
beta =cfg.get("reward_beta", 0.3),
|
||||
gamma=cfg.get("reward_gamma", 0.1),
|
||||
)
|
||||
|
||||
states.append(state)
|
||||
actions.append(weights)
|
||||
log_probs.append(log_p)
|
||||
values.append(value.squeeze(-1))
|
||||
rewards.append(rew)
|
||||
collected += bsz
|
||||
|
||||
states = torch.cat(states)
|
||||
actions = torch.cat(actions)
|
||||
log_probs = torch.cat(log_probs)
|
||||
values = torch.cat(values).cpu()
|
||||
rewards = torch.cat(rewards).cpu()
|
||||
|
||||
# FIX: save raw stats before normalization
|
||||
# The normalized mean is always 0 by construction — useless for logging
|
||||
raw_rew_mean = rewards.mean().item()
|
||||
raw_rew_std = rewards.std().item()
|
||||
|
||||
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
advantages = rewards - values.detach()
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
return dict(states=states, actions=actions, log_probs=log_probs,
|
||||
values=values, rewards=rewards, advantages=advantages,
|
||||
mean_weights=actions.mean(0),
|
||||
raw_rew_mean=raw_rew_mean, raw_rew_std=raw_rew_std)
|
||||
|
||||
|
||||
def ppo_update(agent, opt, rollout, cfg, device, scaler):
|
||||
eps = cfg.get("ppo_clip", 0.2)
|
||||
ppo_ep = cfg.get("ppo_epochs_per_update", 4)
|
||||
mb_size = cfg.get("ppo_mini_batch", 256)
|
||||
v_coef = cfg.get("value_coef", 0.5)
|
||||
ent_coef = cfg.get("entropy_coef", 0.01)
|
||||
|
||||
states = rollout["states"].to(device)
|
||||
actions = rollout["actions"].to(device)
|
||||
old_lp = rollout["log_probs"].to(device)
|
||||
adv = rollout["advantages"].to(device)
|
||||
ret = rollout["rewards"].to(device)
|
||||
n = states.size(0)
|
||||
total_pl = total_vl = total_ent = cnt = 0.0
|
||||
agent.train()
|
||||
|
||||
for _ in range(ppo_ep):
|
||||
perm = torch.randperm(n, device=device)
|
||||
for start in range(0, n, mb_size):
|
||||
idx = perm[start:start + mb_size]
|
||||
s = states[idx]; a = actions[idx]
|
||||
olp = old_lp[idx]; ad = adv[idx]; r = ret[idx]
|
||||
with autocast('cuda'):
|
||||
new_lp, val, ent = agent.evaluate(s, a)
|
||||
val = val.squeeze(-1)
|
||||
ratio = (new_lp - olp).exp()
|
||||
p_loss = -torch.min(ratio*ad,
|
||||
torch.clamp(ratio, 1-eps, 1+eps)*ad).mean()
|
||||
v_loss = F.mse_loss(val, r)
|
||||
e_loss = -ent.mean()
|
||||
loss = p_loss + v_coef*v_loss + ent_coef*e_loss
|
||||
opt.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
total_pl += p_loss.item()
|
||||
total_vl += v_loss.item()
|
||||
total_ent += ent.mean().item()
|
||||
cnt += 1
|
||||
|
||||
return dict(p_loss=total_pl/cnt, v_loss=total_vl/cnt, entropy=total_ent/cnt)
|
||||
|
||||
|
||||
def train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids):
|
||||
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
||||
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
||||
train_ds = MultimodalDataset(data_dir, "train", load_noisy=True,
|
||||
noise_root=noise_root)
|
||||
val_ds = MultimodalDataset(data_dir, "val")
|
||||
eff_bs = cfg.get("batch_size", 128) * len(gpu_ids)
|
||||
val_loader = get_dataloader(val_ds, eff_bs, shuffle=False,
|
||||
distributed=False, drop_last=False)
|
||||
|
||||
for p in encoder.parameters():
|
||||
p.requires_grad_(False)
|
||||
encoder.to(device).eval()
|
||||
|
||||
classifier.to(device)
|
||||
opt_cls = torch.optim.AdamW(classifier.parameters(),
|
||||
lr=cfg.get("cls_lr", 5e-5), weight_decay=1e-4)
|
||||
|
||||
agent = ModalFusionAgent(state_dim=4,
|
||||
hidden=cfg.get("agent_hidden", 128)).to(device)
|
||||
opt_agent = torch.optim.Adam(agent.parameters(), lr=cfg.get("rl_lr", 3e-4))
|
||||
scaler = GradScaler()
|
||||
|
||||
rollout_size = cfg.get("rollout_steps", 512)
|
||||
n_updates = cfg.get("n_ppo_updates", 500)
|
||||
eval_every = cfg.get("eval_every", 10)
|
||||
best_wf1 = 0.0
|
||||
prev_weights = None
|
||||
|
||||
for upd in range(n_updates):
|
||||
rollout = collect_rollout(
|
||||
encoder, classifier, agent,
|
||||
train_ds, device, rollout_size, cfg, prev_weights,
|
||||
)
|
||||
prev_weights = rollout["mean_weights"].to(device)
|
||||
ppo_info = ppo_update(agent, opt_agent, rollout, cfg, device, scaler)
|
||||
|
||||
if upd % 2 == 0:
|
||||
idx = np.random.randint(0, len(train_ds), eff_bs)
|
||||
text = torch.from_numpy(train_ds.text[idx]).to(device)
|
||||
audio = torch.from_numpy(train_ds.audio[idx]).to(device)
|
||||
vision = torch.from_numpy(train_ds.vision[idx]).to(device)
|
||||
labels = torch.from_numpy(train_ds.labels[idx]).to(device)
|
||||
with torch.no_grad():
|
||||
tf, af, vf, confs = encoder(text, audio, vision)
|
||||
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
||||
state = torch.cat([confs, noise_est], dim=-1)
|
||||
weights, *_ = agent.get_action_and_value(state)
|
||||
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
||||
with autocast('cuda'):
|
||||
logits = classifier(fused)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
opt_cls.zero_grad(set_to_none=True)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(opt_cls)
|
||||
scaler.update()
|
||||
|
||||
if upd % eval_every == 0:
|
||||
val_wf1, val_acc = evaluate(encoder, classifier, val_loader, device,
|
||||
agent=agent)
|
||||
# FIX: use raw (pre-normalization) reward for meaningful logging
|
||||
raw_rew = rollout["raw_rew_mean"]
|
||||
raw_std = rollout["raw_rew_std"]
|
||||
mw = rollout["mean_weights"] # mean fusion weights [text, audio, visual]
|
||||
logging.info(
|
||||
f"[StageB] PPO {upd:4d}/{n_updates} | "
|
||||
f"rew={raw_rew:.4f}(+/-{raw_std:.3f}) p={ppo_info['p_loss']:.4f} "
|
||||
f"v={ppo_info['v_loss']:.4f} ent={ppo_info['entropy']:.4f} | "
|
||||
f"val_wf1={val_wf1:.4f} | "
|
||||
f"w=[t:{mw[0]:.3f} a:{mw[1]:.3f} v:{mw[2]:.3f}]"
|
||||
)
|
||||
wandb.log({
|
||||
"B/reward": raw_rew,
|
||||
"B/rew_std": raw_std,
|
||||
"B/w_text": mw[0].item(),
|
||||
"B/w_audio": mw[1].item(),
|
||||
"B/w_visual": mw[2].item(),
|
||||
"B/p_loss": ppo_info["p_loss"],
|
||||
"B/v_loss": ppo_info["v_loss"],
|
||||
"B/entropy": ppo_info["entropy"],
|
||||
"B/val_wf1": val_wf1,
|
||||
"ppo_update": upd,
|
||||
})
|
||||
|
||||
if val_wf1 > best_wf1:
|
||||
best_wf1 = val_wf1
|
||||
save_ckpt({
|
||||
"update": upd,
|
||||
"encoder": encoder.state_dict(),
|
||||
"classifier": classifier.state_dict(),
|
||||
"agent": agent.state_dict(),
|
||||
"val_wf1": val_wf1,
|
||||
**dims,
|
||||
}, os.path.join(args.output, "best.ckpt"))
|
||||
logging.info(f" -> New best WF1: {val_wf1:.4f}")
|
||||
|
||||
logging.info(f"Stage B done. Best val WF1: {best_wf1:.4f}")
|
||||
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--stage", required=True, choices=["supervised", "rl"])
|
||||
p.add_argument("--dataset", default="IEMOCAP")
|
||||
p.add_argument("--config", required=True)
|
||||
p.add_argument("--output", required=True)
|
||||
p.add_argument("--checkpoint", default=None)
|
||||
p.add_argument("--gpus", default="0,1,2,3",
|
||||
help="Comma-separated GPU ids to use")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
gpu_ids = [int(g) for g in args.gpus.split(",")]
|
||||
device = torch.device(f"cuda:{gpu_ids[0]}")
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
log_dir = os.path.join(PROJ, "outputs", "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
stage_tag = "stageA" if args.stage == "supervised" else "stageB"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(
|
||||
os.path.join(log_dir, f"{stage_tag}.log"), mode="a"),
|
||||
],
|
||||
)
|
||||
logging.info(f"Using GPUs: {gpu_ids} (DataParallel, primary: cuda:{gpu_ids[0]})")
|
||||
|
||||
with open(os.path.join(PROJ, args.config)) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
os.environ.setdefault("WANDB_MODE", "offline")
|
||||
wandb.init(
|
||||
project="multimodal_affect",
|
||||
name=f"d1_{args.stage}_{args.dataset}_{time.strftime('%m%d_%H%M')}",
|
||||
config={**cfg, "stage": args.stage, "dataset": args.dataset,
|
||||
"gpus": gpu_ids},
|
||||
dir=os.path.join(PROJ, "outputs"),
|
||||
)
|
||||
|
||||
if args.stage == "supervised":
|
||||
train_stage_a(args, cfg, device, gpu_ids)
|
||||
|
||||
elif args.stage == "rl":
|
||||
if not args.checkpoint:
|
||||
raise ValueError("--checkpoint required for --stage rl")
|
||||
ckpt = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
text_dim = ckpt["text_dim"]
|
||||
audio_dim = ckpt["audio_dim"]
|
||||
vision_dim = ckpt["vision_dim"]
|
||||
num_classes = ckpt["num_classes"]
|
||||
proj_dim = ckpt.get("proj_dim", 1024)
|
||||
dims = dict(text_dim=text_dim, audio_dim=audio_dim,
|
||||
vision_dim=vision_dim, num_classes=num_classes,
|
||||
proj_dim=proj_dim)
|
||||
|
||||
encoder = MultimodalEncoder(text_dim, audio_dim, vision_dim, proj_dim)
|
||||
classifier = EmotionClassifier(proj_dim, num_classes)
|
||||
encoder.load_state_dict(ckpt["encoder"])
|
||||
classifier.load_state_dict(ckpt["classifier"])
|
||||
logging.info(
|
||||
f"Loaded Stage A ckpt: {args.checkpoint} "
|
||||
f"val_wf1={ckpt.get('val_wf1', 0.0):.4f}"
|
||||
)
|
||||
train_stage_b(args, cfg, encoder, classifier, dims, device, gpu_ids)
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user