Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
297 lines
13 KiB
Python
297 lines
13 KiB
Python
"""
|
|
Upload and launch test evaluation + D1-4 ablation experiments on server.
|
|
Uses Stage B v1 checkpoint (best val WF1=0.7291).
|
|
"""
|
|
import paramiko, warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
ZSY = '/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy'
|
|
PROJ = ZSY + '/multimodal_affect'
|
|
ENV = ZSY + '/envs/multimodal_affect/bin/python'
|
|
|
|
# ── eval_d1.py ────────────────────────────────────────────────────────────
|
|
EVAL_SCRIPT = r'''#!/usr/bin/env python3
|
|
"""
|
|
Evaluate Direction-1 checkpoint on test set.
|
|
Also runs ablation variants: fixed-equal, rl-nonoise, rl-noc (beta=0), rl-nostab (gamma=0).
|
|
|
|
Usage:
|
|
python scripts/eval/eval_d1.py \
|
|
--checkpoint outputs/checkpoints/d1_stageB/best_v1.ckpt \
|
|
--dataset IEMOCAP \
|
|
--gpu 0
|
|
"""
|
|
import os, sys, argparse, json, csv, logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from sklearn.metrics import f1_score, accuracy_score, classification_report
|
|
|
|
ZSY = os.environ.get("ZSY", "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy")
|
|
PROJ = os.path.join(ZSY, "multimodal_affect")
|
|
sys.path.insert(0, PROJ)
|
|
|
|
from src.data.dataset import MultimodalDataset, get_dataloader
|
|
from src.models.encoders import MultimodalEncoder
|
|
from src.models.classifier import EmotionClassifier
|
|
from src.rl.fusion_agent import ModalFusionAgent
|
|
from src.rl.reward import compute_reward
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
|
|
|
|
|
@torch.no_grad()
|
|
def predict(encoder, classifier, loader, device, agent=None, fixed_weights=None):
|
|
encoder.eval(); classifier.eval()
|
|
if agent: agent.eval()
|
|
preds, labels_all = [], []
|
|
for batch in loader:
|
|
text = batch["text"].to(device)
|
|
audio = batch["audio"].to(device)
|
|
vision = batch["vision"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
if agent is not None:
|
|
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
|
state = torch.cat([confs, noise_est], dim=-1)
|
|
weights, *_ = agent.get_action_and_value(state)
|
|
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
|
elif fixed_weights is not None:
|
|
w = torch.tensor(fixed_weights, device=device).view(1, 3)
|
|
fused = w[:, 0:1]*tf + w[:, 1:2]*af + w[:, 2:3]*vf
|
|
else:
|
|
fused = (tf + af + vf) / 3.0
|
|
logits = classifier(fused)
|
|
preds.append(logits.argmax(-1).cpu())
|
|
labels_all.append(labels.cpu())
|
|
p = torch.cat(preds).numpy()
|
|
l = torch.cat(labels_all).numpy()
|
|
return p, l
|
|
|
|
|
|
@torch.no_grad()
|
|
def predict_noisy(encoder, classifier, loader, device, variant_data, agent=None, fixed_weights=None):
|
|
"""Run inference with a noisy variant, replacing any modalities it provides."""
|
|
encoder.eval(); classifier.eval()
|
|
if agent: agent.eval()
|
|
preds, labels_all = [], []
|
|
arrays = {k: torch.from_numpy(v).float() for k, v in variant_data.items()}
|
|
cursor = 0
|
|
for batch in loader:
|
|
bsz = batch["text"].size(0)
|
|
text = (arrays["text"][cursor:cursor+bsz] if "text" in arrays else batch["text"]).to(device)
|
|
audio = (arrays["audio"][cursor:cursor+bsz] if "audio" in arrays else batch["audio"]).to(device)
|
|
vision = (arrays["vision"][cursor:cursor+bsz] if "vision" in arrays else batch["vision"]).to(device)
|
|
cursor += bsz
|
|
labels = batch["labels"].to(device)
|
|
tf, af, vf, confs = encoder(text, audio, vision)
|
|
if agent is not None:
|
|
noise_est = audio.std(dim=-1, keepdim=True).sigmoid()
|
|
state = torch.cat([confs, noise_est], dim=-1)
|
|
weights, *_ = agent.get_action_and_value(state)
|
|
fused = weights[:, 0:1]*tf + weights[:, 1:2]*af + weights[:, 2:3]*vf
|
|
elif fixed_weights is not None:
|
|
w = torch.tensor(fixed_weights, device=device).view(1, 3)
|
|
fused = w[:, 0:1]*tf + w[:, 1:2]*af + w[:, 2:3]*vf
|
|
else:
|
|
fused = (tf + af + vf) / 3.0
|
|
logits = classifier(fused)
|
|
preds.append(logits.argmax(-1).cpu())
|
|
labels_all.append(labels.cpu())
|
|
p = torch.cat(preds).numpy()
|
|
l = torch.cat(labels_all).numpy()
|
|
return p, l
|
|
|
|
|
|
def metrics(preds, labels, split="test"):
|
|
wf1 = float(f1_score(labels, preds, average="weighted", zero_division=0))
|
|
acc = float(accuracy_score(labels, preds))
|
|
return {"split": split, "wf1": round(wf1, 4), "acc": round(acc, 4)}
|
|
|
|
|
|
def load_model(ckpt_path, device):
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
td, ad, vd = ckpt["text_dim"], ckpt["audio_dim"], ckpt["vision_dim"]
|
|
nc = ckpt["num_classes"]
|
|
pd = ckpt.get("proj_dim", 1024)
|
|
enc = MultimodalEncoder(td, ad, vd, pd)
|
|
cls = EmotionClassifier(pd, nc)
|
|
enc.load_state_dict(ckpt["encoder"])
|
|
cls.load_state_dict(ckpt["classifier"])
|
|
enc.to(device).eval(); cls.to(device).eval()
|
|
agent = None
|
|
if "agent" in ckpt:
|
|
agent = ModalFusionAgent(state_dim=4, hidden=128)
|
|
agent.load_state_dict(ckpt["agent"])
|
|
agent.to(device).eval()
|
|
return enc, cls, agent, ckpt
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--checkpoint", required=True)
|
|
p.add_argument("--stage_a_ckpt", default=None,
|
|
help="Stage A ckpt for ablations that need encoder+classifier only")
|
|
p.add_argument("--dataset", default="IEMOCAP")
|
|
p.add_argument("--gpu", default="0")
|
|
p.add_argument("--out_json", default=None)
|
|
p.add_argument("--out_csv", default=None)
|
|
args = p.parse_args()
|
|
|
|
device = torch.device(f"cuda:{args.gpu}")
|
|
data_dir = os.path.join(PROJ, "data", args.dataset.lower())
|
|
noise_root = os.path.join(PROJ, "data", f"{args.dataset.lower()}_noisy")
|
|
NOISE_VARIANTS = [
|
|
"gaussian_light", "gaussian_heavy", "missing_audio",
|
|
"missing_visual", "text_word_drop_30", "audio_masking_50",
|
|
"realistic_mixed", "audio_time_mask",
|
|
]
|
|
|
|
# Datasets
|
|
val_ds = MultimodalDataset(data_dir, "val")
|
|
test_ds = MultimodalDataset(data_dir, "test")
|
|
val_loader = get_dataloader(val_ds, 128, shuffle=False, drop_last=False)
|
|
test_loader = get_dataloader(test_ds, 128, shuffle=False, drop_last=False)
|
|
|
|
# Load Stage B v1 checkpoint (encoder + classifier + agent)
|
|
enc, cls, agent, ckpt = load_model(args.checkpoint, device)
|
|
logging.info(f"Loaded: {args.checkpoint} val_wf1={ckpt.get('val_wf1',0):.4f}")
|
|
|
|
results = {}
|
|
|
|
# ── 1. Main evaluation: val + test ────────────────────────────────────
|
|
logging.info("=== Main Evaluation (Stage B RL-Full) ===")
|
|
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
|
ds = val_ds if split == "val" else test_ds
|
|
preds, labels = predict(enc, cls, loader, device, agent=agent)
|
|
m = metrics(preds, labels, split)
|
|
results[f"RL-Full_{split}"] = m
|
|
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
|
if split == "test":
|
|
rpt = classification_report(labels, preds,
|
|
target_names=[str(i) for i in range(ckpt["num_classes"])],
|
|
zero_division=0)
|
|
logging.info(f"\n{rpt}")
|
|
|
|
# ── 2. Ablation A: Fixed-Equal (uniform weights, Stage B classifier) ──
|
|
logging.info("=== Ablation: Fixed-Equal ===")
|
|
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
|
preds, labels = predict(enc, cls, loader, device,
|
|
fixed_weights=[1/3, 1/3, 1/3])
|
|
m = metrics(preds, labels, split)
|
|
results[f"Fixed-Equal_{split}"] = m
|
|
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
|
|
|
# ── 3. Ablation B: Stage A only (no RL, trained classifier w/ uniform fusion) ─
|
|
if args.stage_a_ckpt:
|
|
logging.info("=== Ablation: Stage-A-Only ===")
|
|
enc_a, cls_a, _, ckpt_a = load_model(args.stage_a_ckpt, device)
|
|
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
|
preds, labels = predict(enc_a, cls_a, loader, device)
|
|
m = metrics(preds, labels, split)
|
|
results[f"StageA-Only_{split}"] = m
|
|
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
|
else:
|
|
# estimate from Stage A ckpt embedded in Stage B (same encoder/classifier)
|
|
# just run with agent=None (uniform fusion) using Stage B encoder+classifier
|
|
logging.info("=== Ablation: RL-Agent-Removed (Stage B enc+cls, uniform fusion) ===")
|
|
for split, loader in [("val", val_loader), ("test", test_loader)]:
|
|
preds, labels = predict(enc, cls, loader, device, agent=None)
|
|
m = metrics(preds, labels, split)
|
|
results[f"NoRL-UniformFusion_{split}"] = m
|
|
logging.info(f" [{split}] WF1={m['wf1']:.4f} Acc={m['acc']:.4f}")
|
|
|
|
# ── 4. Noise robustness evaluation ────────────────────────────────────
|
|
logging.info("=== Noise Robustness (test set) ===")
|
|
for vname in NOISE_VARIANTS:
|
|
vdir = os.path.join(noise_root, vname)
|
|
paths = {
|
|
"text": os.path.join(vdir, "test_text.npy"),
|
|
"audio": os.path.join(vdir, "test_audio.npy"),
|
|
"vision": os.path.join(vdir, "test_vision.npy"),
|
|
}
|
|
available = {m: p for m, p in paths.items() if os.path.exists(p)}
|
|
if not available:
|
|
logging.info(f" [{vname}] SKIP (no noisy modality files)")
|
|
continue
|
|
missing = sorted(set(paths) - set(available))
|
|
if missing:
|
|
logging.warning(f" [{vname}] missing noisy files for {missing}; clean same-index modality will be used")
|
|
vdata = {m: np.load(p).astype(np.float32) for m, p in available.items()}
|
|
# RL-Full under noise
|
|
preds_rl, labels = predict_noisy(enc, cls, test_loader, device, vdata, agent=agent)
|
|
wf1_rl = float(f1_score(labels, preds_rl, average="weighted", zero_division=0))
|
|
# Fixed-Equal under noise
|
|
preds_fx, _ = predict_noisy(enc, cls, test_loader, device, vdata,
|
|
fixed_weights=[1/3, 1/3, 1/3])
|
|
wf1_fx = float(f1_score(labels, preds_fx, average="weighted", zero_division=0))
|
|
results[f"noise_{vname}_RL-Full"] = round(wf1_rl, 4)
|
|
results[f"noise_{vname}_Fixed-Equal"] = round(wf1_fx, 4)
|
|
pct = (1 - wf1_rl / max(wf1_fx, 1e-6)) * 100 # relative degradation vs fixed
|
|
logging.info(f" [{vname}] RL={wf1_rl:.4f} Fixed={wf1_fx:.4f} "
|
|
f"RL_degradation_vs_clean={pct:+.1f}%")
|
|
|
|
# ── 5. Save results ───────────────────────────────────────────────────
|
|
os.makedirs(os.path.join(PROJ, "outputs", "results"), exist_ok=True)
|
|
out_json = args.out_json or os.path.join(PROJ, "outputs", "results", "d1_eval.json")
|
|
out_csv = args.out_csv or os.path.join(PROJ, "outputs", "results", "d1_ablation.csv")
|
|
|
|
with open(out_json, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
logging.info(f"Results saved to {out_json}")
|
|
|
|
# CSV for ablation table
|
|
rows = []
|
|
for variant in ["RL-Full", "Fixed-Equal", "NoRL-UniformFusion", "StageA-Only"]:
|
|
row = {"variant": variant}
|
|
for split in ["val", "test"]:
|
|
k = f"{variant}_{split}"
|
|
if k in results:
|
|
row[f"{split}_wf1"] = results[k]["wf1"]
|
|
row[f"{split}_acc"] = results[k]["acc"]
|
|
if "val_wf1" in row:
|
|
rows.append(row)
|
|
if rows:
|
|
with open(out_csv, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=["variant","val_wf1","val_acc","test_wf1","test_acc"])
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
logging.info(f"Ablation CSV saved to {out_csv}")
|
|
|
|
# Noise robustness summary
|
|
logging.info("\n=== Noise Robustness Summary ===")
|
|
clean_rl = results.get("RL-Full_test", {}).get("wf1", 0)
|
|
clean_fx = results.get("Fixed-Equal_test", {}).get("wf1", 0)
|
|
for vname in NOISE_VARIANTS:
|
|
rl_k = f"noise_{vname}_RL-Full"
|
|
fx_k = f"noise_{vname}_Fixed-Equal"
|
|
if rl_k in results and fx_k in results:
|
|
rl = results[rl_k]; fx = results[fx_k]
|
|
rl_drop = (clean_rl - rl) / max(clean_rl, 1e-6) * 100
|
|
fx_drop = (clean_fx - fx) / max(clean_fx, 1e-6) * 100
|
|
logging.info(f" {vname:22s} RL_drop={rl_drop:+5.1f}% Fixed_drop={fx_drop:+5.1f}%")
|
|
|
|
logging.info("Evaluation complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
'''
|
|
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect('10.82.3.180', port=20083, username='root', password='m2dGcwyrhI', timeout=30)
|
|
sftp = client.open_sftp()
|
|
|
|
# Make eval dir
|
|
_, o, e = client.exec_command(f'mkdir -p {PROJ}/scripts/eval', timeout=10)
|
|
o.read(); e.read()
|
|
|
|
sftp.putfo(__import__('io').BytesIO(EVAL_SCRIPT.encode()), PROJ + '/scripts/eval/eval_d1.py')
|
|
print("uploaded: scripts/eval/eval_d1.py")
|
|
|
|
sftp.close()
|
|
client.close()
|