Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dump and dumps methods to Graph #3202

Merged
merged 26 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3927bf5
feat(utils.py): add escape_json_dump function to escape JSON strings …
ogabrielluiz Aug 5, 2024
6c40f7b
refactor(Output): streamline add_types method to prevent duplicate en…
ogabrielluiz Aug 5, 2024
a744489
feat(data.py): add classmethod decorator to validate_data for enhance…
ogabrielluiz Aug 5, 2024
8254275
feat(setup.py): implement retry logic for loading starter projects to…
ogabrielluiz Aug 5, 2024
fd770f2
fix(input_mixin.py): improve model_config formatting and update field…
ogabrielluiz Aug 5, 2024
22fd048
feat(types.py): refactor vertex constructors to use NodeData and add …
ogabrielluiz Aug 5, 2024
e75d06e
feat(schema.py): add NodeData and Position TypedDicts for improved ty…
ogabrielluiz Aug 5, 2024
7e65921
feat(base.py): update Vertex to use NodeData type and add to_data met…
ogabrielluiz Aug 5, 2024
862ae9c
refactor(schema.py): update TargetHandle and SourceHandle models to i…
ogabrielluiz Aug 5, 2024
ac5e541
Add TypedDict classes for graph schema serialization in `schema.py`
ogabrielluiz Aug 5, 2024
3006402
Refactor `Edge` class to improve handle validation and data handling
ogabrielluiz Aug 5, 2024
cf69f83
Refactor `Edge` class to improve handle validation and data handling
ogabrielluiz Aug 5, 2024
6ea2715
Refactor: Standardize attribute naming and add `to_data` method in Ed…
ogabrielluiz Aug 5, 2024
c6f31f8
Refactor: Update Edge class to consistently use snake_case for attrib…
ogabrielluiz Aug 5, 2024
8983260
Refactor: Change node argument type in add_node and _create_vertex me…
ogabrielluiz Aug 5, 2024
5f24db8
Refactor: Implement JSON serialization for graph data with `dumps` an…
ogabrielluiz Aug 5, 2024
4239ab6
Refactor: Add pytest fixtures for ingestion and RAG graphs, enhance t…
ogabrielluiz Aug 5, 2024
46fec80
Refactor: Add pytest fixtures for memory_chatbot_graph tests and impr…
ogabrielluiz Aug 5, 2024
78ff654
Refactor: Remove unused methods in ComponentVertex class to streamlin…
ogabrielluiz Aug 5, 2024
cc3fc12
Refactor: Remove unnecessary line in ComponentVertex class to enhance…
ogabrielluiz Aug 5, 2024
2a4c500
Refactor: Update import path for DefaultPromptField to improve code o…
ogabrielluiz Aug 5, 2024
d84488b
Refactor: Update import path for DefaultPromptField to enhance code o…
ogabrielluiz Aug 5, 2024
8bdb2ef
fix: Remove fixture in test_memory_chatbot.py that blocked db setup
ogabrielluiz Aug 5, 2024
420aaa7
Refactor: Add durations path for unit tests to improve test reporting
ogabrielluiz Aug 5, 2024
4746b80
Refactor: Add splitting algorithm option for unit tests
ogabrielluiz Aug 5, 2024
5f28fb2
Add async option to Makefile for unit tests and update GitHub Actions…
ogabrielluiz Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
with:
timeout_minutes: 12
max_attempts: 2
command: make unit_tests args="--splits ${{ matrix.splitCount }} --group ${{ matrix.group }}"
command: make unit_tests async=false args="--splits ${{ matrix.splitCount }} --group ${{ matrix.group }}"

test-cli:
name: Test CLI - Python ${{ matrix.python-version }}
Expand Down
17 changes: 14 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ env ?= .env
open_browser ?= true
path = src/backend/base/langflow/frontend
workers ?= 1

async ?= true
all: help

######################
Expand Down Expand Up @@ -130,14 +130,25 @@ coverage: ## run the tests and generate a coverage report
@poetry run coverage erase

unit_tests: ## run unit tests
ifeq ($(async), true)
poetry run pytest src/backend/tests \
--ignore=src/backend/tests/integration \
--instafail -n auto -ra -m "not api_key_required" \
--durations-path src/backend/tests/.test_durations \
--splitting-algorithm least_duration \
$(args)
else
poetry run pytest src/backend/tests \
--ignore=src/backend/tests/integration \
--instafail -ra -n auto -m "not api_key_required" \
--instafail -ra -m "not api_key_required" \
--durations-path src/backend/tests/.test_durations \
--splitting-algorithm least_duration \
$(args)
endif

integration_tests: ## run integration tests
poetry run pytest src/backend/tests/integration \
--instafail -ra -n auto \
--instafail -ra \
$(args)

tests: ## run unit, integration, coverage tests
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/prompts/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger

from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.template.field.prompt import DefaultPromptField
from langflow.inputs.inputs import DefaultPromptField


_INVALID_CHARACTERS = {
Expand Down
74 changes: 31 additions & 43 deletions src/backend/base/langflow/graph/edge/base.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,31 @@
from typing import TYPE_CHECKING, Any, List, Optional, cast
from typing import TYPE_CHECKING, Any, cast

from loguru import logger
from pydantic import BaseModel, Field, field_validator

from langflow.graph.edge.schema import EdgeData
from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict
from langflow.schema.schema import INPUT_FIELD_NAME

if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex


class SourceHandle(BaseModel):
baseClasses: list[str] = Field(default_factory=list, description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
name: Optional[str] = Field(None, description="Name of the source handle.")
output_types: List[str] = Field(default_factory=list, description="List of output types for the source handle.")

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, v, _info):
if _info.data["dataType"] == "GroupNode":
# 'OpenAIModel-u4iGV_text_output'
splits = v.split("_", 1)
if len(splits) != 2:
raise ValueError(f"Invalid source handle name {v}")
v = splits[1]
return v


class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")


class Edge:
def __init__(self, source: "Vertex", target: "Vertex", edge: EdgeData):
self.source_id: str = source.id if source else ""
self.target_id: str = target.id if target else ""
self.valid_handles: bool = False
self.target_param: str | None = None
self._target_handle: TargetHandleDict | str | None = None
self._data = edge.copy()
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = data.get("targetHandle", {})
self._target_handle = cast(TargetHandleDict, data.get("targetHandle", {}))
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
self.target_param = self.target_handle.fieldName
if isinstance(self._target_handle, dict):
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
else:
raise ValueError("Target handle is not a dictionary")
self.target_param = self.target_handle.field_name
# validate handles
self.validate_handles(source, target)
else:
Expand All @@ -55,23 +35,31 @@ def __init__(self, source: "Vertex", target: "Vertex", edge: EdgeData):
self._target_handle = edge.get("targetHandle", "") # type: ignore
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
# target_param is documents
self.target_param = cast(str, self._target_handle.split("|")[1]) # type: ignore
if isinstance(self._target_handle, str):
self.target_param = self._target_handle.split("|")[1]
self.source_handle = None
self.target_handle = None
else:
raise ValueError("Target handle is not a string")
# Validate in __init__ to fail fast
self.validate_edge(source, target)

def to_data(self):
return self._data

def validate_handles(self, source, target) -> None:
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_handles(source, target)
else:
self._validate_handles(source, target)

def _validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.output_types

elif self.source_handle.output_types is not None:
self.valid_handles = (
any(output_type in self.target_handle.inputTypes for output_type in self.source_handle.output_types)
any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types)
or self.target_handle.type in self.source_handle.output_types
)

Expand All @@ -81,12 +69,12 @@ def _validate_handles(self, source, target) -> None:
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")

def _legacy_validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.base_classes
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes)
or self.target_handle.type in self.source_handle.base_classes
)
if not self.valid_handles:
logger.debug(self.source_handle)
Expand All @@ -101,9 +89,9 @@ def __setstate__(self, state):
self.target_handle = state.get("target_handle")

def validate_edge(self, source, target) -> None:
# If the self.source_handle has baseClasses, then we are using the legacy
# If the self.source_handle has base_classes, then we are using the legacy
# way of defining the source and target handles
if isinstance(self._source_handle, str) or self.source_handle.baseClasses:
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_edge(source, target)
else:
self._validate_edge(source, target)
Expand Down Expand Up @@ -230,5 +218,5 @@ def __repr__(self) -> str:
if (hasattr(self, "source_handle") and self.source_handle) and (
hasattr(self, "target_handle") and self.target_handle
):
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.fieldName}]-> {self.target_id}"
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}"
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
6 changes: 4 additions & 2 deletions src/backend/base/langflow/graph/edge/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from pydantic import Field, field_validator
from pydantic import ConfigDict, Field, field_validator
from typing_extensions import TypedDict

from langflow.helpers.base_model import BaseModel
Expand Down Expand Up @@ -39,7 +39,8 @@ def format(self, sep: str = "\n") -> str:


class TargetHandle(BaseModel):
fieldName: str = Field(..., alias="fieldName", description="Field name for the target handle.")
model_config = ConfigDict(populate_by_name=True)
field_name: str = Field(..., alias="fieldName", description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
input_types: List[str] = Field(
default_factory=list, alias="inputTypes", description="List of input types for the target handle."
Expand All @@ -48,6 +49,7 @@ class TargetHandle(BaseModel):


class SourceHandle(BaseModel):
model_config = ConfigDict(populate_by_name=True)
base_classes: list[str] = Field(
default_factory=list, alias="baseClasses", description="List of base classes for the source handle."
)
Expand Down
48 changes: 41 additions & 7 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import uuid
from collections import defaultdict, deque
from datetime import datetime, timezone
Expand All @@ -14,11 +15,12 @@
from langflow.graph.edge.schema import EdgeData
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import VertexBuildResult
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
from langflow.graph.vertex.schema import NodeData
from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex
from langflow.schema import Data
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
Expand Down Expand Up @@ -75,7 +77,7 @@ def __init__(
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self.state_manager = GraphStateManager()
self._vertices: List[dict] = []
self._vertices: List[NodeData] = []
self._edges: List[EdgeData] = []
self.top_level_vertices: List[str] = []
self.vertex_map: Dict[str, Vertex] = {}
Expand All @@ -86,6 +88,7 @@ def __init__(
self._run_queue: deque[str] = deque()
self._first_layer: List[str] = []
self._lock = asyncio.Lock()
self.raw_graph_data: GraphData = {"nodes": [], "edges": []}
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
Expand All @@ -97,7 +100,39 @@ def __init__(
if (start is not None and end is None) or (start is None and end is not None):
raise ValueError("You must provide both input and output components")

def add_nodes_and_edges(self, nodes: List[Dict], edges: List[EdgeData]):
def dumps(
self,
name: Optional[str] = None,
description: Optional[str] = None,
endpoint_name: Optional[str] = None,
) -> str:
graph_dict = self.dump(name, description, endpoint_name)
return json.dumps(graph_dict, indent=4, sort_keys=True)

def dump(
self, name: Optional[str] = None, description: Optional[str] = None, endpoint_name: Optional[str] = None
) -> GraphDump:
if self.raw_graph_data != {"nodes": [], "edges": []}:
data_dict = self.raw_graph_data
else:
# we need to convert the vertices and edges to json
nodes = [node.to_data() for node in self.vertices]
edges = [edge.to_data() for edge in self.edges]
self.raw_graph_data = {"nodes": nodes, "edges": edges}
data_dict = self.raw_graph_data
graph_dict: GraphDump = {
"data": data_dict,
"is_component": len(data_dict.get("nodes", [])) == 1 and data_dict["edges"] == [],
}
if name:
graph_dict["name"] = name
if description:
graph_dict["description"] = description
if endpoint_name:
graph_dict["endpoint_name"] = endpoint_name
return graph_dict

def add_nodes_and_edges(self, nodes: List[NodeData], edges: List[EdgeData]):
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
Expand Down Expand Up @@ -183,7 +218,7 @@ async def async_start(self, inputs: Optional[List[dict]] = None):
return

def start(self, inputs: Optional[List[dict]] = None) -> Generator:
#! Change this soon
#! Change this ASAP
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async_gen = self.async_start(inputs)
Expand All @@ -208,8 +243,7 @@ def _add_edge(self, edge: EdgeData):
self.in_degree_map[target_id] += 1
self.parent_child_map[source_id].append(target_id)

# TODO: Create a TypedDict to represente the node
def add_node(self, node: dict):
def add_node(self, node: NodeData):
self._vertices.append(node)

def add_edge(self, edge: EdgeData):
Expand Down Expand Up @@ -1400,7 +1434,7 @@ def _build_vertices(self) -> List[Vertex]:

return vertices

def _create_vertex(self, frontend_data: dict):
def _create_vertex(self, frontend_data: NodeData):
vertex_data = frontend_data["data"]
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
Expand Down
25 changes: 25 additions & 0 deletions src/backend/base/langflow/graph/graph/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from typing import TYPE_CHECKING, NamedTuple

from typing_extensions import NotRequired, TypedDict

from langflow.graph.edge.schema import EdgeData
from langflow.graph.vertex.schema import NodeData

if TYPE_CHECKING:
from langflow.graph.schema import ResultData
from langflow.graph.vertex.base import Vertex


class ViewPort(TypedDict):
x: float
y: float
zoom: float


class GraphData(TypedDict):
nodes: list[NodeData]
edges: list[EdgeData]
viewport: NotRequired[ViewPort]


class GraphDump(TypedDict, total=False):
data: GraphData
is_component: bool
name: str
description: str
endpoint_name: str


class VertexBuildResult(NamedTuple):
result_dict: "ResultData"
params: str
Expand Down
8 changes: 6 additions & 2 deletions src/backend/base/langflow/graph/vertex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction
from langflow.graph.vertex.schema import NodeData
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
from langflow.schema.artifact import ArtifactType
Expand Down Expand Up @@ -42,7 +43,7 @@ class VertexStates(str, Enum):
class Vertex:
def __init__(
self,
data: Dict,
data: NodeData,
graph: "Graph",
base_type: Optional[str] = None,
is_task: bool = False,
Expand All @@ -63,7 +64,7 @@ def __init__(
self.has_external_input = False
self.has_external_output = False
self.graph = graph
self._data = data
self._data = data.copy()
self.base_type: Optional[str] = base_type
self.outputs: List[Dict] = []
self._parse_data()
Expand Down Expand Up @@ -101,6 +102,9 @@ def set_input_value(self, name: str, value: Any):
raise ValueError(f"Vertex {self.id} does not have a component instance.")
self._custom_component._set_input_value(name, value)

def to_data(self):
return self._data

def add_component_instance(self, component_instance: "Component"):
component_instance.set_vertex(self)
self._custom_component = component_instance
Expand Down
Loading
Loading