diff --git a/scripts/run_local_cache_checks.py b/scripts/run_local_cache_checks.py new file mode 100644 index 000000000000..ec79357761b5 --- /dev/null +++ b/scripts/run_local_cache_checks.py @@ -0,0 +1,149 @@ +"""Local cache normalization validation script. + +This script validates cache normalization functionality by testing +the normalizer module and simulating ChatService cache operations. +""" + +import asyncio +import importlib.util +import pickle +import sys +import types +from pathlib import Path + +# Adjust sys.path for src-layout imports +ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(ROOT / "src" / "lfx" / "src")) +sys.path.insert(0, str(ROOT / "src" / "backend" / "base")) + + +def _load_normalizer(): + """Load the normalizer module dynamically. + + Returns: + Module: The loaded normalizer module. + + Raises: + ImportError: If the normalizer module cannot be loaded. + """ + path = ROOT / "src" / "lfx" / "src" / "lfx" / "serialization" / "normalizer.py" + spec = importlib.util.spec_from_file_location("_normalizer_local", path) + if not spec or not spec.loader: + msg = f"Cannot load normalizer from {path}" + raise ImportError(msg) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +_normalizer = _load_normalizer() +normalize_for_cache = _normalizer.normalize_for_cache # type: ignore[attr-defined] + +# Preload modules to avoid heavy lfx.serialization imports (numpy, pandas) + +_serialization_pkg = types.ModuleType("lfx.serialization") +sys.modules["lfx.serialization"] = _serialization_pkg +sys.modules["lfx.serialization.normalizer"] = _normalizer + +# Provide a minimal dill shim for imports in cache.service +_dill = types.ModuleType("dill") +_dill.dumps = lambda obj, *_args, **_kwargs: pickle.dumps(obj) +_dill.loads = lambda b: pickle.loads(b) # noqa: S301 +sys.modules["dill"] = _dill + + +def check_normalizer(): + """Test cache normalization functionality. + + Validates that the normalizer correctly handles dynamic classes, + functions, and vertex snapshots. + + Raises: + AssertionError: If normalization tests fail. + """ + dynamic_type = type("Dynamic", (), {"x": 1}) + + def dyn_func(): + return 42 + + test_value = 123 + obj = {"cls": dynamic_type, "func": dyn_func, "value": test_value} + out = normalize_for_cache(obj) + + if out["value"] != test_value: + msg = f"Expected value {test_value}, got {out['value']}" + raise ValueError(msg) + if "__class_path__" not in out["cls"]: + msg = "Missing __class_path__ in normalized class" + raise ValueError(msg) + if "__callable_path__" not in out["func"]: + msg = "Missing __callable_path__ in normalized function" + raise ValueError(msg) + + vertex_snapshot = { + "built": True, + "results": {"x": 1}, + "artifacts": {}, + "built_object": dyn_func, + "built_result": {"y": 2}, + "full_data": {"id": "v1"}, + } + ov = normalize_for_cache(vertex_snapshot) + + if ov["__cache_vertex__"] is not True: + msg = "Expected __cache_vertex__ to be True" + raise ValueError(msg) + if ov["built_object"] != {"__cache_placeholder__": "unbuilt"}: + msg = f"Expected built_object placeholder, got {ov['built_object']}" + raise ValueError(msg) + + +async def check_chatservice(): + """Test ChatService cache behavior simulation. + + Simulates ChatService.set_cache behavior using normalize_for_cache + since the environment lacks optional dependencies for full ChatService import. + + Raises: + ValueError: If chat service simulation tests fail. + """ + # Environment lacks optional dependencies to import ChatService. + # Instead, simulate ChatService.set_cache behavior using normalize_for_cache directly. + dynamic_cls = type("C", (), {}) + value = { + "built": True, + "results": {"ok": 1}, + "built_object": dynamic_cls, + "artifacts": {}, + "built_result": {"foo": "bar"}, + "full_data": {"id": "v"}, + } + normalized = normalize_for_cache(value) + envelope = {"result": normalized, "type": "normalized", "__envelope_version__": 1} + + if envelope["type"] != "normalized": + msg = f"Expected envelope type 'normalized', got {envelope['type']}" + raise ValueError(msg) + + result = envelope["result"] + if result["__cache_vertex__"] is not True: + msg = "Expected __cache_vertex__ to be True in result" + raise ValueError(msg) + if result["built_object"] != {"__cache_placeholder__": "unbuilt"}: + msg = f"Expected built_object placeholder in result, got {result['built_object']}" + raise ValueError(msg) + + +def main(): + """Run all local cache validation tests. + + Executes normalizer and chat service tests to validate + cache functionality. + """ + check_normalizer() + asyncio.run(check_chatservice()) + print("LOCAL CACHE CHECKS: OK") + + +if __name__ == "__main__": + main() diff --git a/src/backend/base/langflow/api/health_check_router.py b/src/backend/base/langflow/api/health_check_router.py index 84968b1af243..f57e0f472acd 100644 --- a/src/backend/base/langflow/api/health_check_router.py +++ b/src/backend/base/langflow/api/health_check_router.py @@ -1,11 +1,40 @@ import uuid from fastapi import APIRouter, HTTPException, status -from lfx.log.logger import logger + +# Try to import from lfx, fallback to async-compatible wrapper if unavailable +try: + from lfx.log.logger import logger +except ImportError: + import logging + from typing import Any + + class _AsyncLogger: + """Async-compatible wrapper over standard logging.Logger. + + Provides awaitable methods used in this module (e.g., aexception, ainfo) + to avoid attribute errors when lfx logger is not installed. + """ + + def __init__(self, base: logging.Logger) -> None: + self._base = base + + # Pass-through for unknown attributes (sync logging API) + def __getattr__(self, name: str) -> Any: # pragma: no cover - thin shim + return getattr(self._base, name) + + async def aexception(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._base.exception(msg, *args, **kwargs) + + async def ainfo(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._base.info(msg, *args, **kwargs) + + logger = _AsyncLogger(logging.getLogger(__name__)) from pydantic import BaseModel from sqlmodel import select from langflow.api.utils import DbSession +from langflow.services.cache.utils import is_rich_pickle_enabled, validate_rich_pickle_support from langflow.services.database.models.flow.model import Flow from langflow.services.deps import get_chat_service @@ -16,6 +45,7 @@ class HealthResponse(BaseModel): status: str = "nok" chat: str = "error check the server logs" db: str = "error check the server logs" + rich_pickle: str = "not_checked" """ Do not send exceptions and detailed error messages to the client because it might contain credentials and other sensitive server information. @@ -59,6 +89,19 @@ async def health_check( except Exception: # noqa: BLE001 await logger.aexception("Error checking chat service") + # Check Rich pickle support status + try: + if is_rich_pickle_enabled(): + if validate_rich_pickle_support(): + response.rich_pickle = "ok" + else: + response.rich_pickle = "enabled_but_validation_failed" + else: + response.rich_pickle = "not_enabled" + except Exception: # noqa: BLE001 + await logger.aexception("Error checking Rich pickle support") + response.rich_pickle = "error check the server logs" + if response.has_error(): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=response.model_dump()) response.status = "ok" diff --git a/src/backend/base/langflow/services/cache/__init__.py b/src/backend/base/langflow/services/cache/__init__.py index 72f74d7dadea..674eeaddc74a 100644 --- a/src/backend/base/langflow/services/cache/__init__.py +++ b/src/backend/base/langflow/services/cache/__init__.py @@ -1,12 +1,17 @@ from langflow.services.cache.service import AsyncInMemoryCache, CacheService, RedisCache, ThreadingInMemoryCache +from langflow.services.cache.utils import is_rich_pickle_enabled, setup_rich_pickle_support from . import factory, service +# Setup Rich pickle support on module import +_rich_pickle_enabled = setup_rich_pickle_support() + __all__ = [ "AsyncInMemoryCache", "CacheService", "RedisCache", "ThreadingInMemoryCache", "factory", + "is_rich_pickle_enabled", "service", ] diff --git a/src/backend/base/langflow/services/cache/factory.py b/src/backend/base/langflow/services/cache/factory.py index b0f08c15e647..44abf0a88ba0 100644 --- a/src/backend/base/langflow/services/cache/factory.py +++ b/src/backend/base/langflow/services/cache/factory.py @@ -2,11 +2,19 @@ from typing import TYPE_CHECKING -from lfx.log.logger import logger +# Try to import logger, fallback to standard logging if lfx not available +try: + from lfx.log.logger import logger +except ImportError: + import logging + + logger = logging.getLogger(__name__) + from typing_extensions import override from langflow.services.cache.disk import AsyncDiskCache from langflow.services.cache.service import AsyncInMemoryCache, CacheService, RedisCache, ThreadingInMemoryCache +from langflow.services.cache.utils import setup_rich_pickle_support, validate_rich_pickle_support from langflow.services.factory import ServiceFactory if TYPE_CHECKING: @@ -16,15 +24,30 @@ class CacheServiceFactory(ServiceFactory): def __init__(self) -> None: super().__init__(CacheService) + # Setup Rich pickle support when factory is initialized + self._rich_pickle_enabled = setup_rich_pickle_support() + if self._rich_pickle_enabled: + logger.debug("Rich pickle support enabled for cache serialization") + # Optionally validate the support + if validate_rich_pickle_support(): + logger.debug("Rich pickle support validation successful") + else: + logger.warning("Rich pickle support validation failed") + else: + logger.info("Rich pickle support could not be enabled") @override def create(self, settings_service: SettingsService): # Here you would have logic to create and configure a CacheService # based on the settings_service + # Debug: Log the cache type being used + cache_type = settings_service.settings.cache_type + logger.info("Cache factory creating cache with type: %s", cache_type) + if settings_service.settings.cache_type == "redis": logger.debug("Creating Redis cache") - return RedisCache( + cache: RedisCache = RedisCache( host=settings_service.settings.redis_host, port=settings_service.settings.redis_port, db=settings_service.settings.redis_db, @@ -32,6 +55,16 @@ def create(self, settings_service: SettingsService): expiration_time=settings_service.settings.redis_cache_expire, ) + # Log Rich pickle status for Redis caches + if self._rich_pickle_enabled: + logger.info("Redis cache created with Rich object serialization support") + else: + logger.warning( + "Redis cache created without Rich object serialization - may cause issues with console objects" + ) + + return cache + if settings_service.settings.cache_type == "memory": return ThreadingInMemoryCache(expiration_time=settings_service.settings.cache_expire) if settings_service.settings.cache_type == "async": diff --git a/src/backend/base/langflow/services/cache/service.py b/src/backend/base/langflow/services/cache/service.py index ba34b4231313..c09123c48f87 100644 --- a/src/backend/base/langflow/services/cache/service.py +++ b/src/backend/base/langflow/services/cache/service.py @@ -2,6 +2,7 @@ import pickle import threading import time +import warnings from collections import OrderedDict from typing import Generic, Union @@ -232,6 +233,96 @@ async def is_connected(self) -> bool: return False return True + # -- Internal helpers ------------------------------------------------- + def _sanitize_for_pickle(self, obj): + """Sanitize objects known to cause dill recursion issues. + + Specifically targets dynamically created Pydantic models like + lfx.io.schema.InputSchema (both class objects and instances). + Falls back to identity for everything else. + """ + try: + from pydantic import BaseModel # type: ignore[import-untyped] + + base_model = BaseModel + except ImportError: # Failed to import pydantic + base_model = None # type: ignore[assignment] + + visited: set[int] = set() + + def _walk(value): + vid = id(value) + if vid in visited: + # Return a lightweight cycle marker to avoid reintroducing + # the original (potentially unpicklable) object. + return {"__cycle__": True} + visited.add(vid) + + # Replace InputSchema classes with a placeholder reference + if isinstance(value, type): + mod = getattr(value, "__module__", "") + name = getattr(value, "__name__", "") + # Avoid importing pydantic just to check subclassing; use module+name + if mod.startswith("lfx.io.schema") and name == "InputSchema": + return {"__lfx_skipped_class__": f"{mod}.{name}"} + # For any class, store a lightweight path to avoid pickling class objects + return {"__class_path__": f"{mod}.{name}"} + + # Replace InputSchema instances with plain data + if base_model is not None and isinstance(value, base_model): # type: ignore[arg-type] + cls = value.__class__ + mod = getattr(cls, "__module__", "") + name = getattr(cls, "__name__", "") + if mod.startswith("lfx.io.schema") and name == "InputSchema": + try: + return value.model_dump() + except Exception: # noqa: BLE001 + return dict(value.__dict__) + + # Replace callables (functions, methods, lambdas) with path-like hint or repr + try: + import inspect + + if inspect.isfunction(value) or inspect.ismethod(value) or inspect.isbuiltin(value): + mod = getattr(value, "__module__", "") + name = getattr(value, "__qualname__", getattr(value, "__name__", "")) + return {"__callable_path__": f"{mod}.{name}"} + except (AttributeError, TypeError, ValueError): # Some callables may not have introspectable attributes + logger.debug("Failed to introspect callable for cache serialization") + + # Replace instances of dynamically created or custom component classes + # that commonly resist pickling. + cls = ( + value.__class__ + if not isinstance(value, (dict, list, tuple, set, bytes, bytearray, str, int, float, bool, type(None))) + else None + ) + if cls is not None: + mod = getattr(cls, "__module__", "") + qual = getattr(cls, "__qualname__", getattr(cls, "__name__", "")) + if mod.startswith("lfx.custom") or "" in qual or mod in {"__main__", "builtins"}: + # Best-effort shallow representation + try: + return {"__repr__": repr(value)} + except (AttributeError, TypeError, ValueError, RecursionError): # repr() can fail + return {"__class__": f"{mod}.{qual}"} + + # Containers + if isinstance(value, dict): + return {k: _walk(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + seq = [_walk(v) for v in value] + if isinstance(value, tuple): + return tuple(seq) + if isinstance(value, set): + # Sets cannot contain dicts (unhashable) — encode as a list with a marker + return {"__set__": seq} + return seq + + return value + + return _walk(obj) + @override async def get(self, key, lock=None): if key is None: @@ -242,14 +333,36 @@ async def get(self, key, lock=None): @override async def set(self, key, value, lock=None) -> None: try: - if pickled := dill.dumps(value, recurse=True): - result = await self._client.setex(str(key), self.expiration_time, pickled) - if not result: - msg = "RedisCache could not set the value." - raise ValueError(msg) - except pickle.PicklingError as exc: - msg = "RedisCache only accepts values that can be pickled. " - raise TypeError(msg) from exc + # First attempt: try to pickle as-is, suppressing noisy PicklingWarnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + # Try ignoring dill's own PicklingWarning if available + try: + from dill._dill import PicklingWarning as _DillPicklingWarning # type: ignore[import-untyped] + + warnings.simplefilter("ignore", category=_DillPicklingWarning) + except ImportError: # dill._dill may not be available + logger.debug("Could not import dill PicklingWarning") + pickled = dill.dumps(value, recurse=False, byref=True) + except (AttributeError, TypeError, ValueError, RecursionError) as e: + # Fallback: sanitize value to strip problematic dynamic schemas + logger.debug("Initial pickle attempt failed: %s, trying sanitized version", e) + sanitized = self._sanitize_for_pickle(value) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + try: + from dill._dill import PicklingWarning as _DillPicklingWarning # type: ignore[import-untyped] + + warnings.simplefilter("ignore", category=_DillPicklingWarning) + except ImportError: + logger.debug("Could not import dill PicklingWarning for sanitized pickle") + pickled = dill.dumps(sanitized, recurse=False, byref=True) + + if pickled: + result = await self._client.setex(str(key), self.expiration_time, pickled) + if not result: + msg = "RedisCache could not set the value." + raise ValueError(msg) @override async def upsert(self, key, value, lock=None) -> None: diff --git a/src/backend/base/langflow/services/cache/utils.py b/src/backend/base/langflow/services/cache/utils.py index ef3eba6174b8..5bbdb607e1c9 100644 --- a/src/backend/base/langflow/services/cache/utils.py +++ b/src/backend/base/langflow/services/cache/utils.py @@ -2,12 +2,21 @@ import contextlib import hashlib import tempfile +import threading from pathlib import Path from typing import TYPE_CHECKING, Any from fastapi import UploadFile from platformdirs import user_cache_dir +# Try to import logger, fallback to standard logging if lfx not available +try: + from lfx.log.logger import logger +except ImportError: + import logging + + logger = logging.getLogger(__name__) + if TYPE_CHECKING: from langflow.api.v1.schemas import BuildStatus @@ -18,6 +27,15 @@ PREFIX = "langflow_cache" +# Define CACHE_MISS for compatibility +class CacheMiss: + def __repr__(self): + return "" + + +CACHE_MISS = CacheMiss() + + def create_cache_folder(func): def wrapper(*args, **kwargs): # Get the destination folder @@ -156,3 +174,130 @@ def update_build_status(cache_service, flow_id: str, status: "BuildStatus") -> N cache_service[flow_id] = cached_flow cached_flow["status"] = status cache_service[flow_id] = cached_flow + + +def setup_rich_pickle_support() -> bool: + """Setup pickle support for Rich library objects. + + This function adds custom __getstate__ and __setstate__ methods to Rich library's + ConsoleThreadLocals and Console classes to enable serialization for Redis caching. + + Returns: + bool: True if setup was successful, False otherwise + """ + try: + from rich.console import Console, ConsoleThreadLocals + + # Check if already setup + if hasattr(ConsoleThreadLocals, "_langflow_pickle_enabled"): + logger.debug("Rich pickle support already enabled") + return True + + # ConsoleThreadLocals pickle methods + def _console_thread_locals_getstate(self) -> dict[str, Any]: + """Serialize ConsoleThreadLocals for caching.""" + return { + "theme_stack": self.theme_stack, + "buffer": self.buffer.copy() if self.buffer else [], + "buffer_index": self.buffer_index, + } + + def _console_thread_locals_setstate(self, state: dict[str, Any]) -> None: + """Restore ConsoleThreadLocals from cached state.""" + self.theme_stack = state["theme_stack"] + self.buffer = state["buffer"] + self.buffer_index = state["buffer_index"] + + # Console pickle methods + def _console_getstate(self) -> dict[str, Any]: + """Serialize Console for caching.""" + state = self.__dict__.copy() + # Remove unpickleable locks and file handles / environment + for key in ( + "_lock", + "_record_buffer_lock", + "_file", + "_stderr", + "_environ", + ): + state.pop(key, None) + return state + + def _console_setstate(self, state: dict[str, Any]) -> None: + """Restore Console from cached state.""" + self.__dict__.update(state) + # Recreate locks + self._lock = threading.RLock() + self._record_buffer_lock = threading.RLock() + + # Apply the methods + ConsoleThreadLocals.__getstate__ = _console_thread_locals_getstate + ConsoleThreadLocals.__setstate__ = _console_thread_locals_setstate + Console.__getstate__ = _console_getstate + Console.__setstate__ = _console_setstate + + # Mark as setup + ConsoleThreadLocals._langflow_pickle_enabled = True + Console._langflow_pickle_enabled = True + + logger.debug("Rich pickle support setup completed - only Rich ConsoleThreadLocals and Console objects patched") + + except ImportError: + logger.debug("Rich library not available, pickle support not enabled") + return False + except (AttributeError, TypeError) as e: + logger.warning("Failed to setup Rich pickle support: %s", e) + return False + else: + logger.info("Rich pickle support enabled for cache serialization") + return True + + +def validate_rich_pickle_support() -> bool: + """Validate that Rich objects can be pickled successfully. + + Returns: + bool: True if validation passes, False otherwise + """ + try: + import pickle + + from rich.console import Console + + # Test basic serialization + console = Console() + test_data = {"console": console, "metadata": {"validator": "langflow_cache", "test_type": "rich_pickle"}} + + # Serialize and deserialize + pickled = pickle.dumps(test_data) + restored = pickle.loads(pickled) + + # Verify functionality + restored_console = restored["console"] + with restored_console.capture() as capture: + restored_console.print("validation_test") + + validation_passed = "validation_test" in capture.get() + if validation_passed: + logger.debug("Rich pickle validation successful") + else: + logger.warning("Rich pickle validation failed - console not functional") + except (ImportError, AttributeError, TypeError) as e: + logger.warning("Rich pickle validation failed: %s", e) + return False + else: + return validation_passed + + +def is_rich_pickle_enabled() -> bool: + """Check if Rich pickle support is currently enabled. + + Returns: + bool: True if Rich pickle support is enabled, False otherwise + """ + try: + from rich.console import ConsoleThreadLocals + + return hasattr(ConsoleThreadLocals, "_langflow_pickle_enabled") + except ImportError: + return False diff --git a/src/backend/base/langflow/services/chat/service.py b/src/backend/base/langflow/services/chat/service.py index 2f73578eac92..665f6394ae62 100644 --- a/src/backend/base/langflow/services/chat/service.py +++ b/src/backend/base/langflow/services/chat/service.py @@ -3,6 +3,8 @@ from threading import RLock from typing import Any +from lfx.serialization.normalizer import normalize_for_cache + from langflow.services.base import Service from langflow.services.cache.base import AsyncBaseCacheService, CacheService from langflow.services.deps import get_cache_service @@ -29,9 +31,11 @@ async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) Returns: bool: True if the cache was set successfully, False otherwise. """ + normalized = normalize_for_cache(data) result_dict = { - "result": data, - "type": type(data), + "result": normalized, + "type": "normalized", + "__envelope_version__": 1, } if isinstance(self.cache_service, AsyncBaseCacheService): await self.cache_service.upsert(str(key), result_dict, lock=lock or self.async_cache_locks[key]) diff --git a/src/backend/base/pyproject.toml b/src/backend/base/pyproject.toml index 9f17cf255eee..dd69ade13705 100644 --- a/src/backend/base/pyproject.toml +++ b/src/backend/base/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "langchainhub~=0.1.15", "loguru>=0.7.1,<1.0.0", "structlog>=25.4.0", - "rich>=13.7.0,<14.0.0", + "rich @ git+https://github.com/pkusnail/rich.git@3f2eb2d988fe22e3598542dd1773ae010ea4aacd", "langchain-experimental>=0.3.4,<1.0.0", "sqlmodel==0.0.22", "pydantic~=2.10.1", @@ -133,6 +133,9 @@ dev = [ ] +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["langflow"] diff --git a/src/backend/tests/unit/cache/__init__.py b/src/backend/tests/unit/cache/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/backend/tests/unit/cache/test_chatservice_cache.py b/src/backend/tests/unit/cache/test_chatservice_cache.py new file mode 100644 index 000000000000..698374f652a0 --- /dev/null +++ b/src/backend/tests/unit/cache/test_chatservice_cache.py @@ -0,0 +1,43 @@ +import pytest +from langflow.services.chat.service import ChatService + + +class _FakeSyncCache: + def __init__(self): + self.store = {} + + def upsert(self, key, value, lock=None): # noqa: ARG002 + self.store[str(key)] = value + + def __contains__(self, key): + return str(key) in self.store + + def get(self, key, lock=None): # noqa: ARG002 + return self.store.get(str(key)) + + +@pytest.mark.asyncio +async def test_chatservice_set_cache_normalizes_payload(): + cs = ChatService() + # Inject fake async cache + fake = _FakeSyncCache() + cs.cache_service = fake # type: ignore[assignment] + + dynamic_cls = type("C", (), {}) + value = { + "built": True, + "results": {"ok": 1}, + "built_object": dynamic_cls, # not cacheable + "artifacts": {}, + "built_result": {"foo": "bar"}, + "full_data": {"id": "v"}, + } + + ok = await cs.set_cache("k1", value) + assert ok is True + stored = fake.get("k1") + + assert stored["type"] == "normalized" + result = stored["result"] + assert result["__cache_vertex__"] is True + assert result["built_object"] == {"__cache_placeholder__": "unbuilt"} diff --git a/src/backend/tests/unit/cache/test_normalizer.py b/src/backend/tests/unit/cache/test_normalizer.py new file mode 100644 index 000000000000..1b16d54bc75d --- /dev/null +++ b/src/backend/tests/unit/cache/test_normalizer.py @@ -0,0 +1,49 @@ +from pydantic import create_model + +from lfx.serialization.normalizer import normalize_for_cache + + +def test_normalize_dynamic_class_and_function(): + # Dynamic class + dynamic_class = type("Dynamic", (), {"x": 1}) + + # Dynamic function + def dyn_func(): + return 42 + + obj = { + "cls": dynamic_class, + "func": dyn_func, + "value": 123, + } + + out = normalize_for_cache(obj) + assert out["value"] == 123 + assert out["cls"].get("__class_path__") + assert out["func"].get("__callable_path__") + + +def test_normalize_pydantic_model(): + model = create_model("X", a=(int, ...)) + m = model(a=3) + out = normalize_for_cache(m) + assert out == {"a": 3} + + +def test_normalize_vertex_like_dict_replaces_built_object(): + vertex_snapshot = { + "built": True, + "results": {"x": 1}, + "artifacts": {}, + "built_object": lambda x: x, # should never be cached as executable + "built_result": {"y": 2}, + "full_data": {"id": "v1"}, + } + out = normalize_for_cache(vertex_snapshot) + assert out["__cache_vertex__"] is True + assert out["built"] is True + assert out["results"] == {"x": 1} + assert out["artifacts"] == {} + assert out["built_result"] == {"y": 2} + assert out["full_data"] == {"id": "v1"} + assert out["built_object"] == {"__cache_placeholder__": "unbuilt"} diff --git a/src/lfx/src/lfx/graph/graph/base.py b/src/lfx/src/lfx/graph/graph/base.py index 6d2c89e769ce..2a419cd7355d 100644 --- a/src/lfx/src/lfx/graph/graph/base.py +++ b/src/lfx/src/lfx/graph/graph/base.py @@ -30,7 +30,7 @@ should_continue, ) from lfx.graph.schema import InterfaceComponentTypes, RunOutputs -from lfx.graph.utils import log_vertex_build +from lfx.graph.utils import UnbuiltObject, log_vertex_build from lfx.graph.vertex.base import Vertex, VertexStates from lfx.graph.vertex.schema import NodeData, NodeTypeEnum from lfx.graph.vertex.vertex_types import ComponentVertex, InterfaceVertex, StateVertex @@ -1532,13 +1532,27 @@ async def build_vertex( else: try: cached_vertex_dict = cached_result["result"] - # Now set update the vertex with the cached vertex - vertex.built = cached_vertex_dict["built"] - vertex.artifacts = cached_vertex_dict["artifacts"] - vertex.built_object = cached_vertex_dict["built_object"] - vertex.built_result = cached_vertex_dict["built_result"] - vertex.full_data = cached_vertex_dict["full_data"] - vertex.results = cached_vertex_dict["results"] + # Support normalized (DTO) vertex snapshots + if isinstance(cached_vertex_dict, dict) and cached_vertex_dict.get("__cache_vertex__"): + vertex.built = cached_vertex_dict.get("built", True) + vertex.artifacts = cached_vertex_dict.get("artifacts", {}) + built_obj = cached_vertex_dict.get("built_object") + if isinstance(built_obj, dict) and built_obj.get("__cache_placeholder__") == "unbuilt": + vertex.built_object = UnbuiltObject() + else: + vertex.built_object = built_obj + vertex.built_result = cached_vertex_dict.get("built_result") + vertex.full_data = cached_vertex_dict.get("full_data", vertex.full_data) + vertex.results = cached_vertex_dict.get("results", {}) + else: + # Backwards compatibility: original shape + # Now set update the vertex with the cached vertex + vertex.built = cached_vertex_dict["built"] + vertex.artifacts = cached_vertex_dict["artifacts"] + vertex.built_object = cached_vertex_dict["built_object"] + vertex.built_result = cached_vertex_dict["built_result"] + vertex.full_data = cached_vertex_dict["full_data"] + vertex.results = cached_vertex_dict["results"] try: vertex.finalize_build() diff --git a/src/lfx/src/lfx/io/schema.py b/src/lfx/src/lfx/io/schema.py index 1c6736d2326c..3ad480e5216f 100644 --- a/src/lfx/src/lfx/io/schema.py +++ b/src/lfx/src/lfx/io/schema.py @@ -238,6 +238,14 @@ def create_input_schema(inputs: list["InputTypes"]) -> type[BaseModel]: # Create and return the InputSchema model model = create_model("InputSchema", **fields) model.model_rebuild() + + # Register class on module to improve importability for serializers + import sys + + current_module = sys.modules[__name__] + model.__module__ = __name__ + current_module.InputSchema = model + return model @@ -286,4 +294,12 @@ def create_input_schema_from_dict(inputs: list[dotdict], param_key: str | None = model = create_model("InputSchema", **fields) model.model_rebuild() + + # Register class on module to improve importability for serializers + import sys + + current_module = sys.modules[__name__] + model.__module__ = __name__ + current_module.InputSchema = model + return model diff --git a/src/lfx/src/lfx/serialization/normalizer.py b/src/lfx/src/lfx/serialization/normalizer.py new file mode 100644 index 000000000000..88907a5424d3 --- /dev/null +++ b/src/lfx/src/lfx/serialization/normalizer.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator +from typing import Any + +try: + from pydantic import BaseModel # type: ignore[import-untyped] +except ImportError: # pragma: no cover + BaseModel = None # type: ignore[assignment] + + +def normalize_for_cache(obj: Any) -> Any: + """Normalize arbitrary Python objects into cache-safe DTOs. + + - Avoids storing executable objects (classes/functions/generators) by replacing + them with small descriptors. + - Pydantic models are converted via `.model_dump()`. + - Vertex-like dicts get a placeholder for `built_object` and a marker `__cache_vertex__`. + - Recurses into dict/list/tuple/set with cycle protection. + - Falls back to a repr descriptor when encountering unknown complex objects. + """ + visited: set[int] = set() + + def _is_primitive(v: Any) -> bool: + return isinstance(v, (str, int, float, bool, type(None), bytes, bytearray)) + + def _normalize(value: Any) -> Any: + vid = id(value) + if vid in visited: + return {"__cycle__": True} + visited.add(vid) + + # Primitives + if _is_primitive(value): + return value + + # Pydantic models + if BaseModel is not None and isinstance(value, BaseModel): # type: ignore[arg-type] + try: + return value.model_dump() + except (AttributeError, TypeError, ValueError): + return dict(getattr(value, "__dict__", {})) + + # Classes + if isinstance(value, type): + mod = getattr(value, "__module__", "") + name = getattr(value, "__name__", "") + return {"__class_path__": f"{mod}.{name}"} + + # Functions/methods/builtins + try: + import inspect + + if inspect.isfunction(value) or inspect.ismethod(value) or inspect.isbuiltin(value): + mod = getattr(value, "__module__", "") + name = getattr(value, "__qualname__", getattr(value, "__name__", "")) + return {"__callable_path__": f"{mod}.{name}"} + except (AttributeError, ImportError): + pass + + # Generators/iterators (non-cacheable) + if isinstance(value, (Iterator, AsyncIterator)): + return {"__non_cacheable__": "generator"} + + # Dict-like + if isinstance(value, dict): + out: dict[str, Any] = {} + # Treat vertex snapshots specially if recognizable + is_vertex_like = "built" in value and "results" in value + for k, v in value.items(): + if k == "built_object": + # Never store executable object in cache + out[k] = {"__cache_placeholder__": "unbuilt"} + else: + out[k] = _normalize(v) + if is_vertex_like: + out["__cache_vertex__"] = True + return out + + # Sequences + if isinstance(value, (list, tuple, set)): + seq = [_normalize(v) for v in value] + if isinstance(value, tuple): + return tuple(seq) + if isinstance(value, set): + return list(seq) + return seq + + # Fallback: dynamic/custom instances or unknown complex objects + cls = value.__class__ + mod = getattr(cls, "__module__", "") + qual = getattr(cls, "__qualname__", getattr(cls, "__name__", "")) + if mod.startswith("lfx.custom") or "" in qual or mod in ("__main__", "builtins"): + try: + return {"__repr__": repr(value)} + except (AttributeError, TypeError, ValueError): + return {"__class__": f"{mod}.{qual}"} + + # Last resort: shallow repr descriptor + try: + return {"__repr__": repr(value)} + except (AttributeError, TypeError, ValueError): + return {"__class__": f"{mod}.{qual}"} + + return _normalize(obj) diff --git a/uv.lock b/uv.lock index 4b0f920b8c05..63e2b5e7ef04 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version < '3.11' and platform_python_implementation == 'PyPy' and sys_platform == 'darwin'", @@ -5575,7 +5575,7 @@ requires-dist = [ { name = "python-docx", specifier = ">=1.1.0,<2.0.0" }, { name = "python-jose", specifier = ">=3.3.0,<4.0.0" }, { name = "python-multipart", specifier = ">=0.0.12,<1.0.0" }, - { name = "rich", specifier = ">=13.7.0,<14.0.0" }, + { name = "rich", git = "https://github.com/pkusnail/rich.git?rev=3f2eb2d988fe22e3598542dd1773ae010ea4aacd" }, { name = "scipy", specifier = ">=1.15.2" }, { name = "sentence-transformers", marker = "extra == 'all'", specifier = ">=2.0.0" }, { name = "sentence-transformers", marker = "extra == 'local'", specifier = ">=2.0.0" }, @@ -10827,16 +10827,11 @@ wheels = [ [[package]] name = "rich" -version = "13.9.4" -source = { registry = "https://pypi.org/simple" } +version = "14.1.0" +source = { git = "https://github.com/pkusnail/rich.git?rev=3f2eb2d988fe22e3598542dd1773ae010ea4aacd#3f2eb2d988fe22e3598542dd1773ae010ea4aacd" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149, upload-time = "2024-11-01T16:43:57.873Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" }, ] [[package]]