Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add maximum iterations limit in Graph start method. #3336

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.schema import GraphData, GraphDump, StartConfigDict, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.graph.utils import find_start_component_id, process_flow, should_continue, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
Expand Down Expand Up @@ -247,7 +247,7 @@ def add_component_edge(self, source_id: str, output_input_tuple: Tuple[str, str]
}
self._add_edge(edge_data)

async def async_start(self, inputs: Optional[List[dict]] = None):
async def async_start(self, inputs: Optional[List[dict]] = None, max_iterations: Optional[int] = None):
if not self._prepared:
raise ValueError("Graph not prepared. Call prepare() first.")
# The idea is for this to return a generator that yields the result of
Expand All @@ -256,17 +256,40 @@ async def async_start(self, inputs: Optional[List[dict]] = None):
for key, value in _input.items():
vertex = self.get_vertex(key)
vertex.set_input_value(key, value)
while True:
# I want to keep a counter of how many tyimes result.vertex.id
# has been yielded
yielded_counts: dict[str, int] = defaultdict(int)

while should_continue(yielded_counts, max_iterations):
result = await self.astep()
yield result
if hasattr(result, "vertex"):
yielded_counts[result.vertex.id] += 1
if isinstance(result, Finish):
return

def start(self, inputs: Optional[List[dict]] = None) -> Generator:
raise ValueError("Max iterations reached")

def __apply_config(self, config: StartConfigDict):
for vertex in self.vertices:
if vertex._custom_component is None:
continue
for output in vertex._custom_component.outputs:
for key, value in config["output"].items():
setattr(output, key, value)

def start(
self,
inputs: Optional[List[dict]] = None,
max_iterations: Optional[int] = None,
config: Optional[StartConfigDict] = None,
) -> Generator:
if config is not None:
self.__apply_config(config)
#! Change this ASAP
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs)
async_gen = self.async_start(inputs, max_iterations)
async_gen_task = asyncio.ensure_future(async_gen.__anext__())

while True:
Expand Down
8 changes: 8 additions & 0 deletions src/backend/base/langflow/graph/graph/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ class VertexBuildResult(NamedTuple):
valid: bool
artifacts: dict
vertex: "Vertex"


class OutputConfigDict(TypedDict):
cache: bool


class StartConfigDict(TypedDict):
output: OutputConfigDict
Loading