diff --git a/mteb/models/voyage_models.py b/mteb/models/voyage_models.py index 2d3b3b100f..8a222e5d9d 100644 --- a/mteb/models/voyage_models.py +++ b/mteb/models/voyage_models.py @@ -16,6 +16,24 @@ # synthetic data } +# Total token limits per model based on VoyageAI documentation +VOYAGE_TOTAL_TOKEN_LIMITS = { + "voyage-3.5-lite": 1_000_000, + "voyage-3.5": 320_000, + "voyage-2": 320_000, + "voyage-3-large": 120_000, + "voyage-code-3": 120_000, + "voyage-large-2-instruct": 120_000, + "voyage-finance-2": 120_000, + "voyage-multilingual-2": 120_000, + "voyage-law-2": 120_000, + "voyage-large-2": 120_000, + "voyage-3": 120_000, + "voyage-3-lite": 120_000, + "voyage-code-2": 120_000, + "voyage-3-m-exp": 120_000, +} + def token_limit(max_tpm: int, interval: int = 60): limit_interval_start_ts = time.time() @@ -75,6 +93,7 @@ def __init__( max_retries: int = 5, max_rpm: int = 300, max_tpm: int = 1_000_000, + max_tokens: int | None = None, model_prompts: dict[str, str] | None = None, **kwargs, ) -> None: @@ -85,17 +104,32 @@ def __init__( self._embed_func = rate_limit(max_rpm)(token_limit(max_tpm)(self._client.embed)) self._model_name = model_name self._max_tpm = max_tpm + self._max_tokens = max_tokens self.model_prompts = self.validate_task_to_prompt_name(model_prompts) + def _calculate_default_batch_size(self) -> int: + """Calculate the default batch size based on total token limit and context length. + + Formula: floor(total_token_limit / context_length) + """ + if self._max_tokens is None: + return 32 # fallback to original default + + total_token_limit = VOYAGE_TOTAL_TOKEN_LIMITS.get(self._model_name, 120_000) + return max(1, total_token_limit // self._max_tokens) + def encode( self, sentences: list[str], *, - batch_size: int = 32, + batch_size: int | None = None, task_name: str, prompt_type: PromptType | None = None, **kwargs: Any, ) -> np.ndarray: + if batch_size is None: + batch_size = self._calculate_default_batch_size() + prompt_name = self.get_prompt_name(self.model_prompts, task_name, prompt_type) input_type = self.model_prompts.get(prompt_name, "document") @@ -149,6 +183,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3.5", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -174,6 +209,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-large-2-instruct", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -199,6 +235,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-finance-2", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -224,6 +261,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-law-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -249,6 +287,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-code-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -274,6 +313,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-code-3", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -300,6 +340,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-large-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -325,6 +366,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-2", + max_tokens=4000, model_prompts=model_prompts, ), max_tokens=4000, @@ -349,6 +391,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-multilingual-2", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -374,6 +417,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -399,6 +443,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3-lite", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -424,6 +469,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3-m-exp", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000,