Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
vlm2vec_models,
voyage_models,
voyage_v,
wav2clip_model,
wav2vec2_models,
wavlm_models,
whisper_models,
Expand Down Expand Up @@ -215,6 +216,7 @@
clap_models,
wavlm_models,
whisper_models,
wav2clip_model,
seed_models,
qwen2_models,
yamnet_models,
Expand Down
191 changes: 191 additions & 0 deletions mteb/models/wav2clip_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from __future__ import annotations

from collections.abc import Iterable
from functools import partial
from typing import Any

import numpy as np
import torch
import torchaudio
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor

from mteb.encoder_interface import AudioBatch, AudioData, PromptType
from mteb.model_meta import ModelMeta
from mteb.requires_package import requires_package


class Wav2ClipZeroShotWrapper:
def __init__(
self,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs: Any,
):
requires_package(self, "wav2clip", "pip install 'mteb[wav2clip]'")
import wav2clip

self.wav2clip = wav2clip
# audio side
self.device = device
self.audio_model = self.wav2clip.get_model().to(device)
self.sampling_rate = 16_000

# text side (CLIP)
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)

def _handle_batch(
self, batch: AudioData | Iterable[tuple[AudioData, str]]
) -> list[torch.Tensor]:
waveforms: list[torch.Tensor] = []

if isinstance(batch, tuple): # Handle (audio, metadata) tuples
items = [batch]
else:
items = batch

for item in items:
# dict with array and sampling_rate
if isinstance(item, dict) and "array" in item:
audio = item["array"]
if isinstance(audio, np.ndarray):
tensor = torch.from_numpy(audio)
elif isinstance(audio, list):
tensor = torch.tensor(audio, dtype=torch.float32)
else:
tensor = audio # assume it's already a torch.Tensor
tensor = tensor.float().squeeze()
if item.get("sampling_rate", self.sampling_rate) != self.sampling_rate:
resampler = torchaudio.transforms.Resample(
item["sampling_rate"], self.sampling_rate
)
tensor = resampler(tensor)
waveforms.append(tensor)

# dict with path
elif isinstance(item, dict) and "path" in item:
waveform, sr = torchaudio.load(item["path"])
tensor = waveform.float().squeeze()
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
tensor = resampler(tensor)
waveforms.append(tensor)

# direct numpy or torch
elif isinstance(item, (np.ndarray, torch.Tensor, list)):
if isinstance(item, np.ndarray):
tensor = torch.from_numpy(item)
elif isinstance(item, list):
tensor = torch.tensor(item, dtype=torch.float32)
else:
tensor = item
waveforms.append(tensor.float().squeeze())

# file path string
elif isinstance(item, str):
waveform, sr = torchaudio.load(item)
tensor = waveform.float().squeeze()
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
tensor = resampler(tensor)
waveforms.append(tensor)

return waveforms

def get_audio_embeddings(
self,
audio: AudioBatch,
**kwargs: Any,
) -> np.ndarray:
all_embeddings = []

if isinstance(audio, DataLoader):
# Process each batch separately
for batch in tqdm(audio, desc="Processing audio batches"):
batch_embeddings = []

# Process each item in the batch individually
wavs = self._handle_batch(batch)
for wav in wavs:
# Process one audio at a time to avoid memory issues
wav_np = wav.unsqueeze(0).cpu().numpy() # Add batch dimension
Comment on lines +112 to +114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the batch limits for your machine? We should use batch inference here, and suggest a lower batch size, but still allow batch inference (y'know, if people have random H200s lying around 😂 )

embed = self.wav2clip.embed_audio(wav_np, self.audio_model)

# Normalize
norm = np.linalg.norm(embed, axis=-1, keepdims=True)
normalized_embed = embed / norm
batch_embeddings.append(normalized_embed)

all_embeddings.extend(batch_embeddings)

return np.vstack(all_embeddings)
else:
# Process single batch - still do it item by item
wavs = self._handle_batch(audio)
for wav in wavs:
# Process one audio at a time
wav_np = wav.unsqueeze(0).cpu().numpy() # Add batch dimension
embed = self.wav2clip.embed_audio(wav_np, self.audio_model)

# Normalize
norm = np.linalg.norm(embed, axis=-1, keepdims=True)
normalized_embed = embed / norm
all_embeddings.append(normalized_embed)

return np.vstack(all_embeddings)

def get_text_embeddings(
self,
texts: list[str],
**kwargs: Any,
) -> np.ndarray:
inputs = self.clip_processor(text=texts, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

with torch.no_grad():
text_features = self.clip.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

return text_features.cpu().numpy()

def encode(
self,
inputs: AudioBatch | list[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
if isinstance(inputs[0], str):
return self.get_text_embeddings(inputs)
return self.get_audio_embeddings(inputs)


wav2clip_zero = ModelMeta(
loader=partial(Wav2ClipZeroShotWrapper),
name="lyrebird/wav2clip",
languages=["eng-Latn"],
revision="N/A",
release_date="2022-03-15",
modalities=["audio", "text"],
n_parameters=163_000_000, # wav2clip: 11.7M + CLIP: 151.3M ≈ 163M
memory_usage_mb=622, # wav2clip: 44.65MB + CLIP: 577.08MB ≈ 622MB
max_tokens=None,
embed_dim=512,
license="mit",
open_weights=True,
framework=["PyTorch"],
reference="https://github.com/descriptinc/lyrebird-wav2clip",
similarity_fn_name="cosine",
use_instructions=False,
public_training_code="https://github.com/descriptinc/lyrebird-wav2clip",
public_training_data="https://github.com/descriptinc/lyrebird-wav2clip#data",
training_datasets={
# "AudioSet": ["https://research.google.com/audioset/"],
# "FreeSound": ["https://freesound.org/"],
# "BBC Sound Effects": ["https://sound-effects.bbcrewind.co.uk/"],
},
)