From bc09939f7e1f0cea89646ba602b4a1151a8c1ed0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 26 Oct 2024 10:30:51 +0200 Subject: [PATCH] fix: Fix async cache (#4265) Fix async cache --- .../base/langflow/services/cache/base.py | 13 +++- .../base/langflow/services/cache/disk.py | 4 +- .../base/langflow/services/cache/service.py | 12 ++-- .../base/langflow/services/chat/service.py | 72 ++++--------------- src/backend/base/langflow/services/deps.py | 4 +- .../base/langflow/services/session/service.py | 11 +-- .../base/langflow/services/socket/service.py | 22 +++--- .../base/langflow/services/socket/utils.py | 4 +- 8 files changed, 59 insertions(+), 83 deletions(-) diff --git a/src/backend/base/langflow/services/cache/base.py b/src/backend/base/langflow/services/cache/base.py index 7e1fcee0359a..dfa391cac13a 100644 --- a/src/backend/base/langflow/services/cache/base.py +++ b/src/backend/base/langflow/services/cache/base.py @@ -59,6 +59,17 @@ def delete(self, key, lock: LockType | None = None): def clear(self, lock: LockType | None = None): """Clear all items from the cache.""" + @abc.abstractmethod + def contains(self, key) -> bool: + """Check if the key is in the cache. + + Args: + key: The key of the item to check. + + Returns: + True if the key is in the cache, False otherwise. + """ + @abc.abstractmethod def __contains__(self, key) -> bool: """Check if the key is in the cache. @@ -147,7 +158,7 @@ async def clear(self, lock: AsyncLockType | None = None): """Clear all items from the cache.""" @abc.abstractmethod - def __contains__(self, key) -> bool: + async def contains(self, key) -> bool: """Check if the key is in the cache. Args: diff --git a/src/backend/base/langflow/services/cache/disk.py b/src/backend/base/langflow/services/cache/disk.py index ed3412e327fb..170f134f7bc8 100644 --- a/src/backend/base/langflow/services/cache/disk.py +++ b/src/backend/base/langflow/services/cache/disk.py @@ -87,8 +87,8 @@ async def _upsert(self, key, value) -> None: value = existing_value await self.set(key, value) - def __contains__(self, key) -> bool: - return asyncio.run(asyncio.to_thread(self.cache.__contains__, key)) + async def contains(self, key) -> bool: + return await asyncio.to_thread(self.cache.__contains__, key) async def teardown(self) -> None: # Clean up the cache directory diff --git a/src/backend/base/langflow/services/cache/service.py b/src/backend/base/langflow/services/cache/service.py index baec3dbd4025..8c1c51ed6da2 100644 --- a/src/backend/base/langflow/services/cache/service.py +++ b/src/backend/base/langflow/services/cache/service.py @@ -139,10 +139,14 @@ def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP00 with lock or self._lock: self._cache.clear() - def __contains__(self, key) -> bool: + def contains(self, key) -> bool: """Check if the key is in the cache.""" return key in self._cache + def __contains__(self, key) -> bool: + """Check if the key is in the cache.""" + return self.contains(key) + def __getitem__(self, key): """Retrieve an item from the cache using the square bracket notation.""" return self.get(key) @@ -274,11 +278,11 @@ async def clear(self, lock=None) -> None: """Clear all items from the cache.""" await self._client.flushdb() - def __contains__(self, key) -> bool: + async def contains(self, key) -> bool: """Check if the key is in the cache.""" if key is None: return False - return bool(asyncio.run(self._client.exists(str(key)))) + return bool(await self._client.exists(str(key))) def __repr__(self) -> str: """Return a string representation of the RedisCache instance.""" @@ -364,5 +368,5 @@ async def _upsert(self, key, value) -> None: value = existing_value await self.set(key, value) - def __contains__(self, key) -> bool: + async def contains(self, key) -> bool: return key in self.cache diff --git a/src/backend/base/langflow/services/chat/service.py b/src/backend/base/langflow/services/chat/service.py index b6db599378f3..e9528ea90e8a 100644 --- a/src/backend/base/langflow/services/chat/service.py +++ b/src/backend/base/langflow/services/chat/service.py @@ -4,7 +4,7 @@ from typing import Any from langflow.services.base import Service -from langflow.services.cache.base import AsyncBaseCacheService +from langflow.services.cache.base import AsyncBaseCacheService, CacheService from langflow.services.deps import get_cache_service @@ -16,60 +16,7 @@ class ChatService(Service): def __init__(self) -> None: self.async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock) - self.cache_service = get_cache_service() - - def _get_lock(self, key: str): - """Retrieves the lock associated with the given key. - - Args: - key (str): The key to retrieve the lock for. - - Returns: - threading.Lock or asyncio.Lock: The lock associated with the given key. - """ - if isinstance(self.cache_service, AsyncBaseCacheService): - return self.async_cache_locks[key] - return self._sync_cache_locks[key] - - async def _perform_cache_operation( - self, operation: str, key: str, data: Any = None, lock: asyncio.Lock | None = None - ): - """Perform a cache operation based on the given operation type. - - Args: - operation (str): The type of cache operation to perform. Possible values are "upsert", "get", or "delete". - key (str): The key associated with the cache operation. - data (Any, optional): The data to be stored in the cache. Only applicable for "upsert" operation. - Defaults to None. - lock (Optional[asyncio.Lock], optional): The lock to be used for the cache operation. Defaults to None. - - Returns: - Any: The result of the cache operation. Only applicable for "get" operation. - - Raises: - None - - """ - lock = lock or self._get_lock(key) - if isinstance(self.cache_service, AsyncBaseCacheService): - if operation == "upsert": - await self.cache_service.upsert(str(key), data, lock=lock) - return None - if operation == "get": - return await self.cache_service.get(key, lock=lock) - if operation == "delete": - await self.cache_service.delete(key, lock=lock) - return None - return None - if operation == "upsert": - self.cache_service.upsert(str(key), data, lock=lock) - return None - if operation == "get": - return self.cache_service.get(key, lock=lock) - if operation == "delete": - self.cache_service.delete(key, lock=lock) - return None - return None + self.cache_service: CacheService | AsyncBaseCacheService = get_cache_service() async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) -> bool: """Set the cache for a client. @@ -86,7 +33,12 @@ async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) "result": data, "type": type(data), } - await self._perform_cache_operation("upsert", key, result_dict, lock) + if isinstance(self.cache_service, AsyncBaseCacheService): + await self.cache_service.upsert(str(key), result_dict, lock=lock or self.async_cache_locks[key]) + return await self.cache_service.contains(key) + await asyncio.to_thread( + self.cache_service.upsert, str(key), result_dict, lock=lock or self._sync_cache_locks[key] + ) return key in self.cache_service async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any: @@ -99,7 +51,9 @@ async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any: Returns: Any: The cached data. """ - return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key)) + if isinstance(self.cache_service, AsyncBaseCacheService): + return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key]) + return await asyncio.to_thread(self.cache_service.get, key, lock=lock or self._sync_cache_locks[key]) async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None: """Clear the cache for a client. @@ -108,4 +62,6 @@ async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None: key (str): The cache key. lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None. """ - await self._perform_cache_operation("delete", key, lock=lock or self._get_lock(key)) + if isinstance(self.cache_service, AsyncBaseCacheService): + return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key]) + return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key]) diff --git a/src/backend/base/langflow/services/deps.py b/src/backend/base/langflow/services/deps.py index 01bd841d5dcb..40152a0175bc 100644 --- a/src/backend/base/langflow/services/deps.py +++ b/src/backend/base/langflow/services/deps.py @@ -12,7 +12,7 @@ from sqlmodel import Session - from langflow.services.cache.service import CacheService + from langflow.services.cache.service import AsyncBaseCacheService, CacheService from langflow.services.chat.service import ChatService from langflow.services.database.service import DatabaseService from langflow.services.plugins.service import PluginService @@ -188,7 +188,7 @@ def session_scope() -> Generator[Session, None, None]: raise -def get_cache_service() -> CacheService: +def get_cache_service() -> CacheService | AsyncBaseCacheService: """Retrieves the cache service from the service manager. Returns: diff --git a/src/backend/base/langflow/services/session/service.py b/src/backend/base/langflow/services/session/service.py index abd4e630a78f..3c8106a9219d 100644 --- a/src/backend/base/langflow/services/session/service.py +++ b/src/backend/base/langflow/services/session/service.py @@ -16,11 +16,12 @@ def __init__(self, cache_service) -> None: async def load_session(self, key, flow_id: str, data_graph: dict | None = None): # Check if the data is cached - if key in self.cache_service: - result = self.cache_service.get(key) - if isinstance(result, Coroutine): - result = await result - return result + is_cached = self.cache_service.contains(key) + if isinstance(is_cached, Coroutine): + if await is_cached: + return await self.cache_service.get(key) + elif is_cached: + return self.cache_service.get(key) if key is None: key = self.generate_key(session_id=None, data_graph=data_graph) diff --git a/src/backend/base/langflow/services/socket/service.py b/src/backend/base/langflow/services/socket/service.py index d4ec60579fb9..8f6e44000858 100644 --- a/src/backend/base/langflow/services/socket/service.py +++ b/src/backend/base/langflow/services/socket/service.py @@ -1,20 +1,18 @@ -from typing import TYPE_CHECKING, Any +from typing import Any import socketio from loguru import logger from langflow.services.base import Service +from langflow.services.cache.base import AsyncBaseCacheService, CacheService from langflow.services.deps import get_chat_service from langflow.services.socket.utils import build_vertex, get_vertices -if TYPE_CHECKING: - from langflow.services.cache.service import CacheService - class SocketIOService(Service): name = "socket_service" - def __init__(self, cache_service: "CacheService"): + def __init__(self, cache_service: CacheService | AsyncBaseCacheService): self.cache_service = cache_service def init(self, sio: socketio.AsyncServer) -> None: @@ -63,11 +61,14 @@ async def on_build_vertex(self, sid, flow_id, vertex_id) -> None: set_cache=self.set_cache, ) - def get_cache(self, sid: str) -> Any: + async def get_cache(self, sid: str) -> Any: """Get the cache for a client.""" - return self.cache_service.get(sid) + value = self.cache_service.get(sid) + if isinstance(self.cache_service, AsyncBaseCacheService): + return await value + return value - def set_cache(self, sid: str, build_result: Any) -> bool: + async def set_cache(self, sid: str, build_result: Any) -> bool: """Set the cache for a client.""" # client_id is the flow id but that already exists in the cache # so we need to change it to something else @@ -76,5 +77,8 @@ def set_cache(self, sid: str, build_result: Any) -> bool: "result": build_result, "type": type(build_result), } - self.cache_service.upsert(sid, result_dict) + result = self.cache_service.upsert(sid, result_dict) + if isinstance(self.cache_service, AsyncBaseCacheService): + await result + return await self.cache_service.contains(sid) return sid in self.cache_service diff --git a/src/backend/base/langflow/services/socket/utils.py b/src/backend/base/langflow/services/socket/utils.py index d754638014cb..723acd48dd5f 100644 --- a/src/backend/base/langflow/services/socket/utils.py +++ b/src/backend/base/langflow/services/socket/utils.py @@ -50,7 +50,7 @@ async def build_vertex( set_cache: Callable, ) -> None: try: - cache = get_cache(flow_id) + cache = await get_cache(flow_id) graph = cache.get("result") if not isinstance(graph, Graph): @@ -86,7 +86,7 @@ async def build_vertex( valid = False result_dict = ResultDataResponse(results={}) artifacts = {} - set_cache(flow_id, graph) + await set_cache(flow_id, graph) log_vertex_build( flow_id=flow_id, vertex_id=vertex_id,