-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added customisable schema + query manager class
- Loading branch information
Showing
2 changed files
with
95 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,77 @@ | ||
CREATE_TABLE = """CREATE TABLE IF NOT EXISTS {table} ( | ||
session_id VARCHAR(255) NOT NULL PRIMARY KEY, | ||
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'), | ||
data BYTEA, | ||
expiry TIMESTAMP WITHOUT TIME ZONE | ||
); | ||
--- Unique session_id | ||
CREATE UNIQUE INDEX IF NOT EXISTS | ||
uq_{table}_session_id ON {table} (session_id); | ||
--- Index for expiry timestamp | ||
CREATE INDEX IF NOT EXISTS | ||
{table}_expiry_idx ON {table} (expiry); | ||
""" | ||
|
||
RETRIEVE_SESSION_DATA = """--- If the current sessions is expired, delete it | ||
DELETE FROM {table} WHERE session_id = %(session_id)s AND expiry < NOW(); | ||
--- Else retrieve it | ||
SELECT data FROM {table} WHERE session_id = %(session_id)s; | ||
""" | ||
|
||
|
||
UPSERT_SESSION = """INSERT INTO {table} (session_id, data, expiry) | ||
VALUES (%(session_id)s, %(data)s, %(expiry)s) | ||
ON CONFLICT (session_id) | ||
DO UPDATE SET data = %(data)s, expiry = %(expiry)s; | ||
""" | ||
|
||
|
||
DELETE_EXPIRED_SESSIONS = "DELETE FROM {table} WHERE expiry < NOW();" | ||
DELETE_SESSION = "DELETE FROM {table} WHERE session_id = %(session_id)s" | ||
from psycopg2 import sql | ||
|
||
|
||
class Queries: | ||
def __init__(self, schema: str, table: str) -> None: | ||
"""Class to hold all the queries used by the session interface. | ||
Args: | ||
schema (str): The name of the schema to use for the session data. | ||
table (str): The name of the table to use for the session data. | ||
""" | ||
self.schema = schema | ||
self.table = table | ||
|
||
@property | ||
def create_schema(self) -> str: | ||
return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format( | ||
schema=sql.Identifier(self.schema) | ||
) | ||
|
||
@property | ||
def create_table(self) -> str: | ||
uq_idx = sql.Identifier(f"uq_{self.table}_session_id") | ||
expiry_idx = sql.Identifier(f"{self.table}_expiry_idx") | ||
return sql.SQL( | ||
"""CREATE TABLE IF NOT EXISTS {schema}.{table} ( | ||
session_id VARCHAR(255) NOT NULL PRIMARY KEY, | ||
created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'), | ||
data BYTEA, | ||
expiry TIMESTAMP WITHOUT TIME ZONE | ||
); | ||
--- Unique session_id | ||
CREATE UNIQUE INDEX IF NOT EXISTS | ||
{uq_idx} ON {schema}.{table} (session_id); | ||
--- Index for expiry timestamp | ||
CREATE INDEX IF NOT EXISTS | ||
{expiry_idx} ON {schema}.{table} (expiry);""" | ||
).format( | ||
schema=sql.Identifier(self.schema), | ||
table=sql.Identifier(self.table), | ||
uq_idx=uq_idx, | ||
expiry_idx=expiry_idx, | ||
) | ||
|
||
@property | ||
def retrieve_session_data(self) -> str: | ||
return sql.SQL( | ||
"""--- If the current sessions is expired, delete it | ||
DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s AND expiry < NOW(); | ||
--- Else retrieve it | ||
SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s; | ||
""" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) | ||
|
||
@property | ||
def upsert_session(self) -> str: | ||
return sql.SQL( | ||
"""INSERT INTO {schema}.{table} (session_id, data, expiry) | ||
VALUES (%(session_id)s, %(data)s, %(expiry)s) | ||
ON CONFLICT (session_id) | ||
DO UPDATE SET data = %(data)s, expiry = %(expiry)s; | ||
""" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) | ||
|
||
@property | ||
def delete_expired_sessions(self) -> str: | ||
return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format( | ||
schema=sql.Identifier(self.schema), table=sql.Identifier(self.table) | ||
) | ||
|
||
@property | ||
def delete_session(self) -> str: | ||
return sql.SQL( | ||
"DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;" | ||
).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters