Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
201 lines
6.7 KiB
Python
201 lines
6.7 KiB
Python
"""
|
|
MELD (Multimodal EmotionLines Dataset) feature extraction.
|
|
|
|
Dataset structure:
|
|
$DATA_ROOT/MELD.Raw/
|
|
train_sent_emo.csv
|
|
dev_sent_emo.csv
|
|
test_sent_emo.csv
|
|
train/ dev/ test/ → subdirs with mp4 clips
|
|
dia{N}_utt{M}.mp4
|
|
|
|
CSV columns:
|
|
Sr No., Utterance, Speaker, Emotion, Sentiment,
|
|
Dialogue_ID, Utterance_ID, Season, Episode, StartTime, EndTime
|
|
|
|
Emotions: neutral, surprise, fear, sadness, joy, disgust, anger
|
|
|
|
Output: $ZSY/multimodal_affect/data/meld/
|
|
{train,val,test}_{text,audio,labels}.npy
|
|
label_map.json
|
|
"""
|
|
|
|
import os
|
|
import csv
|
|
import json
|
|
import argparse
|
|
import numpy as np
|
|
import wave
|
|
from pathlib import Path
|
|
|
|
|
|
EMOTION_MAP = {
|
|
"neutral": 0, "surprise": 1, "fear": 2,
|
|
"sadness": 3, "joy": 4, "disgust": 5, "anger": 6,
|
|
}
|
|
LABEL_NAMES = ["neutral", "surprise", "fear", "sadness", "joy", "disgust", "anger"]
|
|
N_MFCC = 40
|
|
|
|
|
|
# ── audio loading ──────────────────────────────────────────────────────────
|
|
def _load_audio_bytes(path: str) -> np.ndarray:
|
|
"""Load audio from WAV or MP4 via av; fall back to wave stdlib."""
|
|
path = str(path)
|
|
if path.endswith(".mp4") or path.endswith(".mp3"):
|
|
try:
|
|
import av
|
|
container = av.open(path)
|
|
stream = next((s for s in container.streams if s.type == "audio"), None)
|
|
if stream is None:
|
|
return np.zeros(16000, dtype=np.float32)
|
|
chunks = []
|
|
for pkt in container.demux(stream):
|
|
for frame in pkt.decode():
|
|
arr = frame.to_ndarray()
|
|
if arr.ndim == 2:
|
|
arr = arr.mean(axis=0)
|
|
chunks.append(arr.astype(np.float32))
|
|
container.close()
|
|
if chunks:
|
|
return np.concatenate(chunks)
|
|
except Exception as e:
|
|
print(f" av failed for {path}: {e}")
|
|
return np.zeros(16000, dtype=np.float32)
|
|
|
|
# WAV via stdlib
|
|
with wave.open(path, "rb") as f:
|
|
n_ch = f.getnchannels()
|
|
sw = f.getsampwidth()
|
|
raw = f.readframes(f.getnframes())
|
|
if sw == 2:
|
|
sig = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
|
|
elif sw == 4:
|
|
sig = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2**31
|
|
else:
|
|
sig = np.frombuffer(raw, dtype=np.float32)
|
|
return sig.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else sig
|
|
|
|
|
|
def _compute_mfcc_mean(signal: np.ndarray, sr: int = 16000) -> np.ndarray:
|
|
try:
|
|
import librosa
|
|
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=N_MFCC)
|
|
return mfcc.mean(axis=1)
|
|
except Exception:
|
|
pass
|
|
# energy-based fallback
|
|
rms = float(np.sqrt(np.mean(signal ** 2) + 1e-9))
|
|
feat = np.zeros(N_MFCC, dtype=np.float32)
|
|
feat[0] = rms
|
|
return feat
|
|
|
|
|
|
# ── text features ──────────────────────────────────────────────────────────
|
|
def _text_features(text: str, max_len: int = 64) -> np.ndarray:
|
|
tokens = text.lower().split()[:max_len]
|
|
ids = [hash(t) % 30522 for t in tokens]
|
|
ids += [0] * (max_len - len(ids))
|
|
return np.array(ids, dtype=np.int32)
|
|
|
|
|
|
# ── csv parsing ────────────────────────────────────────────────────────────
|
|
def read_csv(csv_path: str):
|
|
records = []
|
|
with open(csv_path, encoding="utf-8") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
records.append(row)
|
|
return records
|
|
|
|
|
|
def extract_split(csv_path: str, clip_dir: Path, out_prefix: Path,
|
|
has_video: bool = True):
|
|
records = read_csv(csv_path)
|
|
texts, audios, labels_list = [], [], []
|
|
|
|
for rec in records:
|
|
emo = rec.get("Emotion", "").strip().lower()
|
|
if emo not in EMOTION_MAP:
|
|
continue
|
|
label = EMOTION_MAP[emo]
|
|
|
|
utterance = rec.get("Utterance", "").strip()
|
|
dia_id = rec.get("Dialogue_ID", "").strip()
|
|
utt_id = rec.get("Utterance_ID", "").strip()
|
|
|
|
# find audio
|
|
audio_feat = np.zeros(N_MFCC, dtype=np.float32)
|
|
if has_video and clip_dir.exists():
|
|
clip_name = f"dia{dia_id}_utt{utt_id}.mp4"
|
|
clip_path = clip_dir / clip_name
|
|
if clip_path.exists():
|
|
try:
|
|
sig = _load_audio_bytes(str(clip_path))
|
|
audio_feat = _compute_mfcc_mean(sig)
|
|
except Exception as e:
|
|
print(f" [warn] {clip_name}: {e}")
|
|
|
|
text_feat = _text_features(utterance)
|
|
texts.append(text_feat)
|
|
audios.append(audio_feat)
|
|
labels_list.append(label)
|
|
|
|
if not labels_list:
|
|
print(f" [warn] no valid records in {csv_path}")
|
|
return
|
|
|
|
split = out_prefix.name
|
|
base = out_prefix.parent
|
|
np.save(base / f"{split}_text.npy", np.stack(texts))
|
|
np.save(base / f"{split}_audio.npy", np.stack(audios))
|
|
np.save(base / f"{split}_labels.npy", np.array(labels_list, dtype=np.int64))
|
|
print(f" {split}: {len(labels_list)} samples, "
|
|
f"text {np.stack(texts).shape}, audio {np.stack(audios).shape}")
|
|
|
|
|
|
def extract_meld(data_root: str, out_dir: str):
|
|
data_root = Path(data_root)
|
|
out_dir = Path(out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
meld_root = data_root / "MELD.Raw"
|
|
if not meld_root.exists():
|
|
meld_root = data_root # maybe already inside MELD.Raw
|
|
|
|
csv_map = {
|
|
"train": "train_sent_emo.csv",
|
|
"val": "dev_sent_emo.csv",
|
|
"test": "test_sent_emo.csv",
|
|
}
|
|
dir_map = {
|
|
"train": "train",
|
|
"val": "dev",
|
|
"test": "test",
|
|
}
|
|
|
|
for split, csv_name in csv_map.items():
|
|
csv_path = meld_root / csv_name
|
|
if not csv_path.exists():
|
|
print(f" [skip] {csv_path} not found")
|
|
continue
|
|
clip_dir = meld_root / dir_map[split]
|
|
extract_split(str(csv_path), clip_dir, out_dir / split, has_video=clip_dir.exists())
|
|
|
|
label_map = {i: n for i, n in enumerate(LABEL_NAMES)}
|
|
with open(out_dir / "label_map.json", "w") as f:
|
|
json.dump(label_map, f, indent=2)
|
|
|
|
print("Done →", out_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data_root", required=True,
|
|
help="Dir containing MELD.Raw/ (or already inside it)")
|
|
parser.add_argument("--out_dir", default=None)
|
|
args = parser.parse_args()
|
|
|
|
zsy = os.environ.get("ZSY", "/root")
|
|
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/meld"
|
|
extract_meld(args.data_root, out_dir)
|