Skip to content

Commit

Permalink
feat: ensuring that max_context_tokens is never larger than what supp…
Browse files Browse the repository at this point in the history
…orted by models (#3519)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
  • Loading branch information
jacopo-chevallard authored Dec 11, 2024
1 parent e384a0a commit d6e0ed4
Showing 1 changed file with 100 additions and 33 deletions.
133 changes: 100 additions & 33 deletions core/quivr_core/rag/entities/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,89 +75,139 @@ class DefaultModelSuppliers(str, Enum):


class LLMConfig(QuivrBaseConfig):
context: int | None = None
max_context_tokens: int | None = None
max_output_tokens: int | None = None
tokenizer_hub: str | None = None


class LLMModelConfig:
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
DefaultModelSuppliers.OPENAI: {
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
"gpt-4o": LLMConfig(
max_context_tokens=128000,
max_output_tokens=16384,
tokenizer_hub="Xenova/gpt-4o",
),
"gpt-4o-mini": LLMConfig(
max_context_tokens=128000,
max_output_tokens=16384,
tokenizer_hub="Xenova/gpt-4o",
),
"gpt-4-turbo": LLMConfig(
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/gpt-4",
),
"gpt-4": LLMConfig(
max_context_tokens=8192,
max_output_tokens=8192,
tokenizer_hub="Xenova/gpt-4",
),
"gpt-3.5-turbo": LLMConfig(
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
max_context_tokens=16385,
max_output_tokens=4096,
tokenizer_hub="Xenova/gpt-3.5-turbo",
),
"text-embedding-3-large": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-3-small": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-ada-002": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
},
DefaultModelSuppliers.ANTHROPIC: {
"claude-3-5-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=8192,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-opus": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-3-haiku": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-2-1": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=200000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-2-0": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=100000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
"claude-instant-1-2": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
max_context_tokens=100000,
max_output_tokens=4096,
tokenizer_hub="Xenova/claude-tokenizer",
),
},
# Unclear for LLAMA models...
# see https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct/discussions/6
DefaultModelSuppliers.META: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
max_context_tokens=8192,
max_output_tokens=2048,
tokenizer_hub="Xenova/llama3-tokenizer-new",
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.GROQ: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
"llama-3.3-70b": LLMConfig(
max_context_tokens=128000,
max_output_tokens=32768,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3.1-70b": LLMConfig(
max_context_tokens=128000,
max_output_tokens=32768,
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
max_context_tokens=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.MISTRAL: {
"mistral-large": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/mistral-tokenizer-v3",
),
"mistral-small": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/mistral-tokenizer-v3",
),
"mistral-nemo": LLMConfig(
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
max_context_tokens=128000,
max_output_tokens=4096,
tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer",
),
"codestral": LLMConfig(
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
max_context_tokens=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
},
}
Expand Down Expand Up @@ -193,13 +243,12 @@ def get_llm_model_config(
class LLMEndpointConfig(QuivrBaseConfig):
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
model: str = "gpt-4o"
context_length: int | None = None
tokenizer_hub: str | None = None
llm_base_url: str | None = None
env_variable_name: str | None = None
llm_api_key: str | None = None
max_context_tokens: int = 2000
max_output_tokens: int = 2000
max_context_tokens: int = 10000
max_output_tokens: int = 4000
temperature: float = 0.7
streaming: bool = True
prompt: CustomPromptsModel | None = None
Expand Down Expand Up @@ -240,7 +289,25 @@ def set_llm_model_config(self):
self.supplier, self.model
)
if llm_model_config:
self.context_length = llm_model_config.context
if llm_model_config.max_context_tokens:
_max_context_tokens = (
llm_model_config.max_context_tokens
- llm_model_config.max_output_tokens
if llm_model_config.max_output_tokens
else llm_model_config.max_context_tokens
)
if self.max_context_tokens > _max_context_tokens:
logger.warning(
f"Lowering max_context_tokens from {self.max_context_tokens} to {_max_context_tokens}"
)
self.max_context_tokens = _max_context_tokens
if llm_model_config.max_output_tokens:
if self.max_output_tokens > llm_model_config.max_output_tokens:
logger.warning(
f"Lowering max_output_tokens from {self.max_output_tokens} to {llm_model_config.max_output_tokens}"
)
self.max_output_tokens = llm_model_config.max_output_tokens

self.tokenizer_hub = llm_model_config.tokenizer_hub

def set_llm_model(self, model: str):
Expand Down

0 comments on commit d6e0ed4

Please sign in to comment.