Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: enhance CustomComponent class and updates tests #3201

Merged
merged 12 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ async def build_vertex(
artifacts = vertex_build_result.artifacts
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)

result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
if isinstance(exc, ComponentBuildException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def build_tool(self) -> Tool:
return_direct=self.return_direct,
inputs=inputs,
flow_id=str(flow_data.id),
user_id=str(self._user_id),
user_id=str(self.user_id),
)
description_repr = repr(tool.description).strip("'")
args_str = "\n".join([f"- {arg_name}: {arg_data['description']}" for arg_name, arg_data in tool.args.items()])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import operator
import warnings
from typing import Any, ClassVar, Optional
from uuid import UUID
import warnings

from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from langflow.utils import validate

if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackHandler

from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
from langflow.services.tracing.service import TracingService
from langchain.callbacks.base import BaseCallbackHandler


class CustomComponent(BaseComponent):
Expand Down Expand Up @@ -74,7 +75,7 @@ class CustomComponent(BaseComponent):
"""The build parameters of the component. Defaults to None."""
_vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
_code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
Expand All @@ -85,6 +86,20 @@ class CustomComponent(BaseComponent):
_logs: List[Log] = []
_output_logs: dict[str, Log] = {}
_tracing_service: Optional["TracingService"] = None
_tree: Optional[dict] = None

def __init__(self, **data):
"""
Initializes a new instance of the CustomComponent class.

Args:
**data: Additional keyword arguments to initialize the custom component.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
self._results = {}
self._artifacts = {}
super().__init__(**data)

def set_attributes(self, parameters: dict):
pass
Expand Down Expand Up @@ -133,19 +148,6 @@ def get_state(self, name: str):
except Exception as e:
raise ValueError(f"Error getting state: {e}")

_tree: Optional[dict] = None

def __init__(self, **data):
"""
Initializes a new instance of the CustomComponent class.

Args:
**data: Additional keyword arguments to initialize the custom component.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
super().__init__(**data)

@staticmethod
def resolve_path(path: str) -> str:
"""Resolves the path to an absolute path."""
Expand All @@ -169,6 +171,20 @@ def get_full_path(self, path: str) -> str:
def graph(self):
return self._vertex.graph

@property
def user_id(self):
if hasattr(self, "_user_id"):
return self._user_id
return self.graph.user_id

@property
def flow_id(self):
return self.graph.flow_id

@property
def flow_name(self):
return self.graph.flow_name

def _get_field_order(self):
return self.field_order or list(self.field_config.keys())

Expand Down Expand Up @@ -305,7 +321,7 @@ def get_function_entrypoint_args(self) -> list:
Returns:
list: The arguments of the function entrypoint.
"""
build_method = self.get_method(self.function_entrypoint_name)
build_method = self.get_method(self._function_entrypoint_name)
if not build_method:
return []

Expand Down Expand Up @@ -346,9 +362,9 @@ def get_function_entrypoint_return_type(self) -> List[Any]:
Returns:
List[Any]: The return type of the function entrypoint.
"""
return self.get_method_return_type(self.function_entrypoint_name)
return self.get_method_return_type(self._function_entrypoint_name)

def _extract_return_type(self, return_type: Any):
def _extract_return_type(self, return_type: Any) -> List[Any]:
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,
Expand All @@ -374,8 +390,8 @@ def get_main_class_name(self):
if not self._code:
return ""

base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
base_name = self._code_class_base_inheritance
method_name = self._function_entrypoint_name

classes = []
for item in self.tree.get("classes", []):
Expand Down Expand Up @@ -412,12 +428,12 @@ def variables(self):
"""

def get_variable(name: str, field: str):
if hasattr(self, "_user_id") and not self._user_id:
if hasattr(self, "_user_id") and not self.user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
variable_service = get_variable_service() # Get service instance
# Retrieve and decrypt the variable by name for the current user
with session_scope() as session:
user_id = self._user_id or ""
user_id = self.user_id or ""
return variable_service.get_variable(user_id=user_id, name=name, field=field, session=session)

return get_variable
Expand All @@ -432,12 +448,12 @@ def list_key_names(self):
Returns:
List[str]: The names of the variables for the current user.
"""
if hasattr(self, "_user_id") and not self._user_id:
if hasattr(self, "_user_id") and not self.user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
variable_service = get_variable_service()

with session_scope() as session:
return variable_service.list_variables(user_id=self._user_id, session=session)
return variable_service.list_variables(user_id=self.user_id, session=session)

def index(self, value: int = 0):
"""
Expand All @@ -462,10 +478,10 @@ def get_function(self):
Returns:
Callable: The function associated with the custom component.
"""
return validate.create_function(self._code, self.function_entrypoint_name)
return validate.create_function(self._code, self._function_entrypoint_name)

async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
if not self._user_id:
if not self.user_id:
raise ValueError("Session is invalid")
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)

Expand All @@ -487,7 +503,7 @@ async def run_flow(
)

def list_flows(self) -> List[Data]:
if not self._user_id:
if not self.user_id:
raise ValueError("Session is invalid")
try:
return list_flows(user_id=str(self._user_id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from loguru import logger

from langflow.custom import CustomComponent
from langflow.custom import Component


class CustomComponentPathValueError(ValueError):
Expand Down Expand Up @@ -373,7 +373,7 @@ def get_output_types_from_code(code: str) -> list:
"""
Get the output types from the code.
"""
custom_component = CustomComponent(_code=code)
custom_component = Component(_code=code)
types_list = custom_component.get_function_entrypoint_return_type

# Get the name of types classes
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/custom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def build_custom_component_template_from_inputs(
frontend_node.validate_component()
# ! This should be removed when we have a better way to handle this
frontend_node.set_base_classes_from_outputs()
reorder_fields(frontend_node, custom_component._get_field_order())
cc_instance = get_component_instance(custom_component, user_id=user_id)
reorder_fields(frontend_node, cc_instance._get_field_order())

return frontend_node.to_dict(keep_name=False), cc_instance


Expand Down
12 changes: 9 additions & 3 deletions src/backend/base/langflow/schema/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@


def _timestamp_to_str(timestamp: datetime | str) -> str:
if isinstance(timestamp, datetime):
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
return timestamp
if isinstance(timestamp, str):
# Just check if the string is a valid datetime
try:
datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
return timestamp
except ValueError:
raise ValueError(f"Invalid timestamp: {timestamp}")
return timestamp.strftime("%Y-%m-%d %H:%M:%S")


class Message(Data):
Expand Down Expand Up @@ -163,6 +168,7 @@ def sync_get_file_content_dicts(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)

# Keep this async method for backwards compatibility
async def get_file_content_dicts(self):
content_dicts = []
files = await get_file_paths(self.files)
Expand Down
Loading
Loading