-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Enhance tools with enums and improved error handling (#4493)
* fix: Enhance extract_class_name function to identify Component subclasses * Add TODO for improving Component inheritance check in validate.py * Add YahooFinanceMethod enum and improve error handling in Yahoo Finance tool - Introduced YahooFinanceMethod enum to standardize method options. - Updated YahooFinanceSchema to use the new enum for method selection. - Enhanced error handling by raising ToolException on data retrieval failure. - Refactored method handling in _yahoo_finance_tool to use enum values. * Enhance TavilySearchToolComponent with Enums and Improved Error Handling - Introduced `TavilySearchDepth` and `TavilySearchTopic` enums for better type safety and clarity. - Updated `TavilySearchSchema` to use enums for `search_depth` and `topic` fields. - Added validation for enum values in `run_model` and `_tavily_search` methods. - Improved error handling by raising `ToolException` for HTTP and unexpected errors. - Updated dropdown inputs to use enum options directly. * Add error handling and parameter flexibility to SerpAPI tool - Introduced `ToolException` for improved error handling in SerpAPI searches. - Added `SerpAPISchema` for structured search parameters. - Modified `_build_wrapper` to accept dynamic parameters. - Enhanced `search_func` to rebuild wrapper with new parameters and handle exceptions. * feat: Enhance Glean Search API integration Refactor the API wrapper and schema for better clarity and maintainability. Improve error handling for search results and streamline request preparation. * Add error handling to DuckDuckGo search function using ToolException --------- Co-authored-by: Eric Hare <[email protected]>
- Loading branch information
1 parent
bbaec2b
commit 7dfce1d
Showing
5 changed files
with
259 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,8 @@ | |
from urllib.parse import urljoin | ||
|
||
import httpx | ||
from langchain.tools import StructuredTool | ||
from langchain_core.pydantic_v1 import BaseModel | ||
from langchain_core.tools import StructuredTool, ToolException | ||
from pydantic import BaseModel | ||
from pydantic.v1 import Field | ||
|
||
from langflow.base.langchain_utilities.model import LCToolComponent | ||
|
@@ -13,64 +13,54 @@ | |
from langflow.schema import Data | ||
|
||
|
||
class GleanSearchAPIComponent(LCToolComponent): | ||
display_name = "Glean Search API" | ||
description = "Call Glean Search API" | ||
name = "GleanAPI" | ||
class GleanSearchAPISchema(BaseModel): | ||
query: str = Field(..., description="The search query") | ||
page_size: int = Field(10, description="Maximum number of results to return") | ||
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options") | ||
|
||
inputs = [ | ||
StrInput( | ||
name="glean_api_url", | ||
display_name="Glean API URL", | ||
required=True, | ||
), | ||
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True), | ||
MultilineInput(name="query", display_name="Query", required=True), | ||
IntInput(name="page_size", display_name="Page Size", value=10), | ||
NestedDictInput(name="request_options", display_name="Request Options", required=False), | ||
] | ||
|
||
class GleanAPIWrapper(BaseModel): | ||
"""Wrapper around Glean API.""" | ||
|
||
glean_api_url: str | ||
glean_access_token: str | ||
act_as: str = "[email protected]" # TODO: Detect this | ||
|
||
def _prepare_request( | ||
self, | ||
query: str, | ||
page_size: int = 10, | ||
request_options: dict[str, Any] | None = None, | ||
) -> dict: | ||
# Ensure there's a trailing slash | ||
url = self.glean_api_url | ||
if not url.endswith("/"): | ||
url += "/" | ||
|
||
return { | ||
"url": urljoin(url, "search"), | ||
"headers": { | ||
"Authorization": f"Bearer {self.glean_access_token}", | ||
"X-Scio-ActAs": self.act_as, | ||
}, | ||
"payload": { | ||
"query": query, | ||
"pageSize": page_size, | ||
"requestOptions": request_options, | ||
}, | ||
} | ||
|
||
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
results = self._search_api_results(query, **kwargs) | ||
class GleanAPIWrapper(BaseModel): | ||
"""Wrapper around Glean API.""" | ||
|
||
if len(results) == 0: | ||
msg = "No good Glean Search Result was found" | ||
raise AssertionError(msg) | ||
glean_api_url: str | ||
glean_access_token: str | ||
act_as: str = "[email protected]" # TODO: Detect this | ||
|
||
return results | ||
|
||
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
def _prepare_request( | ||
self, | ||
query: str, | ||
page_size: int = 10, | ||
request_options: dict[str, Any] | None = None, | ||
) -> dict: | ||
# Ensure there's a trailing slash | ||
url = self.glean_api_url | ||
if not url.endswith("/"): | ||
url += "/" | ||
|
||
return { | ||
"url": urljoin(url, "search"), | ||
"headers": { | ||
"Authorization": f"Bearer {self.glean_access_token}", | ||
"X-Scio-ActAs": self.act_as, | ||
}, | ||
"payload": { | ||
"query": query, | ||
"pageSize": page_size, | ||
"requestOptions": request_options, | ||
}, | ||
} | ||
|
||
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
results = self._search_api_results(query, **kwargs) | ||
|
||
if len(results) == 0: | ||
msg = "No good Glean Search Result was found" | ||
raise AssertionError(msg) | ||
|
||
return results | ||
|
||
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
try: | ||
results = self.results(query, **kwargs) | ||
|
||
processed_results = [] | ||
|
@@ -80,32 +70,48 @@ def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | |
if "text" not in result["snippets"][0]: | ||
result["snippets"][0]["text"] = result["title"] | ||
|
||
processed_results.append(result) | ||
processed_results.append(result) | ||
except Exception as e: | ||
error_message = f"Error in Glean Search API: {e!s}" | ||
raise ToolException(error_message) from e | ||
|
||
return processed_results | ||
return processed_results | ||
|
||
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
request_details = self._prepare_request(query, **kwargs) | ||
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: | ||
request_details = self._prepare_request(query, **kwargs) | ||
|
||
response = httpx.post( | ||
request_details["url"], | ||
json=request_details["payload"], | ||
headers=request_details["headers"], | ||
) | ||
response = httpx.post( | ||
request_details["url"], | ||
json=request_details["payload"], | ||
headers=request_details["headers"], | ||
) | ||
|
||
response.raise_for_status() | ||
response_json = response.json() | ||
|
||
response.raise_for_status() | ||
response_json = response.json() | ||
return response_json.get("results", []) | ||
|
||
return response_json.get("results", []) | ||
@staticmethod | ||
def _result_as_string(result: dict) -> str: | ||
return json.dumps(result, indent=4) | ||
|
||
@staticmethod | ||
def _result_as_string(result: dict) -> str: | ||
return json.dumps(result, indent=4) | ||
|
||
class GleanSearchAPISchema(BaseModel): | ||
query: str = Field(..., description="The search query") | ||
page_size: int = Field(10, description="Maximum number of results to return") | ||
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options") | ||
class GleanSearchAPIComponent(LCToolComponent): | ||
display_name = "Glean Search API" | ||
description = "Call Glean Search API" | ||
name = "GleanAPI" | ||
|
||
inputs = [ | ||
StrInput( | ||
name="glean_api_url", | ||
display_name="Glean API URL", | ||
required=True, | ||
), | ||
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True), | ||
MultilineInput(name="query", display_name="Query", required=True), | ||
IntInput(name="page_size", display_name="Page Size", value=10), | ||
NestedDictInput(name="request_options", display_name="Request Options", required=False), | ||
] | ||
|
||
def build_tool(self) -> Tool: | ||
wrapper = self._build_wrapper( | ||
|
@@ -117,7 +123,7 @@ def build_tool(self) -> Tool: | |
name="glean_search_api", | ||
description="Search Glean for relevant results.", | ||
func=wrapper.run, | ||
args_schema=self.GleanSearchAPISchema, | ||
args_schema=GleanSearchAPISchema, | ||
) | ||
|
||
self.status = "Glean Search API Tool for Langchain" | ||
|
@@ -137,7 +143,7 @@ def run_model(self) -> list[Data]: | |
|
||
# Build the data | ||
data = [Data(data=result, text=result["snippets"][0]["text"]) for result in results] | ||
self.status = data | ||
self.status = data # type: ignore[assignment] | ||
|
||
return data | ||
|
||
|
@@ -146,7 +152,7 @@ def _build_wrapper( | |
glean_api_url: str, | ||
glean_access_token: str, | ||
): | ||
return self.GleanAPIWrapper( | ||
return GleanAPIWrapper( | ||
glean_api_url=glean_api_url, | ||
glean_access_token=glean_access_token, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.