Skip to content

Commit

Permalink
Add support for vllm (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominastorm authored Apr 3, 2024
1 parent aa797d0 commit a925b05
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions uptrain/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class Settings(BaseSettings):
None, env="AZURE_API_VERSION"
)

custom_llm_provider: t.Optional[str] = None
api_base: t.Optional[str] = None

rpm_limit: int = 100
tpm_limit: int = 90_000
embedding_compute_method: t.Literal["local", "replicate", "api"] = "local"
Expand Down
10 changes: 9 additions & 1 deletion uptrain/operators/language/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(self, settings: t.Optional[Settings] = None, aclient: t.Any = None)
if (
settings.model.startswith("ollama")
):
self.aclient = None
self.aclient = None
self._rpm_limit = settings.check_and_get("rpm_limit")
self._tpm_limit = settings.check_and_get("tpm_limit")

Expand All @@ -229,6 +229,10 @@ def make_payload(
seed = self.settings.seed
response_format = self.settings.response_format

# For vllm
custom_llm_provider = self.settings.custom_llm_provider
api_base = self.settings.api_base

prefixes = ["anyscale/", "azure/", "together/"]
for prefix in prefixes:
model = model.replace(prefix, "")
Expand All @@ -240,6 +244,10 @@ def make_payload(
data["seed"] = seed
if response_format is not None:
data["response_format"] = response_format
if custom_llm_provider is not None:
data["custom_llm_provider"] = custom_llm_provider
if api_base is not None:
data["api_base"] = api_base
return Payload(
endpoint="chat.completions",
data=data,
Expand Down

0 comments on commit a925b05

Please sign in to comment.