Files
CompanionGuard-RL/旧方向信息/scripts/preprocess/extract_meld.py

201 lines
6.7 KiB
Python
Raw Normal View History

"""
MELD (Multimodal EmotionLines Dataset) feature extraction.
Dataset structure:
$DATA_ROOT/MELD.Raw/
train_sent_emo.csv
dev_sent_emo.csv
test_sent_emo.csv
train/ dev/ test/ subdirs with mp4 clips
dia{N}_utt{M}.mp4
CSV columns:
Sr No., Utterance, Speaker, Emotion, Sentiment,
Dialogue_ID, Utterance_ID, Season, Episode, StartTime, EndTime
Emotions: neutral, surprise, fear, sadness, joy, disgust, anger
Output: $ZSY/multimodal_affect/data/meld/
{train,val,test}_{text,audio,labels}.npy
label_map.json
"""
import os
import csv
import json
import argparse
import numpy as np
import wave
from pathlib import Path
EMOTION_MAP = {
"neutral": 0, "surprise": 1, "fear": 2,
"sadness": 3, "joy": 4, "disgust": 5, "anger": 6,
}
LABEL_NAMES = ["neutral", "surprise", "fear", "sadness", "joy", "disgust", "anger"]
N_MFCC = 40
# ── audio loading ──────────────────────────────────────────────────────────
def _load_audio_bytes(path: str) -> np.ndarray:
"""Load audio from WAV or MP4 via av; fall back to wave stdlib."""
path = str(path)
if path.endswith(".mp4") or path.endswith(".mp3"):
try:
import av
container = av.open(path)
stream = next((s for s in container.streams if s.type == "audio"), None)
if stream is None:
return np.zeros(16000, dtype=np.float32)
chunks = []
for pkt in container.demux(stream):
for frame in pkt.decode():
arr = frame.to_ndarray()
if arr.ndim == 2:
arr = arr.mean(axis=0)
chunks.append(arr.astype(np.float32))
container.close()
if chunks:
return np.concatenate(chunks)
except Exception as e:
print(f" av failed for {path}: {e}")
return np.zeros(16000, dtype=np.float32)
# WAV via stdlib
with wave.open(path, "rb") as f:
n_ch = f.getnchannels()
sw = f.getsampwidth()
raw = f.readframes(f.getnframes())
if sw == 2:
sig = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
elif sw == 4:
sig = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2**31
else:
sig = np.frombuffer(raw, dtype=np.float32)
return sig.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else sig
def _compute_mfcc_mean(signal: np.ndarray, sr: int = 16000) -> np.ndarray:
try:
import librosa
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=N_MFCC)
return mfcc.mean(axis=1)
except Exception:
pass
# energy-based fallback
rms = float(np.sqrt(np.mean(signal ** 2) + 1e-9))
feat = np.zeros(N_MFCC, dtype=np.float32)
feat[0] = rms
return feat
# ── text features ──────────────────────────────────────────────────────────
def _text_features(text: str, max_len: int = 64) -> np.ndarray:
tokens = text.lower().split()[:max_len]
ids = [hash(t) % 30522 for t in tokens]
ids += [0] * (max_len - len(ids))
return np.array(ids, dtype=np.int32)
# ── csv parsing ────────────────────────────────────────────────────────────
def read_csv(csv_path: str):
records = []
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
records.append(row)
return records
def extract_split(csv_path: str, clip_dir: Path, out_prefix: Path,
has_video: bool = True):
records = read_csv(csv_path)
texts, audios, labels_list = [], [], []
for rec in records:
emo = rec.get("Emotion", "").strip().lower()
if emo not in EMOTION_MAP:
continue
label = EMOTION_MAP[emo]
utterance = rec.get("Utterance", "").strip()
dia_id = rec.get("Dialogue_ID", "").strip()
utt_id = rec.get("Utterance_ID", "").strip()
# find audio
audio_feat = np.zeros(N_MFCC, dtype=np.float32)
if has_video and clip_dir.exists():
clip_name = f"dia{dia_id}_utt{utt_id}.mp4"
clip_path = clip_dir / clip_name
if clip_path.exists():
try:
sig = _load_audio_bytes(str(clip_path))
audio_feat = _compute_mfcc_mean(sig)
except Exception as e:
print(f" [warn] {clip_name}: {e}")
text_feat = _text_features(utterance)
texts.append(text_feat)
audios.append(audio_feat)
labels_list.append(label)
if not labels_list:
print(f" [warn] no valid records in {csv_path}")
return
split = out_prefix.name
base = out_prefix.parent
np.save(base / f"{split}_text.npy", np.stack(texts))
np.save(base / f"{split}_audio.npy", np.stack(audios))
np.save(base / f"{split}_labels.npy", np.array(labels_list, dtype=np.int64))
print(f" {split}: {len(labels_list)} samples, "
f"text {np.stack(texts).shape}, audio {np.stack(audios).shape}")
def extract_meld(data_root: str, out_dir: str):
data_root = Path(data_root)
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
meld_root = data_root / "MELD.Raw"
if not meld_root.exists():
meld_root = data_root # maybe already inside MELD.Raw
csv_map = {
"train": "train_sent_emo.csv",
"val": "dev_sent_emo.csv",
"test": "test_sent_emo.csv",
}
dir_map = {
"train": "train",
"val": "dev",
"test": "test",
}
for split, csv_name in csv_map.items():
csv_path = meld_root / csv_name
if not csv_path.exists():
print(f" [skip] {csv_path} not found")
continue
clip_dir = meld_root / dir_map[split]
extract_split(str(csv_path), clip_dir, out_dir / split, has_video=clip_dir.exists())
label_map = {i: n for i, n in enumerate(LABEL_NAMES)}
with open(out_dir / "label_map.json", "w") as f:
json.dump(label_map, f, indent=2)
print("Done →", out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", required=True,
help="Dir containing MELD.Raw/ (or already inside it)")
parser.add_argument("--out_dir", default=None)
args = parser.parse_args()
zsy = os.environ.get("ZSY", "/root")
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/meld"
extract_meld(args.data_root, out_dir)