From 2bc7df1be7b8929e55cb25f13845caf0503630d8 Mon Sep 17 00:00:00 2001 From: Lex Date: Fri, 23 Feb 2024 18:13:30 +1000 Subject: [PATCH] Refactor structure for maintainability Refactor, add cachelib --- CHANGES.rst | 20 +- docs/api.rst | 9 +- docs/config_exceptions.rst | 8 +- docs/config_flask_session.rst | 3 +- docs/config_storage.rst | 16 + docs/index.rst | 2 +- docs/installation.rst | 51 +- docs/{quickstart.rst => usage.rst} | 37 +- src/flask_session/__init__.py | 38 +- src/flask_session/_utils.py | 5 + src/flask_session/base.py | 310 ++++++++ src/flask_session/cachelib/__init__.py | 1 + src/flask_session/cachelib/cachelib.py | 68 ++ src/flask_session/defaults.py | 4 + src/flask_session/filesystem/__init__.py | 1 + src/flask_session/filesystem/filesystem.py | 101 +++ src/flask_session/memcached/__init__.py | 1 + src/flask_session/memcached/memcached.py | 114 +++ src/flask_session/mongodb/__init__.py | 1 + src/flask_session/mongodb/mongodb.py | 113 +++ src/flask_session/redis/__init__.py | 1 + src/flask_session/redis/redis.py | 79 ++ src/flask_session/sessions.py | 827 --------------------- src/flask_session/sqlalchemy/__init__.py | 1 + src/flask_session/sqlalchemy/sqlalchemy.py | 186 +++++ tests/test_basic.py | 7 - tests/test_cachelib.py | 29 + tests/test_filesystem.py | 7 +- tests/test_memcached.py | 4 +- tests/test_mongodb.py | 5 +- tests/test_redis.py | 4 +- tests/test_sqlalchemy.py | 17 +- 32 files changed, 1183 insertions(+), 887 deletions(-) rename docs/{quickstart.rst => usage.rst} (50%) create mode 100644 src/flask_session/base.py create mode 100644 src/flask_session/cachelib/__init__.py create mode 100644 src/flask_session/cachelib/cachelib.py create mode 100644 src/flask_session/filesystem/__init__.py create mode 100644 src/flask_session/filesystem/filesystem.py create mode 100644 src/flask_session/memcached/__init__.py create mode 100644 src/flask_session/memcached/memcached.py create mode 100644 src/flask_session/mongodb/__init__.py create mode 100644 src/flask_session/mongodb/mongodb.py create mode 100644 src/flask_session/redis/__init__.py create mode 100644 src/flask_session/redis/redis.py delete mode 100644 src/flask_session/sessions.py create mode 100644 src/flask_session/sqlalchemy/__init__.py create mode 100644 src/flask_session/sqlalchemy/sqlalchemy.py create mode 100644 tests/test_cachelib.py diff --git a/CHANGES.rst b/CHANGES.rst index bbb2be8f..b7790133 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,20 +1,26 @@ Version 0.7.0 ------------------ +Added - Use msgpack for serialization, along with ``SESSION_SERIALIZATION_FORMAT`` to choose between ``json`` and ``msgpack``. -- Deprecated pickle. It is still available to read existing sessions, but will be removed in 1.0.0. All sessions will transfer to msgspec upon first interaction with 0.7.0. -- Prevent sid reuse on storage miss. - Add time-to-live expiration for MongoDB. - Add retry for SQL based storage. -- Abstraction to improve consistency between backends. -- Enforce ``PERMANENT_SESSION_LIFETIME`` as expiration consistently for all backends. -- Add logo and additional documentation. - Add ``flask session_cleanup`` command and alternatively, ``SESSION_CLEANUP_N_REQUESTS`` for SQLAlchemy or future non-TTL backends. -- Use Vary cookie header. - Type hints. +- Add logo and additional documentation. + +Deprecated +- Deprecated pickle. It is still available to read existing sessions, but will be removed in 1.0.0. All sessions will transfer to msgspec upon first interaction with 0.7.0. - Remove null session in favour of specific exception messages. - Deprecate ``SESSION_USE_SIGNER``. -- Remove backend session interfaces from public API and semver. +- Deprecate FileSystemSessionInterface in favor of the broader CacheLibSessionInterface. + +Fixed +- Prevent sid reuse on storage miss. +- Abstraction to improve consistency between backends. +- Enforce ``PERMANENT_SESSION_LIFETIME`` as expiration consistently for all backends. +- Use Vary cookie header as per Flask. +- Specifically include backend session interfaces in public API and document usage. Version 0.6.0 diff --git a/docs/api.rst b/docs/api.rst index 0f3da578..851ee882 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -9,4 +9,11 @@ Anything documented here is part of the public API that Flask-Session provides, .. autoclass:: Session :members: init_app -.. autoclass:: flask_session.sessions.ServerSideSession +.. autoclass:: flask_session.base.ServerSideSession + +.. autoclass:: flask_session.redis.RedisSessionInterface +.. autoclass:: flask_session.memcached.MemcachedSessionInterface +.. autoclass:: flask_session.filesystem.FileSystemSessionInterface +.. autoclass:: flask_session.cachelib.CacheLibSessionInterface +.. autoclass:: flask_session.mongodb.MongoDBSessionInterface +.. autoclass:: flask_session.sqlalchemy.SqlAlchemySessionInterface \ No newline at end of file diff --git a/docs/config_exceptions.rst b/docs/config_exceptions.rst index edc9ba82..fd8cde73 100644 --- a/docs/config_exceptions.rst +++ b/docs/config_exceptions.rst @@ -1,12 +1,7 @@ -Storage exceptions -=================== - -For various reasons, database operations can fail. When a database operation fails, the database client will raise an Exception. - Retries -------- -Upon an Exception, Flask-Session will retry with backoff up to 3 times for SQL based storage. If the operation still fails after 3 retries, the Exception will be raised. +Only for SQL based storage, upon an exception, Flask-Session will retry with backoff up to 3 times. If the operation still fails after 3 retries, the exception will be raised. For other storage types, the retry logic is either included or can be configured in the client setup. Refer to the client's documentation for more information. @@ -22,6 +17,7 @@ Redis example with retries on certain errors: ConnectionError, TimeoutError ) + ... retry = Retry(ExponentialBackoff(), 3) SESSION_REDIS = Redis(host='localhost', port=6379, retry=retry, retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError]) diff --git a/docs/config_flask_session.rst b/docs/config_flask_session.rst index cf643827..2cb6e426 100644 --- a/docs/config_flask_session.rst +++ b/docs/config_flask_session.rst @@ -10,7 +10,8 @@ These are specific to Flask-Session. - **redis**: RedisSessionInterface - **memcached**: MemcachedSessionInterface - - **filesystem**: FileSystemSessionInterface + - **filesystem**: FileSystemSessionInterface (Deprecated in 0.7.0, will be removed in 1.0.0 in favor of CacheLibSessionInterface) + - **cachelib**: CacheLibSessionInterface - **mongodb**: MongoDBSessionInterface - **sqlalchemy**: SqlAlchemySessionInterface diff --git a/docs/config_storage.rst b/docs/config_storage.rst index df3083aa..c61c8b34 100644 --- a/docs/config_storage.rst +++ b/docs/config_storage.rst @@ -43,6 +43,22 @@ FileSystem Default: ``0600`` +.. deprecated:: 0.7.0 + ``SESSION_FILE_MODE``, ``SESSION_FILE_THRESHOLD`` and ``SESSION_FILE_DIR``. Use ``SESSION_CACHELIB`` instead. + +Cachelib +~~~~~~~~~~~~~~~~~~~~~~~ +.. py:data:: SESSION_CACHELIB + + Any valid `cachelib backend `_. This allows you maximum flexibility in choosing the cache backend and its configuration. + + The following would set a cache directory called "flask_session" and a threshold of 500 items before it starts deleting some. + + .. code-block:: python + + app.config['SESSION_CACHELIB'] = FileSystemCache(cache_dir='flask_session', threshold=500) + + Default: ``FileSystemCache`` in ``./flask_session`` directory. MongoDB ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index 786c3b3e..087cd40d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,7 @@ Table of Contents introduction installation - quickstart + usage config api contributing diff --git a/docs/installation.rst b/docs/installation.rst index 9dc44cd5..986a23d7 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -12,7 +12,7 @@ Flask-Session's only required dependency is msgspec for serialization, which has .. note:: - You need to choose a storage type and install an appropriate client library, unless you are using the FileSystemCache. + You need to choose a storage type and install an appropriate client library. For example, if you want to use Redis as your storage, you will need to install the redis-py client library: @@ -22,8 +22,10 @@ For example, if you want to use Redis as your storage, you will need to install Redis is the recommended storage type for Flask-Session, as it has the most complete support for the features of Flask-Session with minimal configuration. -Supported storage and client libraries: +Support +-------- +Directly supported storage and client libraries: .. list-table:: :header-rows: 1 @@ -33,7 +35,7 @@ Supported storage and client libraries: * - Redis - redis-py_ * - Memcached - - pylibmc_, python-memcached_, pymemcache_ + - pylibmc_, python-memcached_ or pymemcache_ * - MongoDB - pymongo_ * - SQL Alchemy @@ -41,9 +43,48 @@ Supported storage and client libraries: Other clients may work if they use the same commands as the ones listed above. +Cachelib +-------- + +Flask-Session also indirectly supports storage and client libraries via cachelib_, which is a wrapper around various cache libraries and subject to change. You must also install cachelib_ itself to use these. + +.. warning:: + + As of writing, cachelib_ still use pickle_ as the default serializer, which may have security implications. + +Using cachlib :class:`FileSystemCache`` or :class:`SimpleCache` may be useful for development. + +.. list-table:: + :header-rows: 1 + + * - Storage + - Client Library + * - File System + - Not required + * - Simple Memory + - Not required + * - UWSGI + - uwsgi_ + * - Redis + - redis-py_ + * - Memcached + - pylibmc_, memcached, libmc_ or `google.appengine.api.memcached`_ + * - MongoDB + - pymongo_ + * - DynamoDB + - boto3_ + + + .. _redis-py: https://github.com/andymccurdy/redis-py .. _pylibmc: http://sendapatch.se/projects/pylibmc/ .. _python-memcached: https://github.com/linsomniac/python-memcached .. _pymemcache: https://github.com/pinterest/pymemcache -.. _pymongo: http://api.mongodb.org/python/current/index.html -.. _Flask-SQLAlchemy: https://github.com/pallets-eco/flask-sqlalchemy \ No newline at end of file +.. _pymongo: https://pymongo.readthedocs.io/en/stable +.. _Flask-SQLAlchemy: https://github.com/pallets-eco/flask-sqlalchemy +.. _cachelib: https://cachelib.readthedocs.io/en/stable/ +.. _google.appengine.api.memcached: https://cloud.google.com/appengine/docs/legacy/standard/python/memcache +.. _boto3: https://boto3.amazonaws.com/v1/documentation/api/latest/index.html +.. _libmc: https://github.com/douban/libmc +.. _uwsgi: https://uwsgi-docs.readthedocs.io/en/latest/WSGIquickstart.html +.. _pickle: https://docs.python.org/3/library/pickle \ No newline at end of file diff --git a/docs/quickstart.rst b/docs/usage.rst similarity index 50% rename from docs/quickstart.rst rename to docs/usage.rst index 447a76e9..66bb7ba2 100644 --- a/docs/quickstart.rst +++ b/docs/usage.rst @@ -1,6 +1,9 @@ -Quick Start +Usage =========== +Quickstart +----------- + .. currentmodule:: flask_session @@ -11,7 +14,7 @@ then create the :class:`Session` object by passing it the application. You can not use ``Session`` instance directly, what ``Session`` does is just change the :attr:`~flask.Flask.session_interface` attribute on - your Flask applications. You should always use :class:`flask.session`. + your Flask applications. You should always use :class:`flask.session` when accessing the current session. .. code-block:: python @@ -19,7 +22,7 @@ then create the :class:`Session` object by passing it the application. from flask_session import Session app = Flask(__name__) - # Check Configuration section for more details + SESSION_TYPE = 'redis' app.config.from_object(__name__) Session(app) @@ -33,10 +36,32 @@ then create the :class:`Session` object by passing it the application. def get(): return session.get('key', 'not set') -You may also set up your application later using :meth:`~Session.init_app` -method. +This would automatically setup a redis client connected to `localhost:6379` and use it to store the session data. -.. code-block:: python +See the `configuration section `_ for more details. +Alternative initialization +--------------------------- + +Rather than calling `Session(app)`, you may initialize later using :meth:`~Session.init_app`. + +.. code-block:: python + sess = Session() sess.init_app(app) + +Or, if you prefer to directly set parameters rather than using the configuration constants, you can initialize by setting the interface constructor directly to the :attr:`session_interface`. + +.. code-block:: python + + from flask_session.implementations.redis import RedisSessionInterface + + ... + + redis = Redis( + host='localhost', + port=6379, + ) + app.session_interface = RedisSessionInterface( + client=redis, + ) \ No newline at end of file diff --git a/src/flask_session/__init__.py b/src/flask_session/__init__.py index c031f8ac..93aa5a0d 100644 --- a/src/flask_session/__init__.py +++ b/src/flask_session/__init__.py @@ -1,11 +1,4 @@ from .defaults import Defaults -from .sessions import ( - FileSystemSessionInterface, - MemcachedSessionInterface, - MongoDBSessionInterface, - RedisSessionInterface, - SqlAlchemySessionInterface, -) __version__ = "0.6.0rc1" @@ -14,6 +7,8 @@ class Session: """This class is used to add Server-side Session to one or more Flask applications. + :param app: A Flask app instance. + For a typical setup use the following initialization:: app = Flask(__name__) @@ -50,10 +45,11 @@ def _get_interface(self, app): # Flask-session specific settings SESSION_TYPE = config.get("SESSION_TYPE", Defaults.SESSION_TYPE) + SESSION_PERMANENT = config.get("SESSION_PERMANENT", Defaults.SESSION_PERMANENT) SESSION_USE_SIGNER = config.get( "SESSION_USE_SIGNER", Defaults.SESSION_USE_SIGNER - ) + ) # TODO: remove in 1.0 SESSION_KEY_PREFIX = config.get( "SESSION_KEY_PREFIX", Defaults.SESSION_KEY_PREFIX ) @@ -70,7 +66,11 @@ def _get_interface(self, app): # Memcached settings SESSION_MEMCACHED = config.get("SESSION_MEMCACHED", Defaults.SESSION_MEMCACHED) + # CacheLib settings + SESSION_CACHELIB = config.get("SESSION_CACHELIB", Defaults.SESSION_CACHELIB) + # Filesystem settings + # TODO: remove in 1.0 SESSION_FILE_DIR = config.get("SESSION_FILE_DIR", Defaults.SESSION_FILE_DIR) SESSION_FILE_THRESHOLD = config.get( "SESSION_FILE_THRESHOLD", Defaults.SESSION_FILE_THRESHOLD @@ -107,7 +107,6 @@ def _get_interface(self, app): ) common_params = { - "app": app, "key_prefix": SESSION_KEY_PREFIX, "use_signer": SESSION_USE_SIGNER, "permanent": SESSION_PERMANENT, @@ -116,23 +115,37 @@ def _get_interface(self, app): } if SESSION_TYPE == "redis": + from .redis import RedisSessionInterface + session_interface = RedisSessionInterface( **common_params, - redis=SESSION_REDIS, + client=SESSION_REDIS, ) elif SESSION_TYPE == "memcached": + from .memcached import MemcachedSessionInterface + session_interface = MemcachedSessionInterface( **common_params, client=SESSION_MEMCACHED, ) elif SESSION_TYPE == "filesystem": + from .filesystem import FileSystemSessionInterface + session_interface = FileSystemSessionInterface( **common_params, cache_dir=SESSION_FILE_DIR, threshold=SESSION_FILE_THRESHOLD, mode=SESSION_FILE_MODE, ) + elif SESSION_TYPE == "cachelib": + from .cachelib import CacheLibSessionInterface + + session_interface = CacheLibSessionInterface( + **common_params, client=SESSION_CACHELIB + ) elif SESSION_TYPE == "mongodb": + from .mongodb import MongoDBSessionInterface + session_interface = MongoDBSessionInterface( **common_params, client=SESSION_MONGODB, @@ -140,9 +153,12 @@ def _get_interface(self, app): collection=SESSION_MONGODB_COLLECT, ) elif SESSION_TYPE == "sqlalchemy": + from .sqlalchemy import SqlAlchemySessionInterface + session_interface = SqlAlchemySessionInterface( + app=app, **common_params, - db=SESSION_SQLALCHEMY, + client=SESSION_SQLALCHEMY, table=SESSION_SQLALCHEMY_TABLE, sequence=SESSION_SQLALCHEMY_SEQUENCE, schema=SESSION_SQLALCHEMY_SCHEMA, diff --git a/src/flask_session/_utils.py b/src/flask_session/_utils.py index 9967fa10..cd7f6855 100644 --- a/src/flask_session/_utils.py +++ b/src/flask_session/_utils.py @@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + import time from functools import wraps from typing import Any, Callable @@ -28,6 +29,10 @@ from flask import current_app +def total_seconds(timedelta): + return int(timedelta.total_seconds()) + + def retry_query( *, max_attempts: int = 3, delay: float = 0.3, backoff: int = 2 ) -> Callable[..., Any]: diff --git a/src/flask_session/base.py b/src/flask_session/base.py new file mode 100644 index 00000000..e3853815 --- /dev/null +++ b/src/flask_session/base.py @@ -0,0 +1,310 @@ +import secrets +import warnings +from abc import ABC, abstractmethod +from contextlib import suppress + +try: + import cPickle as pickle +except ImportError: + import pickle + +import random +from datetime import timedelta as TimeDelta +from typing import Any, Optional + +import msgspec +from flask import Flask, Request, Response +from flask.sessions import SessionInterface as FlaskSessionInterface +from flask.sessions import SessionMixin +from itsdangerous import BadSignature, Signer, want_bytes +from werkzeug.datastructures import CallbackDict + +from ._utils import retry_query +from .defaults import Defaults + + +class ServerSideSession(CallbackDict, SessionMixin): + """Baseclass for server-side based sessions. This can be accessed through ``flask.session``. + + .. attribute:: sid + + Session id, internally we use :func:`secrets.token_urlsafe` to generate one + session id. + + .. attribute:: modified + + When data is changed, this is set to ``True``. Only the session dictionary + itself is tracked; if the session contains mutable data (for example a nested + dict) then this must be set to ``True`` manually when modifying that data. The + session cookie will only be written to the response if this is ``True``. + + Default is ``False``. + + .. attribute:: permanent + + This sets and reflects the ``'_permanent'`` key in the dict. + + Default is ``False``. + + """ + + def __bool__(self) -> bool: + return bool(dict(self)) and self.keys() != {"_permanent"} + + def __init__( + self, + initial: Optional[dict[str, Any]] = None, + sid: Optional[str] = None, + permanent: Optional[bool] = None, + ): + def on_update(self) -> None: + self.modified = True + + CallbackDict.__init__(self, initial, on_update) + self.sid = sid + if permanent: + self.permanent = permanent + self.modified = False + + +class Serializer(ABC): + """Baseclass for session serialization.""" + + @abstractmethod + def decode(self, serialized_data: bytes) -> dict: + """Deserialize the session data.""" + raise NotImplementedError() + + @abstractmethod + def encode(self, session: ServerSideSession) -> bytes: + """Serialize the session data.""" + raise NotImplementedError() + + +class MsgSpecSerializer(Serializer): + def __init__(self, app: Flask, format: str): + self.app = app + if format == "msgpack": + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder() + elif format == "json": + self.encoder = msgspec.json.Encoder() + self.decoder = msgspec.json.Decoder() + else: + raise ValueError(f"Unsupported serialization format: {format}") + + def encode(self, session: ServerSideSession) -> bytes: + """Serialize the session data.""" + try: + return self.encoder.encode(dict(session)) + except Exception as e: + self.app.logger.error(f"Failed to serialize session data: {e}") + raise + + def decode(self, serialized_data: bytes) -> dict: + """Deserialize the session data.""" + # TODO: Remove the pickle fallback in 1.0.0 + with suppress(msgspec.DecodeError): + return self.decoder.decode(serialized_data) + with suppress(msgspec.DecodeError): + return self.alternate_decoder.decode(serialized_data) + with suppress(msgspec.DecodeError): + return pickle.loads(serialized_data) + # If all decoders fail, raise the final exception + self.app.logger.error("Failed to deserialize session data", exc_info=True) + raise pickle.UnpicklingError("Failed to deserialize session data") + + +class ServerSideSessionInterface(FlaskSessionInterface, ABC): + """Used to open a :class:`flask.sessions.ServerSideSessionInterface` instance.""" + + session_class = ServerSideSession + serializer = None + ttl = True + + def __init__( + self, + app: Optional[Flask], + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, + ): + self.app = app + self.key_prefix = key_prefix + self.use_signer = use_signer + if use_signer: + warnings.warn( + "The 'use_signer' option is deprecated and will be removed in the next minor release. " + "Please update your configuration accordingly or open an issue.", + DeprecationWarning, + stacklevel=1, + ) + self.permanent = permanent + self.sid_length = sid_length + self.has_same_site_capability = hasattr(self, "get_cookie_samesite") + self.cleanup_n_requests = cleanup_n_requests + + # Cleanup settings for non-TTL databases only + if getattr(self, "ttl", None) is False: + if self.cleanup_n_requests: + self.app.before_request(self._cleanup_n_requests) + else: + self._register_cleanup_app_command() + + # Set the serialization format + self.serializer = MsgSpecSerializer(format=serialization_format, app=app) + + # INTERNAL METHODS + + def _generate_sid(self, session_id_length: int) -> str: + """Generate a random session id.""" + return secrets.token_urlsafe(session_id_length) + + # TODO: Remove in 1.0.0 + def _get_signer(self, app: Flask) -> Signer: + if not hasattr(app, "secret_key") or not app.secret_key: + raise KeyError("SECRET_KEY must be set when SESSION_USE_SIGNER=True") + return Signer(app.secret_key, salt="flask-session", key_derivation="hmac") + + # TODO: Remove in 1.0.0 + def _unsign(self, app, sid: str) -> str: + signer = self._get_signer(app) + sid_as_bytes = signer.unsign(sid) + sid = sid_as_bytes.decode() + return sid + + # TODO: Remove in 1.0.0 + def _sign(self, app, sid: str) -> str: + signer = self._get_signer(app) + sid_as_bytes = want_bytes(sid) + return signer.sign(sid_as_bytes).decode("utf-8") + + def _get_store_id(self, sid: str) -> str: + return self.key_prefix + sid + + # CLEANUP METHODS FOR NON TTL DATABASES + + def _register_cleanup_app_command(self): + """ + Register a custom Flask CLI command for cleaning up expired sessions. + + Run the command with `flask session_cleanup`. Run with a cron job + or scheduler such as Heroku Scheduler to automatically clean up expired sessions. + """ + + @self.app.cli.command("session_cleanup") + def session_cleanup(): + with self.app.app_context(): + self._delete_expired_sessions() + + def _cleanup_n_requests(self) -> None: + """ + Delete expired sessions on average every N requests. + + This is less desirable than using the scheduled app command cleanup as it may + slow down some requests but may be useful for rapid development. + """ + if self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0: + self._delete_expired_sessions() + + # METHODS OVERRIDE FLASK SESSION INTERFACE + + def save_session( + self, app: Flask, session: ServerSideSession, response: Response + ) -> None: + if not self.should_set_cookie(app, session): + return + + # Get the domain and path for the cookie from the app + domain = self.get_cookie_domain(app) + path = self.get_cookie_path(app) + + # Generate a prefixed session id + store_id = self._get_store_id(session.sid) + + # If the session is empty, do not save it to the database or set a cookie + if not session: + # If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie + if session.modified: + self._delete_session(store_id) + response.delete_cookie( + app.config["SESSION_COOKIE_NAME"], domain=domain, path=path + ) + response.vary.add("Cookie") + return + + # Update existing or create new session in the database + self._upsert_session(app.permanent_session_lifetime, session, store_id) + + # Set the browser cookie + response.set_cookie( + key=app.config["SESSION_COOKIE_NAME"], + value=self._sign(app, session.sid) if self.use_signer else session.sid, + expires=self.get_expiration_time(app, session), + httponly=self.get_cookie_httponly(app), + domain=self.get_cookie_domain(app), + path=self.get_cookie_path(app), + secure=self.get_cookie_secure(app), + samesite=( + self.get_cookie_samesite(app) if self.has_same_site_capability else None + ), + ) + response.vary.add("Cookie") + + def open_session(self, app: Flask, request: Request) -> ServerSideSession: + # Get the session ID from the cookie + sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"]) + + # If there's no session ID, generate a new one + if not sid: + sid = self._generate_sid(self.sid_length) + return self.session_class(sid=sid, permanent=self.permanent) + # If the session ID is signed, unsign it + if self.use_signer: + try: + sid = self._unsign(app, sid) + except BadSignature: + sid = self._generate_sid(self.sid_length) + return self.session_class(sid=sid, permanent=self.permanent) + + # Retrieve the session data from the database + store_id = self._get_store_id(sid) + saved_session_data = self._retrieve_session_data(store_id) + + # If the saved session exists, load the session data from the document + if saved_session_data is not None: + return self.session_class(saved_session_data, sid=sid) + + # If the saved session does not exist, create a new session + sid = self._generate_sid(self.sid_length) + return self.session_class(sid=sid, permanent=self.permanent) + + # METHODS TO BE IMPLEMENTED BY SUBCLASSES + + @abstractmethod + @retry_query() # use only when retry not supported directly by the client + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + """Get the saved session from the session storage.""" + raise NotImplementedError() + + @abstractmethod + @retry_query() # use only when retry not supported directly by the client + def _delete_session(self, store_id: str) -> None: + """Delete session from the session storage.""" + raise NotImplementedError() + + @abstractmethod + @retry_query() # use only when retry not supported directly by the client + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + """Update existing or create new session in the session storage.""" + raise NotImplementedError() + + @retry_query() # use only when retry not supported directly by the client + def _delete_expired_sessions(self) -> None: + """Delete expired sessions from the session storage. Only required for non-TTL databases.""" + pass diff --git a/src/flask_session/cachelib/__init__.py b/src/flask_session/cachelib/__init__.py new file mode 100644 index 00000000..959c5d5d --- /dev/null +++ b/src/flask_session/cachelib/__init__.py @@ -0,0 +1 @@ +from .cachelib import CacheLibSessionInterface, CacheLibSession diff --git a/src/flask_session/cachelib/cachelib.py b/src/flask_session/cachelib/cachelib.py new file mode 100644 index 00000000..62335391 --- /dev/null +++ b/src/flask_session/cachelib/cachelib.py @@ -0,0 +1,68 @@ +from datetime import timedelta as TimeDelta +from typing import Optional + +from flask import Flask +from .._utils import total_seconds +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface +from cachelib.file import FileSystemCache +import warnings + + +class CacheLibSession(ServerSideSession): + pass + + +class CacheLibSessionInterface(ServerSideSessionInterface): + """Uses any :class:`cachelib` backend as a session storage. + + :param client: A :class:`cachelib` backend instance. + :param key_prefix: A prefix that is added to store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + """ + + session_class = CacheLibSession + ttl = True + + def __init__( + self, + client: Optional[FileSystemCache] = Defaults.SESSION_CACHELIB, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + ): + + if client is None: + client = FileSystemCache("flask_session", threshold=500) + self.cache = client + + super().__init__( + None, key_prefix, use_signer, permanent, sid_length, serialization_format + ) + + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (item) from the database + return self.cache.get(store_id) + + def _delete_session(self, store_id: str) -> None: + self.cache.delete(store_id) + + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_time_to_live = total_seconds(session_lifetime) + + # Serialize the session data (or just cast into dictionary in this case) + session_data = dict(session) + + # Update existing or create new session in the database + self.cache.set( + key=store_id, + value=session_data, + timeout=storage_time_to_live, + ) diff --git a/src/flask_session/defaults.py b/src/flask_session/defaults.py index 1df76711..ef968e4f 100644 --- a/src/flask_session/defaults.py +++ b/src/flask_session/defaults.py @@ -19,7 +19,11 @@ class Defaults: # Memcached settings SESSION_MEMCACHED = None + # CacheLib settings + SESSION_CACHELIB = None + # Filesystem settings + # TODO: remove in 1.0 SESSION_FILE_DIR = os.path.join(os.getcwd(), "flask_session") SESSION_FILE_THRESHOLD = 500 SESSION_FILE_MODE = 384 diff --git a/src/flask_session/filesystem/__init__.py b/src/flask_session/filesystem/__init__.py new file mode 100644 index 00000000..79792687 --- /dev/null +++ b/src/flask_session/filesystem/__init__.py @@ -0,0 +1 @@ +from .filesystem import FileSystemSessionInterface, FileSystemSession diff --git a/src/flask_session/filesystem/filesystem.py b/src/flask_session/filesystem/filesystem.py new file mode 100644 index 00000000..6cf1b392 --- /dev/null +++ b/src/flask_session/filesystem/filesystem.py @@ -0,0 +1,101 @@ +from datetime import timedelta as TimeDelta +from typing import Optional + +from flask import Flask +from .._utils import total_seconds +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface +from cachelib.file import FileSystemCache +import warnings + + +class FileSystemSession(ServerSideSession): + pass + + +class FileSystemSessionInterface(ServerSideSessionInterface): + """Uses the :class:`cachelib.file.FileSystemCache` as a session storage. + + :param key_prefix: A prefix that is added to FileSystemCache store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + :param cache_dir: the directory where session files are stored. + :param threshold: the maximum number of items the session stores before it + :param mode: the file mode wanted for the session files, default 0600 + + .. versionadded:: 0.7 + The `serialization_format` and `app` parameters were added. + + .. versionadded:: 0.6 + The `sid_length` parameter was added. + + .. versionadded:: 0.2 + The `use_signer` parameter was added. + """ + + session_class = FileSystemSession + ttl = True + + def __init__( + self, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + cache_dir: str = Defaults.SESSION_FILE_DIR, + threshold: int = Defaults.SESSION_FILE_THRESHOLD, + mode: int = Defaults.SESSION_FILE_MODE, + ): + + # Deprecation warnings + if cache_dir != Defaults.SESSION_FILE_DIR: + warnings.warn( + "'SESSION_FILE_DIR' is deprecated and will be removed in a future release. Instead pass FileSystemCache(directory, threshold, mode) instance as SESSION_CACHELIB.", + DeprecationWarning, + stacklevel=2, + ) + if threshold != Defaults.SESSION_FILE_THRESHOLD: + warnings.warn( + "'SESSION_FILE_THRESHOLD' is deprecated and will be removed in a future release. Instead pass FileSystemCache(directory, threshold, mode) instance as SESSION_CLIENT.", + DeprecationWarning, + stacklevel=2, + ) + if mode != Defaults.SESSION_FILE_MODE: + warnings.warn( + "'SESSION_FILE_MODE' is deprecated and will be removed in a future release. Instead pass FileSystemCache(directory, threshold, mode) instance as SESSION_CLIENT.", + DeprecationWarning, + stacklevel=2, + ) + + self.cache = FileSystemCache( + cache_dir=cache_dir, threshold=threshold, mode=mode + ) + + super().__init__( + None, key_prefix, use_signer, permanent, sid_length, serialization_format + ) + + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (item) from the database + return self.cache.get(store_id) + + def _delete_session(self, store_id: str) -> None: + self.cache.delete(store_id) + + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_time_to_live = total_seconds(session_lifetime) + + # Serialize the session data (or just cast into dictionary in this case) + session_data = dict(session) + + # Update existing or create new session in the database + self.cache.set( + key=store_id, + value=session_data, + timeout=storage_time_to_live, + ) diff --git a/src/flask_session/memcached/__init__.py b/src/flask_session/memcached/__init__.py new file mode 100644 index 00000000..9ca54ee1 --- /dev/null +++ b/src/flask_session/memcached/__init__.py @@ -0,0 +1 @@ +from .memcached import MemcachedSessionInterface, MemcachedSession diff --git a/src/flask_session/memcached/memcached.py b/src/flask_session/memcached/memcached.py new file mode 100644 index 00000000..edbf0b63 --- /dev/null +++ b/src/flask_session/memcached/memcached.py @@ -0,0 +1,114 @@ +from datetime import timedelta as TimeDelta +from typing import Any, Optional + +import msgspec +import time +from flask import Flask +from .._utils import total_seconds +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface +from typing import Protocol, Optional, Any + + +class MemcacheClientProtocol(Protocol): + def get(self, key: str) -> Optional[Any]: ... + def set(self, key: str, value: Any, timeout: int) -> bool: ... + def delete(self, key: str) -> bool: ... + + +class MemcachedSession(ServerSideSession): + pass + + +class MemcachedSessionInterface(ServerSideSessionInterface): + """A Session interface that uses memcached as session storage. (`pylibmc`, `libmc`, `python-memcached` or `pymemcache` required) + + :param client: A ``memcache.Client`` instance. + :param key_prefix: A prefix that is added to all Memcached store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + + .. versionadded:: 0.7 + The `serialization_format` and `app` parameters were added. + + .. versionadded:: 0.6 + The `sid_length` parameter was added. + + .. versionadded:: 0.2 + The `use_signer` parameter was added. + """ + + serializer = ServerSideSessionInterface.serializer + session_class = MemcachedSession + ttl = True + + def __init__( + self, + client: Optional[MemcacheClientProtocol] = Defaults.SESSION_MEMCACHED, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + ): + if client is None: + client = self._get_preferred_memcache_client() + self.client = client + super().__init__( + None, key_prefix, use_signer, permanent, sid_length, serialization_format + ) + + def _get_preferred_memcache_client(self): + clients = [ + ("pylibmc", ["127.0.0.1:11211"]), + ("memcache", ["127.0.0.1:11211"]), # python-memcached + ("pymemcache.client.base", "127.0.0.1:11211"), + ("libmc", ["localhost:11211"]), + ] + + for module_name, server in clients: + try: + module = __import__(module_name) + ClientClass = module.Client + return ClientClass(server) + except ImportError: + continue + + raise ImportError("No memcache module found") + + def _get_memcache_timeout(self, timeout: int) -> int: + """ + Memcached deals with long (> 30 days) timeouts in a special + way. Call this function to obtain a safe value for your timeout. + """ + if timeout > 2592000: # 60*60*24*30, 30 days + # Switch to absolute timestamps. + timeout += int(time.time()) + return timeout + + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (item) from the database + serialized_session_data = self.client.get(store_id) + if serialized_session_data: + return self.serializer.decode(serialized_session_data) + return None + + def _delete_session(self, store_id: str) -> None: + self.client.delete(store_id) + + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_time_to_live = total_seconds(session_lifetime) + + # Serialize the session data + serialized_session_data = self.serializer.encode(session) + + # Update existing or create new session in the database + self.client.set( + store_id, + serialized_session_data, + self._get_memcache_timeout(storage_time_to_live), + ) diff --git a/src/flask_session/mongodb/__init__.py b/src/flask_session/mongodb/__init__.py new file mode 100644 index 00000000..1a003df7 --- /dev/null +++ b/src/flask_session/mongodb/__init__.py @@ -0,0 +1 @@ +from .mongodb import MongoDBSessionInterface, MongoDBSession diff --git a/src/flask_session/mongodb/mongodb.py b/src/flask_session/mongodb/mongodb.py new file mode 100644 index 00000000..52dfddcf --- /dev/null +++ b/src/flask_session/mongodb/mongodb.py @@ -0,0 +1,113 @@ +from datetime import datetime +from datetime import timedelta as TimeDelta +from typing import Any, Optional + +import msgspec +from flask import Flask +from itsdangerous import want_bytes +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface + +from pymongo import MongoClient, version + + +class MongoDBSession(ServerSideSession): + pass + + +class MongoDBSessionInterface(ServerSideSessionInterface): + """A Session interface that uses mongodb as session storage. (`pymongo` required) + + :param client: A ``pymongo.MongoClient`` instance. + :param key_prefix: A prefix that is added to all MongoDB store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + :param db: The database you want to use. + :param collection: The collection you want to use. + + .. versionadded:: 0.7 + The `serialization_format` and `app` parameters were added. + + .. versionadded:: 0.6 + The `sid_length` parameter was added. + + .. versionadded:: 0.2 + The `use_signer` parameter was added. + """ + + session_class = MongoDBSession + ttl = True + + def __init__( + self, + client: Optional[MongoClient] = Defaults.SESSION_MONGODB, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + db: str = Defaults.SESSION_MONGODB_DB, + collection: str = Defaults.SESSION_MONGODB_COLLECT, + ): + + if client is None: + client = MongoClient() + + self.client = client + self.store = client[db][collection] + self.use_deprecated_method = int(version.split(".")[0]) < 4 + + # Create a TTL index on the expiration time, so that mongo can automatically delete expired sessions + self.store.create_index("expiration", expireAfterSeconds=0) + + super().__init__( + None, key_prefix, use_signer, permanent, sid_length, serialization_format + ) + + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (document) from the database + document = self.store.find_one({"id": store_id}) + if document: + serialized_session_data = want_bytes(document["val"]) + return self.serializer.decode(serialized_session_data) + return None + + def _delete_session(self, store_id: str) -> None: + if self.use_deprecated_method: + self.store.remove({"id": store_id}) + else: + self.store.delete_one({"id": store_id}) + + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_expiration_datetime = datetime.utcnow() + session_lifetime + + # Serialize the session data + serialized_session_data = self.serializer.encode(session) + + # Update existing or create new session in the database + if self.use_deprecated_method: + self.store.update( + {"id": store_id}, + { + "id": store_id, + "val": serialized_session_data, + "expiration": storage_expiration_datetime, + }, + True, + ) + else: + self.store.update_one( + {"id": store_id}, + { + "$set": { + "id": store_id, + "val": serialized_session_data, + "expiration": storage_expiration_datetime, + } + }, + True, + ) diff --git a/src/flask_session/redis/__init__.py b/src/flask_session/redis/__init__.py new file mode 100644 index 00000000..8c9f2a9e --- /dev/null +++ b/src/flask_session/redis/__init__.py @@ -0,0 +1 @@ +from .redis import RedisSessionInterface, RedisSession diff --git a/src/flask_session/redis/redis.py b/src/flask_session/redis/redis.py new file mode 100644 index 00000000..9833fded --- /dev/null +++ b/src/flask_session/redis/redis.py @@ -0,0 +1,79 @@ +from datetime import timedelta as TimeDelta +from typing import Any, Optional + +import msgspec + +from flask import Flask +from .._utils import total_seconds +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface +from redis import Redis + + +class RedisSession(ServerSideSession): + pass + + +class RedisSessionInterface(ServerSideSessionInterface): + """Uses the Redis key-value store as a session storage. (`redis-py` required) + + :param client: A ``redis.Redis`` instance. + :param key_prefix: A prefix that is added to all Redis store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + + .. versionadded:: 0.7 + The `serialization_format` and `app` parameters were added. + + .. versionadded:: 0.6 + The `sid_length` parameter was added. + + .. versionadded:: 0.2 + The `use_signer` parameter was added. + """ + + session_class = RedisSession + ttl = True + + def __init__( + self, + client: Optional[Redis] = Defaults.SESSION_REDIS, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + ): + if client is None: + client = Redis() + self.client = client + super().__init__( + None, key_prefix, use_signer, permanent, sid_length, serialization_format + ) + + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (value) from the database + serialized_session_data = self.client.get(store_id) + if serialized_session_data: + return self.serializer.decode(serialized_session_data) + return None + + def _delete_session(self, store_id: str) -> None: + self.client.delete(store_id) + + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_time_to_live = total_seconds(session_lifetime) + + # Serialize the session data + serialized_session_data = self.serializer.encode(session) + + # Update existing or create new session in the database + self.client.set( + name=store_id, + value=serialized_session_data, + ex=storage_time_to_live, + ) diff --git a/src/flask_session/sessions.py b/src/flask_session/sessions.py deleted file mode 100644 index 8182bed5..00000000 --- a/src/flask_session/sessions.py +++ /dev/null @@ -1,827 +0,0 @@ -import secrets -import time -import warnings -from abc import ABC -from contextlib import suppress - -try: - import cPickle as pickle -except ImportError: - import pickle - -import random -from datetime import datetime -from datetime import timedelta as TimeDelta -from typing import Any, Optional - -import msgspec -from flask import Flask, Request, Response -from flask.sessions import SessionInterface as FlaskSessionInterface -from flask.sessions import SessionMixin -from itsdangerous import BadSignature, Signer, want_bytes -from werkzeug.datastructures import CallbackDict - -from ._utils import retry_query -from .defaults import Defaults - - -def total_seconds(timedelta): - return int(timedelta.total_seconds()) - - -class ServerSideSession(CallbackDict, SessionMixin): - """Baseclass for server-side based sessions. This can be accessed through ``flask.session``. - - .. attribute:: sid - - Session id, internally we use :func:`secrets.token_urlsafe` to generate one - session id. - - .. attribute:: modified - - When data is changed, this is set to ``True``. Only the session dictionary - itself is tracked; if the session contains mutable data (for example a nested - dict) then this must be set to ``True`` manually when modifying that data. The - session cookie will only be written to the response if this is ``True``. - - Default is ``False``. - - .. attribute:: permanent - - This sets and reflects the ``'_permanent'`` key in the dict. - - Default is ``False``. - - """ - - def __bool__(self) -> bool: - return bool(dict(self)) and self.keys() != {"_permanent"} - - def __init__( - self, - initial: Optional[dict[str, Any]] = None, - sid: Optional[str] = None, - permanent: Optional[bool] = None, - ): - def on_update(self) -> None: - self.modified = True - - CallbackDict.__init__(self, initial, on_update) - self.sid = sid - if permanent: - self.permanent = permanent - self.modified = False - - -class RedisSession(ServerSideSession): - pass - - -class MemcachedSession(ServerSideSession): - pass - - -class FileSystemSession(ServerSideSession): - pass - - -class MongoDBSession(ServerSideSession): - pass - - -class SqlAlchemySession(ServerSideSession): - pass - - -class SessionInterface(FlaskSessionInterface): - def _generate_sid(self, session_id_length: int) -> str: - return secrets.token_urlsafe(session_id_length) - - def __get_signer(self, app: Flask) -> Signer: - if not hasattr(app, "secret_key") or not app.secret_key: - raise KeyError("SECRET_KEY must be set when SESSION_USE_SIGNER=True") - return Signer(app.secret_key, salt="flask-session", key_derivation="hmac") - - def _unsign(self, app, sid: str) -> str: - signer = self.__get_signer(app) - sid_as_bytes = signer.unsign(sid) - sid = sid_as_bytes.decode() - return sid - - def _sign(self, app, sid: str) -> str: - signer = self.__get_signer(app) - sid_as_bytes = want_bytes(sid) - return signer.sign(sid_as_bytes).decode("utf-8") - - def _serialize(self, session: ServerSideSession) -> bytes: - return self.encoder.encode(dict(session)) - - def _deserialize(self, serialized_data): - with suppress(msgspec.DecodeError): - return self.decoder.decode(serialized_data) - with suppress(msgspec.DecodeError): - return self.alternate_decoder.decode(serialized_data) - with suppress(msgspec.DecodeError): - return pickle.loads(serialized_data) - # If all decoders fail, raise the original exception - raise pickle.UnpicklingError("Failed to deserialize session data") - - def _get_store_id(self, sid: str) -> str: - return self.key_prefix + sid - - -class ServerSideSessionInterface(SessionInterface, ABC): - """Used to open a :class:`flask.sessions.ServerSideSessionInterface` instance.""" - - session_class = ServerSideSession - serializer = None - ttl = True - - def __init__( - self, - app: Flask, - key_prefix: str = Defaults.SESSION_KEY_PREFIX, - use_signer: bool = Defaults.SESSION_USE_SIGNER, - permanent: bool = Defaults.SESSION_PERMANENT, - sid_length: int = Defaults.SESSION_SID_LENGTH, - serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, - cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, - ): - self.app = app - self.key_prefix = key_prefix - self.use_signer = use_signer - if use_signer: - warnings.warn( - "The 'use_signer' option is deprecated and will be removed in the next minor release. " - "Please update your configuration accordingly or open an issue.", - DeprecationWarning, - stacklevel=1, - ) - self.permanent = permanent - self.sid_length = sid_length - self.has_same_site_capability = hasattr(self, "get_cookie_samesite") - self.cleanup_n_requests = cleanup_n_requests - - # Cleanup settings for non-TTL databases only - if getattr(self, "ttl", None) is False: - if self.cleanup_n_requests: - self.app.before_request(self._cleanup_n_requests) - else: - self._register_cleanup_app_command() - - # Set the serialization format - if serialization_format == "msgpack": - self.decoder = msgspec.msgpack.Decoder() - self.encoder = msgspec.msgpack.Encoder() - self.alternate_decoder = msgspec.json.Decoder() - elif serialization_format == "json": - self.decoder = msgspec.json.Decoder() - self.encoder = msgspec.json.Encoder() - self.alternate_decoder = msgspec.msgpack.Decoder() - else: - value = app.config.get("SESSION_SERIALIZATION_FORMAT") - raise ValueError( - f"Unrecognized value for SESSION_SERIALIZATION_FORMAT: {value}" - ) - - def save_session( - self, app: Flask, session: ServerSideSession, response: Response - ) -> None: - if not self.should_set_cookie(app, session): - return - - # Get the domain and path for the cookie from the app - domain = self.get_cookie_domain(app) - path = self.get_cookie_path(app) - - # Generate a prefixed session id - store_id = self._get_store_id(session.sid) - - # If the session is empty, do not save it to the database or set a cookie - if not session: - # If the session was deleted (empty and modified), delete the saved session from the database and tell the client to delete the cookie - if session.modified: - self._delete_session(store_id) - response.delete_cookie( - app.config["SESSION_COOKIE_NAME"], domain=domain, path=path - ) - response.vary.add("Cookie") - return - - # Update existing or create new session in the database - self._upsert_session(app.permanent_session_lifetime, session, store_id) - - # Set the browser cookie - response.set_cookie( - key=app.config["SESSION_COOKIE_NAME"], - value=self._sign(app, session.sid) if self.use_signer else session.sid, - expires=self.get_expiration_time(app, session), - httponly=self.get_cookie_httponly(app), - domain=self.get_cookie_domain(app), - path=self.get_cookie_path(app), - secure=self.get_cookie_secure(app), - samesite=( - self.get_cookie_samesite(app) if self.has_same_site_capability else None - ), - ) - response.vary.add("Cookie") - - def open_session(self, app: Flask, request: Request) -> ServerSideSession: - # Get the session ID from the cookie - sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"]) - - # If there's no session ID, generate a new one - if not sid: - sid = self._generate_sid(self.sid_length) - return self.session_class(sid=sid, permanent=self.permanent) - # If the session ID is signed, unsign it - if self.use_signer: - try: - sid = self._unsign(app, sid) - except BadSignature: - sid = self._generate_sid(self.sid_length) - return self.session_class(sid=sid, permanent=self.permanent) - - # Retrieve the session data from the database - store_id = self._get_store_id(sid) - saved_session_data = self._retrieve_session_data(store_id) - - # If the saved session exists, load the session data from the document - if saved_session_data is not None: - return self.session_class(saved_session_data, sid=sid) - - # If the saved session does not exist, create a new session - sid = self._generate_sid(self.sid_length) - return self.session_class(sid=sid, permanent=self.permanent) - - # CLEANUP METHODS FOR NON TTL DATABASES - - def _register_cleanup_app_command(self): - """ - Register a custom Flask CLI command for cleaning up expired sessions. - - Run the command with `flask session_cleanup`. Run with a cron job - or scheduler such as Heroku Scheduler to automatically clean up expired sessions. - """ - - @self.app.cli.command("session_cleanup") - def session_cleanup(): - with self.app.app_context(): - self._delete_expired_sessions() - - def _cleanup_n_requests(self) -> None: - """ - Delete expired sessions on average every N requests. - - This is less desirable than using the scheduled app command cleanup as it may - slow down some requests but may be useful for rapid development. - """ - if self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0: - self._delete_expired_sessions() - - # METHODS TO BE IMPLEMENTED BY SUBCLASSES - - @retry_query() - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - raise NotImplementedError() - - @retry_query() - def _delete_session(self, store_id: str) -> None: - raise NotImplementedError() - - @retry_query() - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - raise NotImplementedError() - - @retry_query() - def _delete_expired_sessions(self) -> None: - """Delete expired sessions from the session storage. Only required for non-TTL databases.""" - pass - - -class RedisSessionInterface(ServerSideSessionInterface): - """Uses the Redis key-value store as a session storage. (`redis-py` required) - - :param app: A Flask app instance. - :param key_prefix: A prefix that is added to all Redis store keys. - :param use_signer: Whether to sign the session id cookie or not. - :param permanent: Whether to use permanent session or not. - :param sid_length: The length of the generated session id in bytes. - :param serialization_format: The serialization format to use for the session data. - :param redis: A ``redis.Redis`` instance. - - .. versionadded:: 0.7 - The `serialization_format` and `app` parameters were added. - - .. versionadded:: 0.6 - The `sid_length` parameter was added. - - .. versionadded:: 0.2 - The `use_signer` parameter was added. - """ - - session_class = RedisSession - ttl = True - - def __init__( - self, - app: Flask, - key_prefix: str, - use_signer: bool, - permanent: bool, - sid_length: int, - serialization_format: str, - redis: Any = Defaults.SESSION_REDIS, - ): - if redis is None: - from redis import Redis - - redis = Redis() - self.redis = redis - super().__init__( - app, key_prefix, use_signer, permanent, sid_length, serialization_format - ) - - @retry_query() - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - # Get the saved session (value) from the database - serialized_session_data = self.redis.get(store_id) - if serialized_session_data: - try: - session_data = self._deserialize(serialized_session_data) - return session_data - except msgspec.DecodeError: - self.app.logger.error( - "Failed to deserialize session data", exc_info=True - ) - return None - - @retry_query() - def _delete_session(self, store_id: str) -> None: - self.redis.delete(store_id) - - @retry_query() - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - storage_time_to_live = total_seconds(session_lifetime) - - # Serialize the session data - serialized_session_data = self._serialize(dict(session)) - - # Update existing or create new session in the database - self.redis.set( - name=store_id, - value=serialized_session_data, - ex=storage_time_to_live, - ) - - -class MemcachedSessionInterface(ServerSideSessionInterface): - """A Session interface that uses memcached as session storage. (`pylibmc` or `python-memcached` or `pymemcache` required) - - :param app: A Flask app instance. - :param key_prefix: A prefix that is added to all Memcached store keys. - :param use_signer: Whether to sign the session id cookie or not. - :param permanent: Whether to use permanent session or not. - :param sid_length: The length of the generated session id in bytes. - :param serialization_format: The serialization format to use for the session data. - :param client: A ``memcache.Client`` instance. - - .. versionadded:: 0.7 - The `serialization_format` and `app` parameters were added. - - .. versionadded:: 0.6 - The `sid_length` parameter was added. - - .. versionadded:: 0.2 - The `use_signer` parameter was added. - """ - - session_class = MemcachedSession - ttl = True - - def __init__( - self, - app: Flask, - key_prefix: str, - use_signer: bool, - permanent: bool, - sid_length: int, - serialization_format: str, - client: Any = Defaults.SESSION_MEMCACHED, - ): - if client is None: - client = self._get_preferred_memcache_client() - self.client = client - super().__init__( - app, key_prefix, use_signer, permanent, sid_length, serialization_format - ) - - def _get_preferred_memcache_client(self): - clients = [ - ("pylibmc", ["127.0.0.1:11211"]), - ("memcache", ["127.0.0.1:11211"]), - ("pymemcache.client.base", "127.0.0.1:11211"), - ] - - for module_name, server in clients: - try: - module = __import__(module_name) - ClientClass = module.Client - return ClientClass(server) - except ImportError: - continue - - raise ImportError("No memcache module found") - - def _get_memcache_timeout(self, timeout: int) -> int: - """ - Memcached deals with long (> 30 days) timeouts in a special - way. Call this function to obtain a safe value for your timeout. - """ - if timeout > 2592000: # 60*60*24*30, 30 days - # Switch to absolute timestamps. - timeout += int(time.time()) - return timeout - - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - # Get the saved session (item) from the database - serialized_session_data = self.client.get(store_id) - if serialized_session_data: - try: - session_data = self._deserialize(serialized_session_data) - return session_data - except msgspec.DecodeError: - self.app.logger.error( - "Failed to deserialize session data", exc_info=True - ) - return None - - def _delete_session(self, store_id: str) -> None: - self.client.delete(store_id) - - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - storage_time_to_live = total_seconds(session_lifetime) - - # Serialize the session data - serialized_session_data = self._serialize(dict(session)) - - # Update existing or create new session in the database - self.client.set( - store_id, - serialized_session_data, - self._get_memcache_timeout(storage_time_to_live), - ) - - -class FileSystemSessionInterface(ServerSideSessionInterface): - """Uses the :class:`cachelib.file.FileSystemCache` as a session storage. - - :param app: A Flask app instance. - :param key_prefix: A prefix that is added to FileSystemCache store keys. - :param use_signer: Whether to sign the session id cookie or not. - :param permanent: Whether to use permanent session or not. - :param sid_length: The length of the generated session id in bytes. - :param serialization_format: The serialization format to use for the session data. - :param cache_dir: the directory where session files are stored. - :param threshold: the maximum number of items the session stores before it - :param mode: the file mode wanted for the session files, default 0600 - - .. versionadded:: 0.7 - The `serialization_format` and `app` parameters were added. - - .. versionadded:: 0.6 - The `sid_length` parameter was added. - - .. versionadded:: 0.2 - The `use_signer` parameter was added. - """ - - session_class = FileSystemSession - ttl = True - - def __init__( - self, - app: Flask, - key_prefix: str, - use_signer: bool, - permanent: bool, - sid_length: int, - serialization_format: str, - cache_dir: str = Defaults.SESSION_FILE_DIR, - threshold: int = Defaults.SESSION_FILE_THRESHOLD, - mode: int = Defaults.SESSION_FILE_MODE, - ): - from cachelib.file import FileSystemCache - - self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode) - super().__init__( - app, key_prefix, use_signer, permanent, sid_length, serialization_format - ) - - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - # Get the saved session (item) from the database - return self.cache.get(store_id) - - def _delete_session(self, store_id: str) -> None: - self.cache.delete(store_id) - - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - storage_time_to_live = total_seconds(session_lifetime) - - # Serialize the session data (or just cast into dictionary in this case) - session_data = dict(session) - - # Update existing or create new session in the database - self.cache.set( - key=store_id, - value=session_data, - timeout=storage_time_to_live, - ) - - -class MongoDBSessionInterface(ServerSideSessionInterface): - """A Session interface that uses mongodb as session storage. (`pymongo` required) - - :param app: A Flask app instance. - :param key_prefix: A prefix that is added to all MongoDB store keys. - :param use_signer: Whether to sign the session id cookie or not. - :param permanent: Whether to use permanent session or not. - :param sid_length: The length of the generated session id in bytes. - :param serialization_format: The serialization format to use for the session data. - :param client: A ``pymongo.MongoClient`` instance. - :param db: The database you want to use. - :param collection: The collection you want to use. - - .. versionadded:: 0.7 - The `serialization_format` and `app` parameters were added. - - .. versionadded:: 0.6 - The `sid_length` parameter was added. - - .. versionadded:: 0.2 - The `use_signer` parameter was added. - """ - - session_class = MongoDBSession - ttl = True - - def __init__( - self, - app: Flask, - key_prefix: str, - use_signer: bool, - permanent: bool, - sid_length: int, - serialization_format: str, - client: Any = Defaults.SESSION_MONGODB, - db: str = Defaults.SESSION_MONGODB_DB, - collection: str = Defaults.SESSION_MONGODB_COLLECT, - ): - import pymongo - - if client is None: - client = pymongo.MongoClient() - - self.client = client - self.store = client[db][collection] - self.use_deprecated_method = int(pymongo.version.split(".")[0]) < 4 - - # Create a TTL index on the expiration time, so that mongo can automatically delete expired sessions - self.store.create_index("expiration", expireAfterSeconds=0) - - super().__init__( - app, key_prefix, use_signer, permanent, sid_length, serialization_format - ) - - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - # Get the saved session (document) from the database - document = self.store.find_one({"id": store_id}) - if document: - serialized_session_data = want_bytes(document["val"]) - try: - session_data = self._deserialize(serialized_session_data) - return session_data - except msgspec.DecodeError: - self.app.logger.error( - "Failed to deserialize session data", exc_info=True - ) - return None - - def _delete_session(self, store_id: str) -> None: - if self.use_deprecated_method: - self.store.remove({"id": store_id}) - else: - self.store.delete_one({"id": store_id}) - - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - storage_expiration_datetime = datetime.utcnow() + session_lifetime - - # Serialize the session data - serialized_session_data = self._serialize(dict(session)) - - # Update existing or create new session in the database - if self.use_deprecated_method: - self.store.update( - {"id": store_id}, - { - "id": store_id, - "val": serialized_session_data, - "expiration": storage_expiration_datetime, - }, - True, - ) - else: - self.store.update_one( - {"id": store_id}, - { - "$set": { - "id": store_id, - "val": serialized_session_data, - "expiration": storage_expiration_datetime, - } - }, - True, - ) - - -class SqlAlchemySessionInterface(ServerSideSessionInterface): - """Uses the Flask-SQLAlchemy from a flask app as session storage. - - :param app: A Flask app instance. - :param key_prefix: A prefix that is added to all store keys. - :param use_signer: Whether to sign the session id cookie or not. - :param permanent: Whether to use permanent session or not. - :param sid_length: The length of the generated session id in bytes. - :param serialization_format: The serialization format to use for the session data. - :param db: A Flask-SQLAlchemy instance. - :param table: The table name you want to use. - :param sequence: The sequence to use for the primary key if needed. - :param schema: The db schema to use - :param bind_key: The db bind key to use - :param cleanup_n_requests: Delete expired sessions on average every N requests. - - .. versionadded:: 0.7 - The `cleanup_n_requests`, `app`, `cleanup_n_requests` parameters were added. - - .. versionadded:: 0.6 - The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added. - - .. versionadded:: 0.2 - The `use_signer` parameter was added. - """ - - session_class = SqlAlchemySession - ttl = False - - def __init__( - self, - app: Flask, - key_prefix: str, - use_signer: bool, - permanent: bool, - sid_length: int, - serialization_format: str, - db: Any = Defaults.SESSION_SQLALCHEMY, - table: str = Defaults.SESSION_SQLALCHEMY_TABLE, - sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE, - schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA, - bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY, - cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, - ): - self.app = app - if db is None: - from flask_sqlalchemy import SQLAlchemy - - db = SQLAlchemy(app) - self.db = db - self.sequence = sequence - self.schema = schema - self.bind_key = bind_key - super().__init__( - app, - key_prefix, - use_signer, - permanent, - sid_length, - serialization_format, - cleanup_n_requests, - ) - - # Create the Session database model - class Session(self.db.Model): - __tablename__ = table - - if self.schema is not None: - __table_args__ = {"schema": self.schema, "keep_existing": True} - else: - __table_args__ = {"keep_existing": True} - - if self.bind_key is not None: - __bind_key__ = self.bind_key - - # Set the database columns, support for id sequences - if sequence: - id = self.db.Column( - self.db.Integer, self.db.Sequence(sequence), primary_key=True - ) - else: - id = self.db.Column(self.db.Integer, primary_key=True) - session_id = self.db.Column(self.db.String(255), unique=True) - data = self.db.Column(self.db.LargeBinary) - expiry = self.db.Column(self.db.DateTime) - - def __init__(self, session_id: str, data: Any, expiry: datetime): - self.session_id = session_id - self.data = data - self.expiry = expiry - - def __repr__(self): - return "" % self.data - - with app.app_context(): - self.db.create_all() - - self.sql_session_model = Session - - @retry_query() - def _delete_expired_sessions(self) -> None: - try: - self.db.session.query(self.sql_session_model).filter( - self.sql_session_model.expiry <= datetime.utcnow() - ).delete(synchronize_session=False) - self.db.session.commit() - except Exception: - self.db.session.rollback() - raise - - @retry_query() - def _retrieve_session_data(self, store_id: str) -> Optional[dict]: - # Get the saved session (record) from the database - record = self.sql_session_model.query.filter_by(session_id=store_id).first() - - # "Delete the session record if it is expired as SQL has no TTL ability - if record and (record.expiry is None or record.expiry <= datetime.utcnow()): - try: - self.db.session.delete(record) - self.db.session.commit() - except Exception: - self.db.session.rollback() - raise - record = None - - if record: - serialized_session_data = want_bytes(record.data) - try: - session_data = self._deserialize(serialized_session_data) - return session_data - except msgspec.DecodeError as e: - self.app.logger.exception( - e, "Failed to deserialize session data", exc_info=True - ) - return None - - @retry_query() - def _delete_session(self, store_id: str) -> None: - try: - self.sql_session_model.query.filter_by(session_id=store_id).delete() - self.db.session.commit() - except Exception: - self.db.session.rollback() - raise - - @retry_query() - def _upsert_session( - self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str - ) -> None: - storage_expiration_datetime = datetime.utcnow() + session_lifetime - - # Serialize session data - serialized_session_data = self._serialize(dict(session)) - - # Update existing or create new session in the database - try: - record = self.sql_session_model.query.filter_by(session_id=store_id).first() - if record: - record.data = serialized_session_data - record.expiry = storage_expiration_datetime - else: - record = self.sql_session_model( - session_id=store_id, - data=serialized_session_data, - expiry=storage_expiration_datetime, - ) - self.db.session.add(record) - self.db.session.commit() - except Exception: - self.db.session.rollback() - raise diff --git a/src/flask_session/sqlalchemy/__init__.py b/src/flask_session/sqlalchemy/__init__.py new file mode 100644 index 00000000..2344d3a6 --- /dev/null +++ b/src/flask_session/sqlalchemy/__init__.py @@ -0,0 +1 @@ +from .sqlalchemy import SqlAlchemySessionInterface, SqlAlchemySession diff --git a/src/flask_session/sqlalchemy/sqlalchemy.py b/src/flask_session/sqlalchemy/sqlalchemy.py new file mode 100644 index 00000000..95888e66 --- /dev/null +++ b/src/flask_session/sqlalchemy/sqlalchemy.py @@ -0,0 +1,186 @@ +from datetime import datetime +from datetime import timedelta as TimeDelta +from typing import Any, Optional + +import msgspec +from flask import Flask +from itsdangerous import want_bytes + +from .._utils import retry_query +from ..defaults import Defaults +from ..base import ServerSideSession, ServerSideSessionInterface +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import Column, String, LargeBinary, DateTime, Integer, Sequence + + +class SqlAlchemySession(ServerSideSession): + pass + + +def create_session_model(db, table_name, schema=None, bind_key=None, sequence=None): + class Session(db.Model): + __tablename__ = table_name + __table_args__ = ( + {"schema": schema, "keep_existing": True} + if schema + else {"keep_existing": True} + ) + __bind_key__ = bind_key + + id = ( + Column(Integer, Sequence(sequence), primary_key=True) + if sequence + else Column(Integer, primary_key=True) + ) + session_id = Column(String(255), unique=True) + data = Column(LargeBinary) + expiry = Column(DateTime) + + def __init__(self, session_id: str, data: Any, expiry: datetime): + self.session_id = session_id + self.data = data + self.expiry = expiry + + def __repr__(self): + return f"" + + return Session + + +class SqlAlchemySessionInterface(ServerSideSessionInterface): + """Uses the Flask-SQLAlchemy from a flask app as session storage. + + :param app: A Flask app instance. + :param client: A Flask-SQLAlchemy instance. + :param key_prefix: A prefix that is added to all store keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + :param table: The table name you want to use. + :param sequence: The sequence to use for the primary key if needed. + :param schema: The db schema to use + :param bind_key: The db bind key to use + :param cleanup_n_requests: Delete expired sessions on average every N requests. + + .. versionadded:: 1.0 + db changed to client to be standard on all session interfaces. + + .. versionadded:: 0.7 + The `cleanup_n_requests`, `app`, `cleanup_n_requests` parameters were added. + + .. versionadded:: 0.6 + The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added. + + .. versionadded:: 0.2 + The `use_signer` parameter was added. + """ + + session_class = SqlAlchemySession + ttl = False + + def __init__( + self, + app: Optional[Flask], + client: Optional[SQLAlchemy] = Defaults.SESSION_SQLALCHEMY, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_SID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + table: str = Defaults.SESSION_SQLALCHEMY_TABLE, + sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE, + schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA, + bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY, + cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, + ): + self.app = app + + if client is None: + client = SQLAlchemy(app) + self.client = client + + # Create the session model + self.sql_session_model = create_session_model( + client, table, schema, bind_key, sequence + ) + # Create the table if it does not exist + with app.app_context(): + self.client.create_all() + + super().__init__( + app, + key_prefix, + use_signer, + permanent, + sid_length, + serialization_format, + cleanup_n_requests, + ) + + @retry_query() + def _delete_expired_sessions(self) -> None: + try: + self.client.session.query(self.sql_session_model).filter( + self.sql_session_model.expiry <= datetime.utcnow() + ).delete(synchronize_session=False) + self.client.session.commit() + except Exception: + self.client.session.rollback() + raise + + @retry_query() + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + # Get the saved session (record) from the database + record = self.sql_session_model.query.filter_by(session_id=store_id).first() + + # "Delete the session record if it is expired as SQL has no TTL ability + if record and (record.expiry is None or record.expiry <= datetime.utcnow()): + try: + self.client.session.delete(record) + self.client.session.commit() + except Exception: + self.client.session.rollback() + raise + record = None + + if record: + serialized_session_data = want_bytes(record.data) + return self.serializer.decode(serialized_session_data) + return None + + @retry_query() + def _delete_session(self, store_id: str) -> None: + try: + self.sql_session_model.query.filter_by(session_id=store_id).delete() + self.client.session.commit() + except Exception: + self.client.session.rollback() + raise + + @retry_query() + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + storage_expiration_datetime = datetime.utcnow() + session_lifetime + + # Serialize session data + serialized_session_data = self.serializer.encode(session) + + # Update existing or create new session in the database + try: + record = self.sql_session_model.query.filter_by(session_id=store_id).first() + if record: + record.data = serialized_session_data + record.expiry = storage_expiration_datetime + else: + record = self.sql_session_model( + session_id=store_id, + data=serialized_session_data, + expiry=storage_expiration_datetime, + ) + self.client.session.add(record) + self.client.session.commit() + except Exception: + self.client.session.rollback() + raise diff --git a/tests/test_basic.py b/tests/test_basic.py index 8a037322..1a904266 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -3,13 +3,6 @@ import pytest -def test_tot_seconds_func(): - import datetime - - td = datetime.timedelta(days=1) - assert flask_session.sessions.total_seconds(td) == 86400 - - def test_null_session(): """Invalid session should fail to get/set the flask session""" with pytest.raises(ValueError): diff --git a/tests/test_cachelib.py b/tests/test_cachelib.py new file mode 100644 index 00000000..0de0df29 --- /dev/null +++ b/tests/test_cachelib.py @@ -0,0 +1,29 @@ +import tempfile + +import flask +from flask_session.cachelib import CacheLibSession + + +class TestCachelibSession: + + def retrieve_stored_session(self, key, app): + return app.session_interface.cache.get(key) + + def test_filesystem_default(self, app_utils): + app = app_utils.create_app( + {"SESSION_TYPE": "cachelib", "SESSION_SERIALIZATION_FORMAT": "json"} + ) + + # Should be using FileSystem + with app.test_request_context(): + assert isinstance( + flask.session, + CacheLibSession, + ) + app_utils.test_session(app) + + # Check if the session is stored in the filesystem + cookie = app_utils.test_session_with_cookie(app) + session_id = cookie.split(";")[0].split("=")[1] + stored_session = self.retrieve_stored_session(f"session:{session_id}", app) + assert stored_session.get("value") == "44" diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index bcf5f458..be4a9c4b 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -1,7 +1,7 @@ import tempfile import flask -import flask_session +from flask_session.filesystem import FileSystemSession class TestFileSystemSession: @@ -16,7 +16,10 @@ def test_filesystem_default(self, app_utils): # Should be using FileSystem with app.test_request_context(): - assert isinstance(flask.session, flask_session.sessions.FileSystemSession) + assert isinstance( + flask.session, + FileSystemSession, + ) app_utils.test_session(app) # Check if the session is stored in the filesystem diff --git a/tests/test_memcached.py b/tests/test_memcached.py index 2e68b037..acd734c4 100644 --- a/tests/test_memcached.py +++ b/tests/test_memcached.py @@ -4,6 +4,7 @@ import flask import flask_session import memcache # Import the memcache library +from flask_session.memcached import MemcachedSession class TestMemcachedSession: @@ -28,7 +29,8 @@ def test_memcached_default(self, app_utils): with app.test_request_context(): assert isinstance( - flask.session, flask_session.sessions.MemcachedSession + flask.session, + MemcachedSession, ) app_utils.test_session(app) diff --git a/tests/test_mongodb.py b/tests/test_mongodb.py index d173f9db..12a57a21 100644 --- a/tests/test_mongodb.py +++ b/tests/test_mongodb.py @@ -2,9 +2,9 @@ from contextlib import contextmanager import flask -import flask_session from itsdangerous import want_bytes from pymongo import MongoClient +from flask_session.mongodb import MongoDBSession class TestMongoSession: @@ -24,7 +24,6 @@ def setup_mongo(self): def retrieve_stored_session(self, key): document = self.collection.find_one({"id": key}) - print(document) return want_bytes(document["val"]) def test_mongo_default(self, app_utils): @@ -37,7 +36,7 @@ def test_mongo_default(self, app_utils): ) with app.test_request_context(): - assert isinstance(flask.session, flask_session.sessions.MongoDBSession) + assert isinstance(flask.session, MongoDBSession) app_utils.test_session(app) # Check if the session is stored in MongoDB diff --git a/tests/test_redis.py b/tests/test_redis.py index 761e32c5..b92d04df 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -2,7 +2,7 @@ from contextlib import contextmanager import flask -import flask_session +from flask_session.redis import RedisSession from redis import Redis @@ -27,7 +27,7 @@ def test_redis_default(self, app_utils): app = app_utils.create_app({"SESSION_TYPE": "redis"}) with app.test_request_context(): - assert isinstance(flask.session, flask_session.sessions.RedisSession) + assert isinstance(flask.session, RedisSession) app_utils.test_session(app) # Check if the session is stored in Redis diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 48774331..049b62a9 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -2,7 +2,7 @@ from contextlib import contextmanager import flask -import flask_session +from flask_session.sqlalchemy import SqlAlchemySession from sqlalchemy import text @@ -12,16 +12,16 @@ class TestSQLAlchemy: @contextmanager def setup_sqlalchemy(self, app): try: - app.session_interface.db.session.execute(text("DELETE FROM sessions")) - app.session_interface.db.session.commit() + app.session_interface.client.session.execute(text("DELETE FROM sessions")) + app.session_interface.client.session.commit() yield finally: - app.session_interface.db.session.execute(text("DELETE FROM sessions")) - app.session_interface.db.session.close() + app.session_interface.client.session.execute(text("DELETE FROM sessions")) + app.session_interface.client.session.close() def retrieve_stored_session(self, key, app): session_model = ( - app.session_interface.db.session.query( + app.session_interface.client.session.query( app.session_interface.sql_session_model ) .filter_by(session_id=key) @@ -41,7 +41,10 @@ def test_use_signer(self, app_utils): with app.app_context() and self.setup_sqlalchemy( app ) and app.test_request_context(): - assert isinstance(flask.session, flask_session.sessions.SqlAlchemySession) + assert isinstance( + flask.session, + SqlAlchemySession, + ) app_utils.test_session(app) # Check if the session is stored in SQLAlchemy