Files
CompanionGuard-RL/旧方向信息/scripts/preprocess/extract_mosi.py

242 lines
9.2 KiB
Python
Raw Normal View History

"""
CMU-MOSI feature extraction script.
Supports two pickle formats:
Format A CMU Multimodal SDK (aligned_50.pkl):
data[split][modality][sample_id] = np.ndarray
modalities: 'text', 'audio', 'vision', 'labels'
splits: 'train', 'valid', 'test'
Format B declare-lab flat array (mosi.pkl):
data[split][modality] = np.ndarray shape (N, dim)
modalities: 'glove'(text), 'covarep'(audio), 'facet'(visual), 'label'
splits: 'train', 'valid', 'test'
Output: $ZSY/multimodal_affect/data/mosi/
{train,val,test}_{text,audio,vision,labels}.npy
meta.json
"""
import os
import json
import argparse
import pickle
import numpy as np
from pathlib import Path
SENTIMENT_BINS = [(-np.inf, -1, 0), (-1, 1, 1), (1, np.inf, 2)]
LABEL_NAMES = ["negative", "neutral", "positive"]
def sentiment_to_class(score: float) -> int:
"""Continuous sentiment [-3,3] → 3-class label."""
if score < -1:
return 0
if score <= 1:
return 1
return 2
def load_sdk_pickle(pkl_path: str):
"""Load CMU-SDK aligned pickle."""
with open(pkl_path, "rb") as f:
data = pickle.load(f, encoding="latin1")
return data
def extract_from_sdk(pkl_path: str, out_dir: Path):
"""Extract from pre-aligned CMU-SDK pickle."""
data = load_sdk_pickle(pkl_path)
split_map = {"train": "train", "valid": "val", "test": "test"}
for sdk_split, out_split in split_map.items():
if sdk_split not in data:
print(f" [skip] split '{sdk_split}' not in pickle")
continue
split_data = data[sdk_split]
ids = list(split_data.get("text", split_data.get("labels", {})).keys())
if not ids:
continue
texts, audios, visions, labels = [], [], [], []
for sid in ids:
lbl_raw = split_data["labels"].get(sid)
if lbl_raw is None:
continue
score = float(np.array(lbl_raw).flatten()[0])
label = sentiment_to_class(score)
text = np.array(split_data["text"][sid], dtype=np.float32) if "text" in split_data else np.zeros((1, 300), dtype=np.float32)
audio = np.array(split_data["audio"][sid], dtype=np.float32) if "audio" in split_data else np.zeros((1, 74), dtype=np.float32)
vision = np.array(split_data["vision"][sid], dtype=np.float32) if "vision" in split_data else np.zeros((1, 35), dtype=np.float32)
# temporal mean pooling
texts.append(text.mean(axis=0) if text.ndim == 2 else text.flatten())
audios.append(audio.mean(axis=0) if audio.ndim == 2 else audio.flatten())
visions.append(vision.mean(axis=0) if vision.ndim == 2 else vision.flatten())
labels.append(label)
if not labels:
continue
np.save(out_dir / f"{out_split}_text.npy", np.stack(texts))
np.save(out_dir / f"{out_split}_audio.npy", np.stack(audios))
np.save(out_dir / f"{out_split}_vision.npy", np.stack(visions))
np.save(out_dir / f"{out_split}_labels.npy", np.array(labels, dtype=np.int64))
print(f" {out_split}: {len(labels)} samples")
def is_flat_format(data: dict) -> bool:
"""Detect declare-lab flat array format: data[split][modality] = np.ndarray."""
for split in ("train", "valid", "test"):
if split in data:
v = list(data[split].values())[0]
return isinstance(v, np.ndarray)
return False
def extract_from_flat(pkl_path: str, out_dir: Path):
"""Extract from declare-lab flat pickle (mosi.pkl).
Format: data[split]['glove'|'covarep'|'facet'|'label'] = np.ndarray (N, dim)
Labels are continuous scores in [-3, 3]; binarised to 3 classes.
"""
with open(pkl_path, "rb") as f:
data = pickle.load(f, encoding="latin1")
split_map = {"train": "train", "valid": "val", "test": "test"}
# modality name aliases
text_key = next((k for k in ("glove", "text", "bert") if k in list(data.get("train", {}).keys())), None)
audio_key = next((k for k in ("covarep", "audio", "opensmile") if k in list(data.get("train", {}).keys())), None)
vision_key = next((k for k in ("facet", "vision", "visual") if k in list(data.get("train", {}).keys())), None)
label_key = next((k for k in ("label", "labels", "Opinion Segment Labels") if k in list(data.get("train", {}).keys())), None)
print(f" Detected keys — text:{text_key} audio:{audio_key} vision:{vision_key} label:{label_key}")
for sdk_split, out_split in split_map.items():
if sdk_split not in data:
print(f" [skip] '{sdk_split}' not found")
continue
sd = data[sdk_split]
labels_raw = sd[label_key].flatten() if label_key else np.zeros(len(sd[text_key or audio_key]))
labels = np.array([sentiment_to_class(float(s)) for s in labels_raw], dtype=np.int64)
n = len(labels)
text = sd[text_key].astype(np.float32) if text_key else np.zeros((n, 300), dtype=np.float32)
audio = sd[audio_key].astype(np.float32) if audio_key else np.zeros((n, 74), dtype=np.float32)
vision = sd[vision_key].astype(np.float32) if vision_key else np.zeros((n, 46), dtype=np.float32)
# mean-pool time dimension if present: (N, T, dim) → (N, dim)
if text.ndim == 3:
text = text.mean(axis=1)
if audio.ndim == 3:
audio = audio.mean(axis=1)
if vision.ndim == 3:
vision = vision.mean(axis=1)
np.save(out_dir / f"{out_split}_text.npy", text)
np.save(out_dir / f"{out_split}_audio.npy", audio)
np.save(out_dir / f"{out_split}_vision.npy", vision)
np.save(out_dir / f"{out_split}_labels.npy", labels)
print(f" {out_split}: {n} samples text{text.shape} audio{audio.shape} vision{vision.shape}")
def extract_from_raw(raw_dir: Path, out_dir: Path):
"""Fallback: extract from raw files using local MFCC + hashed text."""
import wave
import struct
def load_wav_stdlib(path):
with wave.open(str(path), "rb") as f:
n_ch = f.getnchannels()
sw = f.getsampwidth()
raw = f.readframes(f.getnframes())
if sw == 2:
s = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
else:
s = np.frombuffer(raw, dtype=np.float32)
return s.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else s
print("[raw mode] scanning", raw_dir)
wav_files = sorted(raw_dir.rglob("*.wav"))
if not wav_files:
print(" No WAV files found under", raw_dir)
return
data = []
for wf in wav_files:
try:
sig = load_wav_stdlib(str(wf))
feat = sig.mean(), sig.std(), sig.max(), sig.min()
text_feat = np.array([hash(wf.stem) % 30522], dtype=np.float32)
data.append((text_feat, np.array(feat, dtype=np.float32), 1)) # neutral default
except Exception as e:
print(f" [warn] {wf.name}: {e}")
if data:
np.save(out_dir / "train_audio.npy", np.stack([x[1] for x in data]))
np.save(out_dir / "train_labels.npy", np.array([x[2] for x in data]))
print(f" Saved {len(data)} raw samples")
def extract_mosi(data_root: str, out_dir: str):
data_root = Path(data_root)
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
meta = {"label_names": LABEL_NAMES, "task": "sentiment-3class"}
# try pickle candidates (both SDK and declare-lab flat formats)
pkl_candidates = [
data_root / "mosi.pkl", # declare-lab flat
data_root / "aligned_mosi.pkl", # mmsdk aligned
data_root / "CMU_MOSI" / "Processed" / "aligned_50.pkl", # SDK standard
data_root / "CMU_MOSI" / "Processed" / "unaligned_50.pkl",
data_root / "mosi_data.pkl",
data_root / "aligned_50.pkl",
]
for pkl in pkl_candidates:
if pkl.exists():
print(f"Found pickle: {pkl}")
with open(pkl, "rb") as f:
probe = pickle.load(f, encoding="latin1")
if is_flat_format(probe):
print(" Format: declare-lab flat array")
extract_from_flat(str(pkl), out_dir)
else:
print(" Format: CMU-SDK dict-of-dicts")
extract_from_sdk(str(pkl), out_dir)
meta["source"] = str(pkl)
meta["format"] = "flat" if is_flat_format(probe) else "sdk"
break
else:
raw_dir = data_root / "CMU_MOSI" / "Raw"
if raw_dir.exists():
extract_from_raw(raw_dir, out_dir)
meta["source"] = str(raw_dir)
else:
print(f"[error] No CMU-MOSI data found under {data_root}")
print(" Tried:", [str(p) for p in pkl_candidates])
return
with open(out_dir / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
print("Done →", out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", required=True,
help="Dir containing CMU_MOSI/ subdirectory")
parser.add_argument("--out_dir", default=None)
args = parser.parse_args()
zsy = os.environ.get("ZSY", "/root")
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/mosi"
extract_mosi(args.data_root, out_dir)