242 lines
9.2 KiB
Python
242 lines
9.2 KiB
Python
|
|
"""
|
|||
|
|
CMU-MOSI feature extraction script.
|
|||
|
|
|
|||
|
|
Supports two pickle formats:
|
|||
|
|
|
|||
|
|
Format A – CMU Multimodal SDK (aligned_50.pkl):
|
|||
|
|
data[split][modality][sample_id] = np.ndarray
|
|||
|
|
modalities: 'text', 'audio', 'vision', 'labels'
|
|||
|
|
splits: 'train', 'valid', 'test'
|
|||
|
|
|
|||
|
|
Format B – declare-lab flat array (mosi.pkl):
|
|||
|
|
data[split][modality] = np.ndarray shape (N, dim)
|
|||
|
|
modalities: 'glove'(text), 'covarep'(audio), 'facet'(visual), 'label'
|
|||
|
|
splits: 'train', 'valid', 'test'
|
|||
|
|
|
|||
|
|
Output: $ZSY/multimodal_affect/data/mosi/
|
|||
|
|
{train,val,test}_{text,audio,vision,labels}.npy
|
|||
|
|
meta.json
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
import argparse
|
|||
|
|
import pickle
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
SENTIMENT_BINS = [(-np.inf, -1, 0), (-1, 1, 1), (1, np.inf, 2)]
|
|||
|
|
LABEL_NAMES = ["negative", "neutral", "positive"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def sentiment_to_class(score: float) -> int:
|
|||
|
|
"""Continuous sentiment [-3,3] → 3-class label."""
|
|||
|
|
if score < -1:
|
|||
|
|
return 0
|
|||
|
|
if score <= 1:
|
|||
|
|
return 1
|
|||
|
|
return 2
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_sdk_pickle(pkl_path: str):
|
|||
|
|
"""Load CMU-SDK aligned pickle."""
|
|||
|
|
with open(pkl_path, "rb") as f:
|
|||
|
|
data = pickle.load(f, encoding="latin1")
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_from_sdk(pkl_path: str, out_dir: Path):
|
|||
|
|
"""Extract from pre-aligned CMU-SDK pickle."""
|
|||
|
|
data = load_sdk_pickle(pkl_path)
|
|||
|
|
|
|||
|
|
split_map = {"train": "train", "valid": "val", "test": "test"}
|
|||
|
|
|
|||
|
|
for sdk_split, out_split in split_map.items():
|
|||
|
|
if sdk_split not in data:
|
|||
|
|
print(f" [skip] split '{sdk_split}' not in pickle")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
split_data = data[sdk_split]
|
|||
|
|
ids = list(split_data.get("text", split_data.get("labels", {})).keys())
|
|||
|
|
if not ids:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
texts, audios, visions, labels = [], [], [], []
|
|||
|
|
for sid in ids:
|
|||
|
|
lbl_raw = split_data["labels"].get(sid)
|
|||
|
|
if lbl_raw is None:
|
|||
|
|
continue
|
|||
|
|
score = float(np.array(lbl_raw).flatten()[0])
|
|||
|
|
label = sentiment_to_class(score)
|
|||
|
|
|
|||
|
|
text = np.array(split_data["text"][sid], dtype=np.float32) if "text" in split_data else np.zeros((1, 300), dtype=np.float32)
|
|||
|
|
audio = np.array(split_data["audio"][sid], dtype=np.float32) if "audio" in split_data else np.zeros((1, 74), dtype=np.float32)
|
|||
|
|
vision = np.array(split_data["vision"][sid], dtype=np.float32) if "vision" in split_data else np.zeros((1, 35), dtype=np.float32)
|
|||
|
|
|
|||
|
|
# temporal mean pooling
|
|||
|
|
texts.append(text.mean(axis=0) if text.ndim == 2 else text.flatten())
|
|||
|
|
audios.append(audio.mean(axis=0) if audio.ndim == 2 else audio.flatten())
|
|||
|
|
visions.append(vision.mean(axis=0) if vision.ndim == 2 else vision.flatten())
|
|||
|
|
labels.append(label)
|
|||
|
|
|
|||
|
|
if not labels:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
np.save(out_dir / f"{out_split}_text.npy", np.stack(texts))
|
|||
|
|
np.save(out_dir / f"{out_split}_audio.npy", np.stack(audios))
|
|||
|
|
np.save(out_dir / f"{out_split}_vision.npy", np.stack(visions))
|
|||
|
|
np.save(out_dir / f"{out_split}_labels.npy", np.array(labels, dtype=np.int64))
|
|||
|
|
print(f" {out_split}: {len(labels)} samples")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_flat_format(data: dict) -> bool:
|
|||
|
|
"""Detect declare-lab flat array format: data[split][modality] = np.ndarray."""
|
|||
|
|
for split in ("train", "valid", "test"):
|
|||
|
|
if split in data:
|
|||
|
|
v = list(data[split].values())[0]
|
|||
|
|
return isinstance(v, np.ndarray)
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_from_flat(pkl_path: str, out_dir: Path):
|
|||
|
|
"""Extract from declare-lab flat pickle (mosi.pkl).
|
|||
|
|
|
|||
|
|
Format: data[split]['glove'|'covarep'|'facet'|'label'] = np.ndarray (N, dim)
|
|||
|
|
Labels are continuous scores in [-3, 3]; binarised to 3 classes.
|
|||
|
|
"""
|
|||
|
|
with open(pkl_path, "rb") as f:
|
|||
|
|
data = pickle.load(f, encoding="latin1")
|
|||
|
|
|
|||
|
|
split_map = {"train": "train", "valid": "val", "test": "test"}
|
|||
|
|
# modality name aliases
|
|||
|
|
text_key = next((k for k in ("glove", "text", "bert") if k in list(data.get("train", {}).keys())), None)
|
|||
|
|
audio_key = next((k for k in ("covarep", "audio", "opensmile") if k in list(data.get("train", {}).keys())), None)
|
|||
|
|
vision_key = next((k for k in ("facet", "vision", "visual") if k in list(data.get("train", {}).keys())), None)
|
|||
|
|
label_key = next((k for k in ("label", "labels", "Opinion Segment Labels") if k in list(data.get("train", {}).keys())), None)
|
|||
|
|
|
|||
|
|
print(f" Detected keys — text:{text_key} audio:{audio_key} vision:{vision_key} label:{label_key}")
|
|||
|
|
|
|||
|
|
for sdk_split, out_split in split_map.items():
|
|||
|
|
if sdk_split not in data:
|
|||
|
|
print(f" [skip] '{sdk_split}' not found")
|
|||
|
|
continue
|
|||
|
|
sd = data[sdk_split]
|
|||
|
|
|
|||
|
|
labels_raw = sd[label_key].flatten() if label_key else np.zeros(len(sd[text_key or audio_key]))
|
|||
|
|
labels = np.array([sentiment_to_class(float(s)) for s in labels_raw], dtype=np.int64)
|
|||
|
|
n = len(labels)
|
|||
|
|
|
|||
|
|
text = sd[text_key].astype(np.float32) if text_key else np.zeros((n, 300), dtype=np.float32)
|
|||
|
|
audio = sd[audio_key].astype(np.float32) if audio_key else np.zeros((n, 74), dtype=np.float32)
|
|||
|
|
vision = sd[vision_key].astype(np.float32) if vision_key else np.zeros((n, 46), dtype=np.float32)
|
|||
|
|
|
|||
|
|
# mean-pool time dimension if present: (N, T, dim) → (N, dim)
|
|||
|
|
if text.ndim == 3:
|
|||
|
|
text = text.mean(axis=1)
|
|||
|
|
if audio.ndim == 3:
|
|||
|
|
audio = audio.mean(axis=1)
|
|||
|
|
if vision.ndim == 3:
|
|||
|
|
vision = vision.mean(axis=1)
|
|||
|
|
|
|||
|
|
np.save(out_dir / f"{out_split}_text.npy", text)
|
|||
|
|
np.save(out_dir / f"{out_split}_audio.npy", audio)
|
|||
|
|
np.save(out_dir / f"{out_split}_vision.npy", vision)
|
|||
|
|
np.save(out_dir / f"{out_split}_labels.npy", labels)
|
|||
|
|
print(f" {out_split}: {n} samples text{text.shape} audio{audio.shape} vision{vision.shape}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_from_raw(raw_dir: Path, out_dir: Path):
|
|||
|
|
"""Fallback: extract from raw files using local MFCC + hashed text."""
|
|||
|
|
import wave
|
|||
|
|
import struct
|
|||
|
|
|
|||
|
|
def load_wav_stdlib(path):
|
|||
|
|
with wave.open(str(path), "rb") as f:
|
|||
|
|
n_ch = f.getnchannels()
|
|||
|
|
sw = f.getsampwidth()
|
|||
|
|
raw = f.readframes(f.getnframes())
|
|||
|
|
if sw == 2:
|
|||
|
|
s = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
|
|||
|
|
else:
|
|||
|
|
s = np.frombuffer(raw, dtype=np.float32)
|
|||
|
|
return s.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else s
|
|||
|
|
|
|||
|
|
print("[raw mode] scanning", raw_dir)
|
|||
|
|
wav_files = sorted(raw_dir.rglob("*.wav"))
|
|||
|
|
if not wav_files:
|
|||
|
|
print(" No WAV files found under", raw_dir)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
data = []
|
|||
|
|
for wf in wav_files:
|
|||
|
|
try:
|
|||
|
|
sig = load_wav_stdlib(str(wf))
|
|||
|
|
feat = sig.mean(), sig.std(), sig.max(), sig.min()
|
|||
|
|
text_feat = np.array([hash(wf.stem) % 30522], dtype=np.float32)
|
|||
|
|
data.append((text_feat, np.array(feat, dtype=np.float32), 1)) # neutral default
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" [warn] {wf.name}: {e}")
|
|||
|
|
|
|||
|
|
if data:
|
|||
|
|
np.save(out_dir / "train_audio.npy", np.stack([x[1] for x in data]))
|
|||
|
|
np.save(out_dir / "train_labels.npy", np.array([x[2] for x in data]))
|
|||
|
|
print(f" Saved {len(data)} raw samples")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_mosi(data_root: str, out_dir: str):
|
|||
|
|
data_root = Path(data_root)
|
|||
|
|
out_dir = Path(out_dir)
|
|||
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
meta = {"label_names": LABEL_NAMES, "task": "sentiment-3class"}
|
|||
|
|
|
|||
|
|
# try pickle candidates (both SDK and declare-lab flat formats)
|
|||
|
|
pkl_candidates = [
|
|||
|
|
data_root / "mosi.pkl", # declare-lab flat
|
|||
|
|
data_root / "aligned_mosi.pkl", # mmsdk aligned
|
|||
|
|
data_root / "CMU_MOSI" / "Processed" / "aligned_50.pkl", # SDK standard
|
|||
|
|
data_root / "CMU_MOSI" / "Processed" / "unaligned_50.pkl",
|
|||
|
|
data_root / "mosi_data.pkl",
|
|||
|
|
data_root / "aligned_50.pkl",
|
|||
|
|
]
|
|||
|
|
for pkl in pkl_candidates:
|
|||
|
|
if pkl.exists():
|
|||
|
|
print(f"Found pickle: {pkl}")
|
|||
|
|
with open(pkl, "rb") as f:
|
|||
|
|
probe = pickle.load(f, encoding="latin1")
|
|||
|
|
if is_flat_format(probe):
|
|||
|
|
print(" Format: declare-lab flat array")
|
|||
|
|
extract_from_flat(str(pkl), out_dir)
|
|||
|
|
else:
|
|||
|
|
print(" Format: CMU-SDK dict-of-dicts")
|
|||
|
|
extract_from_sdk(str(pkl), out_dir)
|
|||
|
|
meta["source"] = str(pkl)
|
|||
|
|
meta["format"] = "flat" if is_flat_format(probe) else "sdk"
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
raw_dir = data_root / "CMU_MOSI" / "Raw"
|
|||
|
|
if raw_dir.exists():
|
|||
|
|
extract_from_raw(raw_dir, out_dir)
|
|||
|
|
meta["source"] = str(raw_dir)
|
|||
|
|
else:
|
|||
|
|
print(f"[error] No CMU-MOSI data found under {data_root}")
|
|||
|
|
print(" Tried:", [str(p) for p in pkl_candidates])
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
with open(out_dir / "meta.json", "w") as f:
|
|||
|
|
json.dump(meta, f, indent=2)
|
|||
|
|
print("Done →", out_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("--data_root", required=True,
|
|||
|
|
help="Dir containing CMU_MOSI/ subdirectory")
|
|||
|
|
parser.add_argument("--out_dir", default=None)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
zsy = os.environ.get("ZSY", "/root")
|
|||
|
|
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/mosi"
|
|||
|
|
extract_mosi(args.data_root, out_dir)
|