Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-95: Add SQLAlchemy storage backend #252

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/flask_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def _get_interface(self, app):
"SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY
)

# SQLAlchemy-native settings
SESSION_SQLALCHEMY_ENGINE = config.get(
"SESSION_SQLALCHEMY_ENGINE", Defaults.SESSION_SQLALCHEMY_ENGINE
)

# DynamoDB settings
SESSION_DYNAMODB = config.get("SESSION_DYNAMODB", Defaults.SESSION_DYNAMODB)
SESSION_DYNAMODB_TABLE = config.get(
Expand Down Expand Up @@ -187,6 +192,18 @@ def _get_interface(self, app):
bind_key=SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
)
elif SESSION_TYPE == "sqlalchemy_native":
from .sqlalchemy_native import NativeSqlAlchemySessionInterface

session_interface = NativeSqlAlchemySessionInterface(
**common_params,
engine=SESSION_SQLALCHEMY_ENGINE,
table=SESSION_SQLALCHEMY_TABLE,
sequence=SESSION_SQLALCHEMY_SEQUENCE,
schema=SESSION_SQLALCHEMY_SCHEMA,
bind_key=SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
)
elif SESSION_TYPE == "dynamodb":
from .dynamodb import DynamoDBSessionInterface

Expand Down
3 changes: 3 additions & 0 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Defaults:
SESSION_SQLALCHEMY_SCHEMA = None
SESSION_SQLALCHEMY_BIND_KEY = None

# SQLAlchemy-native settings
SESSION_SQLALCHEMY_ENGINE = None

# DynamoDB settings
SESSION_DYNAMODB = None
SESSION_DYNAMODB_TABLE = "Sessions"
Expand Down
4 changes: 4 additions & 0 deletions src/flask_session/sqlalchemy_native/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sqlalchemy_native import ( # noqa: F401
NativeSqlAlchemySession,
NativeSqlAlchemySessionInterface,
)
178 changes: 178 additions & 0 deletions src/flask_session/sqlalchemy_native/sqlalchemy_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from datetime import datetime
from datetime import timedelta as TimeDelta
from typing import Optional

from flask import Flask
from itsdangerous import want_bytes
from sqlalchemy import (
Column,
DateTime,
Engine,
Integer,
LargeBinary,
Sequence,
String,
delete,
select,
)
from sqlalchemy.orm import DeclarativeBase, Session

from .._utils import retry_query
from ..base import ServerSideSession, ServerSideSessionInterface
from ..defaults import Defaults


class NativeSqlAlchemySession(ServerSideSession):
pass


class Base(DeclarativeBase):
pass


def create_session_model(table_name, schema=None, bind_key=None, sequence=None):
class Session(Base):
__tablename__ = table_name
__table_args__ = {"schema": schema} if schema else {}
__bind_key__ = bind_key

id = (
Column(Integer, Sequence(sequence), primary_key=True)
if sequence
else Column(Integer, primary_key=True)
)
session_id = Column(String(255), unique=True)
data = Column(LargeBinary)
expiry = Column(DateTime)

def __repr__(self):
return f"<Session data {self.data}>"

return Session


class NativeSqlAlchemySessionInterface(ServerSideSessionInterface):
"""Uses a SQLAlchemy engine as session storage.

:param app: A Flask app instance.
:param engine: A SQLAlchemy engine instance.
:param key_prefix: A prefix that is added to all storage keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param sid_length: The length of the generated session id in bytes.
:param serialization_format: The serialization format to use for the session data.
:param table: The table name you want to use.
:param sequence: The sequence to use for the primary key if needed.
:param schema: The db schema to use.
:param bind_key: The db bind key to use.
:param cleanup_n_requests: Delete expired sessions on average every N requests.
"""

session_class = NativeSqlAlchemySession
ttl = False

def __init__(
self,
app: Optional[Flask],
engine: Optional[Engine] = Defaults.SESSION_SQLALCHEMY_ENGINE,
key_prefix: str = Defaults.SESSION_KEY_PREFIX,
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_ID_LENGTH,
serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT,
table: str = Defaults.SESSION_SQLALCHEMY_TABLE,
sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE,
schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA,
bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
):
self.app = app

if engine is None or not isinstance(engine, Engine):
raise TypeError("No valid Engine instance provided.")
self.engine = engine

# Create the session model
self.sql_session_model = create_session_model(
table, schema, bind_key, sequence
)
# Create the table if it does not exist
self.sql_session_model.__table__.create(bind=engine, checkfirst=True)

super().__init__(
app,
key_prefix,
use_signer,
permanent,
sid_length,
serialization_format,
cleanup_n_requests,
)

@retry_query()
def _delete_expired_sessions(self) -> None:
with Session(self.engine) as session:
session.execute(
delete(self.sql_session_model)
.where(self.sql_session_model.expiry <= datetime.utcnow()),
execution_options={"synchronize_session": False}
)
session.commit()

@retry_query()
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (record) from the database
with Session(self.engine) as session:
record = session.scalars(
select(self.sql_session_model)
.where(self.sql_session_model.session_id == store_id)
).first()

# "Delete the session record if it is expired as SQL has no TTL ability
if record and (record.expiry is None or record.expiry <= datetime.utcnow()):
with Session(self.engine) as session:
session.delete(record)
session.commit()
record = None

if record:
serialized_session_data = want_bytes(record.data)
return self.serializer.loads(serialized_session_data)
return None

@retry_query()
def _delete_session(self, store_id: str) -> None:
with Session(self.engine) as session:
session.execute(
delete(self.sql_session_model)
.where(self.sql_session_model.session_id == store_id)
)
session.commit()

@retry_query()
def _upsert_session(
self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str
) -> None:
storage_expiration_datetime = datetime.utcnow() + session_lifetime

# Serialize session data
serialized_session_data = self.serializer.dumps(dict(session))

# Update existing or create new session in the database
with Session(self.engine) as session:
record = session.scalars(
select(self.sql_session_model)
.where(self.sql_session_model.session_id == store_id)
).first()

if record:
record.data = serialized_session_data
record.expiry = storage_expiration_datetime
else:
record = self.sql_session_model(
session_id=store_id,
data=serialized_session_data,
expiry=storage_expiration_datetime,
)
session.add(record)
session.commit()
59 changes: 59 additions & 0 deletions tests/test_sqlalchemy_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from contextlib import contextmanager

import flask
from flask_session.sqlalchemy_native import NativeSqlAlchemySession
from sqlalchemy import create_engine, select, text
from sqlalchemy.orm import Session


class TestNativeSQLAlchemy:
"""This requires package: sqlalchemy"""

@contextmanager
def setup_sqlalchemy(self, app):
try:
with Session(app.session_interface.engine) as session:
session.execute(text("DELETE FROM sessions"))
session.commit()
yield
finally:
with Session(app.session_interface.engine) as session:
session.execute(text("DELETE FROM sessions"))
session.close()

def retrieve_stored_session(self, key, app):
with Session(app.session_interface.engine) as session:
session_model = session.scalars(
select(app.session_interface.sql_session_model)
.where(app.session_interface.sql_session_model.session_id == key)
).first()
if session_model:
return session_model.data
return None

def test_use_signer(self, app_utils):
engine = create_engine("sqlite:///")
app = app_utils.create_app(
{
"SESSION_TYPE": "sqlalchemy_native",
"SESSION_SQLALCHEMY_ENGINE": engine,
}
)
with app.app_context() and self.setup_sqlalchemy(
app
) and app.test_request_context():
assert isinstance(
flask.session,
NativeSqlAlchemySession,
)
app_utils.test_session(app)

# Check if the session is stored in SQLAlchemy
cookie = app_utils.test_session_with_cookie(app)
session_id = cookie.split(";")[0].split("=")[1]
byte_string = self.retrieve_stored_session(f"session:{session_id}", app)
stored_session = (
json.loads(byte_string.decode("utf-8")) if byte_string else {}
)
assert stored_session.get("value") == "44"
Loading