Skip to content

Commit

Permalink
let users pick state manager mode (#4041)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lendemor authored Oct 10, 2024
1 parent 1aed39a commit 6f586c8
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 14 deletions.
16 changes: 16 additions & 0 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union

from reflex.utils.exceptions import ConfigError

try:
import pydantic.v1 as pydantic
except ModuleNotFoundError:
Expand Down Expand Up @@ -220,6 +222,9 @@ class Config:
# Number of gunicorn workers from user
gunicorn_workers: Optional[int] = None

# Indicate which type of state manager to use
state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK

# Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK

Expand All @@ -235,6 +240,9 @@ def __init__(self, *args, **kwargs):
Args:
*args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ConfigError: If some values in the config are invalid.
"""
super().__init__(*args, **kwargs)

Expand All @@ -248,6 +256,14 @@ def __init__(self, *args, **kwargs):
self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs)

if (
self.state_manager_mode == constants.StateManagerMode.REDIS
and not self.redis_url
):
raise ConfigError(
"REDIS_URL is required when using the redis state manager."
)

@property
def module(self) -> str:
"""Get the module name of the app.
Expand Down
2 changes: 2 additions & 0 deletions reflex/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
RouteRegex,
RouteVar,
)
from .state import StateManagerMode
from .style import Tailwind

__ALL__ = [
Expand Down Expand Up @@ -115,6 +116,7 @@
SETTER_PREFIX,
SKIP_COMPILE_ENV_VAR,
SocketEvent,
StateManagerMode,
Tailwind,
Templates,
CompileVars,
Expand Down
11 changes: 11 additions & 0 deletions reflex/constants/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""State-related constants."""

from enum import Enum


class StateManagerMode(str, Enum):
"""State manager constants."""

DISK = "disk"
MEMORY = "memory"
REDIS = "redis"
39 changes: 25 additions & 14 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError,
InvalidStateManagerMode,
LockExpiredError,
SetUndefinedStateVarError,
StateSchemaMismatchError,
Expand Down Expand Up @@ -2514,20 +2515,30 @@ def create(cls, state: Type[BaseState]):
Args:
state: The state class to use.
Returns:
The state manager (either disk or redis).
"""
redis = prerequisites.get_redis()
if redis is not None:
# make sure expiration values are obtained only from the config object on creation
config = get_config()
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
return StateManagerDisk(state=state)
Raises:
InvalidStateManagerMode: If the state manager mode is invalid.
Returns:
The state manager (either disk, memory or redis).
"""
config = get_config()
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerDisk(state=state)
if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis = prerequisites.get_redis()
if redis is not None:
# make sure expiration values are obtained only from the config object on creation
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
)

@abstractmethod
async def get_state(self, token: str) -> BaseState:
Expand Down
8 changes: 8 additions & 0 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ class ReflexError(Exception):
"""Base exception for all Reflex exceptions."""


class ConfigError(ReflexError):
"""Custom exception for config related errors."""


class InvalidStateManagerMode(ReflexError, ValueError):
"""Raised when an invalid state manager mode is provided."""


class ReflexRuntimeError(ReflexError, RuntimeError):
"""Custom RuntimeError for Reflex."""

Expand Down
1 change: 1 addition & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3201,6 +3201,7 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
state_manager_mode="redis",
{config_items}
)
"""
Expand Down

0 comments on commit 6f586c8

Please sign in to comment.