Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mteb/models/msclap_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class MSClapWrapper:
def __init__(
self,
model_name: str = "microsoft/msclap",
model_name: str = "microsoft/msclap-2023",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs: Any,
):
Expand Down Expand Up @@ -188,7 +188,7 @@ def _process_audio_batch(self, batch) -> list[np.ndarray]:
audio_tensor = audio_tensor.to(self.device)
# Get embeddings using the internal audio encoder
with torch.no_grad():
# Use the internal method
# Use the internal method: [0] the audio emebdding, [1] has output class probabilities
audio_features = self.model.clap.audio_encoder(audio_tensor)[0]

# Normalize embeddings
Expand Down Expand Up @@ -221,15 +221,13 @@ def get_text_embeddings(

def encode(
self,
inputs: AudioBatch | list[str],
inputs: list[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
if isinstance(inputs[0], str):
return self.get_text_embeddings(inputs)
return self.get_audio_embeddings(inputs)
return self.get_text_embeddings(inputs, **kwargs)


# Microsoft CLAP Model metadata
Expand Down
6 changes: 2 additions & 4 deletions mteb/models/muq_mulan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,14 @@ def get_text_embeddings(

def encode(
self,
inputs: AudioBatch | list[str],
inputs: list[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
"""Encode inputs (audio or text) into embeddings."""
if isinstance(inputs[0], str):
return self.get_text_embeddings(inputs)
return self.get_audio_embeddings(inputs)
return self.get_text_embeddings(inputs, **kwargs)

def calc_similarity(
self, audio_embeds: np.ndarray, text_embeds: np.ndarray
Expand Down
6 changes: 2 additions & 4 deletions mteb/models/speecht5_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,13 @@ def get_text_embeddings(

def encode(
self,
inputs: AudioBatch,
inputs: list[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
if isinstance(inputs[0], str):
return self.get_text_embeddings(inputs).numpy()
return self.get_audio_embeddings(inputs).numpy()
return self.get_text_embeddings(inputs, **kwargs).numpy()


# ASR model - Optimized for Speech Recognition tasks
Expand Down
14 changes: 6 additions & 8 deletions mteb/models/wav2clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def __init__(
**kwargs: Any,
):
requires_package(self, "wav2clip", "pip install 'mteb[wav2clip]'")
import wav2clip
from wav2clip import embed_audio, get_model

self.wav2clip = wav2clip
self.embed_audio = embed_audio
# audio side
self.device = device
self.audio_model = self.wav2clip.get_model().to(device)
self.audio_model = get_model().to(device)
self.sampling_rate = 16_000

# text side (CLIP)
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_audio_embeddings(
for wav in wavs:
# Process one audio at a time to avoid memory issues
wav_np = wav.unsqueeze(0).cpu().numpy() # Add batch dimension
embed = self.wav2clip.embed_audio(wav_np, self.audio_model)
embed = self.embed_audio(wav_np, self.audio_model)

# Normalize
norm = np.linalg.norm(embed, axis=-1, keepdims=True)
Expand Down Expand Up @@ -153,15 +153,13 @@ def get_text_embeddings(

def encode(
self,
inputs: AudioBatch | list[str],
inputs: list[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
if isinstance(inputs[0], str):
return self.get_text_embeddings(inputs)
return self.get_audio_embeddings(inputs)
return self.get_text_embeddings(inputs, **kwargs)


wav2clip_zero = ModelMeta(
Expand Down