Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix async cache #4265

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/backend/base/langflow/services/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ def delete(self, key, lock: LockType | None = None):
def clear(self, lock: LockType | None = None):
"""Clear all items from the cache."""

@abc.abstractmethod
def contains(self, key) -> bool:
"""Check if the key is in the cache.

Args:
key: The key of the item to check.

Returns:
True if the key is in the cache, False otherwise.
"""

@abc.abstractmethod
def __contains__(self, key) -> bool:
"""Check if the key is in the cache.
Expand Down Expand Up @@ -147,7 +158,7 @@ async def clear(self, lock: AsyncLockType | None = None):
"""Clear all items from the cache."""

@abc.abstractmethod
def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
"""Check if the key is in the cache.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/cache/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ async def _upsert(self, key, value) -> None:
value = existing_value
await self.set(key, value)

def __contains__(self, key) -> bool:
return asyncio.run(asyncio.to_thread(self.cache.__contains__, key))
async def contains(self, key) -> bool:
return await asyncio.to_thread(self.cache.__contains__, key)

async def teardown(self) -> None:
# Clean up the cache directory
Expand Down
12 changes: 8 additions & 4 deletions src/backend/base/langflow/services/cache/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,14 @@ def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP00
with lock or self._lock:
self._cache.clear()

def __contains__(self, key) -> bool:
def contains(self, key) -> bool:
"""Check if the key is in the cache."""
return key in self._cache

def __contains__(self, key) -> bool:
"""Check if the key is in the cache."""
return self.contains(key)

def __getitem__(self, key):
"""Retrieve an item from the cache using the square bracket notation."""
return self.get(key)
Expand Down Expand Up @@ -274,11 +278,11 @@ async def clear(self, lock=None) -> None:
"""Clear all items from the cache."""
await self._client.flushdb()

def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
"""Check if the key is in the cache."""
if key is None:
return False
return bool(asyncio.run(self._client.exists(str(key))))
return bool(await self._client.exists(str(key)))

def __repr__(self) -> str:
"""Return a string representation of the RedisCache instance."""
Expand Down Expand Up @@ -364,5 +368,5 @@ async def _upsert(self, key, value) -> None:
value = existing_value
await self.set(key, value)

def __contains__(self, key) -> bool:
async def contains(self, key) -> bool:
return key in self.cache
72 changes: 14 additions & 58 deletions src/backend/base/langflow/services/chat/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.deps import get_cache_service


Expand All @@ -16,60 +16,7 @@ class ChatService(Service):
def __init__(self) -> None:
self.async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock)
self.cache_service = get_cache_service()

def _get_lock(self, key: str):
"""Retrieves the lock associated with the given key.

Args:
key (str): The key to retrieve the lock for.

Returns:
threading.Lock or asyncio.Lock: The lock associated with the given key.
"""
if isinstance(self.cache_service, AsyncBaseCacheService):
return self.async_cache_locks[key]
return self._sync_cache_locks[key]

async def _perform_cache_operation(
self, operation: str, key: str, data: Any = None, lock: asyncio.Lock | None = None
):
"""Perform a cache operation based on the given operation type.

Args:
operation (str): The type of cache operation to perform. Possible values are "upsert", "get", or "delete".
key (str): The key associated with the cache operation.
data (Any, optional): The data to be stored in the cache. Only applicable for "upsert" operation.
Defaults to None.
lock (Optional[asyncio.Lock], optional): The lock to be used for the cache operation. Defaults to None.

Returns:
Any: The result of the cache operation. Only applicable for "get" operation.

Raises:
None

"""
lock = lock or self._get_lock(key)
if isinstance(self.cache_service, AsyncBaseCacheService):
if operation == "upsert":
await self.cache_service.upsert(str(key), data, lock=lock)
return None
if operation == "get":
return await self.cache_service.get(key, lock=lock)
if operation == "delete":
await self.cache_service.delete(key, lock=lock)
return None
return None
if operation == "upsert":
self.cache_service.upsert(str(key), data, lock=lock)
return None
if operation == "get":
return self.cache_service.get(key, lock=lock)
if operation == "delete":
self.cache_service.delete(key, lock=lock)
return None
return None
self.cache_service: CacheService | AsyncBaseCacheService = get_cache_service()

async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None) -> bool:
"""Set the cache for a client.
Expand All @@ -86,7 +33,12 @@ async def set_cache(self, key: str, data: Any, lock: asyncio.Lock | None = None)
"result": data,
"type": type(data),
}
await self._perform_cache_operation("upsert", key, result_dict, lock)
if isinstance(self.cache_service, AsyncBaseCacheService):
await self.cache_service.upsert(str(key), result_dict, lock=lock or self.async_cache_locks[key])
return await self.cache_service.contains(key)
await asyncio.to_thread(
self.cache_service.upsert, str(key), result_dict, lock=lock or self._sync_cache_locks[key]
)
return key in self.cache_service

async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any:
Expand All @@ -99,7 +51,9 @@ async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any:
Returns:
Any: The cached data.
"""
return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key))
if isinstance(self.cache_service, AsyncBaseCacheService):
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
return await asyncio.to_thread(self.cache_service.get, key, lock=lock or self._sync_cache_locks[key])

async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
"""Clear the cache for a client.
Expand All @@ -108,4 +62,6 @@ async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
key (str): The cache key.
lock (Optional[asyncio.Lock], optional): The lock to use for the cache operation. Defaults to None.
"""
await self._perform_cache_operation("delete", key, lock=lock or self._get_lock(key))
if isinstance(self.cache_service, AsyncBaseCacheService):
return await self.cache_service.get(key, lock=lock or self.async_cache_locks[key])
return await asyncio.to_thread(self.cache_service.delete, key, lock=lock or self._sync_cache_locks[key])
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sqlmodel import Session

from langflow.services.cache.service import CacheService
from langflow.services.cache.service import AsyncBaseCacheService, CacheService
from langflow.services.chat.service import ChatService
from langflow.services.database.service import DatabaseService
from langflow.services.plugins.service import PluginService
Expand Down Expand Up @@ -188,7 +188,7 @@ def session_scope() -> Generator[Session, None, None]:
raise


def get_cache_service() -> CacheService:
def get_cache_service() -> CacheService | AsyncBaseCacheService:
"""Retrieves the cache service from the service manager.

Returns:
Expand Down
11 changes: 6 additions & 5 deletions src/backend/base/langflow/services/session/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ def __init__(self, cache_service) -> None:

async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
# Check if the data is cached
if key in self.cache_service:
result = self.cache_service.get(key)
if isinstance(result, Coroutine):
result = await result
return result
is_cached = self.cache_service.contains(key)
if isinstance(is_cached, Coroutine):
if await is_cached:
return await self.cache_service.get(key)
elif is_cached:
return self.cache_service.get(key)

if key is None:
key = self.generate_key(session_id=None, data_graph=data_graph)
Expand Down
22 changes: 13 additions & 9 deletions src/backend/base/langflow/services/socket/service.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from typing import TYPE_CHECKING, Any
from typing import Any

import socketio
from loguru import logger

from langflow.services.base import Service
from langflow.services.cache.base import AsyncBaseCacheService, CacheService
from langflow.services.deps import get_chat_service
from langflow.services.socket.utils import build_vertex, get_vertices

if TYPE_CHECKING:
from langflow.services.cache.service import CacheService


class SocketIOService(Service):
name = "socket_service"

def __init__(self, cache_service: "CacheService"):
def __init__(self, cache_service: CacheService | AsyncBaseCacheService):
self.cache_service = cache_service

def init(self, sio: socketio.AsyncServer) -> None:
Expand Down Expand Up @@ -63,11 +61,14 @@ async def on_build_vertex(self, sid, flow_id, vertex_id) -> None:
set_cache=self.set_cache,
)

def get_cache(self, sid: str) -> Any:
async def get_cache(self, sid: str) -> Any:
"""Get the cache for a client."""
return self.cache_service.get(sid)
value = self.cache_service.get(sid)
if isinstance(self.cache_service, AsyncBaseCacheService):
return await value
return value

def set_cache(self, sid: str, build_result: Any) -> bool:
async def set_cache(self, sid: str, build_result: Any) -> bool:
"""Set the cache for a client."""
# client_id is the flow id but that already exists in the cache
# so we need to change it to something else
Expand All @@ -76,5 +77,8 @@ def set_cache(self, sid: str, build_result: Any) -> bool:
"result": build_result,
"type": type(build_result),
}
self.cache_service.upsert(sid, result_dict)
result = self.cache_service.upsert(sid, result_dict)
if isinstance(self.cache_service, AsyncBaseCacheService):
await result
return await self.cache_service.contains(sid)
return sid in self.cache_service
4 changes: 2 additions & 2 deletions src/backend/base/langflow/services/socket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def build_vertex(
set_cache: Callable,
) -> None:
try:
cache = get_cache(flow_id)
cache = await get_cache(flow_id)
graph = cache.get("result")

if not isinstance(graph, Graph):
Expand Down Expand Up @@ -86,7 +86,7 @@ async def build_vertex(
valid = False
result_dict = ResultDataResponse(results={})
artifacts = {}
set_cache(flow_id, graph)
await set_cache(flow_id, graph)
log_vertex_build(
flow_id=flow_id,
vertex_id=vertex_id,
Expand Down
Loading