Skip to content

Commit

Permalink
feat: added customisable schema + query manager class
Browse files Browse the repository at this point in the history
  • Loading branch information
giuppep committed May 20, 2023
1 parent 040180b commit 9860830
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 42 deletions.
109 changes: 77 additions & 32 deletions flask_pg_session/_queries.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,77 @@
CREATE_TABLE = """CREATE TABLE IF NOT EXISTS {table} (
session_id VARCHAR(255) NOT NULL PRIMARY KEY,
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'),
data BYTEA,
expiry TIMESTAMP WITHOUT TIME ZONE
);
--- Unique session_id
CREATE UNIQUE INDEX IF NOT EXISTS
uq_{table}_session_id ON {table} (session_id);
--- Index for expiry timestamp
CREATE INDEX IF NOT EXISTS
{table}_expiry_idx ON {table} (expiry);
"""

RETRIEVE_SESSION_DATA = """--- If the current sessions is expired, delete it
DELETE FROM {table} WHERE session_id = %(session_id)s AND expiry < NOW();
--- Else retrieve it
SELECT data FROM {table} WHERE session_id = %(session_id)s;
"""


UPSERT_SESSION = """INSERT INTO {table} (session_id, data, expiry)
VALUES (%(session_id)s, %(data)s, %(expiry)s)
ON CONFLICT (session_id)
DO UPDATE SET data = %(data)s, expiry = %(expiry)s;
"""


DELETE_EXPIRED_SESSIONS = "DELETE FROM {table} WHERE expiry < NOW();"
DELETE_SESSION = "DELETE FROM {table} WHERE session_id = %(session_id)s"
from psycopg2 import sql


class Queries:
def __init__(self, schema: str, table: str) -> None:
"""Class to hold all the queries used by the session interface.
Args:
schema (str): The name of the schema to use for the session data.
table (str): The name of the table to use for the session data.
"""
self.schema = schema
self.table = table

@property
def create_schema(self) -> str:
return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format(
schema=sql.Identifier(self.schema)
)

@property
def create_table(self) -> str:
uq_idx = sql.Identifier(f"uq_{self.table}_session_id")
expiry_idx = sql.Identifier(f"{self.table}_expiry_idx")
return sql.SQL(
"""CREATE TABLE IF NOT EXISTS {schema}.{table} (
session_id VARCHAR(255) NOT NULL PRIMARY KEY,
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'),
data BYTEA,
expiry TIMESTAMP WITHOUT TIME ZONE
);
--- Unique session_id
CREATE UNIQUE INDEX IF NOT EXISTS
{uq_idx} ON {schema}.{table} (session_id);
--- Index for expiry timestamp
CREATE INDEX IF NOT EXISTS
{expiry_idx} ON {schema}.{table} (expiry);"""
).format(
schema=sql.Identifier(self.schema),
table=sql.Identifier(self.table),
uq_idx=uq_idx,
expiry_idx=expiry_idx,
)

@property
def retrieve_session_data(self) -> str:
return sql.SQL(
"""--- If the current sessions is expired, delete it
DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s AND expiry < NOW();
--- Else retrieve it
SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def upsert_session(self) -> str:
return sql.SQL(
"""INSERT INTO {schema}.{table} (session_id, data, expiry)
VALUES (%(session_id)s, %(data)s, %(expiry)s)
ON CONFLICT (session_id)
DO UPDATE SET data = %(data)s, expiry = %(expiry)s;
"""
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))

@property
def delete_expired_sessions(self) -> str:
return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format(
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)
)

@property
def delete_session(self) -> str:
return sql.SQL(
"DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;"
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table))
28 changes: 18 additions & 10 deletions flask_pg_session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from psycopg2.pool import ThreadedConnectionPool
from werkzeug.datastructures import CallbackDict

from . import _queries as queries
from ._queries import Queries
from .utils import retry_query

logger = logging.getLogger(__name__)

DEFAULT_TABLE_NAME = "flask_sessions"
DEFAULT_SCHEMA_NAME = "public"
DEFAULT_KEY_PREFIX = ""
DEFAULT_USE_SIGNER = False
DELETE_EXPIRED_SESSIONS_EVERY_REQUESTS = 1000
Expand Down Expand Up @@ -58,7 +59,8 @@ def init_app(cls, app: Flask) -> "FlaskPgSession":
"""Initialize the Flask-PgSession extension using the app's configuration."""
session_interface = cls(
app.config["SQLALCHEMY_DATABASE_URI"],
table_name=app.config.get("SESSION_SQLALCHEMY_TABLE", DEFAULT_TABLE_NAME),
table_name=app.config.get("SESSION_PG_TABLE", DEFAULT_TABLE_NAME),
schema_name=app.config.get("SESSION_PG_SCHEMA", DEFAULT_SCHEMA_NAME),
key_prefix=app.config.get("SESSION_KEY_PREFIX", DEFAULT_KEY_PREFIX),
use_signer=app.config.get("SESSION_USE_SIGNER", DEFAULT_USE_SIGNER),
permanent=app.config.get("SESSION_PERMANENT", True),
Expand All @@ -71,6 +73,7 @@ def __init__(
uri: str,
*,
table_name: str = DEFAULT_TABLE_NAME,
schema_name: str = DEFAULT_SCHEMA_NAME,
key_prefix: str = DEFAULT_KEY_PREFIX,
use_signer: bool = DEFAULT_USE_SIGNER,
permanent: bool = True,
Expand All @@ -83,6 +86,8 @@ def __init__(
uri (str): The database URI to connect to.
table_name (str, optional): The name of the table to store sessions in.
Defaults to "flask_sessions".
schema_name (str, optional): The name of the schema to store sessions in.
Defaults to "public".
key_prefix (str, optional): The prefix to prepend to the session ID when
storing it in the database. Defaults to "".
use_signer (bool, optional): Whether to use a signer to sign the session.
Expand All @@ -96,14 +101,16 @@ def __init__(
"""
self.pool = ThreadedConnectionPool(1, max_db_conn, uri)
self.key_prefix = key_prefix
self.table_name = table_name

self.permanent = permanent
self.use_signer = use_signer
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")

self.autodelete_expired_sessions = autodelete_expired_sessions

self._create_table(self.table_name)
self._queries = Queries(schema_name, table_name)

self._create_schema_and_table()

# HELPERS

Expand Down Expand Up @@ -151,28 +158,29 @@ def _get_cursor(
self.pool.putconn(_conn)

@retry_query(max_attempts=3)
def _create_table(self, table_name: str) -> None:
def _create_schema_and_table(self) -> None:
with self._get_cursor() as cur:
cur.execute(queries.CREATE_TABLE.format(table=table_name))
cur.execute(self._queries.create_schema)
cur.execute(self._queries.create_table)

def _delete_expired_sessions(self) -> None:
"""Delete all expired sessions from the database."""
with self._get_cursor() as cur:
cur.execute(queries.DELETE_EXPIRED_SESSIONS.format(table=self.table_name))
cur.execute(self._queries.delete_expired_sessions)

@retry_query(max_attempts=3)
def _delete_session(self, sid: str) -> None:
with self._get_cursor() as cur:
cur.execute(
queries.DELETE_SESSION.format(table=self.table_name),
self._queries.delete_session,
dict(session_id=self._get_store_id(sid)),
)

@retry_query(max_attempts=3)
def _retrieve_session_data(self, sid: str) -> bytes | None:
with self._get_cursor() as cur:
cur.execute(
queries.RETRIEVE_SESSION_DATA.format(table=self.table_name),
self._queries.retrieve_session_data,
dict(session_id=self._get_store_id(sid)),
)
data = cur.fetchone()
Expand All @@ -186,7 +194,7 @@ def _update_session(
data = self.serializer.dumps(dict(session))
with self._get_cursor() as cur:
cur.execute(
queries.UPSERT_SESSION.format(table=self.table_name),
self._queries.upsert_session,
dict(session_id=self._get_store_id(sid), data=data, expiry=expires),
)

Expand Down

0 comments on commit 9860830

Please sign in to comment.