Skip to content

Commit

Permalink
fix: Fix async cache (langflow-ai#4265)
Browse files Browse the repository at this point in the history
Fix async cache
  • Loading branch information
cbornet authored and diogocabral committed Nov 26, 2024
1 parent 2e2e7c5 commit bc09939
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 83 deletions.
13 changes: 12 additions & 1 deletion src/backend/base/langflow/services/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ def delete(self, key, lock: LockType | None = None):
def clear(self, lock: LockType | None = None):
"""Clear all items from the cache."""

@abc.abstractmethod
def contains(self, key) -> bool:
"""Check if the key is in the cache.
Args:
key: The key of the item to check.
Returns:
True if the key is in the cache, False otherwise.
"""

@abc.abstractmethod
def __contains__(self, key) -> bool:
"""Check if the key is in the cache.
Expand Down Expand Up @@ -147,7 +158,7 @@ async def clear(self, lock: AsyncLockType | None = None):
"""Clear all items from the cache."""

@abc.abstractmethod
def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
"""Check if the key is in the cache.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/cache/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ async def _upsert(self, key, value) -> None:
value = existing_value
await self.set(key, value)

def __contains__(self, key) -> bool:
return asyncio.run(asyncio.to_thread(self.cache.__contains__, key))
async def contains(self, key) -> bool:
return await asyncio.to_thread(self.cache.__contains__, key)

async def teardown(self) -> None:
# Clean up the cache directory
Expand Down
12 changes: 8 additions & 4 deletions src/backend/base/langflow/services/cache/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,14 @@ def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP00
with lock or self._lock:
self._cache.clear()

def __contains__(self, key) -> bool:
def contains(self, key) -> bool:
"""Check if the key is in the cache."""
return key in self._cache

def __contains__(self, key) -> bool:
"""Check if the key is in the cache."""
return self.contains(key)

def __getitem__(self, key):
"""Retrieve an item from the cache using the square bracket notation."""
return self.get(key)
Expand Down Expand Up @@ -274,11 +278,11 @@ async def clear(self, lock=None) -> None:
"""Clear all items from the cache."""
await self._client.flushdb()

def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
"""Check if the key is in the cache."""
if key is None:
return False
return bool(asyncio.run(self._client.exists(str(key))))
return bool(await self._client.exists(str(key)))

def __repr__(self) -> str:
"""Return a string representation of the RedisCache instance."""
Expand Down Expand Up @@ -364,5 +368,5 @@ async def _upsert(self, key, value) -> None:
value = existing_value
await self.set(key, value)

def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
return key in self.cache
72 changes: 14 additions & 58 deletions src/backend/base/langflow/services/chat/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.deps import get_cache_service


Expand All @@ -16,60 +16,7 @@ class ChatService(Service):
def __init__(self) -> None:
self.async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock)
self.cache_service = get_cache_service()

def _get_lock(self, key: str):
"""Retrieves the lock associated with the given key.
Args:
key (str): The key to retrieve the lock for.
Returns:
threading.Lock or asyncio.Lock: The lock associated with the given key.
"""
if isinstance(self.cache_service, AsyncBaseCacheService):
return self.async_cache_locks[key]
return self._sync_cache_locks[key]

async def _perform_cache_operation(
self, operation: str, key: str, data: Any = None, lock: asyncio.Lock | None = None
):
"""Perform a cache operation based on the given operation type.
Args:
operation (str): The type of cache operation to perform. Possible values are "upsert", "get", or "delete".
key (str): The key associated with the cache operation.
data (Any, optional): The data to be stored in the cache. Only applicable for "upsert" operation.
Defaults to None.
lock (Optional[asyncio.Lock], optional): The lock to be used for the cache operation. Defaults to None.
Returns:
Any: The result of the cache operation. Only applicable for "get" operation.
Raises:
None
"""
lock = lock or self._get_lock(key)
if isinstance(self.cache_service, AsyncBaseCacheService):
if operation == "upsert":
await self.cache_service.upsert(str(key), data, lock=lock)
return None
if operation == "get":
return await self.cache_service.get(key, lock=lock)
if operation == "delete":
await self.cache_service.delete(key, lock=lock)
return None
return None
if operation == "upsert":
self.cache_service.upsert(str(key), data, lock=lock)
return None
if operation == "get":
return self.cache_service.get(key, lock=lock)
if operation == "delete":
self.cache_service.delete(key, lock=lock)
return None
return None
self.cache_service: CacheService | AsyncBaseCacheService = get_cache_service()

async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) -> bool:
"""Set the cache for a client.
Expand All @@ -86,7 +33,12 @@ async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None)
"result": data,
"type": type(data),
}
await self._perform_cache_operation("upsert", key, result_dict, lock)
if isinstance(self.cache_service, AsyncBaseCacheService):
await self.cache_service.upsert(str(key), result_dict, lock=lock or self.async_cache_locks[key])
return await self.cache_service.contains(key)
await asyncio.to_thread(
self.cache_service.upsert, str(key), result_dict, lock=lock or self._sync_cache_locks[key]
)
return key in self.cache_service

async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any:
Expand All @@ -99,7 +51,9 @@ async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any:
Returns:
Any: The cached data.
"""
return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key))
if isinstance(self.cache_service, AsyncBaseCacheService):
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
return await asyncio.to_thread(self.cache_service.get, key, lock=lock or self._sync_cache_locks[key])

async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
"""Clear the cache for a client.
Expand All @@ -108,4 +62,6 @@ async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
key (str): The cache key.
lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None.
"""
await self._perform_cache_operation("delete", key, lock=lock or self._get_lock(key))
if isinstance(self.cache_service, AsyncBaseCacheService):
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key])
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sqlmodel import Session

from langflow.services.cache.service import CacheService
from langflow.services.cache.service import AsyncBaseCacheService, CacheService
from langflow.services.chat.service import ChatService
from langflow.services.database.service import DatabaseService
from langflow.services.plugins.service import PluginService
Expand Down Expand Up @@ -188,7 +188,7 @@ def session_scope() -> Generator[Session, None, None]:
raise


def get_cache_service() -> CacheService:
def get_cache_service() -> CacheService | AsyncBaseCacheService:
"""Retrieves the cache service from the service manager.
Returns:
Expand Down
11 changes: 6 additions & 5 deletions src/backend/base/langflow/services/session/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ def __init__(self, cache_service) -> None:

async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
# Check if the data is cached
if key in self.cache_service:
result = self.cache_service.get(key)
if isinstance(result, Coroutine):
result = await result
return result
is_cached = self.cache_service.contains(key)
if isinstance(is_cached, Coroutine):
if await is_cached:
return await self.cache_service.get(key)
elif is_cached:
return self.cache_service.get(key)

if key is None:
key = self.generate_key(session_id=None, data_graph=data_graph)
Expand Down
22 changes: 13 additions & 9 deletions src/backend/base/langflow/services/socket/service.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from typing import TYPE_CHECKING, Any
from typing import Any

import socketio
from loguru import logger

from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.deps import get_chat_service
from langflow.services.socket.utils import build_vertex, get_vertices

if TYPE_CHECKING:
from langflow.services.cache.service import CacheService


class SocketIOService(Service):
name = "socket_service"

def __init__(self, cache_service: "CacheService"):
def __init__(self, cache_service: CacheService | AsyncBaseCacheService):
self.cache_service = cache_service

def init(self, sio: socketio.AsyncServer) -> None:
Expand Down Expand Up @@ -63,11 +61,14 @@ async def on_build_vertex(self, sid, flow_id, vertex_id) -> None:
set_cache=self.set_cache,
)

def get_cache(self, sid: str) -> Any:
async def get_cache(self, sid: str) -> Any:
"""Get the cache for a client."""
return self.cache_service.get(sid)
value = self.cache_service.get(sid)
if isinstance(self.cache_service, AsyncBaseCacheService):
return await value
return value

def set_cache(self, sid: str, build_result: Any) -> bool:
async def set_cache(self, sid: str, build_result: Any) -> bool:
"""Set the cache for a client."""
# client_id is the flow id but that already exists in the cache
# so we need to change it to something else
Expand All @@ -76,5 +77,8 @@ def set_cache(self, sid: str, build_result: Any) -> bool:
"result": build_result,
"type": type(build_result),
}
self.cache_service.upsert(sid, result_dict)
result = self.cache_service.upsert(sid, result_dict)
if isinstance(self.cache_service, AsyncBaseCacheService):
await result
return await self.cache_service.contains(sid)
return sid in self.cache_service
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/socket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def build_vertex(
set_cache: Callable,
) -> None:
try:
cache = get_cache(flow_id)
cache = await get_cache(flow_id)
graph = cache.get("result")

if not isinstance(graph, Graph):
Expand Down Expand Up @@ -86,7 +86,7 @@ async def build_vertex(
valid = False
result_dict = ResultDataResponse(results={})
artifacts = {}
set_cache(flow_id, graph)
await set_cache(flow_id, graph)
log_vertex_build(
flow_id=flow_id,
vertex_id=vertex_id,
Expand Down

0 comments on commit bc09939

Please sign in to comment.