Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Enhance tools with enums and improved error handling #4493

Merged
merged 11 commits into from
Nov 12, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from langchain.tools import StructuredTool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from langflow.base.langchain_utilities.model import LCToolComponent
Expand Down Expand Up @@ -38,14 +39,18 @@ def build_tool(self) -> Tool:
wrapper = self._build_wrapper()

def search_func(query: str, max_results: int = 5, max_snippet_length: int = 100) -> list[dict[str, Any]]:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
try:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
except Exception as e:
msg = f"Error in DuckDuckGo Search: {e!s}"
raise ToolException(msg) from e
return limited_results

tool = StructuredTool.from_function(
Expand All @@ -67,5 +72,5 @@ def run_model(self) -> list[Data]:
}
)
data_list = [Data(data=result, text=result.get("snippet", "")) for result in results]
self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list
162 changes: 84 additions & 78 deletions src/backend/base/langflow/components/tools/glean_search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from urllib.parse import urljoin

import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import StructuredTool, ToolException
from pydantic import BaseModel
from pydantic.v1 import Field

from langflow.base.langchain_utilities.model import LCToolComponent
Expand All @@ -13,64 +13,54 @@
from langflow.schema import Data


class GleanSearchAPIComponent(LCToolComponent):
display_name = "Glean Search API"
description = "Call Glean Search API"
name = "GleanAPI"
class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")

inputs = [
StrInput(
name="glean_api_url",
display_name="Glean API URL",
required=True,
),
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True),
MultilineInput(name="query", display_name="Query", required=True),
IntInput(name="page_size", display_name="Page Size", value=10),
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""

glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)
class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""

if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)
glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)

if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
try:
results = self.results(query, **kwargs)

processed_results = []
Expand All @@ -80,32 +70,48 @@ def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]

processed_results.append(result)
processed_results.append(result)
except Exception as e:
error_message = f"Error in Glean Search API: {e!s}"
raise ToolException(error_message) from e

return processed_results
return processed_results

def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()
response_json = response.json()

response.raise_for_status()
response_json = response.json()
return response_json.get("results", [])

return response_json.get("results", [])
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
class GleanSearchAPIComponent(LCToolComponent):
display_name = "Glean Search API"
description = "Call Glean Search API"
name = "GleanAPI"

inputs = [
StrInput(
name="glean_api_url",
display_name="Glean API URL",
required=True,
),
SecretStrInput(name="glean_access_token", display_name="Glean Access Token", required=True),
MultilineInput(name="query", display_name="Query", required=True),
IntInput(name="page_size", display_name="Page Size", value=10),
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

def build_tool(self) -> Tool:
wrapper = self._build_wrapper(
Expand All @@ -117,7 +123,7 @@ def build_tool(self) -> Tool:
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
args_schema=GleanSearchAPISchema,
)

self.status = "Glean Search API Tool for Langchain"
Expand All @@ -137,7 +143,7 @@ def run_model(self) -> list[Data]:

# Build the data
data = [Data(data=result, text=result["snippets"][0]["text"]) for result in results]
self.status = data
self.status = data # type: ignore[assignment]

return data

Expand All @@ -146,7 +152,7 @@ def _build_wrapper(
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
return GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)
72 changes: 47 additions & 25 deletions src/backend/base/langflow/components/tools/serp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from langchain.tools import StructuredTool
from langchain_community.utilities.serpapi import SerpAPIWrapper
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field

Expand All @@ -11,6 +12,23 @@
from langflow.schema import Data


class SerpAPISchema(BaseModel):
"""Schema for SerpAPI search parameters."""

query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(
default={
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
},
description="Additional search parameters",
)
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")


class SerpAPIComponent(LCToolComponent):
display_name = "Serp Search API"
description = "Call Serp Search API with result limiting"
Expand All @@ -27,46 +45,50 @@ class SerpAPIComponent(LCToolComponent):
IntInput(name="max_snippet_length", display_name="Max Snippet Length", value=100, advanced=True),
]

class SerpAPISchema(BaseModel):
query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(default_factory=dict, description="Additional search parameters")
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")

def _build_wrapper(self) -> SerpAPIWrapper:
if self.search_params:
def _build_wrapper(self, params: dict[str, Any] | None = None) -> SerpAPIWrapper:
"""Build a SerpAPIWrapper with the provided parameters."""
params = params or {}
if params:
return SerpAPIWrapper(
serpapi_api_key=self.serpapi_api_key,
params=self.search_params,
params=params,
)
return SerpAPIWrapper(serpapi_api_key=self.serpapi_api_key)

def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
wrapper = self._build_wrapper(self.search_params) # noqa: F841

def search_func(
query: str, params: dict[str, Any] | None = None, max_results: int = 5, max_snippet_length: int = 100
) -> list[dict[str, Any]]:
params = params or {}
full_results = wrapper.results(query, **params)
organic_results = full_results.get("organic_results", [])[:max_results]

limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)

try:
# rebuild the wrapper if params are provided
if params:
wrapper = self._build_wrapper(params)

full_results = wrapper.results(query)
organic_results = full_results.get("organic_results", [])[:max_results]

limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)

except Exception as e:
error_message = f"Error in SerpAPI search: {e!s}"
logger.debug(error_message)
raise ToolException(error_message) from e
return limited_results

tool = StructuredTool.from_function(
name="serp_search_api",
description="Search for recent results using SerpAPI with result limiting",
func=search_func,
args_schema=self.SerpAPISchema,
args_schema=SerpAPISchema,
)

self.status = "SerpAPI Tool created"
Expand All @@ -91,5 +113,5 @@ def run_model(self) -> list[Data]:
self.status = f"Error: {e}"
return [Data(data={"error": str(e)}, text=str(e))]

self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list
Loading
Loading