Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
"Tumor detection",
"Duplicate Detection",
"Gender Clustering",
"Voice Emotion Clustering",
]

TASK_DOMAIN = Literal[
Expand Down Expand Up @@ -128,7 +127,6 @@


TASK_CATEGORY = Literal[
"a2a", # Audio-to-audio
"s2s", # Sentence-to-sentence
"s2p", # Sentence-to-paragraph
"p2p", # Paragraph-to-paragraph
Expand Down
48 changes: 9 additions & 39 deletions mteb/evaluation/evaluators/Audio/ClusteringEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
import logging
from typing import Any

import numpy as np
import sklearn
import sklearn.cluster
from datasets import Audio
from scipy.optimize import linear_sum_assignment
from sklearn import metrics
import random
from sklearn.decomposition import PCA

from mteb.encoder_interface import Encoder
from mteb.evaluation.evaluators.Evaluator import Evaluator

logger = logging.getLogger(__name__)


class AudioClusteringEvaluator(Evaluator):
def __init__(
self,
Expand All @@ -25,41 +23,16 @@ def __init__(
task_name: str | None = None,
clustering_batch_size: int = 500,
limit: int | None = None,
cluster_algo: str = "KMeans",
**kwargs,
):
super().__init__(**kwargs)
if limit is not None:
audio = audio[:limit]
labels = labels[:limit]

random.seed(42)
combined = list(zip(audio, labels))
random.shuffle(combined)
audio, labels = map(list, zip(*combined))

self.audio = audio
self.labels = labels
self.clustering_batch_size = clustering_batch_size
self.task_name = task_name
self.cluster_algo = cluster_algo

def __clustering__(self):
if self.cluster_algo == "Kmeans":
logger.info("Fitting Mini-Batch K-Means model...")
clustering_model = sklearn.cluster.MiniBatchKMeans(
n_clusters=len(set(self.labels)),
batch_size=self.clustering_batch_size,
n_init="auto",
)
elif self.cluster_algo == "DBSCAN":
# need to plot out the distribution of the embeddings to decide on parameters for DBSCAN
logger.info("Fitting DBSCAN model...")
clustering_model = sklearn.cluster.DBSCAN(eps=0.5, min_samples=5, metric="euclidean")
elif self.cluster_algo == "Agg":
logger.info("Fitting Agglomerative model...")
clustering_model = sklearn.cluster.AgglomerativeClustering(n_clusters=len(set(self.labels)),linkage='average', metric='cosine')
return clustering_model

def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any] = {}):
if "batch_size" not in encode_kwargs:
Expand All @@ -71,13 +44,13 @@ def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any] = {}):
)

logger.info("Fitting Mini-Batch K-Means model...")

pca = PCA(n_components=200)
audio_embeddings = pca.fit_transform(audio_embeddings)

clustering_output = self.__clustering__()
clustering_output.fit(audio_embeddings)
cluster_assignment = clustering_output.labels_
clustering_model = sklearn.cluster.MiniBatchKMeans(
n_clusters=len(set(self.labels)),
batch_size=self.clustering_batch_size,
n_init="auto",
)
clustering_model.fit(audio_embeddings)
cluster_assignment = clustering_model.labels_

logger.info("Evaluating...")
v_measure = metrics.cluster.v_measure_score(self.labels, cluster_assignment)
Expand All @@ -88,8 +61,6 @@ def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any] = {}):

matrix = metrics.confusion_matrix(self.labels, cluster_assignment)

silhouette = float(metrics.silhouette_score(audio_embeddings, cluster_assignment, metric='euclidean'))
print(self.cluster_algo)
# get linear sum assignment
row_ind, col_ind = linear_sum_assignment(matrix, maximize=True)
total_correct = matrix[row_ind, col_ind].sum()
Expand All @@ -100,5 +71,4 @@ def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any] = {}):
"nmi": nmi,
"ari": ari,
"cluster_accuracy": clustering_accuracy,
"silhouette": silhouette,
}
}
6 changes: 0 additions & 6 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@
voyage_models,
voyage_v,
wav2vec_models,
wavlm_models,
whisper_models,
qwen_models
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,9 +138,6 @@
voyage_models,
fa_models,
wav2vec_models,
wavlm_models,
whisper_models,
qwen_models
]
MODEL_REGISTRY = {}

Expand Down
81 changes: 0 additions & 81 deletions mteb/models/qwen_models.py

This file was deleted.

18 changes: 4 additions & 14 deletions mteb/models/wav2vec_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from functools import partial
from mteb.models.wrapper import Wrapper
from mteb.encoder_interface import PromptType, AudioEncoder

import numpy as np
import torch
from datasets import Audio
Expand All @@ -11,6 +10,7 @@
from mteb.encoder_interface import AudioEncoder, PromptType
from mteb.model_meta import ModelMeta


class Wav2vec2Wrapper(AudioEncoder):
def __init__(
self,
Expand Down Expand Up @@ -48,11 +48,10 @@ def get_audio_embeddings(
audio_data,
sampling_rate=sampling_rates[0],
padding=True,

return_tensors="pt"
return_tensors="pt",
)

if hasattr(self, 'device') and self.device:
if self.device:
inputs = {k: v.to(self.device) for k, v in inputs.items()}

# Get embeddings
Expand All @@ -64,7 +63,6 @@ def get_audio_embeddings(
)

hidden_states = outputs.hidden_states[-1]

batch_embeddings = hidden_states.mean(dim=1).cpu().numpy()
all_embeddings.append(batch_embeddings)

Expand All @@ -90,7 +88,6 @@ def encode(
),
name="facebook/wav2vec2-base",
languages=["en"],

open_weights=True,
revision="0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8",
release_date="2020-10-26",
Expand Down Expand Up @@ -118,7 +115,6 @@ def encode(
),
name="facebook/wav2vec2-base-960h",
languages=["en"],

open_weights=True,
revision="22aad52d435eb6dbaf354bdad9b0da84ce7d6156",
release_date="2020-10-26",
Expand All @@ -134,7 +130,6 @@ def encode(
public_training_code=None,
public_training_data=None,
training_datasets=None,

modalities=["audio"],
)

Expand All @@ -147,7 +142,6 @@ def encode(
),
name="facebook/wav2vec2-large",
languages=["en"],

open_weights=True,
revision="312b2410566b698c7a649068d413b2067848bd75",
release_date="2020-10-26",
Expand All @@ -163,7 +157,6 @@ def encode(
public_training_code=None,
public_training_data=None,
training_datasets=None,

modalities=["audio"],
)

Expand All @@ -176,7 +169,6 @@ def encode(
),
name="facebook/wav2vec2-large-xlsr-53",
languages=["en"],

open_weights=True,
revision="c3f9d884181a224a6ac87bf8885c84d1cff3384f",
release_date="2020-10-26",
Expand Down Expand Up @@ -204,7 +196,6 @@ def encode(
),
name="facebook/wav2vec2-lv-60-espeak-cv-ft",
languages=["en"],

open_weights=True,
revision="ae45363bf3413b374fecd9dc8bc1df0e24c3b7f4",
release_date="2020-10-26",
Expand All @@ -222,4 +213,3 @@ def encode(
training_datasets=None,
modalities=["audio"],
)

Loading
Loading