Skip to content

Commit

Permalink
refactor: separate initialization of nodes and edges in test_graph.py (
Browse files Browse the repository at this point in the history
…#2828)

* refactor: move test_graph.py

* refactor: allow Graph to be initialized with no nodes and edges

The Graph class in `base.py` was refactored to separate the initialization of nodes and edges into a separate method called `add_nodes_and_edges()`. This improves code readability and maintainability by organizing the code logic more effectively.

* refactor: separate initialization of nodes and edges in get_graph()

The `get_graph()` function in `conftest.py` was refactored to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively.

* refactor: separate initialization of nodes and edges in test_graph.py

* refactor: separate initialization of nodes and edges in base.py

The `add_node()` and `add_edge()` methods were added to the `Graph` class in `base.py` to separate the initialization of nodes and edges. This improves code readability and maintainability by organizing the code logic more effectively.
  • Loading branch information
ogabrielluiz authored Jul 22, 2024
1 parent 077f68f commit 77cc789
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 31 deletions.
73 changes: 44 additions & 29 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Graph:

def __init__(
self,
nodes: List[Dict],
edges: List[Dict[str, str]],
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
user_id: Optional[str] = None,
Expand All @@ -48,9 +46,6 @@ def __init__(
edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph.
flow_id (Optional[str], optional): The ID of the flow. Defaults to None.
"""
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
self._runs = 0
self._updates = 0
self.flow_id = flow_id
Expand All @@ -63,35 +58,49 @@ def __init__(
self._sorted_vertices_layers: List[List[str]] = []
self._run_id = ""
self._start_time = datetime.now(timezone.utc)

self.top_level_vertices = []
for vertex in self._vertices:
if vertex_id := vertex.get("id"):
self.top_level_vertices.append(vertex_id)
self._graph_data = process_flow(self.raw_graph_data)

self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.inactivated_vertices: set = set()
self.activated_vertices: List[str] = []
self.vertices_layers: List[List[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex: Optional[str] = None

self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self._build_graph()
self.build_graph_maps(self.edges)
self.define_vertices_lists()
self.state_manager = GraphStateManager()
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
logger.error(f"Error getting tracing service: {exc}")
self.tracing_service = None

def add_nodes_and_edges(self, nodes: List[Dict], edges: List[Dict[str, str]]):
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
self.top_level_vertices = []
for vertex in self._vertices:
if vertex_id := vertex.get("id"):
self.top_level_vertices.append(vertex_id)
self._graph_data = process_flow(self.raw_graph_data)

self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.initialize()

# TODO: Create a TypedDict to represente the node
def add_node(self, node: dict):
self._vertices.append(node)

# TODO: Create a TypedDict to represente the edge
def add_edge(self, edge: dict):
self._edges.append(edge)

def initialize(self):
self._build_graph()
self.build_graph_maps(self.edges)
self.define_vertices_lists()

def get_state(self, name: str) -> Optional[Data]:
"""
Returns the state of the graph with the given name.
Expand Down Expand Up @@ -638,7 +647,9 @@ def from_payload(
try:
vertices = payload["nodes"]
edges = payload["edges"]
return cls(vertices, edges, flow_id, flow_name, user_id)
graph = cls(flow_id, flow_name, user_id)
graph.add_nodes_and_edges(vertices, edges)
return graph
except KeyError as exc:
logger.exception(exc)
if "nodes" not in payload and "edges" not in payload:
Expand Down Expand Up @@ -1188,20 +1199,24 @@ def _build_vertices(self) -> List[Vertex]:
"""Builds the vertices of the graph."""
vertices: List[Vertex] = []
for vertex in self._vertices:
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
if "id" not in vertex_data:
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")

VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])

vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertex_instance = self._create_vertex(vertex)
vertices.append(vertex_instance)

return vertices

def _create_vertex(self, vertex: dict):
vertex_data = vertex["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
if "id" not in vertex_data:
raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id")

VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])

vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
return vertex_instance

def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
Expand Down
4 changes: 3 additions & 1 deletion src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def get_graph(_type="basic"):
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
graph = Graph()
graph.add_nodes_and_edges(nodes, edges)
return graph


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def test_invalid_node_types():
"edges": [],
}
with pytest.raises(Exception):
Graph(graph_data["nodes"], graph_data["edges"])
g = Graph()
g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"])


def test_get_vertices_with_target(basic_graph):
Expand Down

0 comments on commit 77cc789

Please sign in to comment.