diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 8a42f1cbb38..d0adb66ff77 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -34,8 +34,6 @@ class Graph: def __init__( self, - nodes: List[Dict], - edges: List[Dict[str, str]], flow_id: Optional[str] = None, flow_name: Optional[str] = None, user_id: Optional[str] = None, @@ -48,9 +46,6 @@ def __init__( edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph. flow_id (Optional[str], optional): The ID of the flow. Defaults to None. """ - self._vertices = nodes - self._edges = edges - self.raw_graph_data = {"nodes": nodes, "edges": edges} self._runs = 0 self._updates = 0 self.flow_id = flow_id @@ -63,28 +58,15 @@ def __init__( self._sorted_vertices_layers: List[List[str]] = [] self._run_id = "" self._start_time = datetime.now(timezone.utc) - - self.top_level_vertices = [] - for vertex in self._vertices: - if vertex_id := vertex.get("id"): - self.top_level_vertices.append(vertex_id) - self._graph_data = process_flow(self.raw_graph_data) - - self._vertices = self._graph_data["nodes"] - self._edges = self._graph_data["edges"] self.inactivated_vertices: set = set() self.activated_vertices: List[str] = [] self.vertices_layers: List[List[str]] = [] self.vertices_to_run: set[str] = set() self.stop_vertex: Optional[str] = None - self.inactive_vertices: set = set() self.edges: List[ContractEdge] = [] self.vertices: List[Vertex] = [] self.run_manager = RunnableVerticesManager() - self._build_graph() - self.build_graph_maps(self.edges) - self.define_vertices_lists() self.state_manager = GraphStateManager() try: self.tracing_service: "TracingService" | None = get_tracing_service() @@ -92,6 +74,33 @@ def __init__( logger.error(f"Error getting tracing service: {exc}") self.tracing_service = None + def add_nodes_and_edges(self, nodes: List[Dict], edges: List[Dict[str, str]]): + self._vertices = nodes + self._edges = edges + self.raw_graph_data = {"nodes": nodes, "edges": edges} + self.top_level_vertices = [] + for vertex in self._vertices: + if vertex_id := vertex.get("id"): + self.top_level_vertices.append(vertex_id) + self._graph_data = process_flow(self.raw_graph_data) + + self._vertices = self._graph_data["nodes"] + self._edges = self._graph_data["edges"] + self.initialize() + + # TODO: Create a TypedDict to represente the node + def add_node(self, node: dict): + self._vertices.append(node) + + # TODO: Create a TypedDict to represente the edge + def add_edge(self, edge: dict): + self._edges.append(edge) + + def initialize(self): + self._build_graph() + self.build_graph_maps(self.edges) + self.define_vertices_lists() + def get_state(self, name: str) -> Optional[Data]: """ Returns the state of the graph with the given name. @@ -638,7 +647,9 @@ def from_payload( try: vertices = payload["nodes"] edges = payload["edges"] - return cls(vertices, edges, flow_id, flow_name, user_id) + graph = cls(flow_id, flow_name, user_id) + graph.add_nodes_and_edges(vertices, edges) + return graph except KeyError as exc: logger.exception(exc) if "nodes" not in payload and "edges" not in payload: @@ -1188,20 +1199,24 @@ def _build_vertices(self) -> List[Vertex]: """Builds the vertices of the graph.""" vertices: List[Vertex] = [] for vertex in self._vertices: - vertex_data = vertex["data"] - vertex_type: str = vertex_data["type"] # type: ignore - vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore - if "id" not in vertex_data: - raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id") - - VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) - - vertex_instance = VertexClass(vertex, graph=self) - vertex_instance.set_top_level(self.top_level_vertices) + vertex_instance = self._create_vertex(vertex) vertices.append(vertex_instance) return vertices + def _create_vertex(self, vertex: dict): + vertex_data = vertex["data"] + vertex_type: str = vertex_data["type"] # type: ignore + vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore + if "id" not in vertex_data: + raise ValueError(f"Vertex data for {vertex_data['display_name']} does not contain an id") + + VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) + + vertex_instance = VertexClass(vertex, graph=self) + vertex_instance.set_top_level(self.top_level_vertices) + return vertex_instance + def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]: """Returns the children of a vertex based on the vertex type.""" children = [] diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index ab77315f30a..11a3fa2b7d6 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -158,7 +158,9 @@ def get_graph(_type="basic"): data_graph = flow_graph["data"] nodes = data_graph["nodes"] edges = data_graph["edges"] - return Graph(nodes, edges) + graph = Graph() + graph.add_nodes_and_edges(nodes, edges) + return graph @pytest.fixture diff --git a/src/backend/tests/unit/test_graph.py b/src/backend/tests/unit/graph/test_graph.py similarity index 99% rename from src/backend/tests/unit/test_graph.py rename to src/backend/tests/unit/graph/test_graph.py index ec69051d290..f6b85aa4e04 100644 --- a/src/backend/tests/unit/test_graph.py +++ b/src/backend/tests/unit/graph/test_graph.py @@ -116,7 +116,8 @@ def test_invalid_node_types(): "edges": [], } with pytest.raises(Exception): - Graph(graph_data["nodes"], graph_data["edges"]) + g = Graph() + g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"]) def test_get_vertices_with_target(basic_graph):