diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 06c5b9c1c3..8a71d25ff4 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -442,12 +442,15 @@ There are times you may want to cache the embeddings so you can re-use them. Thi ```python # define your task(s) and model above as normal -... +task = mteb.get_task("LccSentimentClassification") +model = mteb.get_model("minishlab/M2V_base_glove_subword") +evaluation = mteb.MTEB(tasks=[task]) + # wrap the model with the cache wrapper from mteb.models.cache_wrapper import CachedEmbeddingWrapper -model_with_cached_emb = CachedEmbeddingWrapper(model, cache_path='') +model_with_cached_emb = CachedEmbeddingWrapper(model, cache_path='path_to_cache_dir') # run as normal -evaluation.run(model, ...) +evaluation.run(model_with_cached_emb) ``` If you want to directly access the cached embeddings (e.g. for subsequent analyses) follow this example: @@ -457,8 +460,8 @@ import numpy as np from mteb.models.cache_wrapper import TextVectorMap # Access the memory-mapped file and convert to array -vector_map = TextVectorMap("/AppsRetrieval") -vector_map.load(name="AppsRetrieval") +vector_map = TextVectorMap("path_to_cache_dir/LccSentimentClassification") +vector_map.load(name="LccSentimentClassification") vectors = np.asarray(vector_map.vectors) # Remove all "placeholders" in the embedding cache diff --git a/mteb/models/cache_wrapper.py b/mteb/models/cache_wrapper.py index 57b8954c7b..74c05562f0 100644 --- a/mteb/models/cache_wrapper.py +++ b/mteb/models/cache_wrapper.py @@ -155,8 +155,8 @@ def load(self, name: str | None = None) -> None: self.vectors = np.memmap( self.vectors_file, dtype="float32", mode="r+" ) - self.vectors = self.vectors.reshape(-1, self.vector_dim) - logger.info(f"Loaded vectors file with shape: {self.vectors.shape}") + self.vectors = self.vectors.reshape(-1, self.vector_dim) # type: ignore + logger.info(f"Loaded vectors file with shape: {self.vectors.shape}") # type: ignore else: logger.warning( "Vector dimension not set. Unable to load vectors file." @@ -214,22 +214,30 @@ def __init__(self, model: Encoder, cache_path: str | Path): logger.info("Initialized CachedEmbeddingWrapper") def encode( - self, texts: list[str], batch_size: int = 32, task_name: str = None, **kwargs + self, + texts: list[str], + batch_size: int = 32, + task_name: str | None = None, + **kwargs, ) -> np.ndarray: """Encode texts using the wrapped model, with caching""" + _task_name = task_name or "no_task_name" + try: results = [] uncached_texts = [] uncached_indices = [] # Initialize cache - if task_name not in self.cache_dict: - self.cache_dict[task_name] = TextVectorMap(self.cache_path / task_name) - self.cache_dict[task_name].load(name=task_name) + if _task_name not in self.cache_dict: + self.cache_dict[_task_name] = TextVectorMap( + self.cache_path / _task_name + ) + self.cache_dict[_task_name].load(name=_task_name) # Check cache for each text for i, text in enumerate(texts): - vector = self.cache_dict[task_name].get_vector(text) + vector = self.cache_dict[_task_name].get_vector(text) if vector is not None: results.append(vector) else: @@ -240,16 +248,19 @@ def encode( if uncached_texts: logger.info(f"Encoding {len(uncached_texts)} new texts") new_vectors = self._model.encode( - uncached_texts, batch_size=batch_size, **kwargs + uncached_texts, + batch_size=batch_size, + task_name=task_name, # type: ignore + **kwargs, ) if isinstance(new_vectors, torch.Tensor): new_vectors = new_vectors.cpu().numpy() # Add new vectors to cache for text, vector in zip(uncached_texts, new_vectors): - self.cache_dict[task_name].add(text, vector) + self.cache_dict[_task_name].add(text, vector) results.extend(new_vectors) - self.cache_dict[task_name].save() + self.cache_dict[_task_name].save() else: logger.info("All texts found in cache") @@ -287,7 +298,7 @@ def __getattr__(self, name: str) -> Any: def __dir__(self) -> list[str]: """Return all attributes from both this class and the wrapped model""" - return list(set(super().__dir__() + dir(self._model))) + return list(set(super().__dir__() + dir(self._model))) # type: ignore def __del__(self): self.close()