Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add tests to cycles in Graph and improve error handling #3628

Merged
merged 7 commits into from
Sep 2, 2024
98 changes: 78 additions & 20 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,31 @@
import uuid
import warnings
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from datetime import datetime, timezone
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional, cast

import nest_asyncio
from loguru import logger

from langflow.exceptions.component import ComponentBuildException
from langflow.graph.edge.base import CycleEdge
from langflow.graph.edge.base import CycleEdge, Edge
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, StartConfigDict, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, should_continue, sort_up_to_vertex
from langflow.graph.graph.utils import (
find_all_cycle_edges,
find_start_component_id,
has_cycle,
process_flow,
should_continue,
sort_up_to_vertex,
)
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
Expand Down Expand Up @@ -279,6 +286,15 @@ async def async_start(self, inputs: list[dict] | None = None, max_iterations: in

raise ValueError("Max iterations reached")

def _snapshot(self):
return {
"_run_queue": self._run_queue.copy(),
"_first_layer": self._first_layer.copy(),
"vertices_layers": copy.deepcopy(self.vertices_layers),
"vertices_to_run": copy.deepcopy(self.vertices_to_run),
"run_manager": copy.deepcopy(self.run_manager.to_dict()),
}

def __apply_config(self, config: StartConfigDict):
for vertex in self.vertices:
if vertex._custom_component is None:
Expand Down Expand Up @@ -460,6 +476,23 @@ def first_layer(self):
raise ValueError("Graph not prepared. Call prepare() first.")
return self._first_layer

@property
def is_cyclic(self):
"""
Check if the graph has any cycles.

Returns:
bool: True if the graph has any cycles, False otherwise.
"""
if self._is_cyclic is None:
vertices = [vertex.id for vertex in self.vertices]
try:
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
except KeyError:
edges = [(e["source"], e["target"]) for e in self._edges]
self._is_cyclic = has_cycle(vertices, edges)
return self._is_cyclic

@property
def run_id(self):
"""
Expand Down Expand Up @@ -1380,7 +1413,7 @@ async def process(self, fallback_to_env_vars: bool, start_component_id: str | No

def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: list[str]) -> list[str]:
next_runnable_vertices = set()
for v_id in vertex_successors_ids:
for v_id in sorted(vertex_successors_ids):
if not self.is_vertex_runnable(v_id):
next_runnable_vertices.update(self.find_runnable_predecessors_for_successor(v_id))
else:
Expand Down Expand Up @@ -1536,29 +1569,42 @@ def get_vertex_neighbors(self, vertex: "Vertex") -> dict["Vertex", int]:
neighbors[neighbor] += 1
return neighbors

@property
def cycles(self):
if self._cycles is None:
if self._start is None:
self._cycles = []
else:
entry_vertex = self._start._id
edges = [(e["data"]["sourceHandle"]["id"], e["data"]["targetHandle"]["id"]) for e in self._edges]
self._cycles = find_all_cycle_edges(entry_vertex, edges)
return self._cycles

def _build_edges(self) -> list[CycleEdge]:
"""Builds the edges of the graph."""
# Edge takes two vertices as arguments, so we need to build the vertices first
# and then build the edges
# if we can't find a vertex, we raise an error

edges: set[CycleEdge] = set()
edges: set[CycleEdge | Edge] = set()
for edge in self._edges:
new_edge = self.build_edge(edge)
edges.add(new_edge)
if self.vertices and not edges:
warnings.warn("Graph has vertices but no edges")
return list(edges)
return list(cast(Iterable[CycleEdge], edges))

def build_edge(self, edge: EdgeData) -> CycleEdge:
def build_edge(self, edge: EdgeData) -> CycleEdge | Edge:
source = self.get_vertex(edge["source"])
target = self.get_vertex(edge["target"])

if source is None:
raise ValueError(f"Source vertex {edge['source']} not found")
if target is None:
raise ValueError(f"Target vertex {edge['target']} not found")
new_edge = CycleEdge(source, target, edge)
if (source.id, target.id) in self.cycles:
new_edge: CycleEdge | Edge = CycleEdge(source, target, edge)
else:
new_edge = Edge(source, target, edge)
return new_edge

def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> type["Vertex"]:
Expand Down Expand Up @@ -1608,7 +1654,6 @@ def prepare(self, stop_component_id: str | None = None, start_component_id: str
if stop_component_id and start_component_id:
raise ValueError("You can only provide one of stop_component_id or start_component_id")
self.validate_stream()
self.edges = self._build_edges()

if stop_component_id or start_component_id:
try:
Expand Down Expand Up @@ -1658,12 +1703,25 @@ def layered_topological_sort(
"""Performs a layered topological sort of the vertices in the graph."""
vertices_ids = {vertex.id for vertex in vertices}
# Queue for vertices with no incoming edges
queue = deque(
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
in_degree_map = self.in_degree_map.copy()
if self.is_cyclic and all(in_degree_map.values()):
# This means we have a cycle because all vertex have in_degree_map > 0
# because of this we set the queue to start on the ._start if it exists
if self._start is not None:
queue = deque([self._start._id])
else:
# Find the chat input component
chat_input = find_start_component_id(vertices_ids)
if chat_input is None:
raise ValueError("No input component found and no start component provided")
queue = deque([chat_input])
else:
queue = deque(
vertex.id
for vertex in vertices
# if filter_graphs then only vertex.is_input will be considered
if in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
layers: list[list[str]] = []
visited = set(queue)

Expand All @@ -1684,13 +1742,13 @@ def layered_topological_sort(
if neighbor not in vertices_ids:
continue

self.in_degree_map[neighbor] -= 1 # 'remove' edge
if self.in_degree_map[neighbor] == 0 and neighbor not in visited:
in_degree_map[neighbor] -= 1 # 'remove' edge
if in_degree_map[neighbor] == 0 and neighbor not in visited:
queue.append(neighbor)

# if > 0 it might mean not all predecessors have added to the queue
# so we should process the neighbors predecessors
elif self.in_degree_map[neighbor] > 0:
elif in_degree_map[neighbor] > 0:
for predecessor in self.predecessor_map[neighbor]:
if predecessor not in queue and predecessor not in visited:
queue.append(predecessor)
Expand Down
16 changes: 5 additions & 11 deletions src/backend/base/langflow/graph/vertex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import os
import traceback
import types
import json
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import AsyncIterator, Callable, Iterator, Mapping

import pandas as pd
from loguru import logger
Expand Down Expand Up @@ -37,9 +36,9 @@
class VertexStates(str, Enum):
"""Vertex are related to it being active, inactive, or in an error state."""

ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
ACTIVE = "ACTIVE"
INACTIVE = "INACTIVE"
ERROR = "ERROR"


class Vertex:
Expand Down Expand Up @@ -105,12 +104,7 @@ def set_input_value(self, name: str, value: Any):
self._custom_component._set_input_value(name, value)

def to_data(self):
try:
data = json.loads(json.dumps(self._data, default=str))
except TypeError:
data = self._data

return data
return self._data

def add_component_instance(self, component_instance: "Component"):
component_instance.set_vertex(self)
Expand Down
2 changes: 0 additions & 2 deletions src/backend/base/langflow/graph/vertex/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ async def _get_result(self, requester: "Vertex", target_handle_name: str | None
)
for edge in self.get_edge_with_target(requester.id):
# We need to check if the edge is a normal edge
# or a contract edge

if edge.is_cycle and edge.target_param:
return requester.get_value_from_template_dict(edge.target_param)

Expand Down
111 changes: 111 additions & 0 deletions src/backend/tests/unit/graph/graph/test_cycles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest

from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.prototypes.ConditionalRouter import ConditionalRouterComponent
from langflow.custom.custom_component.component import Component
from langflow.graph.graph.base import Graph
from langflow.io import MessageTextInput, Output
from langflow.schema.message import Message


@pytest.fixture
def client():
pass


class Concatenate(Component):
display_name = "Concatenate"
description = "Concatenates two strings"

inputs = [
MessageTextInput(name="text", display_name="Text", required=True),
]
outputs = [
Output(display_name="Text", name="some_text", method="concatenate"),
]

def concatenate(self) -> Message:
return Message(text=f"{self.text}{self.text}" or "test")


def test_cycle_in_graph():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router")
chat_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
router.set(
input_text=chat_input.message_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=text_output.text_response)

graph = Graph(chat_input, chat_output)
assert graph.is_cyclic is True

# Run queue should contain chat_input and not router
assert "chat_input" in graph._run_queue
assert "router" not in graph._run_queue
results = []
max_iterations = 20
snapshots = [graph._snapshot()]
for result in graph.start(max_iterations=max_iterations, config={"output": {"cache": False}}):
snapshots.append(graph._snapshot())
results.append(result)
results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")]
assert results_ids[-2:] == ["text_output", "chat_output"]
assert len(results_ids) > len(graph.vertices), snapshots
# Check that chat_output and text_output are the last vertices in the results
assert results_ids == [
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"chat_input",
"concatenate",
"router",
"text_output",
"chat_output",
], f"Results: {results_ids}"


def test_cycle_in_graph_max_iterations():
chat_input = ChatInput(_id="chat_input")
router = ConditionalRouterComponent(_id="router")
chat_input.set(input_value=router.false_response)
concat_component = Concatenate(_id="concatenate")
concat_component.set(text=chat_input.message_response)
router.set(
input_text=chat_input.message_response,
match_text="testtesttesttest",
operator="equals",
message=concat_component.concatenate,
)
text_output = TextOutputComponent(_id="text_output")
text_output.set(input_value=router.true_response)
chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=text_output.text_response)

graph = Graph(chat_input, chat_output)
assert graph.is_cyclic is True

# Run queue should contain chat_input and not router
assert "chat_input" in graph._run_queue
assert "router" not in graph._run_queue
results = []

with pytest.raises(ValueError, match="Max iterations reached"):
for result in graph.start(max_iterations=2, config={"output": {"cache": False}}):
results.append(result)
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,35 @@ def memory_chatbot_graph():
openai_component.set(
input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key"
)
openai_component.get_output("text_output").value = "Mock response"
openai_component.set_on_output(name="text_output", value="Mock response", cache=True)

chat_output = ChatOutput(_id="chat_output")
chat_output.set(input_value=openai_component.text_response)

graph = Graph(chat_input, chat_output)
assert graph.in_degree_map == {"chat_output": 1, "prompt": 2, "openai": 1, "chat_input": 0, "chat_memory": 0}
return graph


def test_memory_chatbot(memory_chatbot_graph):
# Now we run step by step
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])
assert memory_chatbot_graph.in_degree_map == {
"chat_output": 1,
"prompt": 2,
"openai": 1,
"chat_input": 0,
"chat_memory": 0,
}
assert memory_chatbot_graph.vertices_layers == [["prompt"], ["openai"], ["chat_output"]]
assert memory_chatbot_graph.first_layer == ["chat_input", "chat_memory"]

for step in expected_order:
result = memory_chatbot_graph.step()
if isinstance(result, Finish):
break
assert step == result.vertex.id

assert step == result.vertex.id, (memory_chatbot_graph.in_degree_map, memory_chatbot_graph.vertices_layers)


def test_memory_chatbot_dump_structure(memory_chatbot_graph: Graph):
Expand Down
Loading
Loading