Skip to content

Commit

Permalink
feat: refactor graph vertex sorting (#2583)
Browse files Browse the repository at this point in the history
* refactor: extract method from class to new func

* test: add new tests

* refactor: simplify funcs to improve readability

* refactor: extract new func from larger func

* refactor: remove recursion from func

* refactor: remove coupling with graph and vertex

* refactor: create adapter funcs to use new code

* refactor: add test for sorting up to vertex N with is_start=True

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
  • Loading branch information
italojohnny and ogabrielluiz authored Jul 10, 2024
1 parent 3406575 commit aa1958a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 71 deletions.
91 changes: 20 additions & 71 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import find_start_component_id, process_flow
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
Expand Down Expand Up @@ -1197,74 +1197,6 @@ def __repr__(self):
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"

def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
# Initial setup
visited = set() # To keep track of visited vertices
excluded = set() # To keep track of vertices that should be excluded

def get_successors(vertex, recursive=True):
# Recursively get the successors of the current vertex
successors = vertex.successors
if not successors:
return []
successors_result = []
for successor in successors:
# Just return a list of successors
if recursive:
next_successors = get_successors(successor)
successors_result.extend(next_successors)
successors_result.append(successor)
return successors_result

try:
stop_or_start_vertex = self.get_vertex(vertex_id)
stack = [vertex_id] # Use a list as a stack for DFS
except ValueError:
stop_or_start_vertex = self.get_root_of_group_node(vertex_id)
stack = [stop_or_start_vertex.id]
vertex_id = stop_or_start_vertex.id
stop_predecessors = [pre.id for pre in stop_or_start_vertex.predecessors]
# DFS to collect all vertices that can reach the specified vertex
while stack:
current_id = stack.pop()
if current_id not in visited and current_id not in excluded:
visited.add(current_id)
current_vertex = self.get_vertex(current_id)
# Assuming get_predecessors is a method that returns all vertices with edges to current_vertex
for predecessor in current_vertex.predecessors:
stack.append(predecessor.id)

if current_id == vertex_id:
# We should add to visited all the vertices that are successors of the current vertex
# and their successors and so on
# if the vertex is a start, it means we are starting from the beginning
# and getting successors
for successor in current_vertex.successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
all_successors = get_successors(successor, recursive=False)
for successor in all_successors:
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
elif current_id not in stop_predecessors and is_start:
# If the current vertex is not the target vertex, we should add all its successors
# to the stack if they are not in visited

# If we are starting from the beginning, we should add all successors
for successor in current_vertex.successors:
if successor.id not in visited:
stack.append(successor.id)

# Filter the original graph's vertices and edges to keep only those in `visited`
vertices_to_keep = [self.get_vertex(vid) for vid in visited]

return vertices_to_keep

def layered_topological_sort(
self,
vertices: List[Vertex],
Expand Down Expand Up @@ -1395,6 +1327,21 @@ def _max_dependency_index(self, vertex_id: str, index_map: Dict[str, int]) -> in
max_index = max(max_index, index_map[successor.id])
return max_index

def __to_dict(self) -> Dict[str, Dict[str, List[str]]]:
"""Converts the graph to a dictionary."""
result: Dict = dict()
for vertex in self.vertices:
vertex_id = vertex.id
sucessors = [i.id for i in self.get_all_successors(vertex)]
predecessors = [i.id for i in self.get_predecessors(vertex)]
result |= {vertex_id: {"successors": sucessors, "predecessors": predecessors}}
return result

def __filter_vertices(self, vertex_id: str, is_start: bool = False):
dictionaryized_graph = self.__to_dict()
vertex_ids = sort_up_to_vertex(dictionaryized_graph, vertex_id, is_start)
return [self.get_vertex(vertex_id) for vertex_id in vertex_ids]

def sort_vertices(
self,
stop_component_id: Optional[str] = None,
Expand All @@ -1404,9 +1351,11 @@ def sort_vertices(
self.mark_all_vertices("ACTIVE")
if stop_component_id is not None:
self.stop_vertex = stop_component_id
vertices = self.sort_up_to_vertex(stop_component_id)
vertices = self.__filter_vertices(stop_component_id)

elif start_component_id:
vertices = self.sort_up_to_vertex(start_component_id, is_start=True)
vertices = self.__filter_vertices(start_component_id, is_start=True)

else:
vertices = self.vertices
# without component_id we are probably running in the chat
Expand Down
48 changes: 48 additions & 0 deletions src/backend/base/langflow/graph/graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List, Dict
import copy
from collections import deque


PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]


Expand Down Expand Up @@ -224,3 +226,49 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
if edge["target"] == group_node_id or edge["source"] == group_node_id:
updated_edges.append(new_edge)
return updated_edges


def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
successors_result = []
stack = [vertex_id]
while stack:
current_id = stack.pop()
successors_result.append(current_id)
stack.extend(graph[current_id]["successors"])
return successors_result


def sort_up_to_vertex(graph: Dict[str, Dict[str, List[str]]], vertex_id: str, is_start: bool = False) -> List[str]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
try:
stop_or_start_vertex = graph[vertex_id]
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found into graph")

visited, excluded = set(), set()
stack = [vertex_id]
stop_predecessors = set(stop_or_start_vertex["predecessors"])

while stack:
current_id = stack.pop()
if current_id in visited or current_id in excluded:
continue

visited.add(current_id)
current_vertex = graph[current_id]

stack.extend(current_vertex["predecessors"])

if current_id == vertex_id or (current_id not in stop_predecessors and is_start):
for successor_id in current_vertex["successors"]:
if is_start:
stack.append(successor_id)
else:
excluded.add(successor_id)
for succ_id in get_successors(graph, successor_id):
if is_start:
stack.append(succ_id)
else:
excluded.add(succ_id)

return list(visited)
122 changes: 122 additions & 0 deletions tests/unit/graph/graph/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from langflow.graph.graph import utils


@pytest.fixture
def graph():
return {
"A": {"successors": ["B"], "predecessors": []},
"B": {"successors": ["D"], "predecessors": ["A", "C"]},
"C": {"successors": ["B", "I"], "predecessors": ["N"]},
"D": {"successors": ["E", "F"], "predecessors": ["B"]},
"E": {"successors": ["G"], "predecessors": ["D"]},
"F": {"successors": ["G", "H"], "predecessors": ["D"]},
"G": {"successors": [], "predecessors": ["E", "F"]},
"H": {"successors": [], "predecessors": ["F"]},
"I": {"successors": ["M"], "predecessors": ["C", "J"]},
"J": {"successors": ["I", "K"], "predecessors": ["N"]},
"K": {"successors": ["Q", "P", "O"], "predecessors": ["J", "L"]},
"L": {"successors": ["K"], "predecessors": []},
"M": {"successors": [], "predecessors": ["I"]},
"N": {"successors": ["C", "J"], "predecessors": []},
"O": {"successors": ["R"], "predecessors": ["K"]},
"P": {"successors": ["U"], "predecessors": ["K"]},
"Q": {"successors": ["V"], "predecessors": ["K"]},
"R": {"successors": ["S"], "predecessors": ["O"]},
"S": {"successors": ["T"], "predecessors": ["R"]},
"T": {"successors": [], "predecessors": ["S"]},
"U": {"successors": ["W"], "predecessors": ["P"]},
"V": {"successors": ["Y"], "predecessors": ["Q"]},
"W": {"successors": ["X"], "predecessors": ["U"]},
"X": {"successors": [], "predecessors": ["W"]},
"Y": {"successors": ["Z"], "predecessors": ["V"]},
"Z": {"successors": [], "predecessors": ["Y"]},
}


def test_get_successors_a(graph):
vertex_id = "A"

result = utils.get_successors(graph, vertex_id)

assert set(result) == {"A", "B", "D", "E", "F", "H", "G"}


def test_get_successors_z(graph):
vertex_id = "Z"

result = utils.get_successors(graph, vertex_id)

assert set(result) == {"Z"}


def test_sort_up_to_vertex_n_is_start(graph):
vertex_id = "N"

result = utils.sort_up_to_vertex(graph, vertex_id, is_start=True)
# Result shoud be all the vertices
assert set(result) == set(graph.keys())


def test_sort_up_to_vertex_z(graph):
vertex_id = "Z"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "Q", "V", "Y", "Z"}


def test_sort_up_to_vertex_x(graph):
vertex_id = "X"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "P", "U", "W", "X"}


def test_sort_up_to_vertex_t(graph):
vertex_id = "T"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"L", "N", "J", "K", "O", "R", "S", "T"}


def test_sort_up_to_vertex_m(graph):
vertex_id = "M"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "J", "I", "M"}


def test_sort_up_to_vertex_h(graph):
vertex_id = "H"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "A", "B", "D", "F", "H"}


def test_sort_up_to_vertex_g(graph):
vertex_id = "G"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"N", "C", "A", "B", "D", "F", "E", "G"}


def test_sort_up_to_vertex_a(graph):
vertex_id = "A"

result = utils.sort_up_to_vertex(graph, vertex_id)

assert set(result) == {"A"}


def test_sort_up_to_vertex_invalid_vertex(graph):
vertex_id = "7"

with pytest.raises(ValueError):
utils.sort_up_to_vertex(graph, vertex_id)

0 comments on commit aa1958a

Please sign in to comment.