Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mteb/models/ast_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# AST processes raw waveforms directly through its feature extractor
Expand Down
7 changes: 6 additions & 1 deletion mteb/models/clap_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,16 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> np.ndarray:
all_features = []
processed_audio = self._process_audio(audio)

for i in tqdm(
range(0, len(processed_audio), batch_size), desc="Processing audio batches"
range(0, len(processed_audio), batch_size),
desc="Processing audio batches",
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]
batch_arrays = [tensor.numpy() for tensor in batch]
Expand All @@ -125,6 +128,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding=True,
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

Expand Down
15 changes: 13 additions & 2 deletions mteb/models/cnn14_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,14 @@ def _handle_batch(
def _convert_audio(self, audio: AudioData) -> torch.Tensor:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
return audio.squeeze()
audio = audio.squeeze()

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if audio.shape[-1] > max_length:
audio = audio[..., :max_length]

return audio

def _load_audio_file(self, path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
Expand All @@ -113,13 +120,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# Convert batch to tensors and move to device
Expand Down
8 changes: 7 additions & 1 deletion mteb/models/data2vec_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# Pre-process audio
Expand All @@ -125,6 +129,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
return_attention_mask=True,
).to(self.device)

Expand Down
8 changes: 7 additions & 1 deletion mteb/models/encodec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# Process audio through EnCodec's processor
Expand All @@ -121,6 +125,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding=True,
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
).to(self.device)

# Get the latent representations directly from the encoder
Expand Down
8 changes: 7 additions & 1 deletion mteb/models/hubert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# Pre-process like Wav2Vec2
Expand All @@ -125,6 +129,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
return_attention_mask=True,
).to(self.device)

Expand Down
8 changes: 7 additions & 1 deletion mteb/models/mctct_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

# Process each audio in the batch
Expand All @@ -182,6 +186,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding=True,
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
).to(self.device)

# Get embeddings from the model
Expand Down
8 changes: 7 additions & 1 deletion mteb/models/mms_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,17 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed_audio = self._process_audio(audio)
all_embeddings = []

with torch.no_grad():
for i in tqdm(range(0, len(processed_audio), batch_size)):
for i in tqdm(
range(0, len(processed_audio), batch_size),
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

batch_tensor = self._pad_audio_batch(batch)
Expand All @@ -142,6 +146,8 @@ def get_audio_embeddings(
sampling_rate=self.sampling_rate,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=30 * self.sampling_rate, # 30 seconds max
return_attention_mask=True,
).to(self.device)

Expand Down
25 changes: 23 additions & 2 deletions mteb/models/msclap_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def _handle_batch(
)
audio_array = resampler(audio_array)

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if audio_array.shape[-1] > max_length:
audio_array = audio_array[..., :max_length]

# Only squeeze here, don't call _convert_audio again
waveforms.append(audio_array.squeeze())
elif "path" in item:
Expand All @@ -107,14 +112,27 @@ def _handle_batch(
def _convert_audio(self, audio: AudioData) -> torch.Tensor:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
return audio.squeeze().float() # Ensure float32
audio = audio.squeeze().float() # Ensure float32

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if audio.shape[-1] > max_length:
audio = audio[..., :max_length]

return audio

def _load_audio_file(self, path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
waveform = waveform.float() # Ensure float32
if sample_rate != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)
waveform = resampler(waveform)

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if waveform.shape[-1] > max_length:
waveform = waveform[..., :max_length]

return waveform.squeeze()

def get_audio_embeddings(
Expand All @@ -124,13 +142,16 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> np.ndarray:
all_features = []
processed_audio = self._process_audio(audio)

for i in tqdm(
range(0, len(processed_audio), batch_size), desc="Processing audio batches"
range(0, len(processed_audio), batch_size),
desc="Processing audio batches",
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]

Expand Down
15 changes: 13 additions & 2 deletions mteb/models/muq_mulan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,14 @@ def _convert_audio(self, audio: AudioData) -> torch.Tensor:
"""Convert audio data to torch tensor."""
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
return audio.squeeze().float() # Ensure float32
audio = audio.squeeze().float() # Ensure float32

# Apply audio truncation (30 seconds max)
max_length = 30 * self.target_sampling_rate # 30 seconds
if audio.shape[-1] > max_length:
audio = audio[..., :max_length]

return audio

def _load_audio_file(self, path: str) -> torch.Tensor:
"""Load audio file and resample to target sampling rate."""
Expand All @@ -109,17 +116,21 @@ def get_audio_embeddings(
self,
audio: AudioBatch,
*,
show_progress_bar: bool = True,
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
**kwargs: Any,
) -> np.ndarray:
"""Get audio embeddings using MuQ-MuLan."""
all_features = []

processed_audio = self._process_audio(audio)

for i in tqdm(
range(0, len(processed_audio), batch_size), desc="Processing audio batches"
range(0, len(processed_audio), batch_size),
desc="Processing audio batches",
disable=not show_progress_bar,
):
batch = processed_audio[i : i + batch_size]
batch_features = self._process_audio_batch(batch)
Expand Down
24 changes: 21 additions & 3 deletions mteb/models/qwen2_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torchaudio
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

from mteb.encoder_interface import AudioBatch, AudioData, PromptType
Expand Down Expand Up @@ -81,7 +82,14 @@ def _convert_audio_from_numpy(self, audio: AudioData) -> torch.Tensor:
audio = torch.from_numpy(audio)
if audio.ndim == 2:
audio = audio.mean(dim=0)
return audio.squeeze()
audio = audio.squeeze()

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if audio.shape[-1] > max_length:
audio = audio[..., :max_length]

return audio

def _load_audio_file(self, path: str) -> torch.Tensor:
waveform, sr = torchaudio.load(path)
Expand All @@ -90,7 +98,14 @@ def _load_audio_file(self, path: str) -> torch.Tensor:
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
waveform = resampler(waveform)
return waveform.squeeze()
waveform = waveform.squeeze()

# Apply audio truncation (30 seconds max)
max_length = 30 * self.sampling_rate # 30 seconds
if waveform.shape[-1] > max_length:
waveform = waveform[..., :max_length]

return waveform

def _pad_audio_batch(self, batch: list[torch.Tensor]) -> torch.Tensor:
max_len = max(w.shape[0] for w in batch)
Expand All @@ -104,13 +119,16 @@ def get_audio_embeddings(
task_name: str | None = None,
prompt_type: PromptType | None = None,
batch_size: int = 4,
show_progress_bar: bool = True,
**kwargs: Any,
) -> torch.Tensor:
processed = self._process_audio(audio)
embeddings_list: list[torch.Tensor] = []

with torch.no_grad():
for i in range(0, len(processed), batch_size):
for i in tqdm(
range(0, len(processed), batch_size), disable=not show_progress_bar
):
batch = processed[i : i + batch_size]

audio_list = [w.numpy() for w in batch]
Expand Down
Loading