Skip to content

Commit

Permalink
refactor: Enhance tools with enums and improved error handling (#4493)
Browse files Browse the repository at this point in the history
* fix: Enhance extract_class_name function to identify Component subclasses

* Add TODO for improving Component inheritance check in validate.py

* Add YahooFinanceMethod enum and improve error handling in Yahoo Finance tool

- Introduced YahooFinanceMethod enum to standardize method options.
- Updated YahooFinanceSchema to use the new enum for method selection.
- Enhanced error handling by raising ToolException on data retrieval failure.
- Refactored method handling in _yahoo_finance_tool to use enum values.

* Enhance TavilySearchToolComponent with Enums and Improved Error Handling

- Introduced `TavilySearchDepth` and `TavilySearchTopic` enums for better type safety and clarity.
- Updated `TavilySearchSchema` to use enums for `search_depth` and `topic` fields.
- Added validation for enum values in `run_model` and `_tavily_search` methods.
- Improved error handling by raising `ToolException` for HTTP and unexpected errors.
- Updated dropdown inputs to use enum options directly.

* Add error handling and parameter flexibility to SerpAPI tool

- Introduced `ToolException` for improved error handling in SerpAPI searches.
- Added `SerpAPISchema` for structured search parameters.
- Modified `_build_wrapper` to accept dynamic parameters.
- Enhanced `search_func` to rebuild wrapper with new parameters and handle exceptions.

* feat: Enhance Glean Search API integration

Refactor the API wrapper and schema for better clarity and maintainability. Improve error handling for search results and streamline request preparation.

* Add error handling to DuckDuckGo search function using ToolException

---------

Co-authored-by: Eric Hare <[email protected]>
  • Loading branch information
ogabrielluiz and erichare authored Nov 12, 2024
1 parent bbaec2b commit 7dfce1d
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from langchain.tools import StructuredTool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from langflow.base.langchain_utilities.model import LCToolComponent
Expand Down Expand Up @@ -38,14 +39,18 @@ def build_tool(self) -> Tool:
wrapper = self._build_wrapper()

def search_func(query: str, max_results: int = 5, max_snippet_length: int = 100) -> list[dict[str, Any]]:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
try:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
except Exception as e:
msg = f"Error in DuckDuckGo Search: {e!s}"
raise ToolException(msg) from e
return limited_results

tool = StructuredTool.from_function(
Expand All @@ -67,5 +72,5 @@ def run_model(self) -> list[Data]:
}
)
data_list = [Data(data=result, text=result.get("snippet", "")) for result in results]
self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list
162 changes: 84 additions & 78 deletions src/backend/base/langflow/components/tools/glean_search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from urllib.parse import urljoin

import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import StructuredTool, ToolException
from pydantic import BaseModel
from pydantic.v1 import Field

from langflow.base.langchain_utilities.model import LCToolComponent
Expand All @@ -13,64 +13,54 @@
from langflow.schema import Data


class GleanSearchAPIComponent(LCToolComponent):
display_name = "Glean Search API"
description = "Call Glean Search API"
name = "GleanAPI"
class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")

inputs = [
StrInput(
name="glean_api_url",
display_name="Glean API URL",
required=True,
),
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True),
MultilineInput(name="query", display_name="Query", required=True),
IntInput(name="page_size", display_name="Page Size", value=10),
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""

glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)
class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""

if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)
glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)

if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
try:
results = self.results(query, **kwargs)

processed_results = []
Expand All @@ -80,32 +70,48 @@ def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]

processed_results.append(result)
processed_results.append(result)
except Exception as e:
error_message = f"Error in Glean Search API: {e!s}"
raise ToolException(error_message) from e

return processed_results
return processed_results

def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()
response_json = response.json()

response.raise_for_status()
response_json = response.json()
return response_json.get("results", [])

return response_json.get("results", [])
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
class GleanSearchAPIComponent(LCToolComponent):
display_name = "Glean Search API"
description = "Call Glean Search API"
name = "GleanAPI"

inputs = [
StrInput(
name="glean_api_url",
display_name="Glean API URL",
required=True,
),
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True),
MultilineInput(name="query", display_name="Query", required=True),
IntInput(name="page_size", display_name="Page Size", value=10),
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

def build_tool(self) -> Tool:
wrapper = self._build_wrapper(
Expand All @@ -117,7 +123,7 @@ def build_tool(self) -> Tool:
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
args_schema=GleanSearchAPISchema,
)

self.status = "Glean Search API Tool for Langchain"
Expand All @@ -137,7 +143,7 @@ def run_model(self) -> list[Data]:

# Build the data
data = [Data(data=result, text=result["snippets"][0]["text"]) for result in results]
self.status = data
self.status = data # type: ignore[assignment]

return data

Expand All @@ -146,7 +152,7 @@ def _build_wrapper(
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
return GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)
72 changes: 47 additions & 25 deletions src/backend/base/langflow/components/tools/serp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from langchain.tools import StructuredTool
from langchain_community.utilities.serpapi import SerpAPIWrapper
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field

Expand All @@ -11,6 +12,23 @@
from langflow.schema import Data


class SerpAPISchema(BaseModel):
"""Schema for SerpAPI search parameters."""

query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(
default={
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
},
description="Additional search parameters",
)
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")


class SerpAPIComponent(LCToolComponent):
display_name = "Serp Search API"
description = "Call Serp Search API with result limiting"
Expand All @@ -27,46 +45,50 @@ class SerpAPIComponent(LCToolComponent):
IntInput(name="max_snippet_length", display_name="Max Snippet Length", value=100, advanced=True),
]

class SerpAPISchema(BaseModel):
query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(default_factory=dict, description="Additional search parameters")
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")

def _build_wrapper(self) -> SerpAPIWrapper:
if self.search_params:
def _build_wrapper(self, params: dict[str, Any] | None = None) -> SerpAPIWrapper:
"""Build a SerpAPIWrapper with the provided parameters."""
params = params or {}
if params:
return SerpAPIWrapper(
serpapi_api_key=self.serpapi_api_key,
params=self.search_params,
params=params,
)
return SerpAPIWrapper(serpapi_api_key=self.serpapi_api_key)

def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
wrapper = self._build_wrapper(self.search_params) # noqa: F841

def search_func(
query: str, params: dict[str, Any] | None = None, max_results: int = 5, max_snippet_length: int = 100
) -> list[dict[str, Any]]:
params = params or {}
full_results = wrapper.results(query, **params)
organic_results = full_results.get("organic_results", [])[:max_results]

limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)

try:
# rebuild the wrapper if params are provided
if params:
wrapper = self._build_wrapper(params)

full_results = wrapper.results(query)
organic_results = full_results.get("organic_results", [])[:max_results]

limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)

except Exception as e:
error_message = f"Error in SerpAPI search: {e!s}"
logger.debug(error_message)
raise ToolException(error_message) from e
return limited_results

tool = StructuredTool.from_function(
name="serp_search_api",
description="Search for recent results using SerpAPI with result limiting",
func=search_func,
args_schema=self.SerpAPISchema,
args_schema=SerpAPISchema,
)

self.status = "SerpAPI Tool created"
Expand All @@ -91,5 +113,5 @@ def run_model(self) -> list[Data]:
self.status = f"Error: {e}"
return [Data(data={"error": str(e)}, text=str(e))]

self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list
Loading

0 comments on commit 7dfce1d

Please sign in to comment.