Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor graph vertex sorting #2583

Merged
merged 8 commits into from
Jul 10, 2024
91 changes: 20 additions & 71 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import find_start_component_id, process_flow
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
Expand Down Expand Up @@ -1197,74 +1197,6 @@ def __repr__(self):
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"

def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
# Initial setup
visited = set() # To keep track of visited vertices
excluded = set() # To keep track of vertices that should be excluded

def get_successors(vertex, recursive=True):
# Recursively get the successors of the current vertex
successors = vertex.successors
if not successors:
return []
successors_result = []
for successor in successors:
# Just return a list of successors
if recursive:
next_successors = get_successors(successor)
successors_result.extend(next_successors)
successors_result.append(successor)
return successors_result

try:
stop_or_start_vertex = self.get_vertex(vertex_id)
stack = [vertex_id] # Use a list as a stack for DFS
except ValueError:
stop_or_start_vertex = self.get_root_of_group_node(vertex_id)
stack = [stop_or_start_vertex.id]
vertex_id = stop_or_start_vertex.id
stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors]
# DFS to collect all vertices that can reach the specified vertex
while stack:
current_id = stack.pop()
if current_id not in visited and current_id not in excluded:
visited.add(current_id)
current_vertex = self.get_vertex(current_id)
# Assuming get_predecessors is a method that returns all vertices with edges to current_vertex
for predecessor in current_vertex.predecessors:
stack.append(predecessor.id)

if current_id == vertex_id:
# We should add to visited all the vertices that are successors of the current vertex
# and their successors and so on
# if the vertex is a start, it means we are starting from the beginning
# and getting successors
for successor in current_vertex.successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
all_successors = get_successors(successor, recursive=False)
for successor in all_successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
elif current_id not in stop_predecessors and is_start:
# If the current vertex is not the target vertex, we should add all its successors
# to the stack if they are not in visited

# If we are starting from the beginning, we should add all successors
for successor in current_vertex.successors:
if successor.id not in visited:
stack.append(successor.id)

# Filter the original graph's vertices and edges to keep only those in `visited`
vertices_to_keep = [self.get_vertex(vid) for vid in visited]

return vertices_to_keep

def layered_topological_sort(
self,
vertices: List[Vertex],
Expand Down Expand Up @@ -1395,6 +1327,21 @@ def _max_dependency_index(self, vertex_id: str, index_map: Dict[str, int]) -> in
max_index = max(max_index, index_map[successor.id])
return max_index

def __to_dict(self) -> Dict[str, Dict[str, List[str]]]:
"""Converts the graph to a dictionary."""
result: Dict = dict()
for vertex in self.vertices:
vertex_id = vertex.id
sucessors = [i.id for i in self.get_all_successors(vertex)]
predecessors = [i.id for i in self.get_predecessors(vertex)]
result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}}
return result

def __filter_vertices(self, vertex_id: str, is_start: bool = False):
dictionaryized_graph = self.__to_dict()
vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start)
return [self.get_vertex(vertex_id) for vertex_id in vertex_ids]

def sort_vertices(
self,
stop_component_id: Optional[str] = None,
Expand All @@ -1404,9 +1351,11 @@ def sort_vertices(
self.mark_all_vertices("ACTIVE")
if stop_component_id is not None:
self.stop_vertex = stop_component_id
vertices = self.sort_up_to_vertex(stop_component_id)
vertices = self.__filter_vertices(stop_component_id)

elif start_component_id:
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
vertices = self.__filter_vertices(start_component_id, is_start=True)

else:
vertices = self.vertices
# without component_id we are probably running in the chat
Expand Down
48 changes: 48 additions & 0 deletions src/backend/base/langflow/graph/graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List, Dict
import copy
from collections import deque


PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]


Expand Down Expand Up @@ -224,3 +226,49 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
if edge["target"] == group_node_id or edge["source"] == group_node_id:
updated_edges.append(new_edge)
return updated_edges


def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
successors_result = []
stack = [vertex_id]
while stack:
current_id = stack.pop()
successors_result.append(current_id)
stack.extend(graph[current_id]["successors"])
return successors_result


def sort_up_to_vertex(graph: Dict[str, Dict[str, List[str]]], vertex_id: str, is_start: bool = False) -> List[str]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
try:
stop_or_start_vertex = graph[vertex_id]
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found into graph")

visited, excluded = set(), set()
stack = [vertex_id]
stop_predecessors = set(stop_or_start_vertex["predecessors"])

while stack:
current_id = stack.pop()
if current_id in visited or current_id in excluded:
continue

visited.add(current_id)
current_vertex = graph[current_id]

stack.extend(current_vertex["predecessors"])

if current_id == vertex_id or (current_id not in stop_predecessors and is_start):
for successor_id in current_vertex["successors"]:
if is_start:
stack.append(successor_id)
else:
excluded.add(successor_id)
for succ_id in get_successors(graph, successor_id):
if is_start:
stack.append(succ_id)
else:
excluded.add(succ_id)

return list(visited)
122 changes: 122 additions & 0 deletions tests/unit/graph/graph/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from langflow.graph.graph import utils


@pytest.fixture
def graph():
return {
"A": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
"C": {"successors": ["B", "I"], "predecessors": ["N"]},
"D": {"successors": ["E", "F"], "predecessors": ["B"]},
"E": {"successors": ["G"], "predecessors": ["D"]},
"F": {"successors": ["G", "H"], "predecessors": ["D"]},
"G": {"successors": [], "predecessors": ["E", "F"]},
"H": {"successors": [], "predecessors": ["F"]},
"I": {"successors": ["M"], "predecessors": ["C", "J"]},
"J": {"successors": ["I", "K"], "predecessors": ["N"]},
"K": {"successors": ["Q", "P", "O"], "predecessors": ["J", "L"]},
"L": {"successors": ["K"], "predecessors": []},
"M": {"successors": [], "predecessors": ["I"]},
"N": {"successors": ["C", "J"], "predecessors": []},
"O": {"successors": ["R"], "predecessors": ["K"]},
"P": {"successors": ["U"], "predecessors": ["K"]},
"Q": {"successors": ["V"], "predecessors": ["K"]},
"R": {"successors": ["S"], "predecessors": ["O"]},
"S": {"successors": ["T"], "predecessors": ["R"]},
"T": {"successors": [], "predecessors": ["S"]},
"U": {"successors": ["W"], "predecessors": ["P"]},
"V": {"successors": ["Y"], "predecessors": ["Q"]},
"W": {"successors": ["X"], "predecessors": ["U"]},
"X": {"successors": [], "predecessors": ["W"]},
"Y": {"successors": ["Z"], "predecessors": ["V"]},
"Z": {"successors": [], "predecessors": ["Y"]},
}


def test_get_successors_a(graph):
vertex_id = "A"

result = utils.get_successors(graph, vertex_id)

assert set(result) == {"A", "B", "D", "E", "F", "H", "G"}


def test_get_successors_z(graph):
vertex_id = "Z"

result = utils.get_successors(graph, vertex_id)

assert set(result) == {"Z"}


def test_sort_up_to_vertex_n_is_start(graph):
vertex_id = "N"

result = utils.sort_up_to_vertex(graph, vertex_id, is_start=True)
# Result shoud be all the vertices
assert set(result) == set(graph.keys())


def test_sort_up_to_vertex_z(graph):
vertex_id = "Z"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "Q", "V", "Y", "Z"}


def test_sort_up_to_vertex_x(graph):
vertex_id = "X"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "P", "U", "W", "X"}


def test_sort_up_to_vertex_t(graph):
vertex_id = "T"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "O", "R", "S", "T"}


def test_sort_up_to_vertex_m(graph):
vertex_id = "M"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "J", "I", "M"}


def test_sort_up_to_vertex_h(graph):
vertex_id = "H"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "A", "B", "D", "F", "H"}


def test_sort_up_to_vertex_g(graph):
vertex_id = "G"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "A", "B", "D", "F", "E", "G"}


def test_sort_up_to_vertex_a(graph):
vertex_id = "A"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"A"}


def test_sort_up_to_vertex_invalid_vertex(graph):
vertex_id = "7"

with pytest.raises(ValueError):
utils.sort_up_to_vertex(graph, vertex_id)
Loading