Skip to content

Commit

Permalink
fix: Ollama Model Component build config updates and formats info to …
Browse files Browse the repository at this point in the history
…prevent issues in DSLF. (#5978)

* Update ollama.py

* update ib build config

* [autofix.ci] apply automated fixes

* improves stability of build config.

* [autofix.ci] apply automated fixes

* Update ollama.py

* test: update ChatOllamaComponent test to validate Ollama URL handling

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2025
1 parent 3163a56 commit b2ef231
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
24 changes: 15 additions & 9 deletions src/backend/base/langflow/components/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ChatOllamaComponent(LCModelComponent):
MessageTextInput(
name="base_url",
display_name="Base URL",
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.",
info="Endpoint of the Ollama API.",
value="",
),
DropdownInput(
Expand Down Expand Up @@ -114,16 +114,16 @@ class ChatOllamaComponent(LCModelComponent):
MessageTextInput(
name="system", display_name="System", info="System to use for generating text.", advanced=True
),
MessageTextInput(
name="template", display_name="Template", info="Template to use for generating text.", advanced=True
),
BoolInput(
name="tool_model_enabled",
display_name="Tool Model Enabled",
info="Whether to enable tool calling in the model.",
value=True,
real_time_refresh=True,
),
MessageTextInput(
name="template", display_name="Template", info="Template to use for generating text.", advanced=True
),
*LCModelComponent._base_inputs,
]

Expand Down Expand Up @@ -160,12 +160,12 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
"temperature": self.temperature or None,
"stop": self.stop_tokens.split(",") if self.stop_tokens else None,
"system": self.system,
"template": self.template,
"tfs_z": self.tfs_z or None,
"timeout": self.timeout or None,
"top_k": self.top_k or None,
"top_p": self.top_p or None,
"verbose": self.verbose,
"template": self.template,
}

# Remove parameters with None values
Expand All @@ -185,7 +185,7 @@ def build_model(self) -> LanguageModel: # type: ignore[type-var]
async def is_valid_ollama_url(self, url: str) -> bool:
try:
async with httpx.AsyncClient() as client:
return (await client.get(f"{url}/api/tags")).status_code == HTTP_STATUS_OK
return (await client.get(urljoin(url, "api/tags"))).status_code == HTTP_STATUS_OK
except httpx.RequestError:
return False

Expand All @@ -208,14 +208,20 @@ async def update_build_config(self, build_config: dict, field_value: Any, field_
build_config["mirostat_eta"]["value"] = 0.1
build_config["mirostat_tau"]["value"] = 5

if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(field_value):
if field_name in {"base_url", "model_name"} and not await self.is_valid_ollama_url(
build_config["base_url"].get("value", "")
):
# Check if any URL in the list is valid
valid_url = ""
for url in URL_LIST:
if await self.is_valid_ollama_url(url):
valid_url = url
break
build_config["base_url"]["value"] = valid_url
if valid_url != "":
build_config["base_url"]["value"] = valid_url
else:
msg = "No valid Ollama URL found."
raise ValueError(msg)
if field_name in {"model_name", "base_url", "tool_model_enabled"}:
if await self.is_valid_ollama_url(self.base_url):
tool_model_enabled = build_config["tool_model_enabled"].get("value", False) or self.tool_model_enabled
Expand All @@ -241,7 +247,7 @@ async def update_build_config(self, build_config: dict, field_value: Any, field_

async def get_model(self, base_url_value: str, tool_model_enabled: bool | None = None) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
url = urljoin(base_url_value, "api/tags")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ async def test_update_build_config_model_name(mock_get, component):
field_value = None
field_name = "model_name"

updated_config = await component.update_build_config(build_config, field_value, field_name)

assert updated_config["model_name"]["options"] == []
with pytest.raises(ValueError, match="No valid Ollama URL found"):
await component.update_build_config(build_config, field_value, field_name)


async def test_update_build_config_keep_alive(component):
Expand Down

0 comments on commit b2ef231

Please sign in to comment.