chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
200
旧方向信息/scripts/preprocess/extract_meld.py
Normal file
200
旧方向信息/scripts/preprocess/extract_meld.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
MELD (Multimodal EmotionLines Dataset) feature extraction.
|
||||
|
||||
Dataset structure:
|
||||
$DATA_ROOT/MELD.Raw/
|
||||
train_sent_emo.csv
|
||||
dev_sent_emo.csv
|
||||
test_sent_emo.csv
|
||||
train/ dev/ test/ → subdirs with mp4 clips
|
||||
dia{N}_utt{M}.mp4
|
||||
|
||||
CSV columns:
|
||||
Sr No., Utterance, Speaker, Emotion, Sentiment,
|
||||
Dialogue_ID, Utterance_ID, Season, Episode, StartTime, EndTime
|
||||
|
||||
Emotions: neutral, surprise, fear, sadness, joy, disgust, anger
|
||||
|
||||
Output: $ZSY/multimodal_affect/data/meld/
|
||||
{train,val,test}_{text,audio,labels}.npy
|
||||
label_map.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
EMOTION_MAP = {
|
||||
"neutral": 0, "surprise": 1, "fear": 2,
|
||||
"sadness": 3, "joy": 4, "disgust": 5, "anger": 6,
|
||||
}
|
||||
LABEL_NAMES = ["neutral", "surprise", "fear", "sadness", "joy", "disgust", "anger"]
|
||||
N_MFCC = 40
|
||||
|
||||
|
||||
# ── audio loading ──────────────────────────────────────────────────────────
|
||||
def _load_audio_bytes(path: str) -> np.ndarray:
|
||||
"""Load audio from WAV or MP4 via av; fall back to wave stdlib."""
|
||||
path = str(path)
|
||||
if path.endswith(".mp4") or path.endswith(".mp3"):
|
||||
try:
|
||||
import av
|
||||
container = av.open(path)
|
||||
stream = next((s for s in container.streams if s.type == "audio"), None)
|
||||
if stream is None:
|
||||
return np.zeros(16000, dtype=np.float32)
|
||||
chunks = []
|
||||
for pkt in container.demux(stream):
|
||||
for frame in pkt.decode():
|
||||
arr = frame.to_ndarray()
|
||||
if arr.ndim == 2:
|
||||
arr = arr.mean(axis=0)
|
||||
chunks.append(arr.astype(np.float32))
|
||||
container.close()
|
||||
if chunks:
|
||||
return np.concatenate(chunks)
|
||||
except Exception as e:
|
||||
print(f" av failed for {path}: {e}")
|
||||
return np.zeros(16000, dtype=np.float32)
|
||||
|
||||
# WAV via stdlib
|
||||
with wave.open(path, "rb") as f:
|
||||
n_ch = f.getnchannels()
|
||||
sw = f.getsampwidth()
|
||||
raw = f.readframes(f.getnframes())
|
||||
if sw == 2:
|
||||
sig = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768
|
||||
elif sw == 4:
|
||||
sig = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2**31
|
||||
else:
|
||||
sig = np.frombuffer(raw, dtype=np.float32)
|
||||
return sig.reshape(-1, n_ch).mean(axis=1) if n_ch > 1 else sig
|
||||
|
||||
|
||||
def _compute_mfcc_mean(signal: np.ndarray, sr: int = 16000) -> np.ndarray:
|
||||
try:
|
||||
import librosa
|
||||
mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=N_MFCC)
|
||||
return mfcc.mean(axis=1)
|
||||
except Exception:
|
||||
pass
|
||||
# energy-based fallback
|
||||
rms = float(np.sqrt(np.mean(signal ** 2) + 1e-9))
|
||||
feat = np.zeros(N_MFCC, dtype=np.float32)
|
||||
feat[0] = rms
|
||||
return feat
|
||||
|
||||
|
||||
# ── text features ──────────────────────────────────────────────────────────
|
||||
def _text_features(text: str, max_len: int = 64) -> np.ndarray:
|
||||
tokens = text.lower().split()[:max_len]
|
||||
ids = [hash(t) % 30522 for t in tokens]
|
||||
ids += [0] * (max_len - len(ids))
|
||||
return np.array(ids, dtype=np.int32)
|
||||
|
||||
|
||||
# ── csv parsing ────────────────────────────────────────────────────────────
|
||||
def read_csv(csv_path: str):
|
||||
records = []
|
||||
with open(csv_path, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
records.append(row)
|
||||
return records
|
||||
|
||||
|
||||
def extract_split(csv_path: str, clip_dir: Path, out_prefix: Path,
|
||||
has_video: bool = True):
|
||||
records = read_csv(csv_path)
|
||||
texts, audios, labels_list = [], [], []
|
||||
|
||||
for rec in records:
|
||||
emo = rec.get("Emotion", "").strip().lower()
|
||||
if emo not in EMOTION_MAP:
|
||||
continue
|
||||
label = EMOTION_MAP[emo]
|
||||
|
||||
utterance = rec.get("Utterance", "").strip()
|
||||
dia_id = rec.get("Dialogue_ID", "").strip()
|
||||
utt_id = rec.get("Utterance_ID", "").strip()
|
||||
|
||||
# find audio
|
||||
audio_feat = np.zeros(N_MFCC, dtype=np.float32)
|
||||
if has_video and clip_dir.exists():
|
||||
clip_name = f"dia{dia_id}_utt{utt_id}.mp4"
|
||||
clip_path = clip_dir / clip_name
|
||||
if clip_path.exists():
|
||||
try:
|
||||
sig = _load_audio_bytes(str(clip_path))
|
||||
audio_feat = _compute_mfcc_mean(sig)
|
||||
except Exception as e:
|
||||
print(f" [warn] {clip_name}: {e}")
|
||||
|
||||
text_feat = _text_features(utterance)
|
||||
texts.append(text_feat)
|
||||
audios.append(audio_feat)
|
||||
labels_list.append(label)
|
||||
|
||||
if not labels_list:
|
||||
print(f" [warn] no valid records in {csv_path}")
|
||||
return
|
||||
|
||||
split = out_prefix.name
|
||||
base = out_prefix.parent
|
||||
np.save(base / f"{split}_text.npy", np.stack(texts))
|
||||
np.save(base / f"{split}_audio.npy", np.stack(audios))
|
||||
np.save(base / f"{split}_labels.npy", np.array(labels_list, dtype=np.int64))
|
||||
print(f" {split}: {len(labels_list)} samples, "
|
||||
f"text {np.stack(texts).shape}, audio {np.stack(audios).shape}")
|
||||
|
||||
|
||||
def extract_meld(data_root: str, out_dir: str):
|
||||
data_root = Path(data_root)
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meld_root = data_root / "MELD.Raw"
|
||||
if not meld_root.exists():
|
||||
meld_root = data_root # maybe already inside MELD.Raw
|
||||
|
||||
csv_map = {
|
||||
"train": "train_sent_emo.csv",
|
||||
"val": "dev_sent_emo.csv",
|
||||
"test": "test_sent_emo.csv",
|
||||
}
|
||||
dir_map = {
|
||||
"train": "train",
|
||||
"val": "dev",
|
||||
"test": "test",
|
||||
}
|
||||
|
||||
for split, csv_name in csv_map.items():
|
||||
csv_path = meld_root / csv_name
|
||||
if not csv_path.exists():
|
||||
print(f" [skip] {csv_path} not found")
|
||||
continue
|
||||
clip_dir = meld_root / dir_map[split]
|
||||
extract_split(str(csv_path), clip_dir, out_dir / split, has_video=clip_dir.exists())
|
||||
|
||||
label_map = {i: n for i, n in enumerate(LABEL_NAMES)}
|
||||
with open(out_dir / "label_map.json", "w") as f:
|
||||
json.dump(label_map, f, indent=2)
|
||||
|
||||
print("Done →", out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", required=True,
|
||||
help="Dir containing MELD.Raw/ (or already inside it)")
|
||||
parser.add_argument("--out_dir", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
zsy = os.environ.get("ZSY", "/root")
|
||||
out_dir = args.out_dir or f"{zsy}/multimodal_affect/data/meld"
|
||||
extract_meld(args.data_root, out_dir)
|
||||
Reference in New Issue
Block a user