Files
CompanionGuard-RL/旧方向信息/scripts/preprocess/extract_mosi.py
zhangsiyuan bd1f51c496 chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 11:28:42 +08:00

242 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CMU-MOSI feature extraction script.
Supports two pickle formats:
Format A CMU Multimodal SDK (aligned_50.pkl):
data[split][modality][sample_id] = np.ndarray
modalities: 'text', 'audio', 'vision', 'labels'
splits: 'train', 'valid', 'test'
Format B declare-lab flat array (mosi.pkl):
data[split][modality] = np.ndarray shape (N, dim)
modalities: 'glove'(text), 'covarep'(audio), 'facet'(visual), 'label'
splits: 'train', 'valid', 'test'
Output: $ZSY/multimodal_affect/data/mosi/
{train,val,test}_{text,audio,vision,labels}.npy
meta.json
"""
import os
import json
import argparse
import pickle
import numpy as np
from pathlib import Path
SENTIMENT_BINS = [(-np.inf, -1, 0), (-1, 1, 1), (1, np.inf, 2)]
LABEL_NAMES = ["negative", "neutral", "positive"]
def sentiment_to_class(score: float) -> int:
"""Continuous sentiment [-3,3] → 3-class label."""
if score < -1:
return 0
if score <= 1:
return 1
return 2
def load_sdk_pickle(pkl_path: str):
"""Load CMU-SDK aligned pickle."""
with open(pkl_path, "rb") as f:
data = pickle.load(f, encoding="latin1")
return data
def extract_from_sdk(pkl_path: str, out_dir: Path):
"""Extract from pre-aligned CMU-SDK pickle."""
data = load_sdk_pickle(pkl_path)
split_map = {"train": "train", "valid": "val", "test": "test"}
for sdk_split, out_split in split_map.items():
if sdk_split not in data:
print(f" [skip] split '{sdk_split}' not in pickle")
continue
split_data = data[sdk_split]
ids = list(split_data.get("text", split_data.get("labels", {})).keys())
if not ids:
continue
texts, audios, visions, labels = [], [], [], []
for sid in ids:
lbl_raw = split_data["labels"].get(sid)
if lbl_raw is None:
continue
score = float(np.array(lbl_raw).flatten()[0])
label = sentiment_to_class(score)
text = np.array(split_data["text"][sid], dtype=np.float32) if "text" in split_data else np.zeros((1, 300), dtype=np.float32)
audio = np.array(split_data["audio"][sid], dtype=np.float32) if "audio" in split_data else np.zeros((1, 74), dtype=np.float32)
vision = np.array(split_data["vision"][sid], dtype=np.float32) if "vision" in split_data else np.zeros((1, 35), dtype=np.float32)
# temporal mean pooling
texts.append(text.mean(axis=0) if text.ndim == 2 else text.flatten())
audios.append(audio.mean(axis=0) if audio.ndim == 2 else audio.flatten())
visions.append(vision.mean(axis=0) if vision.ndim == 2 else vision.flatten())
labels.append(label)
if not labels:
continue
np.save(out_dir / f"{out_split}_text.npy", np.stack(texts))
np.save(out_dir / f"{out_split}_audio.npy", np.stack(audios))
np.save(out_dir / f"{out_split}_vision.npy", np.stack(visions))
np.save(out_dir / f"{out_split}_labels.npy", np.array(labels, dtype=np.int64))
print(f" {out_split}: {len(labels)} samples")
def is_flat_format(data: dict) -> bool:
"""Detect declare-lab flat array format: data[split][modality] = np.ndarray."""
for split in ("train", "valid", "test"):
if split in data:
v = list(data[split].values())[0]
return isinstance(v, np.ndarray)
return False
def extract_from_flat(pkl_path: str, out_dir: Path):
"""Extract from declare-lab flat pickle (mosi.pkl).
Format: data[split]['glove'|'covarep'|'facet'|'label'] = np.ndarray (N, dim)
Labels are continuous scores in [-3, 3]; binarised to 3 classes.
"""
with open(pkl_path, "rb") as f:
data = pickle.load(f, encoding="latin1")
split_map = {"train": "train", "valid": "val", "test": "test"}
# modality name aliases
text_key = next((k for k in ("glove", "text", "bert") if k in list(data.get("train", {}).keys())), None)
audio_key = next((k for k in ("covarep", "audio", "opensmile") if k in list(data.get("train", {}).keys())), None)
vision_key = next((k for k in ("facet", "vision", "visual") if k in list(data.get("train", {}).keys())), None)
label_key = next((k for k in ("label", "labels", "Opinion Segment Labels") if k in list(data.get("train", {}).keys())), None)
print(f" Detected keys — text:{text_key} audio:{audio_key} vision:{vision_key} label:{label_key}")
for sdk_split, out_split in split_map.items():
if sdk_split not in data:
print(f" [skip] '{sdk_split}' not found")
continue
sd = data[sdk_split]
labels_raw = sd[label_key].flatten() if label_key else np.zeros(len(sd[text_key or audio_key]))
labels = np.array([sentiment_to_class(float(s)) for s in labels_raw], dtype=np.int64)
n = len(labels)
text = sd[text_key].astype(np.float32) if text_key else np.zeros((n, 300), dtype=np.float32)
audio = sd[audio_key].astype(np.float32) if audio_key else np.zeros((n, 74), dtype=np.float32)
vision = sd[vision_key].astype(np.float32) if vision_key else np.zeros((n, 46), dtype=np.float32)
# mean-pool time dimension if present: (N, T, dim) → (N, dim)
if text.ndim == 3:
text = text.mean(axis=1)
if audio.ndim == 3:
audio = audio.mean(axis=1)
if vision.ndim == 3:
vision = vision.mean(axis=1)
np.save(out_dir / f"{out_split}_text.npy", text)
np.save(out_dir / f"{out_split}_audio.npy", audio)
np.save(out_dir / f"{out_split}_vision.npy", vision)
np.save(out_dir / f"{out_split}_labels.npy", labels)
print(f" {out_split}: {n} samples text{text.shape} audio{audio.shape} vision{vision.shape}")
def extract_from_raw(raw_dir: Path, out_dir: Path):
"""Fallback: extract from raw files using local MFCC + hashed text."""
import wave
import struct
def load_wav_stdlib(path):
with wave.open(str(path), "rb") as f:
n_ch = f.getnchannels()
sw = f.getsampwidth()
raw = f.readframes(f.getnframes())
if sw == 2:
s = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
else:
s = np.frombuffer(raw, dtype=np.float32)
return s.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else s
print("[raw mode] scanning", raw_dir)
wav_files = sorted(raw_dir.rglob("*.wav"))
if not wav_files:
print(" No WAV files found under", raw_dir)
return
data = []
for wf in wav_files:
try:
sig = load_wav_stdlib(str(wf))
feat = sig.mean(), sig.std(), sig.max(), sig.min()
text_feat = np.array([hash(wf.stem) % 30522], dtype=np.float32)
data.append((text_feat, np.array(feat, dtype=np.float32), 1)) # neutral default
except Exception as e:
print(f" [warn] {wf.name}: {e}")
if data:
np.save(out_dir / "train_audio.npy", np.stack([x[1] for x in data]))
np.save(out_dir / "train_labels.npy", np.array([x[2] for x in data]))
print(f" Saved {len(data)} raw samples")
def extract_mosi(data_root: str, out_dir: str):
data_root = Path(data_root)
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
meta = {"label_names": LABEL_NAMES, "task": "sentiment-3class"}
# try pickle candidates (both SDK and declare-lab flat formats)
pkl_candidates = [
data_root / "mosi.pkl", # declare-lab flat
data_root / "aligned_mosi.pkl", # mmsdk aligned
data_root / "CMU_MOSI" / "Processed" / "aligned_50.pkl", # SDK standard
data_root / "CMU_MOSI" / "Processed" / "unaligned_50.pkl",
data_root / "mosi_data.pkl",
data_root / "aligned_50.pkl",
]
for pkl in pkl_candidates:
if pkl.exists():
print(f"Found pickle: {pkl}")
with open(pkl, "rb") as f:
probe = pickle.load(f, encoding="latin1")
if is_flat_format(probe):
print(" Format: declare-lab flat array")
extract_from_flat(str(pkl), out_dir)
else:
print(" Format: CMU-SDK dict-of-dicts")
extract_from_sdk(str(pkl), out_dir)
meta["source"] = str(pkl)
meta["format"] = "flat" if is_flat_format(probe) else "sdk"
break
else:
raw_dir = data_root / "CMU_MOSI" / "Raw"
if raw_dir.exists():
extract_from_raw(raw_dir, out_dir)
meta["source"] = str(raw_dir)
else:
print(f"[error] No CMU-MOSI data found under {data_root}")
print(" Tried:", [str(p) for p in pkl_candidates])
return
with open(out_dir / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
print("Done →", out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", required=True,
help="Dir containing CMU_MOSI/ subdirectory")
parser.add_argument("--out_dir", default=None)
args = parser.parse_args()
zsy = os.environ.get("ZSY", "/root")
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/mosi"
extract_mosi(args.data_root, out_dir)