Skip to content

Commit

Permalink
feat: add dump and dumps methods to Graph (#3202)
Browse files Browse the repository at this point in the history
* feat(utils.py): add escape_json_dump function to escape JSON strings for Edge dictionaries

* refactor(Output): streamline add_types method to prevent duplicate entries in types list for improved type management

* feat(data.py): add classmethod decorator to validate_data for enhanced validation logic when checking data types

* feat(setup.py): implement retry logic for loading starter projects to enhance robustness against JSON decode errors

* fix(input_mixin.py): improve model_config formatting and update field_type alias for clarity and consistency in field definitions

* feat(types.py): refactor vertex constructors to use NodeData and add input/output methods for better component interaction

* feat(schema.py): add NodeData and Position TypedDicts for improved type safety and structure in vertex data handling

* feat(base.py): update Vertex to use NodeData type and add to_data method for better data management and access

* refactor(schema.py): update TargetHandle and SourceHandle models to include model_config attribute

* Add TypedDict classes for graph schema serialization in `schema.py`

* Refactor `Edge` class to improve handle validation and data handling

- Consolidated imports and removed redundant `BaseModel` definitions for `SourceHandle` and `TargetHandle`.
- Added `valid_handles`, `target_param`, and `_target_handle` attributes to `Edge` class.
- Enhanced handle validation logic to distinguish between dictionary and string types.
- Introduced `to_data` method to return edge data.
- Updated attribute names to follow consistent naming conventions (`base_classes`, `input_types`, `field_name`).

* Refactor `Edge` class to improve handle validation and data handling

* Refactor: Standardize attribute naming and add `to_data` method in Edge class

- Renamed attributes to use snake_case consistently (`baseClasses` to `base_classes`, `inputTypes` to `input_types`, `fieldName` to `field_name`).
- Added `to_data` method to return `_data` attribute.
- Updated validation methods to use new attribute names.

* Refactor: Update Edge class to consistently use snake_case for attributes and improve validation logic for handles

* Refactor: Change node argument type in add_node and _create_vertex methods to NodeData for better type safety and clarity

* Refactor: Implement JSON serialization for graph data with `dumps` and `dump` methods, enhancing data export capabilities

* Refactor: Add pytest fixtures for ingestion and RAG graphs, enhance test structure for better clarity and organization

* Refactor: Add pytest fixtures for memory_chatbot_graph tests and improve test structure

* Refactor: Remove unused methods in ComponentVertex class to streamline code and improve readability

* Refactor: Remove unnecessary line in ComponentVertex class to enhance code clarity and maintainability

* Refactor: Update import path for DefaultPromptField to improve code organization and maintainability in api_utils.py

* Refactor: Update import path for DefaultPromptField to enhance code organization and maintainability in prompt.py

* fix: Remove  fixture in test_memory_chatbot.py that blocked db setup

* Refactor: Add durations path for unit tests to improve test reporting

* Refactor: Add splitting algorithm option for unit tests

* Add async option to Makefile for unit tests and update GitHub Actions workflow

- Introduced `async` variable in Makefile to conditionally run unit tests with or without parallel execution.
- Updated `unit_tests` target in Makefile to handle `async` flag.
- Modified GitHub Actions workflow to set `async=false` for unit tests.
  • Loading branch information
ogabrielluiz authored Aug 5, 2024
1 parent 8ece1ca commit bb1bc5c
Show file tree
Hide file tree
Showing 19 changed files with 798 additions and 108 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
with:
timeout_minutes: 12
max_attempts: 2
command: make unit_tests args="--splits ${{ matrix.splitCount }} --group ${{ matrix.group }}"
command: make unit_tests async=false args="--splits ${{ matrix.splitCount }} --group ${{ matrix.group }}"

test-cli:
name: Test CLI - Python ${{ matrix.python-version }}
Expand Down
17 changes: 14 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ env ?= .env
open_browser ?= true
path = src/backend/base/langflow/frontend
workers ?= 1

async ?= true
all: help

######################
Expand Down Expand Up @@ -130,14 +130,25 @@ coverage: ## run the tests and generate a coverage report
@poetry run coverage erase

unit_tests: ## run unit tests
ifeq ($(async), true)
poetry run pytest src/backend/tests \
--ignore=src/backend/tests/integration \
--instafail -n auto -ra -m "not api_key_required" \
--durations-path src/backend/tests/.test_durations \
--splitting-algorithm least_duration \
$(args)
else
poetry run pytest src/backend/tests \
--ignore=src/backend/tests/integration \
--instafail -ra -n auto -m "not api_key_required" \
--instafail -ra -m "not api_key_required" \
--durations-path src/backend/tests/.test_durations \
--splitting-algorithm least_duration \
$(args)
endif

integration_tests: ## run integration tests
poetry run pytest src/backend/tests/integration \
--instafail -ra -n auto \
--instafail -ra \
$(args)

tests: ## run unit, integration, coverage tests
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/prompts/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger

from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.template.field.prompt import DefaultPromptField
from langflow.inputs.inputs import DefaultPromptField


_INVALID_CHARACTERS = {
Expand Down
74 changes: 31 additions & 43 deletions src/backend/base/langflow/graph/edge/base.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,31 @@
from typing import TYPE_CHECKING, Any, List, Optional, cast
from typing import TYPE_CHECKING, Any, cast

from loguru import logger
from pydantic import BaseModel, Field, field_validator

from langflow.graph.edge.schema import EdgeData
from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict
from langflow.schema.schema import INPUT_FIELD_NAME

if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex


class SourceHandle(BaseModel):
baseClasses: list[str] = Field(default_factory=list, description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
name: Optional[str] = Field(None, description="Name of the source handle.")
output_types: List[str] = Field(default_factory=list, description="List of output types for the source handle.")

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, v, _info):
if _info.data["dataType"] == "GroupNode":
# 'OpenAIModel-u4iGV_text_output'
splits = v.split("_", 1)
if len(splits) != 2:
raise ValueError(f"Invalid source handle name {v}")
v = splits[1]
return v


class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")


class Edge:
def __init__(self, source: "Vertex", target: "Vertex", edge: EdgeData):
self.source_id: str = source.id if source else ""
self.target_id: str = target.id if target else ""
self.valid_handles: bool = False
self.target_param: str | None = None
self._target_handle: TargetHandleDict | str | None = None
self._data = edge.copy()
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = data.get("targetHandle", {})
self._target_handle = cast(TargetHandleDict, data.get("targetHandle", {}))
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
self.target_param = self.target_handle.fieldName
if isinstance(self._target_handle, dict):
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
else:
raise ValueError("Target handle is not a dictionary")
self.target_param = self.target_handle.field_name
# validate handles
self.validate_handles(source, target)
else:
Expand All @@ -55,23 +35,31 @@ def __init__(self, source: "Vertex", target: "Vertex", edge: EdgeData):
self._target_handle = edge.get("targetHandle", "") # type: ignore
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
# target_param is documents
self.target_param = cast(str, self._target_handle.split("|")[1]) # type: ignore
if isinstance(self._target_handle, str):
self.target_param = self._target_handle.split("|")[1]
self.source_handle = None
self.target_handle = None
else:
raise ValueError("Target handle is not a string")
# Validate in __init__ to fail fast
self.validate_edge(source, target)

def to_data(self):
return self._data

def validate_handles(self, source, target) -> None:
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_handles(source, target)
else:
self._validate_handles(source, target)

def _validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.output_types

elif self.source_handle.output_types is not None:
self.valid_handles = (
any(output_type in self.target_handle.inputTypes for output_type in self.source_handle.output_types)
any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types)
or self.target_handle.type in self.source_handle.output_types
)

Expand All @@ -81,12 +69,12 @@ def _validate_handles(self, source, target) -> None:
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")

def _legacy_validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.base_classes
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes)
or self.target_handle.type in self.source_handle.base_classes
)
if not self.valid_handles:
logger.debug(self.source_handle)
Expand All @@ -101,9 +89,9 @@ def __setstate__(self, state):
self.target_handle = state.get("target_handle")

def validate_edge(self, source, target) -> None:
# If the self.source_handle has baseClasses, then we are using the legacy
# If the self.source_handle has base_classes, then we are using the legacy
# way of defining the source and target handles
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_edge(source, target)
else:
self._validate_edge(source, target)
Expand Down Expand Up @@ -230,5 +218,5 @@ def __repr__(self) -> str:
if (hasattr(self, "source_handle") and self.source_handle) and (
hasattr(self, "target_handle") and self.target_handle
):
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.fieldName}]-> {self.target_id}"
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}"
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
6 changes: 4 additions & 2 deletions src/backend/base/langflow/graph/edge/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from pydantic import Field, field_validator
from pydantic import ConfigDict, Field, field_validator
from typing_extensions import TypedDict

from langflow.helpers.base_model import BaseModel
Expand Down Expand Up @@ -39,7 +39,8 @@ def format(self, sep: str = "\n") -> str:


class TargetHandle(BaseModel):
fieldName: str = Field(..., alias="fieldName", description="Field name for the target handle.")
model_config = ConfigDict(populate_by_name=True)
field_name: str = Field(..., alias="fieldName", description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
input_types: List[str] = Field(
default_factory=list, alias="inputTypes", description="List of input types for the target handle."
Expand All @@ -48,6 +49,7 @@ class TargetHandle(BaseModel):


class SourceHandle(BaseModel):
model_config = ConfigDict(populate_by_name=True)
base_classes: list[str] = Field(
default_factory=list, alias="baseClasses", description="List of base classes for the source handle."
)
Expand Down
48 changes: 41 additions & 7 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import uuid
from collections import defaultdict, deque
from datetime import datetime, timezone
Expand All @@ -14,11 +15,12 @@
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import VertexBuildResult
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex
from langflow.schema import Data
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
Expand Down Expand Up @@ -75,7 +77,7 @@ def __init__(
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self.state_manager = GraphStateManager()
self._vertices: List[dict] = []
self._vertices: List[NodeData] = []
self._edges: List[EdgeData] = []
self.top_level_vertices: List[str] = []
self.vertex_map: Dict[str, Vertex] = {}
Expand All @@ -86,6 +88,7 @@ def __init__(
self._run_queue: deque[str] = deque()
self._first_layer: List[str] = []
self._lock = asyncio.Lock()
self.raw_graph_data: GraphData = {"nodes": [], "edges": []}
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
Expand All @@ -97,7 +100,39 @@ def __init__(
if (start is not None and end is None) or (start is None and end is not None):
raise ValueError("You must provide both input and output components")

def add_nodes_and_edges(self, nodes: List[Dict], edges: List[EdgeData]):
def dumps(
self,
name: Optional[str] = None,
description: Optional[str] = None,
endpoint_name: Optional[str] = None,
) -> str:
graph_dict = self.dump(name, description, endpoint_name)
return json.dumps(graph_dict, indent=4, sort_keys=True)

def dump(
self, name: Optional[str] = None, description: Optional[str] = None, endpoint_name: Optional[str] = None
) -> GraphDump:
if self.raw_graph_data != {"nodes": [], "edges": []}:
data_dict = self.raw_graph_data
else:
# we need to convert the vertices and edges to json
nodes = [node.to_data() for node in self.vertices]
edges = [edge.to_data() for edge in self.edges]
self.raw_graph_data = {"nodes": nodes, "edges": edges}
data_dict = self.raw_graph_data
graph_dict: GraphDump = {
"data": data_dict,
"is_component": len(data_dict.get("nodes", [])) == 1 and data_dict["edges"] == [],
}
if name:
graph_dict["name"] = name
if description:
graph_dict["description"] = description
if endpoint_name:
graph_dict["endpoint_name"] = endpoint_name
return graph_dict

def add_nodes_and_edges(self, nodes: List[NodeData], edges: List[EdgeData]):
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
Expand Down Expand Up @@ -183,7 +218,7 @@ async def async_start(self, inputs: Optional[List[dict]] = None):
return

def start(self, inputs: Optional[List[dict]] = None) -> Generator:
#! Change this soon
#! Change this ASAP
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs)
Expand All @@ -208,8 +243,7 @@ def _add_edge(self, edge: EdgeData):
self.in_degree_map[target_id] += 1
self.parent_child_map[source_id].append(target_id)

# TODO: Create a TypedDict to represente the node
def add_node(self, node: dict):
def add_node(self, node: NodeData):
self._vertices.append(node)

def add_edge(self, edge: EdgeData):
Expand Down Expand Up @@ -1400,7 +1434,7 @@ def _build_vertices(self) -> List[Vertex]:

return vertices

def _create_vertex(self, frontend_data: dict):
def _create_vertex(self, frontend_data: NodeData):
vertex_data = frontend_data["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
Expand Down
25 changes: 25 additions & 0 deletions src/backend/base/langflow/graph/graph/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from typing import TYPE_CHECKING, NamedTuple

from typing_extensions import NotRequired, TypedDict

from langflow.graph.edge.schema import EdgeData
from langflow.graph.vertex.schema import NodeData

if TYPE_CHECKING:
from langflow.graph.schema import ResultData
from langflow.graph.vertex.base import Vertex


class ViewPort(TypedDict):
x: float
y: float
zoom: float


class GraphData(TypedDict):
nodes: list[NodeData]
edges: list[EdgeData]
viewport: NotRequired[ViewPort]


class GraphDump(TypedDict, total=False):
data: GraphData
is_component: bool
name: str
description: str
endpoint_name: str


class VertexBuildResult(NamedTuple):
result_dict: "ResultData"
params: str
Expand Down
8 changes: 6 additions & 2 deletions src/backend/base/langflow/graph/vertex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction
from langflow.graph.vertex.schema import NodeData
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
from langflow.schema.artifact import ArtifactType
Expand Down Expand Up @@ -42,7 +43,7 @@ class VertexStates(str, Enum):
class Vertex:
def __init__(
self,
data: Dict,
data: NodeData,
graph: "Graph",
base_type: Optional[str] = None,
is_task: bool = False,
Expand All @@ -63,7 +64,7 @@ def __init__(
self.has_external_input = False
self.has_external_output = False
self.graph = graph
self._data = data
self._data = data.copy()
self.base_type: Optional[str] = base_type
self.outputs: List[Dict] = []
self._parse_data()
Expand Down Expand Up @@ -101,6 +102,9 @@ def set_input_value(self, name: str, value: Any):
raise ValueError(f"Vertex {self.id} does not have a component instance.")
self._custom_component._set_input_value(name, value)

def to_data(self):
return self._data

def add_component_instance(self, component_instance: "Component"):
component_instance.set_vertex(self)
self._custom_component = component_instance
Expand Down
Loading

0 comments on commit bb1bc5c

Please sign in to comment.