From 08179e797c67f069a989f2d77dc11e5221cfe314 Mon Sep 17 00:00:00 2001 From: Lex Date: Mon, 12 Feb 2024 09:35:43 +1000 Subject: [PATCH] Add retry decorator Formatting --- src/flask_session/_utils.py | 62 +++++++++++++++++++++++++++++++++++ src/flask_session/sessions.py | 44 +++++++++++++++---------- 2 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 src/flask_session/_utils.py diff --git a/src/flask_session/_utils.py b/src/flask_session/_utils.py new file mode 100644 index 00000000..9967fa10 --- /dev/null +++ b/src/flask_session/_utils.py @@ -0,0 +1,62 @@ +""" +MIT License + +Copyright (c) 2023 giuppep + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import time +from functools import wraps +from typing import Any, Callable + +from flask import current_app + + +def retry_query( + *, max_attempts: int = 3, delay: float = 0.3, backoff: int = 2 +) -> Callable[..., Any]: + """Decorator to retry a query when an OperationalError is raised. + + Args: + max_attempts: Maximum number of attempts. Defaults to 3. + delay: Delay between attempts in seconds. Defaults to 0.3. + backoff: Backoff factor. Defaults to 2. + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + # TODO: use proper exception type + except Exception as e: + if attempt == max_attempts - 1: + raise e + + sleep_time = delay * backoff**attempt + current_app.logger.exception( + f"Exception when querying database ({e})." + f"Retrying ({attempt + 1}/{max_attempts}) in {sleep_time:.2f}s." + ) + time.sleep(sleep_time) + + return wrapper + + return decorator diff --git a/src/flask_session/sessions.py b/src/flask_session/sessions.py index 38429cde..7b502fa5 100644 --- a/src/flask_session/sessions.py +++ b/src/flask_session/sessions.py @@ -19,6 +19,7 @@ from sqlalchemy.exc import SQLAlchemyError from werkzeug.datastructures import CallbackDict +from ._utils import retry_query from .defaults import Defaults @@ -219,17 +220,21 @@ def _cleanup_n_requests(self) -> None: # METHODS TO BE IMPLEMENTED BY SUBCLASSES + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: raise NotImplementedError() + @retry_query() def _delete_session(self, store_id: str) -> None: raise NotImplementedError() + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: raise NotImplementedError() + @retry_query() def _delete_expired_sessions(self) -> None: """Delete expired sessions from the backend storage. Only required for non-TTL databases.""" pass @@ -271,6 +276,7 @@ def __init__( self.redis = redis super().__init__(app, key_prefix, use_signer, permanent, sid_length) + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: # Get the saved session (value) from the database serialized_session_data = self.redis.get(store_id) @@ -282,9 +288,11 @@ def _retrieve_session_data(self, store_id: str) -> Optional[dict]: self.app.logger.error("Failed to unpickle session data", exc_info=True) return None + @retry_query() def _delete_session(self, store_id: str) -> None: self.redis.delete(store_id) + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: @@ -315,7 +323,6 @@ class MemcachedSessionInterface(ServerSideSessionInterface): .. versionadded:: 0.2 The `use_signer` parameter was added. - """ serializer = pickle @@ -363,6 +370,7 @@ def _get_memcache_timeout(self, timeout: int) -> int: timeout += int(time.time()) return timeout + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: # Get the saved session (item) from the database serialized_session_data = self.client.get(store_id) @@ -374,9 +382,11 @@ def _retrieve_session_data(self, store_id: str) -> Optional[dict]: self.app.logger.error("Failed to unpickle session data", exc_info=True) return None + @retry_query() def _delete_session(self, store_id: str) -> None: self.client.delete(store_id) + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: @@ -432,13 +442,16 @@ def __init__( self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode) super().__init__(app, key_prefix, use_signer, permanent, sid_length) + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: # Get the saved session (item) from the database return self.cache.get(store_id) + @retry_query() def _delete_session(self, store_id: str) -> None: self.cache.delete(store_id) + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: @@ -502,6 +515,7 @@ def __init__( super().__init__(app, key_prefix, use_signer, permanent, sid_length) + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: # Get the saved session (document) from the database document = self.store.find_one({"id": store_id}) @@ -514,12 +528,14 @@ def _retrieve_session_data(self, store_id: str) -> Optional[dict]: self.app.logger.error("Failed to unpickle session data", exc_info=True) return None + @retry_query() def _delete_session(self, store_id: str) -> None: if self.use_deprecated_method: self.store.remove({"id": store_id}) else: self.store.delete_one({"id": store_id}) + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: @@ -645,19 +661,18 @@ def __repr__(self): self.sql_session_model = Session + @retry_query() def _delete_expired_sessions(self) -> None: try: self.db.session.query(self.sql_session_model).filter( self.sql_session_model.expiry <= datetime.utcnow() ).delete(synchronize_session=False) self.db.session.commit() - except SQLAlchemyError as e: - self.app.logger.exception( - e, "Failed to delete expired sessions. Retrying...", exc_info=True - ) + except SQLAlchemyError: self.db.session.rollback() raise + @retry_query() def _retrieve_session_data(self, store_id: str) -> Optional[dict]: # Get the saved session (record) from the database record = self.sql_session_model.query.filter_by(session_id=store_id).first() @@ -667,10 +682,7 @@ def _retrieve_session_data(self, store_id: str) -> Optional[dict]: try: self.db.session.delete(record) self.db.session.commit() - except SQLAlchemyError as e: - self.app.logger.exception( - e, "Failed to retrieve sessions. Retrying...", exc_info=True - ) + except SQLAlchemyError: self.db.session.rollback() raise record = None @@ -681,22 +693,21 @@ def _retrieve_session_data(self, store_id: str) -> Optional[dict]: session_data = self.serializer.loads(serialized_session_data) return session_data except pickle.UnpicklingError as e: - self.app.logger.error( + self.app.logger.exception( e, "Failed to unpickle session data", exc_info=True ) return None + @retry_query() def _delete_session(self, store_id: str) -> None: try: self.sql_session_model.query.filter_by(session_id=store_id).delete() self.db.session.commit() - except SQLAlchemyError as e: - self.app.logger.exception( - e, "Failed to delete session. Retrying...", exc_info=True - ) + except SQLAlchemyError: self.db.session.rollback() raise + @retry_query() def _upsert_session( self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str ) -> None: @@ -719,9 +730,6 @@ def _upsert_session( ) self.db.session.add(record) self.db.session.commit() - except SQLAlchemyError as e: - self.app.logger.exception( - e, "Failed to upsert session. Retrying...", exc_info=True - ) + except SQLAlchemyError: self.db.session.rollback() raise