Skip to content

Commit

Permalink
bugfix: Properly output a Tool from Glean Search (langflow-ai#3851)
Browse files Browse the repository at this point in the history
* bugfix: Properly output a Tool from Glean Search

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and diogocabral committed Nov 26, 2024
1 parent 0638214 commit a4677c7
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
from typing import Dict, List

from langflow.custom import Component
from langflow.io import DataInput, Output
Expand Down Expand Up @@ -49,7 +48,7 @@ def parse_transcription(self) -> Data:
self.status = error_message
return Data(data={"error": error_message})

def parse_with_speakers(self, utterances: List[Dict]) -> str:
def parse_with_speakers(self, utterances: list[dict]) -> str:
parsed_result = []
for utterance in utterances:
speaker = utterance["speaker"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import assemblyai as aai

from langflow.custom import Component
Expand Down Expand Up @@ -48,7 +46,7 @@ class AssemblyAIListTranscripts(Component):
Output(display_name="Transcript List", name="transcript_list", method="list_transcripts"),
]

def list_transcripts(self) -> List[Data]:
def list_transcripts(self) -> list[Data]:
aai.settings.api_key = self.api_key

params = aai.ListTranscriptParameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from langchain_core.prompts import HumanMessagePromptTemplate

from langchain_core.prompts import HumanMessagePromptTemplate

from langflow.custom import Component
from langflow.inputs import DefaultPromptField, SecretStrInput, StrInput
from langflow.io import Output
Expand Down
187 changes: 113 additions & 74 deletions src/backend/base/langflow/components/tools/GleanSearchAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from urllib.parse import urljoin

import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from pydantic.v1 import Field

from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.field_typing import Tool
Expand All @@ -28,90 +30,127 @@ class GleanSearchAPIComponent(LCToolComponent):
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]

class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""

glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)

if len(results) == 0:
raise AssertionError("No good Glean Search Result was found")

return results

def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self.results(query, **kwargs)

processed_results = []
for result in results:
if "title" in result:
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]

processed_results.append(result)

return processed_results

def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()
response_json = response.json()

return response_json.get("results", [])

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")

def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
wrapper = self._build_wrapper(
glean_api_url=self.glean_api_url,
glean_access_token=self.glean_access_token,
)

return Tool(name="glean_search_api", description="Search with the Glean API", func=wrapper.run)
tool = StructuredTool.from_function(
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
)

def run_model(self) -> Data | list[Data]:
wrapper = self._build_wrapper()
self.status = "Glean Search API Tool for Langchain"

results = wrapper.results(
query=self.query,
page_size=self.page_size,
request_options=self.request_options,
)
return tool

list_results = results.get("results", [])
def run_model(self) -> list[Data]:
tool = self.build_tool()

results = tool.run(
{
"query": self.query,
"page_size": self.page_size,
"request_options": self.request_options,
}
)

# Build the data
data = []
for result in list_results:
data.append(Data(data=result))
for result in results:
data.append(Data(data=result, text=result["snippets"][0]["text"]))

self.status = data

return data

def _build_wrapper(self):
class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""

glean_api_url: str
glean_access_token: str
act_as: str = "[email protected]" # TODO: Detect this

def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"

return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}

def run(self, query: str, **kwargs: Any) -> str:
results = self.results(query, **kwargs)

return self._result_as_string(results)

def results(self, query: str, **kwargs: Any) -> dict:
results = self._search_api_results(query, **kwargs)

return results

def _search_api_results(self, query: str, **kwargs: Any) -> dict[str, Any]:
request_details = self._prepare_request(query, **kwargs)

response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)

response.raise_for_status()

return response.json()

@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)

return GleanAPIWrapper(glean_api_url=self.glean_api_url, glean_access_token=self.glean_access_token)
def _build_wrapper(
self,
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)

0 comments on commit a4677c7

Please sign in to comment.