Skip to content

Commit

Permalink
fix: remove user environment variables from ChatLiteLLMModelComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
berrytern committed Jun 27, 2024
1 parent 25e475c commit d0da23f
Showing 1 changed file with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
FloatInput,
IntInput,
MessageInput,
Output,
SecretStrInput,
StrInput,
)
Expand Down Expand Up @@ -59,11 +58,20 @@ class ChatLiteLLMModelComponent(LCModelComponent):
required=False,
value=0.7,
),
DictInput(
name="kwargs",
display_name="Kwargs",
advanced=True,
required=False,
is_list=True,
value={},
),
DictInput(
name="model_kwargs",
display_name="Model kwargs",
advanced=True,
required=False,
is_list=True,
value={},
),
FloatInput(name="top_p", display_name="Top p", advanced=True, required=False, value=0.5),
Expand Down Expand Up @@ -133,11 +141,13 @@ def build_model(self) -> LanguageModel:
"OpenRouter": "openrouter_api_key",
}
# Set the API key based on the provider
api_keys: dict[str, Optional[str]] = {provider_map[self.provider]: self.api_key}
os.environ[self.provider.upper() + "_API_KEY"] = self.api_key

output = ChatLiteLLM(
model=self.model,
self.kwargs[self.provider] = self.api_key
self.model_kwargs["api_key"] = self.api_key
if self.provider == "Azure":
if "api_base" not in self.kwargs:
raise Exception("Missing api_base on kwargs")
llm = ChatLiteLLM(
model=f"{self.provider.lower()}/{self.model}",
client=None,
streaming=self.stream,
temperature=self.temperature,
Expand All @@ -147,7 +157,7 @@ def build_model(self) -> LanguageModel:
n=self.n,
max_tokens=self.max_tokens,
max_retries=self.max_retries,
**api_keys,
**self.kwargs
)

return output
return llm

0 comments on commit d0da23f

Please sign in to comment.