Skip to content

Commit

Permalink
[autofix.ci] apply automated fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
autofix-ci[bot] authored Jul 4, 2024
1 parent cb4f0d0 commit e0b859d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 97 deletions.
20 changes: 4 additions & 16 deletions src/backend/base/langflow/components/chains/ConversationChain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from langchain.chains import ConversationChain

from langflow.custom import Component
Expand All @@ -15,27 +13,17 @@ class ConversationChainComponent(Component):

inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(
name="memory",
display_name="Memory",
input_types=["BaseChatMemory"],
)
),
]

outputs = [
Output(display_name="Text", name="text", method="invoke_chain")
]
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def invoke_chain(self) -> Message:
if not self.memory:
Expand Down
18 changes: 4 additions & 14 deletions src/backend/base/langflow/components/chains/LLMCheckerChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,17 @@ class LLMCheckerChainComponent(Component):

inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
]

outputs = [
Output(display_name="Text", name="text", method="invoke_chain")
]
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def invoke_chain(self) -> Message:
chain = LLMCheckerChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
result = response.get(chain.output_key, "")
result = str(result)
self.status = result
return Message(text=result)
return Message(text=result)
16 changes: 3 additions & 13 deletions src/backend/base/langflow/components/chains/LLMMathChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,12 @@ class LLMMathChainComponent(Component):

inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
]

outputs = [
Output(display_name="Text", name="text", method="invoke_chain")
]
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def invoke_chain(self) -> Message:
chain = LLMMathChain.from_llm(llm=self.llm)
Expand Down
31 changes: 7 additions & 24 deletions src/backend/base/langflow/components/chains/RetrievalQA.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.chains import RetrievalQA
from langflow.custom import Component
from langflow.field_typing import Message
from langflow.inputs import HandleInput, MultilineInput, BoolInput, DropdownInput
Expand All @@ -12,10 +12,7 @@ class RetrievalQAComponent(Component):

inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
DropdownInput(
name="chain_type",
Expand All @@ -25,18 +22,8 @@ class RetrievalQAComponent(Component):
value="Stuff",
advanced=True,
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True
),
HandleInput(
name="retriever",
display_name="Retriever",
input_types=["Retriever"],
required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"], required=True),
HandleInput(
name="memory",
display_name="Memory",
Expand All @@ -46,12 +33,10 @@ class RetrievalQAComponent(Component):
name="return_source_documents",
display_name="Return Source Documents",
value=False,
)
),
]

outputs = [
Output(display_name="Text", name="text", method="invoke_chain")
]
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def invoke_chain(self) -> Message:
chain_type = self.chain_type.lower().replace(" ", "_")
Expand All @@ -66,7 +51,7 @@ def invoke_chain(self) -> Message:
memory=self.memory,
# always include to help debugging
#
return_source_documents=True
return_source_documents=True,
)

result = runnable.invoke({"query": self.input_value})
Expand All @@ -79,5 +64,3 @@ def invoke_chain(self) -> Message:
# put the entire result to debug history, query and content
self.status = {**result, "source_documents": source_docs, "output": result_str}
return result_str


37 changes: 7 additions & 30 deletions src/backend/base/langflow/components/chains/SQLGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,17 @@ class SQLGeneratorComponent(Component):

inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True
),
HandleInput(
name="db",
display_name="SQLDatabase",
input_types=["SQLDatabase"],
required=True
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(name="db", display_name="SQLDatabase", input_types=["SQLDatabase"], required=True),
IntInput(
name="top_k",
display_name="Top K",
info="The number of results per select statement to return.",
value=5
),
MultilineInput(
name="prompt",
display_name="Prompt",
info="The prompt must contain `{question}`."
name="top_k", display_name="Top K", info="The number of results per select statement to return.", value=5
),
MultilineInput(name="prompt", display_name="Prompt", info="The prompt must contain `{question}`."),
]

outputs = [
Output(display_name="Text", name="text", method="invoke_chain")
]

outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def invoke_chain(self) -> Message:
if self.prompt:
Expand All @@ -70,4 +47,4 @@ def invoke_chain(self) -> Message:
response = query_writer.invoke({"question": self.input_value})
query = response.get("query")
self.status = query
return query
return query

0 comments on commit e0b859d

Please sign in to comment.