Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,15 @@ There are times you may want to cache the embeddings so you can re-use them. Thi

```python
# define your task(s) and model above as normal
...
task = mteb.get_task("LccSentimentClassification")
model = mteb.get_model("minishlab/M2V_base_glove_subword")
evaluation = mteb.MTEB(tasks=[task])

# wrap the model with the cache wrapper
from mteb.models.cache_wrapper import CachedEmbeddingWrapper
model_with_cached_emb = CachedEmbeddingWrapper(model, cache_path='<path_to_cache_dir>')
model_with_cached_emb = CachedEmbeddingWrapper(model, cache_path='path_to_cache_dir')
# run as normal
evaluation.run(model, ...)
evaluation.run(model_with_cached_emb)
```

If you want to directly access the cached embeddings (e.g. for subsequent analyses) follow this example:
Expand All @@ -457,8 +460,8 @@ import numpy as np
from mteb.models.cache_wrapper import TextVectorMap

# Access the memory-mapped file and convert to array
vector_map = TextVectorMap("<path_to_cache_dir>/AppsRetrieval")
vector_map.load(name="AppsRetrieval")
vector_map = TextVectorMap("path_to_cache_dir/LccSentimentClassification")
vector_map.load(name="LccSentimentClassification")
vectors = np.asarray(vector_map.vectors)

# Remove all "placeholders" in the embedding cache
Expand Down
33 changes: 22 additions & 11 deletions mteb/models/cache_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def load(self, name: str | None = None) -> None:
self.vectors = np.memmap(
self.vectors_file, dtype="float32", mode="r+"
)
self.vectors = self.vectors.reshape(-1, self.vector_dim)
logger.info(f"Loaded vectors file with shape: {self.vectors.shape}")
self.vectors = self.vectors.reshape(-1, self.vector_dim) # type: ignore
logger.info(f"Loaded vectors file with shape: {self.vectors.shape}") # type: ignore
else:
logger.warning(
"Vector dimension not set. Unable to load vectors file."
Expand Down Expand Up @@ -214,22 +214,30 @@ def __init__(self, model: Encoder, cache_path: str | Path):
logger.info("Initialized CachedEmbeddingWrapper")

def encode(
self, texts: list[str], batch_size: int = 32, task_name: str = None, **kwargs
self,
texts: list[str],
batch_size: int = 32,
task_name: str | None = None,
**kwargs,
) -> np.ndarray:
"""Encode texts using the wrapped model, with caching"""
_task_name = task_name or "no_task_name"

try:
results = []
uncached_texts = []
uncached_indices = []

# Initialize cache
if task_name not in self.cache_dict:
self.cache_dict[task_name] = TextVectorMap(self.cache_path / task_name)
self.cache_dict[task_name].load(name=task_name)
if _task_name not in self.cache_dict:
self.cache_dict[_task_name] = TextVectorMap(
self.cache_path / _task_name
)
self.cache_dict[_task_name].load(name=_task_name)

# Check cache for each text
for i, text in enumerate(texts):
vector = self.cache_dict[task_name].get_vector(text)
vector = self.cache_dict[_task_name].get_vector(text)
if vector is not None:
results.append(vector)
else:
Expand All @@ -240,16 +248,19 @@ def encode(
if uncached_texts:
logger.info(f"Encoding {len(uncached_texts)} new texts")
new_vectors = self._model.encode(
uncached_texts, batch_size=batch_size, **kwargs
uncached_texts,
batch_size=batch_size,
task_name=task_name, # type: ignore
**kwargs,
)
if isinstance(new_vectors, torch.Tensor):
new_vectors = new_vectors.cpu().numpy()

# Add new vectors to cache
for text, vector in zip(uncached_texts, new_vectors):
self.cache_dict[task_name].add(text, vector)
self.cache_dict[_task_name].add(text, vector)
results.extend(new_vectors)
self.cache_dict[task_name].save()
self.cache_dict[_task_name].save()
else:
logger.info("All texts found in cache")

Expand Down Expand Up @@ -287,7 +298,7 @@ def __getattr__(self, name: str) -> Any:

def __dir__(self) -> list[str]:
"""Return all attributes from both this class and the wrapped model"""
return list(set(super().__dir__() + dir(self._model)))
return list(set(super().__dir__() + dir(self._model))) # type: ignore

def __del__(self):
self.close()
Expand Down