From 58cb47772a228d16fba47266e7bd643b721f62c8 Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:13:50 -0700 Subject: [PATCH 01/27] move extension --- src/flask_sqlalchemy/__init__.py | 1088 +---------------------------- src/flask_sqlalchemy/extension.py | 1085 ++++++++++++++++++++++++++++ tests/test_basic_app.py | 2 +- tests/test_binds.py | 2 +- tests/test_pagination.py | 2 +- tests/test_query_class.py | 2 +- tests/test_signals.py | 4 +- tests/test_sqlalchemy_includes.py | 2 +- 8 files changed, 1096 insertions(+), 1091 deletions(-) create mode 100644 src/flask_sqlalchemy/extension.py diff --git a/src/flask_sqlalchemy/__init__.py b/src/flask_sqlalchemy/__init__.py index 192788af..d54cf4a7 100644 --- a/src/flask_sqlalchemy/__init__.py +++ b/src/flask_sqlalchemy/__init__.py @@ -1,1087 +1,7 @@ -import functools -import os -import sys -import warnings -from math import ceil -from operator import itemgetter -from threading import Lock -from time import perf_counter - -import sqlalchemy -from flask import _app_ctx_stack -from flask import abort -from flask import current_app -from flask import request -from flask.signals import Namespace -from sqlalchemy import event -from sqlalchemy import inspect -from sqlalchemy import orm -from sqlalchemy.engine.url import make_url -from sqlalchemy.orm.exc import UnmappedClassError -from sqlalchemy.orm.session import Session as SessionBase - -from .model import DefaultMeta -from .model import Model - -try: - from sqlalchemy.orm import declarative_base - from sqlalchemy.orm import DeclarativeMeta -except ImportError: - # SQLAlchemy <= 1.3 - from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.ext.declarative import DeclarativeMeta - -# Scope the session to the current greenlet if greenlet is available, -# otherwise fall back to the current thread. -try: - from greenlet import getcurrent as _ident_func -except ImportError: - from threading import get_ident as _ident_func +from .extension import SQLAlchemy __version__ = "3.0.0.dev0" -_signals = Namespace() -models_committed = _signals.signal("models-committed") -before_models_committed = _signals.signal("before-models-committed") - - -def _sa_url_set(url, **kwargs): - try: - url = url.set(**kwargs) - except AttributeError: - # SQLAlchemy <= 1.3 - for key, value in kwargs.items(): - setattr(url, key, value) - - return url - - -def _sa_url_query_setdefault(url, **kwargs): - query = dict(url.query) - - for key, value in kwargs.items(): - query.setdefault(key, value) - - return _sa_url_set(url, query=query) - - -def _make_table(db): - def _make_table(*args, **kwargs): - if len(args) > 1 and isinstance(args[1], db.Column): - args = (args[0], db.metadata) + args[1:] - info = kwargs.pop("info", None) or {} - info.setdefault("bind_key", None) - kwargs["info"] = info - return sqlalchemy.Table(*args, **kwargs) - - return _make_table - - -def _set_default_query_class(d, cls): - if "query_class" not in d: - d["query_class"] = cls - - -def _wrap_with_default_query_class(fn, cls): - @functools.wraps(fn) - def newfn(*args, **kwargs): - _set_default_query_class(kwargs, cls) - if "backref" in kwargs: - backref = kwargs["backref"] - if isinstance(backref, str): - backref = (backref, {}) - _set_default_query_class(backref[1], cls) - return fn(*args, **kwargs) - - return newfn - - -def _include_sqlalchemy(obj, cls): - for module in sqlalchemy, sqlalchemy.orm: - for key in module.__all__: - if not hasattr(obj, key): - setattr(obj, key, getattr(module, key)) - # Note: obj.Table does not attempt to be a SQLAlchemy Table class. - obj.Table = _make_table(obj) - obj.relationship = _wrap_with_default_query_class(obj.relationship, cls) - obj.relation = _wrap_with_default_query_class(obj.relation, cls) - obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls) - obj.event = event - - -class _DebugQueryTuple(tuple): - statement = property(itemgetter(0)) - parameters = property(itemgetter(1)) - start_time = property(itemgetter(2)) - end_time = property(itemgetter(3)) - context = property(itemgetter(4)) - - @property - def duration(self): - return self.end_time - self.start_time - - def __repr__(self): - return ( - f"" - ) - - -def _calling_context(app_path): - frm = sys._getframe(1) - while frm.f_back is not None: - name = frm.f_globals.get("__name__") - if name and (name == app_path or name.startswith(f"{app_path}.")): - funcname = frm.f_code.co_name - return f"{frm.f_code.co_filename}:{frm.f_lineno} ({funcname})" - frm = frm.f_back - return "" - - -class SignallingSession(SessionBase): - """The signalling session is the default session that Flask-SQLAlchemy - uses. It extends the default session system with bind selection and - modification tracking. - - If you want to use a different session you can override the - :meth:`SQLAlchemy.create_session` function. - - .. versionadded:: 2.0 - - .. versionadded:: 2.1 - The `binds` option was added, which allows a session to be joined - to an external transaction. - """ - - def __init__(self, db, autocommit=False, autoflush=True, **options): - #: The application that this session belongs to. - self.app = app = db.get_app() - track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] - bind = options.pop("bind", None) or db.engine - binds = options.pop("binds", db.get_binds(app)) - - if track_modifications: - _SessionSignalEvents.register(self) - - SessionBase.__init__( - self, - autocommit=autocommit, - autoflush=autoflush, - bind=bind, - binds=binds, - **options, - ) - - def get_bind(self, mapper=None, **kwargs): - """Return the engine or connection for a given model or - table, using the ``__bind_key__`` if it is set. - """ - # mapper is None if someone tries to just get a connection - if mapper is not None: - try: - # SA >= 1.3 - persist_selectable = mapper.persist_selectable - except AttributeError: - # SA < 1.3 - persist_selectable = mapper.mapped_table - - info = getattr(persist_selectable, "info", {}) - bind_key = info.get("bind_key") - if bind_key is not None: - state = get_state(self.app) - return state.db.get_engine(self.app, bind=bind_key) - - return super().get_bind(mapper, **kwargs) - - -class _SessionSignalEvents: - @classmethod - def register(cls, session): - if not hasattr(session, "_model_changes"): - session._model_changes = {} - - event.listen(session, "before_flush", cls.record_ops) - event.listen(session, "before_commit", cls.record_ops) - event.listen(session, "before_commit", cls.before_commit) - event.listen(session, "after_commit", cls.after_commit) - event.listen(session, "after_rollback", cls.after_rollback) - - @classmethod - def unregister(cls, session): - if hasattr(session, "_model_changes"): - del session._model_changes - - event.remove(session, "before_flush", cls.record_ops) - event.remove(session, "before_commit", cls.record_ops) - event.remove(session, "before_commit", cls.before_commit) - event.remove(session, "after_commit", cls.after_commit) - event.remove(session, "after_rollback", cls.after_rollback) - - @staticmethod - def record_ops(session, flush_context=None, instances=None): - try: - d = session._model_changes - except AttributeError: - return - - for targets, operation in ( - (session.new, "insert"), - (session.dirty, "update"), - (session.deleted, "delete"), - ): - for target in targets: - state = inspect(target) - key = state.identity_key if state.has_identity else id(target) - d[key] = (target, operation) - - @staticmethod - def before_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - before_models_committed.send(session.app, changes=list(d.values())) - - @staticmethod - def after_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - models_committed.send(session.app, changes=list(d.values())) - d.clear() - - @staticmethod - def after_rollback(session): - try: - d = session._model_changes - except AttributeError: - return - - d.clear() - - -class _EngineDebuggingSignalEvents: - """Sets up handlers for two events that let us track the execution time of - queries.""" - - def __init__(self, engine, import_name): - self.engine = engine - self.app_package = import_name - - def register(self): - event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute) - event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute) - - def before_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - context._query_start_time = perf_counter() - - def after_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - try: - queries = _app_ctx_stack.top.sqlalchemy_queries - except AttributeError: - queries = _app_ctx_stack.top.sqlalchemy_queries = [] - - queries.append( - _DebugQueryTuple( - ( - statement, - parameters, - context._query_start_time, - perf_counter(), - _calling_context(self.app_package), - ) - ) - ) - - -def get_debug_queries(): - """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL - queries sent to the database. This information is available until the end - of request which makes it possible to easily ensure that the SQL generated - is the one expected on errors or in unittesting. Alternatively, you can also - enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` - config variable to `True`. - - The value returned will be a list of named tuples with the following - attributes: - - `statement` - The SQL statement issued - - `parameters` - The parameters for the SQL statement - - `start_time` / `end_time` - Time the query started / the results arrived. Please keep in mind - that the timer function used depends on your platform. These - values are only useful for sorting or comparing. They do not - necessarily represent an absolute timestamp. - - `duration` - Time the query took in seconds - - `context` - A string giving a rough estimation of where in your application - query was issued. The exact format is undefined so don't try - to reconstruct filename or function name. - """ - return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) - - -class Pagination: - """Internal helper class returned by :meth:`BaseQuery.paginate`. You - can also construct it from any other SQLAlchemy query object if you are - working with other libraries. Additionally it is possible to pass `None` - as query object in which case the :meth:`prev` and :meth:`next` will - no longer work. - """ - - def __init__(self, query, page, per_page, total, items): - #: the unlimited query object that was used to create this - #: pagination object. - self.query = query - #: the current page number (1 indexed) - self.page = page - #: the number of items to be displayed on a page. - self.per_page = per_page - #: the total number of items matching the query - self.total = total - #: the items for the current page - self.items = items - - @property - def pages(self): - """The total number of pages""" - if self.per_page == 0 or self.total is None: - pages = 0 - else: - pages = int(ceil(self.total / float(self.per_page))) - return pages - - def prev(self, error_out=False): - """Returns a :class:`Pagination` object for the previous page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page - 1, self.per_page, error_out) - - @property - def prev_num(self): - """Number of the previous page.""" - if not self.has_prev: - return None - return self.page - 1 - - @property - def has_prev(self): - """True if a previous page exists""" - return self.page > 1 - - def next(self, error_out=False): - """Returns a :class:`Pagination` object for the next page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page + 1, self.per_page, error_out) - - @property - def has_next(self): - """True if a next page exists.""" - return self.page < self.pages - - @property - def next_num(self): - """Number of the next page""" - if not self.has_next: - return None - return self.page + 1 - - def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2): - """Iterates over the page numbers in the pagination. The four - parameters control the thresholds how many numbers should be produced - from the sides. Skipped page numbers are represented as `None`. - This is how you could render such a pagination in the templates: - - .. sourcecode:: html+jinja - - {% macro render_pagination(pagination, endpoint) %} - - {% endmacro %} - """ - last = 0 - for num in range(1, self.pages + 1): - if ( - num <= left_edge - or ( - num > self.page - left_current - 1 - and num < self.page + right_current - ) - or num > self.pages - right_edge - ): - if last + 1 != num: - yield None - yield num - last = num - - -class BaseQuery(orm.Query): - """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with - convenience methods for querying in a web application. - - This is the default :attr:`~Model.query` object used for models, and - exposed as :attr:`~SQLAlchemy.Query`. Override the query class for - an individual model by subclassing this and setting - :attr:`~Model.query_class`. - """ - - def get_or_404(self, ident, description=None): - """Like :meth:`get` but aborts with 404 if not found instead of - returning ``None``. - """ - rv = self.get(ident) - if rv is None: - abort(404, description=description) - return rv - - def first_or_404(self, description=None): - """Like :meth:`first` but aborts with 404 if not found instead - of returning ``None``. - """ - rv = self.first() - if rv is None: - abort(404, description=description) - return rv - - def paginate( - self, page=None, per_page=None, error_out=True, max_per_page=None, count=True - ): - """Returns ``per_page`` items from page ``page``. - - If ``page`` or ``per_page`` are ``None``, they will be retrieved from - the request query. If ``max_per_page`` is specified, ``per_page`` will - be limited to that value. If there is no request or they aren't in the - query, they default to 1 and 20 respectively. If ``count`` is ``False``, - no query to help determine total page count will be run. - - When ``error_out`` is ``True`` (default), the following rules will - cause a 404 response: - - * No items are found and ``page`` is not 1. - * ``page`` is less than 1, or ``per_page`` is negative. - * ``page`` or ``per_page`` are not ints. - - When ``error_out`` is ``False``, ``page`` and ``per_page`` default to - 1 and 20 respectively. - - Returns a :class:`Pagination` object. - """ - - if request: - if page is None: - try: - page = int(request.args.get("page", 1)) - except (TypeError, ValueError): - if error_out: - abort(404) - - page = 1 - - if per_page is None: - try: - per_page = int(request.args.get("per_page", 20)) - except (TypeError, ValueError): - if error_out: - abort(404) - - per_page = 20 - else: - if page is None: - page = 1 - - if per_page is None: - per_page = 20 - - if max_per_page is not None: - per_page = min(per_page, max_per_page) - - if page < 1: - if error_out: - abort(404) - else: - page = 1 - - if per_page < 0: - if error_out: - abort(404) - else: - per_page = 20 - - items = self.limit(per_page).offset((page - 1) * per_page).all() - - if not items and page != 1 and error_out: - abort(404) - - if not count: - total = None - else: - total = self.order_by(None).count() - - return Pagination(self, page, per_page, total, items) - - -class _QueryProperty: - def __init__(self, sa): - self.sa = sa - - def __get__(self, obj, type): - try: - mapper = orm.class_mapper(type) - if mapper: - return type.query_class(mapper, session=self.sa.session()) - except UnmappedClassError: - return None - - -def _record_queries(app): - if app.debug: - return True - rq = app.config["SQLALCHEMY_RECORD_QUERIES"] - if rq is not None: - return rq - return bool(app.config.get("TESTING")) - - -class _EngineConnector: - def __init__(self, sa, app, bind=None): - self._sa = sa - self._app = app - self._engine = None - self._connected_for = None - self._bind = bind - self._lock = Lock() - - def get_uri(self): - if self._bind is None: - return self._app.config["SQLALCHEMY_DATABASE_URI"] - binds = self._app.config.get("SQLALCHEMY_BINDS") or () - assert ( - self._bind in binds - ), f"Bind {self._bind!r} is not configured in 'SQLALCHEMY_BINDS'." - return binds[self._bind] - - def get_engine(self): - with self._lock: - uri = self.get_uri() - echo = self._app.config["SQLALCHEMY_ECHO"] - if (uri, echo) == self._connected_for: - return self._engine - - sa_url = make_url(uri) - sa_url, options = self.get_options(sa_url, echo) - self._engine = rv = self._sa.create_engine(sa_url, options) - - if _record_queries(self._app): - _EngineDebuggingSignalEvents( - self._engine, self._app.import_name - ).register() - - self._connected_for = (uri, echo) - - return rv - - def get_options(self, sa_url, echo): - options = {} - sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options) - - if echo: - options["echo"] = echo - - # Give the config options set by a developer explicitly priority - # over decisions FSA makes. - options.update(self._app.config["SQLALCHEMY_ENGINE_OPTIONS"]) - # Give options set in SQLAlchemy.__init__() ultimate priority - options.update(self._sa._engine_options) - return sa_url, options - - -def get_state(app): - """Gets the state for the application""" - assert "sqlalchemy" in app.extensions, ( - "The sqlalchemy extension was not registered to the current " - "application. Please make sure to call init_app() first." - ) - return app.extensions["sqlalchemy"] - - -class _SQLAlchemyState: - """Remembers configuration for the (db, app) tuple.""" - - def __init__(self, db): - self.db = db - self.connectors = {} - - -class SQLAlchemy: - """This class is used to control the SQLAlchemy integration to one - or more Flask applications. Depending on how you initialize the - object it is usable right away or will attach as needed to a - Flask application. - - There are two usage modes which work very similarly. One is binding - the instance to a very specific Flask application:: - - app = Flask(__name__) - db = SQLAlchemy(app) - - The second possibility is to create the object once and configure the - application later to support it:: - - db = SQLAlchemy() - - def create_app(): - app = Flask(__name__) - db.init_app(app) - return app - - The difference between the two is that in the first case methods like - :meth:`create_all` and :meth:`drop_all` will work all the time but in - the second case a :meth:`flask.Flask.app_context` has to exist. - - By default Flask-SQLAlchemy will apply some backend-specific settings - to improve your experience with them. - - This class also provides access to all the SQLAlchemy functions and classes - from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can - declare models like this:: - - class User(db.Model): - username = db.Column(db.String(80), unique=True) - pw_hash = db.Column(db.String(80)) - - You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but - note that Flask-SQLAlchemy customizations are available only through an - instance of this :class:`SQLAlchemy` class. Query classes default to - :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default - query_class for `db.relationship` and `db.backref`. If you use these - interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, - the default query class will be that of :mod:`sqlalchemy`. - - .. admonition:: Check types carefully - - Don't perform type or `isinstance` checks against `db.Table`, which - emulates `Table` behavior but is not a class. `db.Table` exposes the - `Table` interface, but is a function which allows omission of metadata. - - The ``session_options`` parameter, if provided, is a dict of parameters - to be passed to the session constructor. See - :class:`~sqlalchemy.orm.session.Session` for the standard options. - - The ``engine_options`` parameter, if provided, is a dict of parameters - to be passed to create engine. See :func:`~sqlalchemy.create_engine` - for the standard options. The values given here will be merged with and - override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config - variable or othewise set by this library. - - .. versionchanged:: 3.0 - Removed the ``use_native_unicode`` parameter and config. - - .. versionchanged:: 3.0 - ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in - version 3.1. Call ``db.session.commit()`` directly instead. - - .. versionchanged:: 2.4 - Added the ``engine_options`` parameter. - - .. versionchanged:: 2.1 - Added the ``metadata`` parameter. This allows for setting custom - naming conventions among other, non-trivial things. - - .. versionchanged:: 2.1 - Added the ``query_class`` parameter, to allow customisation - of the query class, in place of the default of - :class:`BaseQuery`. - - .. versionchanged:: 2.1 - Added the ``model_class`` parameter, which allows a custom model - class to be used in place of :class:`Model`. - - .. versionchanged:: 2.1 - Use the same query class across ``session``, ``Model.query`` and - ``Query``. - - .. versionchanged:: 0.16 - ``scopefunc`` is now accepted on ``session_options``. It allows - specifying a custom function which will define the SQLAlchemy - session's scoping. - - .. versionchanged:: 0.10 - Added the ``session_options`` parameter. - """ - - #: Default query class used by :attr:`Model.query` and other queries. - #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`. - #: Defaults to :class:`BaseQuery`. - Query = None - - def __init__( - self, - app=None, - session_options=None, - metadata=None, - query_class=BaseQuery, - model_class=Model, - engine_options=None, - ): - - self.Query = query_class - self.session = self.create_scoped_session(session_options) - self.Model = self.make_declarative_base(model_class, metadata) - self._engine_lock = Lock() - self.app = app - self._engine_options = engine_options or {} - _include_sqlalchemy(self, query_class) - - if app is not None: - self.init_app(app) - - @property - def metadata(self): - """The metadata associated with ``db.Model``.""" - - return self.Model.metadata - - def create_scoped_session(self, options=None): - """Create a :class:`~sqlalchemy.orm.scoping.scoped_session` - on the factory from :meth:`create_session`. - - An extra key ``'scopefunc'`` can be set on the ``options`` dict to - specify a custom scope function. If it's not provided, Flask's app - context stack identity is used. This will ensure that sessions are - created and removed with the request/response cycle, and should be fine - in most cases. - - :param options: dict of keyword arguments passed to session class in - ``create_session`` - """ - - if options is None: - options = {} - - scopefunc = options.pop("scopefunc", _ident_func) - options.setdefault("query_cls", self.Query) - return orm.scoped_session(self.create_session(options), scopefunc=scopefunc) - - def create_session(self, options): - """Create the session factory used by :meth:`create_scoped_session`. - - The factory **must** return an object that SQLAlchemy recognizes as a session, - or registering session events may raise an exception. - - Valid factories include a :class:`~sqlalchemy.orm.session.Session` - class or a :class:`~sqlalchemy.orm.session.sessionmaker`. - - The default implementation creates a ``sessionmaker`` for - :class:`SignallingSession`. - - :param options: dict of keyword arguments passed to session class - """ - - return orm.sessionmaker(class_=SignallingSession, db=self, **options) - - def make_declarative_base(self, model, metadata=None): - """Creates the declarative base that all models will inherit from. - - :param model: base model class (or a tuple of base classes) to pass - to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class - returned from ``declarative_base``, in which case a new base class - is not created. - :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or - none to use SQLAlchemy's default. - - .. versionchanged 2.3.0:: - ``model`` can be an existing declarative base in order to support - complex customization such as changing the metaclass. - """ - if not isinstance(model, DeclarativeMeta): - model = declarative_base( - cls=model, name="Model", metadata=metadata, metaclass=DefaultMeta - ) - - # if user passed in a declarative base and a metaclass for some reason, - # make sure the base uses the metaclass - if metadata is not None and model.metadata is not metadata: - model.metadata = metadata - - if not getattr(model, "query_class", None): - model.query_class = self.Query - - model.query = _QueryProperty(self) - return model - - def init_app(self, app): - """This callback can be used to initialize an application for the - use with this database setup. Never use a database in the context - of an application not initialized that way or connections will - leak. - """ - - # We intentionally don't set self.app = app, to support multiple - # applications. If the app is passed in the constructor, - # we set it and don't support multiple applications. - if not ( - app.config.get("SQLALCHEMY_DATABASE_URI") - or app.config.get("SQLALCHEMY_BINDS") - ): - raise RuntimeError( - "Either SQLALCHEMY_DATABASE_URI or SQLALCHEMY_BINDS needs to be set." - ) - - app.config.setdefault("SQLALCHEMY_DATABASE_URI", None) - app.config.setdefault("SQLALCHEMY_BINDS", None) - app.config.setdefault("SQLALCHEMY_ECHO", False) - app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", None) - app.config.setdefault("SQLALCHEMY_COMMIT_ON_TEARDOWN", False) - app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False) - app.config.setdefault("SQLALCHEMY_ENGINE_OPTIONS", {}) - - app.extensions["sqlalchemy"] = _SQLAlchemyState(self) - - @app.teardown_appcontext - def shutdown_session(response_or_exc): - if app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"]: - warnings.warn( - "'COMMIT_ON_TEARDOWN' is deprecated and will be" - " removed in version 3.1. Call" - " 'db.session.commit()'` directly instead.", - DeprecationWarning, - ) - - if response_or_exc is None: - self.session.commit() - - self.session.remove() - return response_or_exc - - def apply_driver_hacks(self, app, sa_url, options): - """This method is called before engine creation and used to inject - driver specific hacks into the options. The `options` parameter is - a dictionary of keyword arguments that will then be used to call - the :func:`sqlalchemy.create_engine` function. - - The default implementation provides some defaults for things - like pool sizes for MySQL and SQLite. - - .. versionchanged:: 3.0 - Change the default MySQL character set to "utf8mb4". - - .. versionchanged:: 2.5 - Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL - immutable, so any changes to it must now be passed back up - to the original caller. - """ - if sa_url.drivername.startswith("mysql"): - sa_url = _sa_url_query_setdefault(sa_url, charset="utf8mb4") - - if sa_url.drivername != "mysql+gaerdbms": - options.setdefault("pool_size", 10) - options.setdefault("pool_recycle", 7200) - elif sa_url.drivername == "sqlite": - pool_size = options.get("pool_size") - detected_in_memory = False - if sa_url.database in (None, "", ":memory:"): - detected_in_memory = True - from sqlalchemy.pool import StaticPool - - options["poolclass"] = StaticPool - if "connect_args" not in options: - options["connect_args"] = {} - options["connect_args"]["check_same_thread"] = False - - # we go to memory and the pool size was explicitly set - # to 0 which is fail. Let the user know that - if pool_size == 0: - raise RuntimeError( - "SQLite in memory database with an " - "empty queue not possible due to data " - "loss." - ) - # if pool size is None or explicitly set to 0 we assume the - # user did not want a queue for this sqlite connection and - # hook in the null pool. - elif not pool_size: - from sqlalchemy.pool import NullPool - - options["poolclass"] = NullPool - - # If the database path is not absolute, it's relative to the - # app instance path, which might need to be created. - if not detected_in_memory and not os.path.isabs(sa_url.database): - os.makedirs(app.instance_path, exist_ok=True) - sa_url = _sa_url_set( - sa_url, database=os.path.join(app.root_path, sa_url.database) - ) - - return sa_url, options - - @property - def engine(self): - """Gives access to the engine. If the database configuration is bound - to a specific application (initialized with an application) this will - always return a database connection. If however the current application - is used this might raise a :exc:`RuntimeError` if no application is - active at the moment. - """ - return self.get_engine() - - def make_connector(self, app=None, bind=None): - """Creates the connector for a given state and bind.""" - return _EngineConnector(self, self.get_app(app), bind) - - def get_engine(self, app=None, bind=None): - """Returns a specific engine.""" - - app = self.get_app(app) - state = get_state(app) - - with self._engine_lock: - connector = state.connectors.get(bind) - - if connector is None: - connector = self.make_connector(app, bind) - state.connectors[bind] = connector - - return connector.get_engine() - - def create_engine(self, sa_url, engine_opts): - """Override this method to have final say over how the - SQLAlchemy engine is created. - - In most cases, you will want to use - ``'SQLALCHEMY_ENGINE_OPTIONS'`` config variable or set - ``engine_options`` for :func:`SQLAlchemy`. - """ - return sqlalchemy.create_engine(sa_url, **engine_opts) - - def get_app(self, reference_app=None): - """Helper method that implements the logic to look up an - application.""" - - if reference_app is not None: - return reference_app - - if current_app: - return current_app._get_current_object() - - if self.app is not None: - return self.app - - raise RuntimeError( - "No application found. Either work inside a view function or push" - " an application context. See" - " https://flask-sqlalchemy.palletsprojects.com/contexts/." - ) - - def get_tables_for_bind(self, bind=None): - """Returns a list of all tables relevant for a bind.""" - result = [] - for table in self.Model.metadata.tables.values(): - if table.info.get("bind_key") == bind: - result.append(table) - return result - - def get_binds(self, app=None): - """Returns a dictionary with a table->engine mapping. - - This is suitable for use of sessionmaker(binds=db.get_binds(app)). - """ - app = self.get_app(app) - binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) - retval = {} - for bind in binds: - engine = self.get_engine(app, bind) - tables = self.get_tables_for_bind(bind) - retval.update({table: engine for table in tables}) - return retval - - def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): - app = self.get_app(app) - - if bind == "__all__": - binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) - elif isinstance(bind, str) or bind is None: - binds = [bind] - else: - binds = bind - - for bind in binds: - extra = {} - if not skip_tables: - tables = self.get_tables_for_bind(bind) - extra["tables"] = tables - op = getattr(self.Model.metadata, operation) - op(bind=self.get_engine(app, bind), **extra) - - def create_all(self, bind="__all__", app=None): - """Create all tables that do not already exist in the database. - This does not update existing tables, use a migration library - for that. - - :param bind: A bind key or list of keys to create the tables - for. Defaults to all binds. - :param app: Use this app instead of requiring an app context. - - .. versionchanged:: 0.12 - Added the ``bind`` and ``app`` parameters. - """ - self._execute_for_all_tables(app, bind, "create_all") - - def drop_all(self, bind="__all__", app=None): - """Drop all tables. - - :param bind: A bind key or list of keys to drop the tables for. - Defaults to all binds. - :param app: Use this app instead of requiring an app context. - - .. versionchanged:: 0.12 - Added the ``bind`` and ``app`` parameters. - """ - self._execute_for_all_tables(app, bind, "drop_all") - - def reflect(self, bind="__all__", app=None): - """Reflects tables from the database. - - :param bind: A bind key or list of keys to reflect the tables - from. Defaults to all binds. - :param app: Use this app instead of requiring an app context. - - .. versionchanged:: 0.12 - Added the ``bind`` and ``app`` parameters. - """ - self._execute_for_all_tables(app, bind, "reflect", skip_tables=True) - - def __repr__(self): - url = self.engine.url if self.app or current_app else None - return f"<{type(self).__name__} engine={url!r}>" +__all__ = [ + "SQLAlchemy", +] diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py new file mode 100644 index 00000000..0752cd3b --- /dev/null +++ b/src/flask_sqlalchemy/extension.py @@ -0,0 +1,1085 @@ +import functools +import os +import sys +import warnings +from math import ceil +from operator import itemgetter +from threading import Lock +from time import perf_counter + +import sqlalchemy +from flask import _app_ctx_stack +from flask import abort +from flask import current_app +from flask import request +from flask.signals import Namespace +from sqlalchemy import event +from sqlalchemy import inspect +from sqlalchemy import orm +from sqlalchemy.engine.url import make_url +from sqlalchemy.orm.exc import UnmappedClassError +from sqlalchemy.orm.session import Session as SessionBase + +from .model import DefaultMeta +from .model import Model + +try: + from sqlalchemy.orm import declarative_base + from sqlalchemy.orm import DeclarativeMeta +except ImportError: + # SQLAlchemy <= 1.3 + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeclarativeMeta + +# Scope the session to the current greenlet if greenlet is available, +# otherwise fall back to the current thread. +try: + from greenlet import getcurrent as _ident_func +except ImportError: + from threading import get_ident as _ident_func + +_signals = Namespace() +models_committed = _signals.signal("models-committed") +before_models_committed = _signals.signal("before-models-committed") + + +def _sa_url_set(url, **kwargs): + try: + url = url.set(**kwargs) + except AttributeError: + # SQLAlchemy <= 1.3 + for key, value in kwargs.items(): + setattr(url, key, value) + + return url + + +def _sa_url_query_setdefault(url, **kwargs): + query = dict(url.query) + + for key, value in kwargs.items(): + query.setdefault(key, value) + + return _sa_url_set(url, query=query) + + +def _make_table(db): + def _make_table(*args, **kwargs): + if len(args) > 1 and isinstance(args[1], db.Column): + args = (args[0], db.metadata) + args[1:] + info = kwargs.pop("info", None) or {} + info.setdefault("bind_key", None) + kwargs["info"] = info + return sqlalchemy.Table(*args, **kwargs) + + return _make_table + + +def _set_default_query_class(d, cls): + if "query_class" not in d: + d["query_class"] = cls + + +def _wrap_with_default_query_class(fn, cls): + @functools.wraps(fn) + def newfn(*args, **kwargs): + _set_default_query_class(kwargs, cls) + if "backref" in kwargs: + backref = kwargs["backref"] + if isinstance(backref, str): + backref = (backref, {}) + _set_default_query_class(backref[1], cls) + return fn(*args, **kwargs) + + return newfn + + +def _include_sqlalchemy(obj, cls): + for module in sqlalchemy, sqlalchemy.orm: + for key in module.__all__: + if not hasattr(obj, key): + setattr(obj, key, getattr(module, key)) + # Note: obj.Table does not attempt to be a SQLAlchemy Table class. + obj.Table = _make_table(obj) + obj.relationship = _wrap_with_default_query_class(obj.relationship, cls) + obj.relation = _wrap_with_default_query_class(obj.relation, cls) + obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls) + obj.event = event + + +class _DebugQueryTuple(tuple): + statement = property(itemgetter(0)) + parameters = property(itemgetter(1)) + start_time = property(itemgetter(2)) + end_time = property(itemgetter(3)) + context = property(itemgetter(4)) + + @property + def duration(self): + return self.end_time - self.start_time + + def __repr__(self): + return ( + f"" + ) + + +def _calling_context(app_path): + frm = sys._getframe(1) + while frm.f_back is not None: + name = frm.f_globals.get("__name__") + if name and (name == app_path or name.startswith(f"{app_path}.")): + funcname = frm.f_code.co_name + return f"{frm.f_code.co_filename}:{frm.f_lineno} ({funcname})" + frm = frm.f_back + return "" + + +class SignallingSession(SessionBase): + """The signalling session is the default session that Flask-SQLAlchemy + uses. It extends the default session system with bind selection and + modification tracking. + + If you want to use a different session you can override the + :meth:`SQLAlchemy.create_session` function. + + .. versionadded:: 2.0 + + .. versionadded:: 2.1 + The `binds` option was added, which allows a session to be joined + to an external transaction. + """ + + def __init__(self, db, autocommit=False, autoflush=True, **options): + #: The application that this session belongs to. + self.app = app = db.get_app() + track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] + bind = options.pop("bind", None) or db.engine + binds = options.pop("binds", db.get_binds(app)) + + if track_modifications: + _SessionSignalEvents.register(self) + + SessionBase.__init__( + self, + autocommit=autocommit, + autoflush=autoflush, + bind=bind, + binds=binds, + **options, + ) + + def get_bind(self, mapper=None, **kwargs): + """Return the engine or connection for a given model or + table, using the ``__bind_key__`` if it is set. + """ + # mapper is None if someone tries to just get a connection + if mapper is not None: + try: + # SA >= 1.3 + persist_selectable = mapper.persist_selectable + except AttributeError: + # SA < 1.3 + persist_selectable = mapper.mapped_table + + info = getattr(persist_selectable, "info", {}) + bind_key = info.get("bind_key") + if bind_key is not None: + state = get_state(self.app) + return state.db.get_engine(self.app, bind=bind_key) + + return super().get_bind(mapper, **kwargs) + + +class _SessionSignalEvents: + @classmethod + def register(cls, session): + if not hasattr(session, "_model_changes"): + session._model_changes = {} + + event.listen(session, "before_flush", cls.record_ops) + event.listen(session, "before_commit", cls.record_ops) + event.listen(session, "before_commit", cls.before_commit) + event.listen(session, "after_commit", cls.after_commit) + event.listen(session, "after_rollback", cls.after_rollback) + + @classmethod + def unregister(cls, session): + if hasattr(session, "_model_changes"): + del session._model_changes + + event.remove(session, "before_flush", cls.record_ops) + event.remove(session, "before_commit", cls.record_ops) + event.remove(session, "before_commit", cls.before_commit) + event.remove(session, "after_commit", cls.after_commit) + event.remove(session, "after_rollback", cls.after_rollback) + + @staticmethod + def record_ops(session, flush_context=None, instances=None): + try: + d = session._model_changes + except AttributeError: + return + + for targets, operation in ( + (session.new, "insert"), + (session.dirty, "update"), + (session.deleted, "delete"), + ): + for target in targets: + state = inspect(target) + key = state.identity_key if state.has_identity else id(target) + d[key] = (target, operation) + + @staticmethod + def before_commit(session): + try: + d = session._model_changes + except AttributeError: + return + + if d: + before_models_committed.send(session.app, changes=list(d.values())) + + @staticmethod + def after_commit(session): + try: + d = session._model_changes + except AttributeError: + return + + if d: + models_committed.send(session.app, changes=list(d.values())) + d.clear() + + @staticmethod + def after_rollback(session): + try: + d = session._model_changes + except AttributeError: + return + + d.clear() + + +class _EngineDebuggingSignalEvents: + """Sets up handlers for two events that let us track the execution time of + queries.""" + + def __init__(self, engine, import_name): + self.engine = engine + self.app_package = import_name + + def register(self): + event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute) + event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute) + + def before_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): + if current_app: + context._query_start_time = perf_counter() + + def after_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): + if current_app: + try: + queries = _app_ctx_stack.top.sqlalchemy_queries + except AttributeError: + queries = _app_ctx_stack.top.sqlalchemy_queries = [] + + queries.append( + _DebugQueryTuple( + ( + statement, + parameters, + context._query_start_time, + perf_counter(), + _calling_context(self.app_package), + ) + ) + ) + + +def get_debug_queries(): + """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL + queries sent to the database. This information is available until the end + of request which makes it possible to easily ensure that the SQL generated + is the one expected on errors or in unittesting. Alternatively, you can also + enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` + config variable to `True`. + + The value returned will be a list of named tuples with the following + attributes: + + `statement` + The SQL statement issued + + `parameters` + The parameters for the SQL statement + + `start_time` / `end_time` + Time the query started / the results arrived. Please keep in mind + that the timer function used depends on your platform. These + values are only useful for sorting or comparing. They do not + necessarily represent an absolute timestamp. + + `duration` + Time the query took in seconds + + `context` + A string giving a rough estimation of where in your application + query was issued. The exact format is undefined so don't try + to reconstruct filename or function name. + """ + return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) + + +class Pagination: + """Internal helper class returned by :meth:`BaseQuery.paginate`. You + can also construct it from any other SQLAlchemy query object if you are + working with other libraries. Additionally it is possible to pass `None` + as query object in which case the :meth:`prev` and :meth:`next` will + no longer work. + """ + + def __init__(self, query, page, per_page, total, items): + #: the unlimited query object that was used to create this + #: pagination object. + self.query = query + #: the current page number (1 indexed) + self.page = page + #: the number of items to be displayed on a page. + self.per_page = per_page + #: the total number of items matching the query + self.total = total + #: the items for the current page + self.items = items + + @property + def pages(self): + """The total number of pages""" + if self.per_page == 0 or self.total is None: + pages = 0 + else: + pages = int(ceil(self.total / float(self.per_page))) + return pages + + def prev(self, error_out=False): + """Returns a :class:`Pagination` object for the previous page.""" + assert ( + self.query is not None + ), "a query object is required for this method to work" + return self.query.paginate(self.page - 1, self.per_page, error_out) + + @property + def prev_num(self): + """Number of the previous page.""" + if not self.has_prev: + return None + return self.page - 1 + + @property + def has_prev(self): + """True if a previous page exists""" + return self.page > 1 + + def next(self, error_out=False): + """Returns a :class:`Pagination` object for the next page.""" + assert ( + self.query is not None + ), "a query object is required for this method to work" + return self.query.paginate(self.page + 1, self.per_page, error_out) + + @property + def has_next(self): + """True if a next page exists.""" + return self.page < self.pages + + @property + def next_num(self): + """Number of the next page""" + if not self.has_next: + return None + return self.page + 1 + + def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2): + """Iterates over the page numbers in the pagination. The four + parameters control the thresholds how many numbers should be produced + from the sides. Skipped page numbers are represented as `None`. + This is how you could render such a pagination in the templates: + + .. sourcecode:: html+jinja + + {% macro render_pagination(pagination, endpoint) %} + + {% endmacro %} + """ + last = 0 + for num in range(1, self.pages + 1): + if ( + num <= left_edge + or ( + num > self.page - left_current - 1 + and num < self.page + right_current + ) + or num > self.pages - right_edge + ): + if last + 1 != num: + yield None + yield num + last = num + + +class BaseQuery(orm.Query): + """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with + convenience methods for querying in a web application. + + This is the default :attr:`~Model.query` object used for models, and + exposed as :attr:`~SQLAlchemy.Query`. Override the query class for + an individual model by subclassing this and setting + :attr:`~Model.query_class`. + """ + + def get_or_404(self, ident, description=None): + """Like :meth:`get` but aborts with 404 if not found instead of + returning ``None``. + """ + rv = self.get(ident) + if rv is None: + abort(404, description=description) + return rv + + def first_or_404(self, description=None): + """Like :meth:`first` but aborts with 404 if not found instead + of returning ``None``. + """ + rv = self.first() + if rv is None: + abort(404, description=description) + return rv + + def paginate( + self, page=None, per_page=None, error_out=True, max_per_page=None, count=True + ): + """Returns ``per_page`` items from page ``page``. + + If ``page`` or ``per_page`` are ``None``, they will be retrieved from + the request query. If ``max_per_page`` is specified, ``per_page`` will + be limited to that value. If there is no request or they aren't in the + query, they default to 1 and 20 respectively. If ``count`` is ``False``, + no query to help determine total page count will be run. + + When ``error_out`` is ``True`` (default), the following rules will + cause a 404 response: + + * No items are found and ``page`` is not 1. + * ``page`` is less than 1, or ``per_page`` is negative. + * ``page`` or ``per_page`` are not ints. + + When ``error_out`` is ``False``, ``page`` and ``per_page`` default to + 1 and 20 respectively. + + Returns a :class:`Pagination` object. + """ + + if request: + if page is None: + try: + page = int(request.args.get("page", 1)) + except (TypeError, ValueError): + if error_out: + abort(404) + + page = 1 + + if per_page is None: + try: + per_page = int(request.args.get("per_page", 20)) + except (TypeError, ValueError): + if error_out: + abort(404) + + per_page = 20 + else: + if page is None: + page = 1 + + if per_page is None: + per_page = 20 + + if max_per_page is not None: + per_page = min(per_page, max_per_page) + + if page < 1: + if error_out: + abort(404) + else: + page = 1 + + if per_page < 0: + if error_out: + abort(404) + else: + per_page = 20 + + items = self.limit(per_page).offset((page - 1) * per_page).all() + + if not items and page != 1 and error_out: + abort(404) + + if not count: + total = None + else: + total = self.order_by(None).count() + + return Pagination(self, page, per_page, total, items) + + +class _QueryProperty: + def __init__(self, sa): + self.sa = sa + + def __get__(self, obj, type): + try: + mapper = orm.class_mapper(type) + if mapper: + return type.query_class(mapper, session=self.sa.session()) + except UnmappedClassError: + return None + + +def _record_queries(app): + if app.debug: + return True + rq = app.config["SQLALCHEMY_RECORD_QUERIES"] + if rq is not None: + return rq + return bool(app.config.get("TESTING")) + + +class _EngineConnector: + def __init__(self, sa, app, bind=None): + self._sa = sa + self._app = app + self._engine = None + self._connected_for = None + self._bind = bind + self._lock = Lock() + + def get_uri(self): + if self._bind is None: + return self._app.config["SQLALCHEMY_DATABASE_URI"] + binds = self._app.config.get("SQLALCHEMY_BINDS") or () + assert ( + self._bind in binds + ), f"Bind {self._bind!r} is not configured in 'SQLALCHEMY_BINDS'." + return binds[self._bind] + + def get_engine(self): + with self._lock: + uri = self.get_uri() + echo = self._app.config["SQLALCHEMY_ECHO"] + if (uri, echo) == self._connected_for: + return self._engine + + sa_url = make_url(uri) + sa_url, options = self.get_options(sa_url, echo) + self._engine = rv = self._sa.create_engine(sa_url, options) + + if _record_queries(self._app): + _EngineDebuggingSignalEvents( + self._engine, self._app.import_name + ).register() + + self._connected_for = (uri, echo) + + return rv + + def get_options(self, sa_url, echo): + options = {} + sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options) + + if echo: + options["echo"] = echo + + # Give the config options set by a developer explicitly priority + # over decisions FSA makes. + options.update(self._app.config["SQLALCHEMY_ENGINE_OPTIONS"]) + # Give options set in SQLAlchemy.__init__() ultimate priority + options.update(self._sa._engine_options) + return sa_url, options + + +def get_state(app): + """Gets the state for the application""" + assert "sqlalchemy" in app.extensions, ( + "The sqlalchemy extension was not registered to the current " + "application. Please make sure to call init_app() first." + ) + return app.extensions["sqlalchemy"] + + +class _SQLAlchemyState: + """Remembers configuration for the (db, app) tuple.""" + + def __init__(self, db): + self.db = db + self.connectors = {} + + +class SQLAlchemy: + """This class is used to control the SQLAlchemy integration to one + or more Flask applications. Depending on how you initialize the + object it is usable right away or will attach as needed to a + Flask application. + + There are two usage modes which work very similarly. One is binding + the instance to a very specific Flask application:: + + app = Flask(__name__) + db = SQLAlchemy(app) + + The second possibility is to create the object once and configure the + application later to support it:: + + db = SQLAlchemy() + + def create_app(): + app = Flask(__name__) + db.init_app(app) + return app + + The difference between the two is that in the first case methods like + :meth:`create_all` and :meth:`drop_all` will work all the time but in + the second case a :meth:`flask.Flask.app_context` has to exist. + + By default Flask-SQLAlchemy will apply some backend-specific settings + to improve your experience with them. + + This class also provides access to all the SQLAlchemy functions and classes + from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can + declare models like this:: + + class User(db.Model): + username = db.Column(db.String(80), unique=True) + pw_hash = db.Column(db.String(80)) + + You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but + note that Flask-SQLAlchemy customizations are available only through an + instance of this :class:`SQLAlchemy` class. Query classes default to + :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default + query_class for `db.relationship` and `db.backref`. If you use these + interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, + the default query class will be that of :mod:`sqlalchemy`. + + .. admonition:: Check types carefully + + Don't perform type or `isinstance` checks against `db.Table`, which + emulates `Table` behavior but is not a class. `db.Table` exposes the + `Table` interface, but is a function which allows omission of metadata. + + The ``session_options`` parameter, if provided, is a dict of parameters + to be passed to the session constructor. See + :class:`~sqlalchemy.orm.session.Session` for the standard options. + + The ``engine_options`` parameter, if provided, is a dict of parameters + to be passed to create engine. See :func:`~sqlalchemy.create_engine` + for the standard options. The values given here will be merged with and + override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config + variable or othewise set by this library. + + .. versionchanged:: 3.0 + Removed the ``use_native_unicode`` parameter and config. + + .. versionchanged:: 3.0 + ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in + version 3.1. Call ``db.session.commit()`` directly instead. + + .. versionchanged:: 2.4 + Added the ``engine_options`` parameter. + + .. versionchanged:: 2.1 + Added the ``metadata`` parameter. This allows for setting custom + naming conventions among other, non-trivial things. + + .. versionchanged:: 2.1 + Added the ``query_class`` parameter, to allow customisation + of the query class, in place of the default of + :class:`BaseQuery`. + + .. versionchanged:: 2.1 + Added the ``model_class`` parameter, which allows a custom model + class to be used in place of :class:`Model`. + + .. versionchanged:: 2.1 + Use the same query class across ``session``, ``Model.query`` and + ``Query``. + + .. versionchanged:: 0.16 + ``scopefunc`` is now accepted on ``session_options``. It allows + specifying a custom function which will define the SQLAlchemy + session's scoping. + + .. versionchanged:: 0.10 + Added the ``session_options`` parameter. + """ + + #: Default query class used by :attr:`Model.query` and other queries. + #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`. + #: Defaults to :class:`BaseQuery`. + Query = None + + def __init__( + self, + app=None, + session_options=None, + metadata=None, + query_class=BaseQuery, + model_class=Model, + engine_options=None, + ): + + self.Query = query_class + self.session = self.create_scoped_session(session_options) + self.Model = self.make_declarative_base(model_class, metadata) + self._engine_lock = Lock() + self.app = app + self._engine_options = engine_options or {} + _include_sqlalchemy(self, query_class) + + if app is not None: + self.init_app(app) + + @property + def metadata(self): + """The metadata associated with ``db.Model``.""" + + return self.Model.metadata + + def create_scoped_session(self, options=None): + """Create a :class:`~sqlalchemy.orm.scoping.scoped_session` + on the factory from :meth:`create_session`. + + An extra key ``'scopefunc'`` can be set on the ``options`` dict to + specify a custom scope function. If it's not provided, Flask's app + context stack identity is used. This will ensure that sessions are + created and removed with the request/response cycle, and should be fine + in most cases. + + :param options: dict of keyword arguments passed to session class in + ``create_session`` + """ + + if options is None: + options = {} + + scopefunc = options.pop("scopefunc", _ident_func) + options.setdefault("query_cls", self.Query) + return orm.scoped_session(self.create_session(options), scopefunc=scopefunc) + + def create_session(self, options): + """Create the session factory used by :meth:`create_scoped_session`. + + The factory **must** return an object that SQLAlchemy recognizes as a session, + or registering session events may raise an exception. + + Valid factories include a :class:`~sqlalchemy.orm.session.Session` + class or a :class:`~sqlalchemy.orm.session.sessionmaker`. + + The default implementation creates a ``sessionmaker`` for + :class:`SignallingSession`. + + :param options: dict of keyword arguments passed to session class + """ + + return orm.sessionmaker(class_=SignallingSession, db=self, **options) + + def make_declarative_base(self, model, metadata=None): + """Creates the declarative base that all models will inherit from. + + :param model: base model class (or a tuple of base classes) to pass + to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class + returned from ``declarative_base``, in which case a new base class + is not created. + :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or + none to use SQLAlchemy's default. + + .. versionchanged 2.3.0:: + ``model`` can be an existing declarative base in order to support + complex customization such as changing the metaclass. + """ + if not isinstance(model, DeclarativeMeta): + model = declarative_base( + cls=model, name="Model", metadata=metadata, metaclass=DefaultMeta + ) + + # if user passed in a declarative base and a metaclass for some reason, + # make sure the base uses the metaclass + if metadata is not None and model.metadata is not metadata: + model.metadata = metadata + + if not getattr(model, "query_class", None): + model.query_class = self.Query + + model.query = _QueryProperty(self) + return model + + def init_app(self, app): + """This callback can be used to initialize an application for the + use with this database setup. Never use a database in the context + of an application not initialized that way or connections will + leak. + """ + + # We intentionally don't set self.app = app, to support multiple + # applications. If the app is passed in the constructor, + # we set it and don't support multiple applications. + if not ( + app.config.get("SQLALCHEMY_DATABASE_URI") + or app.config.get("SQLALCHEMY_BINDS") + ): + raise RuntimeError( + "Either SQLALCHEMY_DATABASE_URI or SQLALCHEMY_BINDS needs to be set." + ) + + app.config.setdefault("SQLALCHEMY_DATABASE_URI", None) + app.config.setdefault("SQLALCHEMY_BINDS", None) + app.config.setdefault("SQLALCHEMY_ECHO", False) + app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", None) + app.config.setdefault("SQLALCHEMY_COMMIT_ON_TEARDOWN", False) + app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False) + app.config.setdefault("SQLALCHEMY_ENGINE_OPTIONS", {}) + + app.extensions["sqlalchemy"] = _SQLAlchemyState(self) + + @app.teardown_appcontext + def shutdown_session(response_or_exc): + if app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"]: + warnings.warn( + "'COMMIT_ON_TEARDOWN' is deprecated and will be" + " removed in version 3.1. Call" + " 'db.session.commit()'` directly instead.", + DeprecationWarning, + ) + + if response_or_exc is None: + self.session.commit() + + self.session.remove() + return response_or_exc + + def apply_driver_hacks(self, app, sa_url, options): + """This method is called before engine creation and used to inject + driver specific hacks into the options. The `options` parameter is + a dictionary of keyword arguments that will then be used to call + the :func:`sqlalchemy.create_engine` function. + + The default implementation provides some defaults for things + like pool sizes for MySQL and SQLite. + + .. versionchanged:: 3.0 + Change the default MySQL character set to "utf8mb4". + + .. versionchanged:: 2.5 + Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL + immutable, so any changes to it must now be passed back up + to the original caller. + """ + if sa_url.drivername.startswith("mysql"): + sa_url = _sa_url_query_setdefault(sa_url, charset="utf8mb4") + + if sa_url.drivername != "mysql+gaerdbms": + options.setdefault("pool_size", 10) + options.setdefault("pool_recycle", 7200) + elif sa_url.drivername == "sqlite": + pool_size = options.get("pool_size") + detected_in_memory = False + if sa_url.database in (None, "", ":memory:"): + detected_in_memory = True + from sqlalchemy.pool import StaticPool + + options["poolclass"] = StaticPool + if "connect_args" not in options: + options["connect_args"] = {} + options["connect_args"]["check_same_thread"] = False + + # we go to memory and the pool size was explicitly set + # to 0 which is fail. Let the user know that + if pool_size == 0: + raise RuntimeError( + "SQLite in memory database with an " + "empty queue not possible due to data " + "loss." + ) + # if pool size is None or explicitly set to 0 we assume the + # user did not want a queue for this sqlite connection and + # hook in the null pool. + elif not pool_size: + from sqlalchemy.pool import NullPool + + options["poolclass"] = NullPool + + # If the database path is not absolute, it's relative to the + # app instance path, which might need to be created. + if not detected_in_memory and not os.path.isabs(sa_url.database): + os.makedirs(app.instance_path, exist_ok=True) + sa_url = _sa_url_set( + sa_url, database=os.path.join(app.root_path, sa_url.database) + ) + + return sa_url, options + + @property + def engine(self): + """Gives access to the engine. If the database configuration is bound + to a specific application (initialized with an application) this will + always return a database connection. If however the current application + is used this might raise a :exc:`RuntimeError` if no application is + active at the moment. + """ + return self.get_engine() + + def make_connector(self, app=None, bind=None): + """Creates the connector for a given state and bind.""" + return _EngineConnector(self, self.get_app(app), bind) + + def get_engine(self, app=None, bind=None): + """Returns a specific engine.""" + + app = self.get_app(app) + state = get_state(app) + + with self._engine_lock: + connector = state.connectors.get(bind) + + if connector is None: + connector = self.make_connector(app, bind) + state.connectors[bind] = connector + + return connector.get_engine() + + def create_engine(self, sa_url, engine_opts): + """Override this method to have final say over how the + SQLAlchemy engine is created. + + In most cases, you will want to use + ``'SQLALCHEMY_ENGINE_OPTIONS'`` config variable or set + ``engine_options`` for :func:`SQLAlchemy`. + """ + return sqlalchemy.create_engine(sa_url, **engine_opts) + + def get_app(self, reference_app=None): + """Helper method that implements the logic to look up an + application.""" + + if reference_app is not None: + return reference_app + + if current_app: + return current_app._get_current_object() + + if self.app is not None: + return self.app + + raise RuntimeError( + "No application found. Either work inside a view function or push" + " an application context. See" + " https://flask-sqlalchemy.palletsprojects.com/contexts/." + ) + + def get_tables_for_bind(self, bind=None): + """Returns a list of all tables relevant for a bind.""" + result = [] + for table in self.Model.metadata.tables.values(): + if table.info.get("bind_key") == bind: + result.append(table) + return result + + def get_binds(self, app=None): + """Returns a dictionary with a table->engine mapping. + + This is suitable for use of sessionmaker(binds=db.get_binds(app)). + """ + app = self.get_app(app) + binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) + retval = {} + for bind in binds: + engine = self.get_engine(app, bind) + tables = self.get_tables_for_bind(bind) + retval.update({table: engine for table in tables}) + return retval + + def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): + app = self.get_app(app) + + if bind == "__all__": + binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) + elif isinstance(bind, str) or bind is None: + binds = [bind] + else: + binds = bind + + for bind in binds: + extra = {} + if not skip_tables: + tables = self.get_tables_for_bind(bind) + extra["tables"] = tables + op = getattr(self.Model.metadata, operation) + op(bind=self.get_engine(app, bind), **extra) + + def create_all(self, bind="__all__", app=None): + """Create all tables that do not already exist in the database. + This does not update existing tables, use a migration library + for that. + + :param bind: A bind key or list of keys to create the tables + for. Defaults to all binds. + :param app: Use this app instead of requiring an app context. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._execute_for_all_tables(app, bind, "create_all") + + def drop_all(self, bind="__all__", app=None): + """Drop all tables. + + :param bind: A bind key or list of keys to drop the tables for. + Defaults to all binds. + :param app: Use this app instead of requiring an app context. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._execute_for_all_tables(app, bind, "drop_all") + + def reflect(self, bind="__all__", app=None): + """Reflects tables from the database. + + :param bind: A bind key or list of keys to reflect the tables + from. Defaults to all binds. + :param app: Use this app instead of requiring an app context. + + .. versionchanged:: 0.12 + Added the ``bind`` and ``app`` parameters. + """ + self._execute_for_all_tables(app, bind, "reflect", skip_tables=True) + + def __repr__(self): + url = self.engine.url if self.app or current_app else None + return f"<{type(self).__name__} engine={url!r}>" diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index 0eaef3d8..c2c53855 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -1,7 +1,7 @@ import flask -from flask_sqlalchemy import get_debug_queries from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.extension import get_debug_queries def test_basic_insert(app, db, Todo): diff --git a/tests/test_binds.py b/tests/test_binds.py index 4a063b5b..b0a9fbb9 100644 --- a/tests/test_binds.py +++ b/tests/test_binds.py @@ -1,5 +1,5 @@ -from flask_sqlalchemy import get_state from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.extension import get_state def test_basic_binds(app, db): diff --git a/tests/test_pagination.py b/tests/test_pagination.py index fbc7307a..cbe1b308 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,7 +1,7 @@ import pytest from werkzeug.exceptions import NotFound -from flask_sqlalchemy import Pagination +from flask_sqlalchemy.extension import Pagination def test_basic_pagination(): diff --git a/tests/test_query_class.py b/tests/test_query_class.py index fca5b7a7..d13916d6 100644 --- a/tests/test_query_class.py +++ b/tests/test_query_class.py @@ -1,5 +1,5 @@ -from flask_sqlalchemy import BaseQuery from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.extension import BaseQuery def test_default_query_class(db): diff --git a/tests/test_signals.py b/tests/test_signals.py index 05af7033..e362439c 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,8 +1,8 @@ import flask import pytest -from flask_sqlalchemy import before_models_committed -from flask_sqlalchemy import models_committed +from flask_sqlalchemy.extension import before_models_committed +from flask_sqlalchemy.extension import models_committed pytestmark = pytest.mark.skipif( not flask.signals_available, reason="Signals require the blinker library." diff --git a/tests/test_sqlalchemy_includes.py b/tests/test_sqlalchemy_includes.py index 0e2d5c3d..da8fcdd2 100644 --- a/tests/test_sqlalchemy_includes.py +++ b/tests/test_sqlalchemy_includes.py @@ -1,7 +1,7 @@ import sqlalchemy as sa -from flask_sqlalchemy import BaseQuery from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.extension import BaseQuery def test_sqlalchemy_includes(): From 05576719e0f09834b7ecbfd6bd840660fc7784db Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:18:45 -0700 Subject: [PATCH 02/27] move track_modifications --- src/flask_sqlalchemy/extension.py | 78 +-------------------- src/flask_sqlalchemy/track_modifications.py | 78 +++++++++++++++++++++ tests/test_signals.py | 4 +- 3 files changed, 81 insertions(+), 79 deletions(-) create mode 100644 src/flask_sqlalchemy/track_modifications.py diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 0752cd3b..773e8b71 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -12,9 +12,7 @@ from flask import abort from flask import current_app from flask import request -from flask.signals import Namespace from sqlalchemy import event -from sqlalchemy import inspect from sqlalchemy import orm from sqlalchemy.engine.url import make_url from sqlalchemy.orm.exc import UnmappedClassError @@ -22,6 +20,7 @@ from .model import DefaultMeta from .model import Model +from .track_modifications import _SessionSignalEvents try: from sqlalchemy.orm import declarative_base @@ -38,10 +37,6 @@ except ImportError: from threading import get_ident as _ident_func -_signals = Namespace() -models_committed = _signals.signal("models-committed") -before_models_committed = _signals.signal("before-models-committed") - def _sa_url_set(url, **kwargs): try: @@ -192,77 +187,6 @@ def get_bind(self, mapper=None, **kwargs): return super().get_bind(mapper, **kwargs) -class _SessionSignalEvents: - @classmethod - def register(cls, session): - if not hasattr(session, "_model_changes"): - session._model_changes = {} - - event.listen(session, "before_flush", cls.record_ops) - event.listen(session, "before_commit", cls.record_ops) - event.listen(session, "before_commit", cls.before_commit) - event.listen(session, "after_commit", cls.after_commit) - event.listen(session, "after_rollback", cls.after_rollback) - - @classmethod - def unregister(cls, session): - if hasattr(session, "_model_changes"): - del session._model_changes - - event.remove(session, "before_flush", cls.record_ops) - event.remove(session, "before_commit", cls.record_ops) - event.remove(session, "before_commit", cls.before_commit) - event.remove(session, "after_commit", cls.after_commit) - event.remove(session, "after_rollback", cls.after_rollback) - - @staticmethod - def record_ops(session, flush_context=None, instances=None): - try: - d = session._model_changes - except AttributeError: - return - - for targets, operation in ( - (session.new, "insert"), - (session.dirty, "update"), - (session.deleted, "delete"), - ): - for target in targets: - state = inspect(target) - key = state.identity_key if state.has_identity else id(target) - d[key] = (target, operation) - - @staticmethod - def before_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - before_models_committed.send(session.app, changes=list(d.values())) - - @staticmethod - def after_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - models_committed.send(session.app, changes=list(d.values())) - d.clear() - - @staticmethod - def after_rollback(session): - try: - d = session._model_changes - except AttributeError: - return - - d.clear() - - class _EngineDebuggingSignalEvents: """Sets up handlers for two events that let us track the execution time of queries.""" diff --git a/src/flask_sqlalchemy/track_modifications.py b/src/flask_sqlalchemy/track_modifications.py new file mode 100644 index 00000000..006cd493 --- /dev/null +++ b/src/flask_sqlalchemy/track_modifications.py @@ -0,0 +1,78 @@ +from flask.signals import Namespace +from sqlalchemy import event +from sqlalchemy import inspect + +_signals = Namespace() +models_committed = _signals.signal("models-committed") +before_models_committed = _signals.signal("before-models-committed") + + +class _SessionSignalEvents: + @classmethod + def register(cls, session): + if not hasattr(session, "_model_changes"): + session._model_changes = {} + + event.listen(session, "before_flush", cls.record_ops) + event.listen(session, "before_commit", cls.record_ops) + event.listen(session, "before_commit", cls.before_commit) + event.listen(session, "after_commit", cls.after_commit) + event.listen(session, "after_rollback", cls.after_rollback) + + @classmethod + def unregister(cls, session): + if hasattr(session, "_model_changes"): + del session._model_changes + + event.remove(session, "before_flush", cls.record_ops) + event.remove(session, "before_commit", cls.record_ops) + event.remove(session, "before_commit", cls.before_commit) + event.remove(session, "after_commit", cls.after_commit) + event.remove(session, "after_rollback", cls.after_rollback) + + @staticmethod + def record_ops(session, flush_context=None, instances=None): + try: + d = session._model_changes + except AttributeError: + return + + for targets, operation in ( + (session.new, "insert"), + (session.dirty, "update"), + (session.deleted, "delete"), + ): + for target in targets: + state = inspect(target) + key = state.identity_key if state.has_identity else id(target) + d[key] = (target, operation) + + @staticmethod + def before_commit(session): + try: + d = session._model_changes + except AttributeError: + return + + if d: + before_models_committed.send(session.app, changes=list(d.values())) + + @staticmethod + def after_commit(session): + try: + d = session._model_changes + except AttributeError: + return + + if d: + models_committed.send(session.app, changes=list(d.values())) + d.clear() + + @staticmethod + def after_rollback(session): + try: + d = session._model_changes + except AttributeError: + return + + d.clear() diff --git a/tests/test_signals.py b/tests/test_signals.py index e362439c..bdd02c43 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,8 +1,8 @@ import flask import pytest -from flask_sqlalchemy.extension import before_models_committed -from flask_sqlalchemy.extension import models_committed +from flask_sqlalchemy.track_modifications import before_models_committed +from flask_sqlalchemy.track_modifications import models_committed pytestmark = pytest.mark.skipif( not flask.signals_available, reason="Signals require the blinker library." From 35a1f579ef2f208db9f8782afc7433a66170a985 Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:21:18 -0700 Subject: [PATCH 03/27] move record_queries --- src/flask_sqlalchemy/extension.py | 108 +----------------------- src/flask_sqlalchemy/record_queries.py | 110 +++++++++++++++++++++++++ tests/test_basic_app.py | 2 +- 3 files changed, 112 insertions(+), 108 deletions(-) create mode 100644 src/flask_sqlalchemy/record_queries.py diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 773e8b71..12c845b1 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -1,14 +1,10 @@ import functools import os -import sys import warnings from math import ceil -from operator import itemgetter from threading import Lock -from time import perf_counter import sqlalchemy -from flask import _app_ctx_stack from flask import abort from flask import current_app from flask import request @@ -20,6 +16,7 @@ from .model import DefaultMeta from .model import Model +from .record_queries import _EngineDebuggingSignalEvents from .track_modifications import _SessionSignalEvents try: @@ -102,35 +99,6 @@ def _include_sqlalchemy(obj, cls): obj.event = event -class _DebugQueryTuple(tuple): - statement = property(itemgetter(0)) - parameters = property(itemgetter(1)) - start_time = property(itemgetter(2)) - end_time = property(itemgetter(3)) - context = property(itemgetter(4)) - - @property - def duration(self): - return self.end_time - self.start_time - - def __repr__(self): - return ( - f"" - ) - - -def _calling_context(app_path): - frm = sys._getframe(1) - while frm.f_back is not None: - name = frm.f_globals.get("__name__") - if name and (name == app_path or name.startswith(f"{app_path}.")): - funcname = frm.f_code.co_name - return f"{frm.f_code.co_filename}:{frm.f_lineno} ({funcname})" - frm = frm.f_back - return "" - - class SignallingSession(SessionBase): """The signalling session is the default session that Flask-SQLAlchemy uses. It extends the default session system with bind selection and @@ -187,80 +155,6 @@ def get_bind(self, mapper=None, **kwargs): return super().get_bind(mapper, **kwargs) -class _EngineDebuggingSignalEvents: - """Sets up handlers for two events that let us track the execution time of - queries.""" - - def __init__(self, engine, import_name): - self.engine = engine - self.app_package = import_name - - def register(self): - event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute) - event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute) - - def before_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - context._query_start_time = perf_counter() - - def after_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - try: - queries = _app_ctx_stack.top.sqlalchemy_queries - except AttributeError: - queries = _app_ctx_stack.top.sqlalchemy_queries = [] - - queries.append( - _DebugQueryTuple( - ( - statement, - parameters, - context._query_start_time, - perf_counter(), - _calling_context(self.app_package), - ) - ) - ) - - -def get_debug_queries(): - """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL - queries sent to the database. This information is available until the end - of request which makes it possible to easily ensure that the SQL generated - is the one expected on errors or in unittesting. Alternatively, you can also - enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` - config variable to `True`. - - The value returned will be a list of named tuples with the following - attributes: - - `statement` - The SQL statement issued - - `parameters` - The parameters for the SQL statement - - `start_time` / `end_time` - Time the query started / the results arrived. Please keep in mind - that the timer function used depends on your platform. These - values are only useful for sorting or comparing. They do not - necessarily represent an absolute timestamp. - - `duration` - Time the query took in seconds - - `context` - A string giving a rough estimation of where in your application - query was issued. The exact format is undefined so don't try - to reconstruct filename or function name. - """ - return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) - - class Pagination: """Internal helper class returned by :meth:`BaseQuery.paginate`. You can also construct it from any other SQLAlchemy query object if you are diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py new file mode 100644 index 00000000..a7b490cc --- /dev/null +++ b/src/flask_sqlalchemy/record_queries.py @@ -0,0 +1,110 @@ +import sys +from operator import itemgetter +from time import perf_counter + +from flask import _app_ctx_stack +from flask import current_app +from sqlalchemy import event + + +class _DebugQueryTuple(tuple): + statement = property(itemgetter(0)) + parameters = property(itemgetter(1)) + start_time = property(itemgetter(2)) + end_time = property(itemgetter(3)) + context = property(itemgetter(4)) + + @property + def duration(self): + return self.end_time - self.start_time + + def __repr__(self): + return ( + f"" + ) + + +def _calling_context(app_path): + frm = sys._getframe(1) + while frm.f_back is not None: + name = frm.f_globals.get("__name__") + if name and (name == app_path or name.startswith(f"{app_path}.")): + funcname = frm.f_code.co_name + return f"{frm.f_code.co_filename}:{frm.f_lineno} ({funcname})" + frm = frm.f_back + return "" + + +class _EngineDebuggingSignalEvents: + """Sets up handlers for two events that let us track the execution time of + queries.""" + + def __init__(self, engine, import_name): + self.engine = engine + self.app_package = import_name + + def register(self): + event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute) + event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute) + + def before_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): + if current_app: + context._query_start_time = perf_counter() + + def after_cursor_execute( + self, conn, cursor, statement, parameters, context, executemany + ): + if current_app: + try: + queries = _app_ctx_stack.top.sqlalchemy_queries + except AttributeError: + queries = _app_ctx_stack.top.sqlalchemy_queries = [] + + queries.append( + _DebugQueryTuple( + ( + statement, + parameters, + context._query_start_time, + perf_counter(), + _calling_context(self.app_package), + ) + ) + ) + + +def get_debug_queries(): + """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL + queries sent to the database. This information is available until the end + of request which makes it possible to easily ensure that the SQL generated + is the one expected on errors or in unittesting. Alternatively, you can also + enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` + config variable to `True`. + + The value returned will be a list of named tuples with the following + attributes: + + `statement` + The SQL statement issued + + `parameters` + The parameters for the SQL statement + + `start_time` / `end_time` + Time the query started / the results arrived. Please keep in mind + that the timer function used depends on your platform. These + values are only useful for sorting or comparing. They do not + necessarily represent an absolute timestamp. + + `duration` + Time the query took in seconds + + `context` + A string giving a rough estimation of where in your application + query was issued. The exact format is undefined so don't try + to reconstruct filename or function name. + """ + return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index c2c53855..b69d7e5d 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -1,7 +1,7 @@ import flask from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.extension import get_debug_queries +from flask_sqlalchemy.record_queries import get_debug_queries def test_basic_insert(app, db, Todo): From 7dc6a6464f8dd514abed844b28a1cb2c9887f5e2 Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:22:09 -0700 Subject: [PATCH 04/27] move pagination --- src/flask_sqlalchemy/extension.py | 110 +---------------------------- src/flask_sqlalchemy/pagination.py | 109 ++++++++++++++++++++++++++++ tests/test_pagination.py | 2 +- 3 files changed, 111 insertions(+), 110 deletions(-) create mode 100644 src/flask_sqlalchemy/pagination.py diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 12c845b1..c24f2aa9 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -1,7 +1,6 @@ import functools import os import warnings -from math import ceil from threading import Lock import sqlalchemy @@ -16,6 +15,7 @@ from .model import DefaultMeta from .model import Model +from .pagination import Pagination from .record_queries import _EngineDebuggingSignalEvents from .track_modifications import _SessionSignalEvents @@ -155,114 +155,6 @@ def get_bind(self, mapper=None, **kwargs): return super().get_bind(mapper, **kwargs) -class Pagination: - """Internal helper class returned by :meth:`BaseQuery.paginate`. You - can also construct it from any other SQLAlchemy query object if you are - working with other libraries. Additionally it is possible to pass `None` - as query object in which case the :meth:`prev` and :meth:`next` will - no longer work. - """ - - def __init__(self, query, page, per_page, total, items): - #: the unlimited query object that was used to create this - #: pagination object. - self.query = query - #: the current page number (1 indexed) - self.page = page - #: the number of items to be displayed on a page. - self.per_page = per_page - #: the total number of items matching the query - self.total = total - #: the items for the current page - self.items = items - - @property - def pages(self): - """The total number of pages""" - if self.per_page == 0 or self.total is None: - pages = 0 - else: - pages = int(ceil(self.total / float(self.per_page))) - return pages - - def prev(self, error_out=False): - """Returns a :class:`Pagination` object for the previous page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page - 1, self.per_page, error_out) - - @property - def prev_num(self): - """Number of the previous page.""" - if not self.has_prev: - return None - return self.page - 1 - - @property - def has_prev(self): - """True if a previous page exists""" - return self.page > 1 - - def next(self, error_out=False): - """Returns a :class:`Pagination` object for the next page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page + 1, self.per_page, error_out) - - @property - def has_next(self): - """True if a next page exists.""" - return self.page < self.pages - - @property - def next_num(self): - """Number of the next page""" - if not self.has_next: - return None - return self.page + 1 - - def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2): - """Iterates over the page numbers in the pagination. The four - parameters control the thresholds how many numbers should be produced - from the sides. Skipped page numbers are represented as `None`. - This is how you could render such a pagination in the templates: - - .. sourcecode:: html+jinja - - {% macro render_pagination(pagination, endpoint) %} - - {% endmacro %} - """ - last = 0 - for num in range(1, self.pages + 1): - if ( - num <= left_edge - or ( - num > self.page - left_current - 1 - and num < self.page + right_current - ) - or num > self.pages - right_edge - ): - if last + 1 != num: - yield None - yield num - last = num - - class BaseQuery(orm.Query): """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with convenience methods for querying in a web application. diff --git a/src/flask_sqlalchemy/pagination.py b/src/flask_sqlalchemy/pagination.py new file mode 100644 index 00000000..03b321a8 --- /dev/null +++ b/src/flask_sqlalchemy/pagination.py @@ -0,0 +1,109 @@ +from math import ceil + + +class Pagination: + """Internal helper class returned by :meth:`BaseQuery.paginate`. You + can also construct it from any other SQLAlchemy query object if you are + working with other libraries. Additionally it is possible to pass `None` + as query object in which case the :meth:`prev` and :meth:`next` will + no longer work. + """ + + def __init__(self, query, page, per_page, total, items): + #: the unlimited query object that was used to create this + #: pagination object. + self.query = query + #: the current page number (1 indexed) + self.page = page + #: the number of items to be displayed on a page. + self.per_page = per_page + #: the total number of items matching the query + self.total = total + #: the items for the current page + self.items = items + + @property + def pages(self): + """The total number of pages""" + if self.per_page == 0 or self.total is None: + pages = 0 + else: + pages = int(ceil(self.total / float(self.per_page))) + return pages + + def prev(self, error_out=False): + """Returns a :class:`Pagination` object for the previous page.""" + assert ( + self.query is not None + ), "a query object is required for this method to work" + return self.query.paginate(self.page - 1, self.per_page, error_out) + + @property + def prev_num(self): + """Number of the previous page.""" + if not self.has_prev: + return None + return self.page - 1 + + @property + def has_prev(self): + """True if a previous page exists""" + return self.page > 1 + + def next(self, error_out=False): + """Returns a :class:`Pagination` object for the next page.""" + assert ( + self.query is not None + ), "a query object is required for this method to work" + return self.query.paginate(self.page + 1, self.per_page, error_out) + + @property + def has_next(self): + """True if a next page exists.""" + return self.page < self.pages + + @property + def next_num(self): + """Number of the next page""" + if not self.has_next: + return None + return self.page + 1 + + def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2): + """Iterates over the page numbers in the pagination. The four + parameters control the thresholds how many numbers should be produced + from the sides. Skipped page numbers are represented as `None`. + This is how you could render such a pagination in the templates: + + .. sourcecode:: html+jinja + + {% macro render_pagination(pagination, endpoint) %} + + {% endmacro %} + """ + last = 0 + for num in range(1, self.pages + 1): + if ( + num <= left_edge + or ( + num > self.page - left_current - 1 + and num < self.page + right_current + ) + or num > self.pages - right_edge + ): + if last + 1 != num: + yield None + yield num + last = num diff --git a/tests/test_pagination.py b/tests/test_pagination.py index cbe1b308..985cd777 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,7 +1,7 @@ import pytest from werkzeug.exceptions import NotFound -from flask_sqlalchemy.extension import Pagination +from flask_sqlalchemy.pagination import Pagination def test_basic_pagination(): From 9d644c0f653ef74df4d0fd7bda0b849c26995f4d Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:23:44 -0700 Subject: [PATCH 05/27] move query --- src/flask_sqlalchemy/extension.py | 109 +---------------------------- src/flask_sqlalchemy/query.py | 110 ++++++++++++++++++++++++++++++ tests/test_query_class.py | 2 +- tests/test_sqlalchemy_includes.py | 2 +- 4 files changed, 113 insertions(+), 110 deletions(-) create mode 100644 src/flask_sqlalchemy/query.py diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index c24f2aa9..b0fa116f 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -4,9 +4,7 @@ from threading import Lock import sqlalchemy -from flask import abort from flask import current_app -from flask import request from sqlalchemy import event from sqlalchemy import orm from sqlalchemy.engine.url import make_url @@ -15,7 +13,7 @@ from .model import DefaultMeta from .model import Model -from .pagination import Pagination +from .query import BaseQuery from .record_queries import _EngineDebuggingSignalEvents from .track_modifications import _SessionSignalEvents @@ -155,111 +153,6 @@ def get_bind(self, mapper=None, **kwargs): return super().get_bind(mapper, **kwargs) -class BaseQuery(orm.Query): - """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with - convenience methods for querying in a web application. - - This is the default :attr:`~Model.query` object used for models, and - exposed as :attr:`~SQLAlchemy.Query`. Override the query class for - an individual model by subclassing this and setting - :attr:`~Model.query_class`. - """ - - def get_or_404(self, ident, description=None): - """Like :meth:`get` but aborts with 404 if not found instead of - returning ``None``. - """ - rv = self.get(ident) - if rv is None: - abort(404, description=description) - return rv - - def first_or_404(self, description=None): - """Like :meth:`first` but aborts with 404 if not found instead - of returning ``None``. - """ - rv = self.first() - if rv is None: - abort(404, description=description) - return rv - - def paginate( - self, page=None, per_page=None, error_out=True, max_per_page=None, count=True - ): - """Returns ``per_page`` items from page ``page``. - - If ``page`` or ``per_page`` are ``None``, they will be retrieved from - the request query. If ``max_per_page`` is specified, ``per_page`` will - be limited to that value. If there is no request or they aren't in the - query, they default to 1 and 20 respectively. If ``count`` is ``False``, - no query to help determine total page count will be run. - - When ``error_out`` is ``True`` (default), the following rules will - cause a 404 response: - - * No items are found and ``page`` is not 1. - * ``page`` is less than 1, or ``per_page`` is negative. - * ``page`` or ``per_page`` are not ints. - - When ``error_out`` is ``False``, ``page`` and ``per_page`` default to - 1 and 20 respectively. - - Returns a :class:`Pagination` object. - """ - - if request: - if page is None: - try: - page = int(request.args.get("page", 1)) - except (TypeError, ValueError): - if error_out: - abort(404) - - page = 1 - - if per_page is None: - try: - per_page = int(request.args.get("per_page", 20)) - except (TypeError, ValueError): - if error_out: - abort(404) - - per_page = 20 - else: - if page is None: - page = 1 - - if per_page is None: - per_page = 20 - - if max_per_page is not None: - per_page = min(per_page, max_per_page) - - if page < 1: - if error_out: - abort(404) - else: - page = 1 - - if per_page < 0: - if error_out: - abort(404) - else: - per_page = 20 - - items = self.limit(per_page).offset((page - 1) * per_page).all() - - if not items and page != 1 and error_out: - abort(404) - - if not count: - total = None - else: - total = self.order_by(None).count() - - return Pagination(self, page, per_page, total, items) - - class _QueryProperty: def __init__(self, sa): self.sa = sa diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py new file mode 100644 index 00000000..4cac8744 --- /dev/null +++ b/src/flask_sqlalchemy/query.py @@ -0,0 +1,110 @@ +from flask import abort +from flask import request +from sqlalchemy import orm + +from .pagination import Pagination + + +class BaseQuery(orm.Query): + """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with + convenience methods for querying in a web application. + + This is the default :attr:`~Model.query` object used for models, and + exposed as :attr:`~SQLAlchemy.Query`. Override the query class for + an individual model by subclassing this and setting + :attr:`~Model.query_class`. + """ + + def get_or_404(self, ident, description=None): + """Like :meth:`get` but aborts with 404 if not found instead of + returning ``None``. + """ + rv = self.get(ident) + if rv is None: + abort(404, description=description) + return rv + + def first_or_404(self, description=None): + """Like :meth:`first` but aborts with 404 if not found instead + of returning ``None``. + """ + rv = self.first() + if rv is None: + abort(404, description=description) + return rv + + def paginate( + self, page=None, per_page=None, error_out=True, max_per_page=None, count=True + ): + """Returns ``per_page`` items from page ``page``. + + If ``page`` or ``per_page`` are ``None``, they will be retrieved from + the request query. If ``max_per_page`` is specified, ``per_page`` will + be limited to that value. If there is no request or they aren't in the + query, they default to 1 and 20 respectively. If ``count`` is ``False``, + no query to help determine total page count will be run. + + When ``error_out`` is ``True`` (default), the following rules will + cause a 404 response: + + * No items are found and ``page`` is not 1. + * ``page`` is less than 1, or ``per_page`` is negative. + * ``page`` or ``per_page`` are not ints. + + When ``error_out`` is ``False``, ``page`` and ``per_page`` default to + 1 and 20 respectively. + + Returns a :class:`Pagination` object. + """ + + if request: + if page is None: + try: + page = int(request.args.get("page", 1)) + except (TypeError, ValueError): + if error_out: + abort(404) + + page = 1 + + if per_page is None: + try: + per_page = int(request.args.get("per_page", 20)) + except (TypeError, ValueError): + if error_out: + abort(404) + + per_page = 20 + else: + if page is None: + page = 1 + + if per_page is None: + per_page = 20 + + if max_per_page is not None: + per_page = min(per_page, max_per_page) + + if page < 1: + if error_out: + abort(404) + else: + page = 1 + + if per_page < 0: + if error_out: + abort(404) + else: + per_page = 20 + + items = self.limit(per_page).offset((page - 1) * per_page).all() + + if not items and page != 1 and error_out: + abort(404) + + if not count: + total = None + else: + total = self.order_by(None).count() + + return Pagination(self, page, per_page, total, items) diff --git a/tests/test_query_class.py b/tests/test_query_class.py index d13916d6..df11b181 100644 --- a/tests/test_query_class.py +++ b/tests/test_query_class.py @@ -1,5 +1,5 @@ from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.extension import BaseQuery +from flask_sqlalchemy.query import BaseQuery def test_default_query_class(db): diff --git a/tests/test_sqlalchemy_includes.py b/tests/test_sqlalchemy_includes.py index da8fcdd2..b805054f 100644 --- a/tests/test_sqlalchemy_includes.py +++ b/tests/test_sqlalchemy_includes.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.extension import BaseQuery +from flask_sqlalchemy.query import BaseQuery def test_sqlalchemy_includes(): From 402745bf60e084c9dcb04104d4a12f4963142682 Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:27:02 -0700 Subject: [PATCH 06/27] move session --- src/flask_sqlalchemy/extension.py | 59 +----------------------------- src/flask_sqlalchemy/session.py | 61 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 58 deletions(-) create mode 100644 src/flask_sqlalchemy/session.py diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index b0fa116f..cf1e3519 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -9,13 +9,12 @@ from sqlalchemy import orm from sqlalchemy.engine.url import make_url from sqlalchemy.orm.exc import UnmappedClassError -from sqlalchemy.orm.session import Session as SessionBase from .model import DefaultMeta from .model import Model from .query import BaseQuery from .record_queries import _EngineDebuggingSignalEvents -from .track_modifications import _SessionSignalEvents +from .session import SignallingSession try: from sqlalchemy.orm import declarative_base @@ -97,62 +96,6 @@ def _include_sqlalchemy(obj, cls): obj.event = event -class SignallingSession(SessionBase): - """The signalling session is the default session that Flask-SQLAlchemy - uses. It extends the default session system with bind selection and - modification tracking. - - If you want to use a different session you can override the - :meth:`SQLAlchemy.create_session` function. - - .. versionadded:: 2.0 - - .. versionadded:: 2.1 - The `binds` option was added, which allows a session to be joined - to an external transaction. - """ - - def __init__(self, db, autocommit=False, autoflush=True, **options): - #: The application that this session belongs to. - self.app = app = db.get_app() - track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] - bind = options.pop("bind", None) or db.engine - binds = options.pop("binds", db.get_binds(app)) - - if track_modifications: - _SessionSignalEvents.register(self) - - SessionBase.__init__( - self, - autocommit=autocommit, - autoflush=autoflush, - bind=bind, - binds=binds, - **options, - ) - - def get_bind(self, mapper=None, **kwargs): - """Return the engine or connection for a given model or - table, using the ``__bind_key__`` if it is set. - """ - # mapper is None if someone tries to just get a connection - if mapper is not None: - try: - # SA >= 1.3 - persist_selectable = mapper.persist_selectable - except AttributeError: - # SA < 1.3 - persist_selectable = mapper.mapped_table - - info = getattr(persist_selectable, "info", {}) - bind_key = info.get("bind_key") - if bind_key is not None: - state = get_state(self.app) - return state.db.get_engine(self.app, bind=bind_key) - - return super().get_bind(mapper, **kwargs) - - class _QueryProperty: def __init__(self, sa): self.sa = sa diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py new file mode 100644 index 00000000..ffa2d06f --- /dev/null +++ b/src/flask_sqlalchemy/session.py @@ -0,0 +1,61 @@ +from sqlalchemy.orm import Session as SessionBase + +from .track_modifications import _SessionSignalEvents + + +class SignallingSession(SessionBase): + """The signalling session is the default session that Flask-SQLAlchemy + uses. It extends the default session system with bind selection and + modification tracking. + + If you want to use a different session you can override the + :meth:`SQLAlchemy.create_session` function. + + .. versionadded:: 2.0 + + .. versionadded:: 2.1 + The `binds` option was added, which allows a session to be joined + to an external transaction. + """ + + def __init__(self, db, autocommit=False, autoflush=True, **options): + #: The application that this session belongs to. + self.app = app = db.get_app() + track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] + bind = options.pop("bind", None) or db.engine + binds = options.pop("binds", db.get_binds(app)) + + if track_modifications: + _SessionSignalEvents.register(self) + + SessionBase.__init__( + self, + autocommit=autocommit, + autoflush=autoflush, + bind=bind, + binds=binds, + **options, + ) + + def get_bind(self, mapper=None, **kwargs): + """Return the engine or connection for a given model or + table, using the ``__bind_key__`` if it is set. + """ + # mapper is None if someone tries to just get a connection + if mapper is not None: + try: + # SA >= 1.3 + persist_selectable = mapper.persist_selectable + except AttributeError: + # SA < 1.3 + persist_selectable = mapper.mapped_table + + info = getattr(persist_selectable, "info", {}) + bind_key = info.get("bind_key") + if bind_key is not None: + from .extension import get_state + + state = get_state(self.app) + return state.db.get_engine(self.app, bind=bind_key) + + return super().get_bind(mapper, **kwargs) From 00d32639b1d7114f895a38750b3444c4a6a1a7c9 Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 29 Aug 2022 10:36:50 -0700 Subject: [PATCH 07/27] move model --- src/flask_sqlalchemy/extension.py | 15 +-------------- src/flask_sqlalchemy/model.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index cf1e3519..fcc764dc 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -8,8 +8,8 @@ from sqlalchemy import event from sqlalchemy import orm from sqlalchemy.engine.url import make_url -from sqlalchemy.orm.exc import UnmappedClassError +from .model import _QueryProperty from .model import DefaultMeta from .model import Model from .query import BaseQuery @@ -96,19 +96,6 @@ def _include_sqlalchemy(obj, cls): obj.event = event -class _QueryProperty: - def __init__(self, sa): - self.sa = sa - - def __get__(self, obj, type): - try: - mapper = orm.class_mapper(type) - if mapper: - return type.query_class(mapper, session=self.sa.session()) - except UnmappedClassError: - return None - - def _record_queries(app): if app.debug: return True diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index aac077cf..3f578b08 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -2,8 +2,10 @@ import sqlalchemy as sa from sqlalchemy import inspect +from sqlalchemy import orm from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.schema import _get_table_key @@ -138,3 +140,16 @@ def __repr__(self): pk = ", ".join(str(value) for value in identity) return f"<{type(self).__name__} {pk}>" + + +class _QueryProperty: + def __init__(self, sa): + self.sa = sa + + def __get__(self, obj, type): + try: + mapper = orm.class_mapper(type) + if mapper: + return type.query_class(mapper, session=self.sa.session()) + except UnmappedClassError: + return None From d2db52cb9e251bdab88b23a188cc574ce784babe Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 07:02:05 -0700 Subject: [PATCH 08/27] refactor pagination --- CHANGES.rst | 7 + src/flask_sqlalchemy/pagination.py | 329 +++++++++++++++++++++++------ src/flask_sqlalchemy/query.py | 114 ++++------ tests/test_pagination.py | 6 +- 4 files changed, 309 insertions(+), 147 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 5b01d79d..fab6a87b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -23,6 +23,13 @@ Unreleased design issues that are difficult to debug. Call ``db.session.commit()`` directly instead. :issue:`216` - Change the default MySQL character set to "utf8mb4". :issue:`875` +- ``Pagination``, ``Pagination.iter_pages``, and ``Query.paginate`` parameters are + keyword-only. +- ``Pagination`` is iterable, iterating over its items. +- ``Pagination.apply_to_query`` can be used instead of ``query.paginate``. +- ``Query.paginate`` ``count`` is more efficient. +- ``Pagination.iter_pages`` is more efficient. +- ``Pagination.iter_pages`` ``right_current`` parameter is inclusive. Version 2.5.1 diff --git a/src/flask_sqlalchemy/pagination.py b/src/flask_sqlalchemy/pagination.py index 03b321a8..1c51ab7d 100644 --- a/src/flask_sqlalchemy/pagination.py +++ b/src/flask_sqlalchemy/pagination.py @@ -1,109 +1,298 @@ +from __future__ import annotations + +import typing as t from math import ceil +import sqlalchemy as sa +import sqlalchemy.orm +from flask import abort +from flask import request + class Pagination: - """Internal helper class returned by :meth:`BaseQuery.paginate`. You - can also construct it from any other SQLAlchemy query object if you are - working with other libraries. Additionally it is possible to pass `None` - as query object in which case the :meth:`prev` and :meth:`next` will - no longer work. + """Returned by :meth:`.Query.paginate`, this describes the current page of data. + + :param query: The original query that was paginated. + :param page: The current page. + :param per_page: The maximum number of items on a page. + :param total: The total number of items across all pages. + :param items: The items on the current page. + + .. versionchanged:: 3.0 + All parameters are keyword-only. + + .. versionchanged:: 3.0 + Iterating over a pagination object iterates over its items. """ - def __init__(self, query, page, per_page, total, items): - #: the unlimited query object that was used to create this - #: pagination object. + def __init__( + self, + *, + query: sa.orm.Query[t.Any] | None, + page: int, + per_page: int, + total: int | None, + items: list[t.Any], + ) -> None: self.query = query - #: the current page number (1 indexed) + """The original query that was paginated.""" + self.page = page - #: the number of items to be displayed on a page. + """The current page.""" + self.per_page = per_page - #: the total number of items matching the query + """The maximum number of items on a page.""" + self.total = total - #: the items for the current page + """The total number of items across all pages.""" + self.items = items + """The items on the current page. Iterating over the pagination object is + equivalent to iterating over the items. + """ + + @staticmethod + def _prepare_args( + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + ) -> tuple[int, int]: + if request: + if page is None: + try: + page = int(request.args.get("page", 1)) + except (TypeError, ValueError): + if error_out: + abort(404) + + page = 1 + + if per_page is None: + try: + per_page = int(request.args.get("per_page", 20)) + except (TypeError, ValueError): + if error_out: + abort(404) + + per_page = 20 + else: + if page is None: + page = 1 + + if per_page is None: + per_page = 20 + + if max_per_page is not None: + per_page = min(per_page, max_per_page) + + if page < 1: + if error_out: + abort(404) + else: + page = 1 + + if per_page < 0: + if error_out: + abort(404) + else: + per_page = 20 + + return page, per_page + + @classmethod + def apply_to_query( + cls, + query: sa.orm.Query[t.Any], + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + count: bool = True, + ) -> Pagination: + """Apply an offset and limit to the query based on the current page and number + of items per page, returning a :class:`Pagination` object. This is called by + :meth:`.Query.paginate`, or can be called manually. + + :param query: The query to paginate. + :param page: The current page, used to calculate the offset. Defaults to the + ``page`` query arg during a request, or 1 otherwise. + :param per_page: The maximum number of items on a page, used to calculate the + offset and limit. Defaults to the ``per_page`` query arg during a request, + or 20 otherwise. + :param max_per_page: The maximum allowed value for ``per_page``, to limit a + user-provided value. + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` is less than 1 or ``per_page`` is + negative, or if either are not ints. + :param count: Calculate the total number of values by issuing an extra count + query. For very complex queries this may be inaccurate or slow, so it can be + disabled and set manually if necessary. + + .. versionadded:: 3.0 + + .. versionchanged:: 3.0 + The ``count`` query is more efficient. + """ + page, per_page = cls._prepare_args( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + ) + items = query.limit(per_page).offset((page - 1) * per_page).all() + + if not items and page != 1 and error_out: + abort(404) + + if count: + total = query.options(sa.orm.lazyload("*")).order_by(None).count() + # Using `.with_entities([sa.func.count()]).scalar()` is an alternative, but + # is not guaranteed to be correct for many possible queries. If custom + # counting is needed, it can be disabled here and set manually after. + else: + total = None + + return cls( + query=query, + page=page, + per_page=per_page, + total=total, + items=items, + ) + + # TODO: apply_to_select, requires access to session @property - def pages(self): - """The total number of pages""" + def pages(self) -> int: + """The total number of pages.""" if self.per_page == 0 or self.total is None: - pages = 0 - else: - pages = int(ceil(self.total / float(self.per_page))) - return pages + return 0 + + return ceil(self.total / self.per_page) - def prev(self, error_out=False): - """Returns a :class:`Pagination` object for the previous page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page - 1, self.per_page, error_out) + @property + def has_prev(self) -> bool: + """``True`` if this is not the first page.""" + return self.page > 1 @property - def prev_num(self): - """Number of the previous page.""" + def prev_num(self) -> int | None: + """The previous page number, or ``None`` if this is the first page.""" if not self.has_prev: return None - return self.page - 1 - @property - def has_prev(self): - """True if a previous page exists""" - return self.page > 1 + return self.page - 1 - def next(self, error_out=False): - """Returns a :class:`Pagination` object for the next page.""" - assert ( - self.query is not None - ), "a query object is required for this method to work" - return self.query.paginate(self.page + 1, self.per_page, error_out) + def prev(self, error_out: bool = False) -> Pagination: + """Query the :class:`Pagination` object for the previous page.""" + assert self.query is not None + return self.apply_to_query( + self.query, page=self.page - 1, per_page=self.per_page, error_out=error_out + ) @property - def has_next(self): - """True if a next page exists.""" + def has_next(self) -> bool: + """``True`` if this is not the last page.""" return self.page < self.pages @property - def next_num(self): - """Number of the next page""" + def next_num(self) -> int | None: + """The next page number, or ``None`` if this is the last page.""" if not self.has_next: return None + return self.page + 1 - def iter_pages(self, left_edge=2, left_current=2, right_current=5, right_edge=2): - """Iterates over the page numbers in the pagination. The four - parameters control the thresholds how many numbers should be produced - from the sides. Skipped page numbers are represented as `None`. - This is how you could render such a pagination in the templates: + def next(self, error_out: bool = False) -> Pagination: + """Query the :class:`Pagination` object for the next page.""" + assert self.query is not None + return self.apply_to_query( + self.query, page=self.page + 1, per_page=self.per_page, error_out=error_out + ) + + def iter_pages( + self, + *, + left_edge: int = 2, + left_current: int = 2, + right_current: int = 4, + right_edge: int = 2, + ) -> t.Iterator[int | None]: + """Yield page numbers for a pagination widget. Skipped pages between the edges + and middle are represented by a ``None``. + + For example, if there are 20 pages and the current page is 7, the following + values are yielded. + + .. code-block:: python - .. sourcecode:: html+jinja + 1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20 + + The following Jinja macro renders a simple pagination widget. + + .. code-block:: jinja {% macro render_pagination(pagination, endpoint) %} {% endmacro %} + + :param left_edge: How many pages to show from the first page. + :param left_current: How many pages to show left of the current page. + :param right_current: How many pages to show right of the current page. + :param right_edge: How many pages to show from the last page. + + .. versionchanged:: 3.0 + Improved efficiency of calculating what to yield. + + .. versionchanged:: 3.0 + ``right_current`` boundary is inclusive. + + .. versionchanged:: 3.0 + All parameters are keyword-only. """ - last = 0 - for num in range(1, self.pages + 1): - if ( - num <= left_edge - or ( - num > self.page - left_current - 1 - and num < self.page + right_current - ) - or num > self.pages - right_edge - ): - if last + 1 != num: - yield None - yield num - last = num + pages_end = self.pages + 1 + + if pages_end == 1: + return + + left_end = min(1 + left_edge, pages_end) + yield from range(1, left_end) + + if left_end == pages_end: + return + + mid_start = max(left_end, self.page - left_current) + mid_end = min(self.page + right_current + 1, pages_end) + + if mid_start - left_end > 0: + yield None + + yield from range(mid_start, mid_end) + + if mid_end == pages_end: + return + + right_start = max(mid_end, pages_end - right_edge) + + if right_start - mid_end > 0: + yield None + + yield from range(right_start, pages_end) + + def __iter__(self) -> t.Iterator[t.Any]: + yield from self.items diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py index 4cac8744..6ba1ba58 100644 --- a/src/flask_sqlalchemy/query.py +++ b/src/flask_sqlalchemy/query.py @@ -1,5 +1,4 @@ from flask import abort -from flask import request from sqlalchemy import orm from .pagination import Pagination @@ -34,77 +33,44 @@ def first_or_404(self, description=None): return rv def paginate( - self, page=None, per_page=None, error_out=True, max_per_page=None, count=True - ): - """Returns ``per_page`` items from page ``page``. - - If ``page`` or ``per_page`` are ``None``, they will be retrieved from - the request query. If ``max_per_page`` is specified, ``per_page`` will - be limited to that value. If there is no request or they aren't in the - query, they default to 1 and 20 respectively. If ``count`` is ``False``, - no query to help determine total page count will be run. - - When ``error_out`` is ``True`` (default), the following rules will - cause a 404 response: - - * No items are found and ``page`` is not 1. - * ``page`` is less than 1, or ``per_page`` is negative. - * ``page`` or ``per_page`` are not ints. - - When ``error_out`` is ``False``, ``page`` and ``per_page`` default to - 1 and 20 respectively. - - Returns a :class:`Pagination` object. + self, + *, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + count: bool = True, + ) -> Pagination: + """Apply an offset and limit to the query based on the current page and number + of items per page, returning a :class:`.Pagination` object. + + :param query: The query to paginate. + :param page: The current page, used to calculate the offset. Defaults to the + ``page`` query arg during a request, or 1 otherwise. + :param per_page: The maximum number of items on a page, used to calculate the + offset and limit. Defaults to the ``per_page`` query arg during a request, + or 20 otherwise. + :param max_per_page: The maximum allowed value for ``per_page``, to limit a + user-provided value. + :param error_out: Abort with a ``404 Not Found`` error if no items are returned + and ``page`` is not 1, or if ``page`` is less than 1 or ``per_page`` is + negative, or if either are not ints. If disabled, an invalid ``page`` + defaults to 1, and ``per_page`` defaults to 20. + :param count: Calculate the total number of values by issuing an extra count + query. For very complex queries this may be inaccurate or slow, so it can be + disabled and set manually if necessary. + + .. versionchanged:: 3.0 + All parameters are keyword-only. + + .. versionchanged:: 3.0 + The ``count`` query is more efficient. """ - - if request: - if page is None: - try: - page = int(request.args.get("page", 1)) - except (TypeError, ValueError): - if error_out: - abort(404) - - page = 1 - - if per_page is None: - try: - per_page = int(request.args.get("per_page", 20)) - except (TypeError, ValueError): - if error_out: - abort(404) - - per_page = 20 - else: - if page is None: - page = 1 - - if per_page is None: - per_page = 20 - - if max_per_page is not None: - per_page = min(per_page, max_per_page) - - if page < 1: - if error_out: - abort(404) - else: - page = 1 - - if per_page < 0: - if error_out: - abort(404) - else: - per_page = 20 - - items = self.limit(per_page).offset((page - 1) * per_page).all() - - if not items and page != 1 and error_out: - abort(404) - - if not count: - total = None - else: - total = self.order_by(None).count() - - return Pagination(self, page, per_page, total, items) + return Pagination.apply_to_query( + self, + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + count=count, + ) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 985cd777..1d91e580 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -5,7 +5,7 @@ def test_basic_pagination(): - p = Pagination(None, 1, 20, 500, []) + p = Pagination(query=None, page=1, per_page=20, total=500, items=[]) assert p.page == 1 assert not p.has_prev assert p.has_next @@ -18,12 +18,12 @@ def test_basic_pagination(): def test_pagination_pages_when_0_items_per_page(): - p = Pagination(None, 1, 0, 500, []) + p = Pagination(query=None, page=1, per_page=0, total=500, items=[]) assert p.pages == 0 def test_pagination_pages_when_total_is_none(): - p = Pagination(None, 1, 100, None, []) + p = Pagination(query=None, page=1, per_page=20, total=None, items=[]) assert p.pages == 0 From 2da06091cbf38d2b8751740933538b9cbd88ac44 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 07:20:55 -0700 Subject: [PATCH 09/27] refactor query Rename from BaseQuery Add one_or_404 --- CHANGES.rst | 2 + docs/api.rst | 2 +- docs/customizing.rst | 4 +- docs/quickstart.rst | 2 +- src/flask_sqlalchemy/extension.py | 13 ++---- src/flask_sqlalchemy/model.py | 2 +- src/flask_sqlalchemy/query.py | 69 +++++++++++++++++++++++++------ tests/test_query_class.py | 18 ++++---- tests/test_sqlalchemy_includes.py | 4 +- 9 files changed, 78 insertions(+), 38 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index fab6a87b..ce0b8452 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -30,6 +30,8 @@ Unreleased - ``Query.paginate`` ``count`` is more efficient. - ``Pagination.iter_pages`` is more efficient. - ``Pagination.iter_pages`` ``right_current`` parameter is inclusive. +- ``Query`` is renamed from ``BaseQuery``. +- ``Query.one_or_404`` is added. Version 2.5.1 diff --git a/docs/api.rst b/docs/api.rst index dea92574..2fd4c0bf 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -27,7 +27,7 @@ Models primary key defined. If the ``__table__`` or ``__tablename__`` is set explicitly, that will be used instead. -.. autoclass:: BaseQuery +.. autoclass:: Query :members: Sessions diff --git a/docs/customizing.rst b/docs/customizing.rst index 63e9b39d..8eb2290a 100644 --- a/docs/customizing.rst +++ b/docs/customizing.rst @@ -84,9 +84,9 @@ It is also possible to customize what is available for use on the special ``query`` property of models. For example, providing a ``get_or`` method:: - from flask_sqlalchemy import BaseQuery, SQLAlchemy + from flask_sqlalchemy import Query, SQLAlchemy - class GetOrQuery(BaseQuery): + class GetOrQuery(Query): def get_or(self, ident, default=None): return self.get(ident) or default diff --git a/docs/quickstart.rst b/docs/quickstart.rst index dc184053..acbeaffd 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -181,7 +181,7 @@ The only things you need to know compared to plain SQLAlchemy are: 2. The :class:`Model` declarative base class behaves like a regular Python class but has a ``query`` attribute attached that can be used to - query the model. (:class:`Model` and :class:`BaseQuery`) + query the model. (:class:`Model` and :class:`Query`) 3. You have to commit the session, but you don't have to remove it at the end of the request, Flask-SQLAlchemy does that for you. diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index fcc764dc..84d2f346 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -12,7 +12,7 @@ from .model import _QueryProperty from .model import DefaultMeta from .model import Model -from .query import BaseQuery +from .query import Query from .record_queries import _EngineDebuggingSignalEvents from .session import SignallingSession @@ -215,7 +215,7 @@ class User(db.Model): You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but note that Flask-SQLAlchemy customizations are available only through an instance of this :class:`SQLAlchemy` class. Query classes default to - :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default + :class:`Query` for `db.Query`, `db.Model.query_class`, and the default query_class for `db.relationship` and `db.backref`. If you use these interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, the default query class will be that of :mod:`sqlalchemy`. @@ -253,7 +253,7 @@ class User(db.Model): .. versionchanged:: 2.1 Added the ``query_class`` parameter, to allow customisation of the query class, in place of the default of - :class:`BaseQuery`. + :class:`Query`. .. versionchanged:: 2.1 Added the ``model_class`` parameter, which allows a custom model @@ -272,17 +272,12 @@ class to be used in place of :class:`Model`. Added the ``session_options`` parameter. """ - #: Default query class used by :attr:`Model.query` and other queries. - #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`. - #: Defaults to :class:`BaseQuery`. - Query = None - def __init__( self, app=None, session_options=None, metadata=None, - query_class=BaseQuery, + query_class=Query, model_class=Model, engine_options=None, ): diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 3f578b08..ff7ba3b4 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -123,7 +123,7 @@ class Model: """ #: Query class used by :attr:`query`. Defaults to - # :class:`SQLAlchemy.Query`, which defaults to :class:`BaseQuery`. + # :class:`SQLAlchemy.Query`, which defaults to :class:`Query`. query_class = None #: Convenience property to query the database for instances of this model diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py index 6ba1ba58..a9efc35f 100644 --- a/src/flask_sqlalchemy/query.py +++ b/src/flask_sqlalchemy/query.py @@ -1,37 +1,65 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.exc +import sqlalchemy.orm from flask import abort -from sqlalchemy import orm from .pagination import Pagination -class BaseQuery(orm.Query): - """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with - convenience methods for querying in a web application. +class Query(sa.orm.Query): + """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods + useful for querying in a web application. + + This is the default query class for :attr:`.Model.query`. - This is the default :attr:`~Model.query` object used for models, and - exposed as :attr:`~SQLAlchemy.Query`. Override the query class for - an individual model by subclassing this and setting - :attr:`~Model.query_class`. + .. versionchanged:: 3.0 + Renamed to ``Query`` from ``BaseQuery``. """ - def get_or_404(self, ident, description=None): - """Like :meth:`get` but aborts with 404 if not found instead of + def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any: + """Like :meth:`get` but aborts with a ``404 Not Found`` error instead of returning ``None``. + + :param ident: The primary key to query. + :param description: A custom message to show on the error page. """ rv = self.get(ident) + if rv is None: abort(404, description=description) + return rv - def first_or_404(self, description=None): - """Like :meth:`first` but aborts with 404 if not found instead - of returning ``None``. + def first_or_404(self, description: str | None = None) -> t.Any: + """Like :meth:`first` but aborts with a ``404 Not Found`` error instead of + returning ``None``. + + :param description: A custom message to show on the error page. """ rv = self.first() + if rv is None: abort(404, description=description) + return rv + def one_or_404(self, description: str | None = None) -> t.Any: + """Like :meth:`one` but aborts with a ``404 Not Found`` error instead of raising + ``NoResultFound`` or ``MultipleResultsFound``. + + :param description: A custom message to show on the error page. + + .. versionadded:: 3.0 + """ + try: + return self.one() + except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound): + abort(404, description=description) + def paginate( self, *, @@ -74,3 +102,18 @@ def paginate( error_out=error_out, count=count, ) + + +def __getattr__(name: str) -> t.Any: + import warnings + + if name == "BaseQuery": + warnings.warn( + "'BaseQuery' is renamed to 'Query'. The old name is deprecated and will be" + " removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + return Query + + raise AttributeError(name) diff --git a/tests/test_query_class.py b/tests/test_query_class.py index df11b181..f74fff64 100644 --- a/tests/test_query_class.py +++ b/tests/test_query_class.py @@ -1,5 +1,5 @@ from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.query import BaseQuery +from flask_sqlalchemy.query import Query def test_default_query_class(db): @@ -15,14 +15,14 @@ class Child(db.Model): c = Child() c.parent = p - assert type(Parent.query) == BaseQuery - assert type(Child.query) == BaseQuery - assert isinstance(p.children, BaseQuery) - assert isinstance(db.session.query(Parent), BaseQuery) + assert type(Parent.query) == Query + assert type(Child.query) == Query + assert isinstance(p.children, Query) + assert isinstance(db.session.query(Parent), Query) def test_custom_query_class(app): - class CustomQueryClass(BaseQuery): + class CustomQueryClass(Query): pass db = SQLAlchemy(app, query_class=CustomQueryClass) @@ -48,13 +48,13 @@ class Child(db.Model): def test_dont_override_model_default(app): - class CustomQueryClass(BaseQuery): + class CustomQueryClass(Query): pass db = SQLAlchemy(app, query_class=CustomQueryClass) class SomeModel(db.Model): id = db.Column(db.Integer, primary_key=True) - query_class = BaseQuery + query_class = Query - assert type(SomeModel.query) == BaseQuery + assert type(SomeModel.query) == Query diff --git a/tests/test_sqlalchemy_includes.py b/tests/test_sqlalchemy_includes.py index b805054f..aa84a7f5 100644 --- a/tests/test_sqlalchemy_includes.py +++ b/tests/test_sqlalchemy_includes.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.query import BaseQuery +from flask_sqlalchemy.query import Query def test_sqlalchemy_includes(): @@ -11,4 +11,4 @@ def test_sqlalchemy_includes(): assert db.Column == sa.Column # The Query object we expose is actually our own subclass. - assert db.Query == BaseQuery + assert db.Query == Query From 8ce1b447f228ff5688f2a08593fbf1e53d945bd1 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 07:35:12 -0700 Subject: [PATCH 10/27] reorganize model --- src/flask_sqlalchemy/model.py | 142 +++++++++++++++++----------------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index ff7ba3b4..ad8cbe90 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -9,44 +9,55 @@ from sqlalchemy.schema import _get_table_key -def should_set_tablename(cls): - """Determine whether ``__tablename__`` should be automatically generated - for a model. +class _QueryProperty: + def __init__(self, sa): + self.sa = sa - * If no class in the MRO sets a name, one should be generated. - * If a declared attr is found, it should be used instead. - * If a name is found, it should be used if the class is a mixin, otherwise - one should be generated. - * Abstract models should not have one generated. + def __get__(self, obj, type): + try: + mapper = orm.class_mapper(type) + if mapper: + return type.query_class(mapper, session=self.sa.session()) + except UnmappedClassError: + return None - Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the - model looks like single or joined-table inheritance. If no primary key is - found, the name will be unset. + +class Model: + """Base class for SQLAlchemy declarative base model. + + To define models, subclass :attr:`db.Model `, not this + class. To customize ``db.Model``, subclass this and pass it as + ``model_class`` to :class:`SQLAlchemy`. """ - if cls.__dict__.get("__abstract__", False) or not any( - isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:] - ): - return False - for base in cls.__mro__: - if "__tablename__" not in base.__dict__: - continue + #: Query class used by :attr:`query`. Defaults to + # :class:`SQLAlchemy.Query`, which defaults to :class:`Query`. + query_class = None - if isinstance(base.__dict__["__tablename__"], declared_attr): - return False + #: Convenience property to query the database for instances of this model + # using the current session. Equivalent to ``db.session.query(Model)`` + # unless :attr:`query_class` has been changed. + query = None - return not ( - base is cls - or base.__dict__.get("__abstract__", False) - or not isinstance(base, DeclarativeMeta) - ) + def __repr__(self): + identity = inspect(self).identity - return True + if identity is None: + pk = f"(transient {id(self)})" + else: + pk = ", ".join(str(value) for value in identity) + return f"<{type(self).__name__} {pk}>" -def camel_to_snake_case(name): - name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) - return name.lower().lstrip("_") + +class BindMetaMixin(type): + def __init__(cls, name, bases, d): + bind_key = d.pop("__bind_key__", None) or getattr(cls, "__bind_key__", None) + + super().__init__(name, bases, d) + + if bind_key is not None and getattr(cls, "__table__", None) is not None: + cls.__table__.info["bind_key"] = bind_key class NameMetaMixin(type): @@ -100,56 +111,45 @@ def __table_cls__(cls, *args, **kwargs): del cls.__tablename__ -class BindMetaMixin(type): - def __init__(cls, name, bases, d): - bind_key = d.pop("__bind_key__", None) or getattr(cls, "__bind_key__", None) - - super().__init__(name, bases, d) - - if bind_key is not None and getattr(cls, "__table__", None) is not None: - cls.__table__.info["bind_key"] = bind_key - - -class DefaultMeta(NameMetaMixin, BindMetaMixin, DeclarativeMeta): - pass - +def should_set_tablename(cls): + """Determine whether ``__tablename__`` should be automatically generated + for a model. -class Model: - """Base class for SQLAlchemy declarative base model. + * If no class in the MRO sets a name, one should be generated. + * If a declared attr is found, it should be used instead. + * If a name is found, it should be used if the class is a mixin, otherwise + one should be generated. + * Abstract models should not have one generated. - To define models, subclass :attr:`db.Model `, not this - class. To customize ``db.Model``, subclass this and pass it as - ``model_class`` to :class:`SQLAlchemy`. + Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the + model looks like single or joined-table inheritance. If no primary key is + found, the name will be unset. """ + if cls.__dict__.get("__abstract__", False) or not any( + isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:] + ): + return False - #: Query class used by :attr:`query`. Defaults to - # :class:`SQLAlchemy.Query`, which defaults to :class:`Query`. - query_class = None + for base in cls.__mro__: + if "__tablename__" not in base.__dict__: + continue - #: Convenience property to query the database for instances of this model - # using the current session. Equivalent to ``db.session.query(Model)`` - # unless :attr:`query_class` has been changed. - query = None + if isinstance(base.__dict__["__tablename__"], declared_attr): + return False - def __repr__(self): - identity = inspect(self).identity + return not ( + base is cls + or base.__dict__.get("__abstract__", False) + or not isinstance(base, DeclarativeMeta) + ) - if identity is None: - pk = f"(transient {id(self)})" - else: - pk = ", ".join(str(value) for value in identity) + return True - return f"<{type(self).__name__} {pk}>" +def camel_to_snake_case(name): + name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) + return name.lower().lstrip("_") -class _QueryProperty: - def __init__(self, sa): - self.sa = sa - def __get__(self, obj, type): - try: - mapper = orm.class_mapper(type) - if mapper: - return type.query_class(mapper, session=self.sa.session()) - except UnmappedClassError: - return None +class DefaultMeta(NameMetaMixin, BindMetaMixin, DeclarativeMeta): + pass From c6b61e401637c9b7c3c75cf2a0963136339bcc3d Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 10:47:19 -0700 Subject: [PATCH 11/27] reorganize record_queries --- src/flask_sqlalchemy/record_queries.py | 68 +++++++++++++------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index a7b490cc..308df402 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -7,6 +7,40 @@ from sqlalchemy import event +def get_debug_queries(): + """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL + queries sent to the database. This information is available until the end + of request which makes it possible to easily ensure that the SQL generated + is the one expected on errors or in unittesting. Alternatively, you can also + enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` + config variable to `True`. + + The value returned will be a list of named tuples with the following + attributes: + + `statement` + The SQL statement issued + + `parameters` + The parameters for the SQL statement + + `start_time` / `end_time` + Time the query started / the results arrived. Please keep in mind + that the timer function used depends on your platform. These + values are only useful for sorting or comparing. They do not + necessarily represent an absolute timestamp. + + `duration` + Time the query took in seconds + + `context` + A string giving a rough estimation of where in your application + query was issued. The exact format is undefined so don't try + to reconstruct filename or function name. + """ + return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) + + class _DebugQueryTuple(tuple): statement = property(itemgetter(0)) parameters = property(itemgetter(1)) @@ -74,37 +108,3 @@ def after_cursor_execute( ) ) ) - - -def get_debug_queries(): - """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL - queries sent to the database. This information is available until the end - of request which makes it possible to easily ensure that the SQL generated - is the one expected on errors or in unittesting. Alternatively, you can also - enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` - config variable to `True`. - - The value returned will be a list of named tuples with the following - attributes: - - `statement` - The SQL statement issued - - `parameters` - The parameters for the SQL statement - - `start_time` / `end_time` - Time the query started / the results arrived. Please keep in mind - that the timer function used depends on your platform. These - values are only useful for sorting or comparing. They do not - necessarily represent an absolute timestamp. - - `duration` - Time the query took in seconds - - `context` - A string giving a rough estimation of where in your application - query was issued. The exact format is undefined so don't try - to reconstruct filename or function name. - """ - return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) From 580d75b202a4d0e837fedc7c2dd45db9044e1513 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 11:08:16 -0700 Subject: [PATCH 12/27] refactor record_queries Rename get_debug_queries to get_recorded_queries Query info is a dataclass Rename context to location --- CHANGES.rst | 4 + src/flask_sqlalchemy/extension.py | 7 +- src/flask_sqlalchemy/record_queries.py | 228 +++++++++++++++---------- tests/test_basic_app.py | 16 +- 4 files changed, 151 insertions(+), 104 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ce0b8452..085313f7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -32,6 +32,10 @@ Unreleased - ``Pagination.iter_pages`` ``right_current`` parameter is inclusive. - ``Query`` is renamed from ``BaseQuery``. - ``Query.one_or_404`` is added. +- ``get_debug_queries`` is renamed to ``get_recorded_queries`` to better match the + config and functionality. +- Recorded query info is a dataclass instead of a tuple. The ``context`` attribute is + renamed to ``location``. Finding the location uses a more inclusive check. Version 2.5.1 diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 84d2f346..98455b54 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -13,7 +13,6 @@ from .model import DefaultMeta from .model import Model from .query import Query -from .record_queries import _EngineDebuggingSignalEvents from .session import SignallingSession try: @@ -135,9 +134,9 @@ def get_engine(self): self._engine = rv = self._sa.create_engine(sa_url, options) if _record_queries(self._app): - _EngineDebuggingSignalEvents( - self._engine, self._app.import_name - ).register() + from . import record_queries + + record_queries._listen(self._engine) self._connected_for = (uri, echo) diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index 308df402..dfe92403 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -1,110 +1,154 @@ -import sys -from operator import itemgetter +from __future__ import annotations + +import dataclasses +import inspect +import typing as t from time import perf_counter -from flask import _app_ctx_stack +import sqlalchemy as sa +import sqlalchemy.event from flask import current_app -from sqlalchemy import event +from flask import g +from flask import has_app_context + + +def get_recorded_queries() -> list[_QueryInfo]: + """Get the list of recorded query information for the current session. Queries are + recorded if the app is in debug or testing mode, or if the config + :data:`SQLALCHEMY_RECORD_QUERIES` is enabled. + + Each query info object has the following attributes: + + ``statement`` + The string of SQL generated by SQLAlchemy with parameter placeholders. + ``parameters`` + The parameters sent with the SQL statment. + ``start_time`` / ``end_time`` + Timing info about when the query started execution and when the results where + returned. Accuracy and value depends on the operating system. + ``duration`` + The time the query took in seconds. + ``location`` + A string description of where in your application code the query was executed. + This may not be possible to calculate, and the format is not stable. + + .. versionchanged:: 3.0 + Renamed from ``get_debug_queries``. + + .. versionchanged:: 3.0 + The info object is a dataclass instead of a tuple. + + .. versionchanged:: 3.0 + The info object attribute ``context`` is renamed to ``location``. + """ + return g.get("_sqlalchemy_queries", []) -def get_debug_queries(): - """In debug mode or testing mode, Flask-SQLAlchemy will log all the SQL - queries sent to the database. This information is available until the end - of request which makes it possible to easily ensure that the SQL generated - is the one expected on errors or in unittesting. Alternatively, you can also - enable the query recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` - config variable to `True`. +@dataclasses.dataclass +class _QueryInfo: + """Information about an executed query. Returned by :func:`get_recorded_queries`. - The value returned will be a list of named tuples with the following - attributes: + .. versionchanged:: 3.0 + Renamed from ``_DebugQueryTuple``. - `statement` - The SQL statement issued + .. versionchanged:: 3.0 + Changed to a dataclass instead of a tuple. - `parameters` - The parameters for the SQL statement + .. versionchanged:: 3.0 + ``context`` is renamed to ``location``. + """ - `start_time` / `end_time` - Time the query started / the results arrived. Please keep in mind - that the timer function used depends on your platform. These - values are only useful for sorting or comparing. They do not - necessarily represent an absolute timestamp. + statement: str + parameters: t.Any + start_time: float + end_time: float + location: str - `duration` - Time the query took in seconds + @property + def duration(self) -> float: + return self.end_time - self.start_time - `context` - A string giving a rough estimation of where in your application - query was issued. The exact format is undefined so don't try - to reconstruct filename or function name. - """ - return getattr(_app_ctx_stack.top, "sqlalchemy_queries", []) + @property + def context(self) -> str: + import warnings + + warnings.warn( + "'context' is renamed to 'location'. The old name is deprecated and will be" + " removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + return self.location + def __getitem__(self, key: int) -> object: + import warnings -class _DebugQueryTuple(tuple): - statement = property(itemgetter(0)) - parameters = property(itemgetter(1)) - start_time = property(itemgetter(2)) - end_time = property(itemgetter(3)) - context = property(itemgetter(4)) + name = ("statement", "parameters", "start_time", "end_time", "location")[key] + warnings.warn( + "Query info is a dataclass, not a tuple. Lookup by index is deprecated and" + f" will be removed in Flask-SQLAlchemy 3.1. Use 'info.{name}' instead.", + DeprecationWarning, + stacklevel=2, + ) + return getattr(self, name) + + +def _listen(engine: sa.Engine) -> None: + sa.event.listen(engine, "before_cursor_execute", _record_start, named=True) + sa.event.listen(engine, "after_cursor_execute", _record_end, named=True) - @property - def duration(self): - return self.end_time - self.start_time - def __repr__(self): - return ( - f"" +def _record_start(context: sa.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + context._fsa_start_time = perf_counter() + + +def _record_end(context: sa.ExecutionContext, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + if "_sqlalchemy_queries" not in g: + g._sqlalchemy_queries = [] + + import_top = current_app.import_name.partition(".")[0] + import_dot = f"{import_top}." + frame = inspect.currentframe() + + while frame: + name = frame.f_globals.get("__name__") + + if name and (name == import_top or name.startswith(import_dot)): + code = frame.f_code + location = f"{code.co_filename}:{frame.f_lineno} ({code.co_name})" + break + + frame = frame.f_back + else: + location = "" + + g._sqlalchemy_queries.append( + _QueryInfo( + statement=context.statement, + parameters=context.parameters, + start_time=context._fsa_start_time, + end_time=perf_counter(), + location=location, ) + ) -def _calling_context(app_path): - frm = sys._getframe(1) - while frm.f_back is not None: - name = frm.f_globals.get("__name__") - if name and (name == app_path or name.startswith(f"{app_path}.")): - funcname = frm.f_code.co_name - return f"{frm.f_code.co_filename}:{frm.f_lineno} ({funcname})" - frm = frm.f_back - return "" - - -class _EngineDebuggingSignalEvents: - """Sets up handlers for two events that let us track the execution time of - queries.""" - - def __init__(self, engine, import_name): - self.engine = engine - self.app_package = import_name - - def register(self): - event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute) - event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute) - - def before_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - context._query_start_time = perf_counter() - - def after_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): - if current_app: - try: - queries = _app_ctx_stack.top.sqlalchemy_queries - except AttributeError: - queries = _app_ctx_stack.top.sqlalchemy_queries = [] - - queries.append( - _DebugQueryTuple( - ( - statement, - parameters, - context._query_start_time, - perf_counter(), - _calling_context(self.app_package), - ) - ) - ) +def __getattr__(name: str) -> t.Any: + import warnings + + if name == "get_debug_queries": + warnings.warn( + "'get_debug_queries' is renamed to 'get_recorded_queries'. The old name is" + " deprecated and will be removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + return get_recorded_queries + + raise AttributeError(name) diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index b69d7e5d..b56ae841 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -1,7 +1,7 @@ import flask from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.record_queries import get_debug_queries +from flask_sqlalchemy.record_queries import get_recorded_queries def test_basic_insert(app, db, Todo): @@ -32,20 +32,20 @@ def test_query_recording(app, db, Todo): todo.done = True db.session.commit() - queries = get_debug_queries() + queries = get_recorded_queries() assert len(queries) == 2 query = queries[0] assert "insert into" in query.statement.lower() - assert query.parameters[0] == "Test 1" - assert query.parameters[1] == "test" - assert "test_basic_app.py" in query.context - assert "test_query_recording" in query.context + assert query.parameters[0][0] == "Test 1" + assert query.parameters[0][1] == "test" + assert "test_basic_app.py" in query.location + assert "test_query_recording" in query.location query = queries[1] assert "update" in query.statement.lower() - assert query.parameters[0] == 1 - assert query.parameters[1] == 1 + assert query.parameters[0][0] == 1 + assert query.parameters[0][1] == 1 def test_helper_api(db): From 898dc5991ccb0429a7b9ffda61ff3bb7125b7fb3 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 30 Aug 2022 11:23:30 -0700 Subject: [PATCH 13/27] refactor track_modifications --- src/flask_sqlalchemy/session.py | 8 +- src/flask_sqlalchemy/track_modifications.py | 135 ++++++++++---------- tests/test_signals.py | 35 +++-- 3 files changed, 93 insertions(+), 85 deletions(-) diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index ffa2d06f..cfb31352 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -1,7 +1,5 @@ from sqlalchemy.orm import Session as SessionBase -from .track_modifications import _SessionSignalEvents - class SignallingSession(SessionBase): """The signalling session is the default session that Flask-SQLAlchemy @@ -21,12 +19,16 @@ class SignallingSession(SessionBase): def __init__(self, db, autocommit=False, autoflush=True, **options): #: The application that this session belongs to. self.app = app = db.get_app() + self._db = db + self._model_changes = {} track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] bind = options.pop("bind", None) or db.engine binds = options.pop("binds", db.get_binds(app)) if track_modifications: - _SessionSignalEvents.register(self) + from . import track_modifications + + track_modifications._listen(self) SessionBase.__init__( self, diff --git a/src/flask_sqlalchemy/track_modifications.py b/src/flask_sqlalchemy/track_modifications.py index 006cd493..4a518963 100644 --- a/src/flask_sqlalchemy/track_modifications.py +++ b/src/flask_sqlalchemy/track_modifications.py @@ -1,78 +1,71 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.event +from flask import current_app +from flask import has_app_context from flask.signals import Namespace -from sqlalchemy import event -from sqlalchemy import inspect + +if t.TYPE_CHECKING: + from .session import SignallingSession _signals = Namespace() models_committed = _signals.signal("models-committed") before_models_committed = _signals.signal("before-models-committed") -class _SessionSignalEvents: - @classmethod - def register(cls, session): - if not hasattr(session, "_model_changes"): - session._model_changes = {} - - event.listen(session, "before_flush", cls.record_ops) - event.listen(session, "before_commit", cls.record_ops) - event.listen(session, "before_commit", cls.before_commit) - event.listen(session, "after_commit", cls.after_commit) - event.listen(session, "after_rollback", cls.after_rollback) - - @classmethod - def unregister(cls, session): - if hasattr(session, "_model_changes"): - del session._model_changes - - event.remove(session, "before_flush", cls.record_ops) - event.remove(session, "before_commit", cls.record_ops) - event.remove(session, "before_commit", cls.before_commit) - event.remove(session, "after_commit", cls.after_commit) - event.remove(session, "after_rollback", cls.after_rollback) - - @staticmethod - def record_ops(session, flush_context=None, instances=None): - try: - d = session._model_changes - except AttributeError: - return - - for targets, operation in ( - (session.new, "insert"), - (session.dirty, "update"), - (session.deleted, "delete"), - ): - for target in targets: - state = inspect(target) - key = state.identity_key if state.has_identity else id(target) - d[key] = (target, operation) - - @staticmethod - def before_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - before_models_committed.send(session.app, changes=list(d.values())) - - @staticmethod - def after_commit(session): - try: - d = session._model_changes - except AttributeError: - return - - if d: - models_committed.send(session.app, changes=list(d.values())) - d.clear() - - @staticmethod - def after_rollback(session): - try: - d = session._model_changes - except AttributeError: - return - - d.clear() +def _listen(session) -> None: + sa.event.listen(session, "before_flush", _record_ops, named=True) + sa.event.listen(session, "before_commit", _record_ops, named=True) + sa.event.listen(session, "before_commit", _before_commit) + sa.event.listen(session, "after_commit", _after_commit) + sa.event.listen(session, "after_rollback", _after_rollback) + + +def _record_ops(session: SignallingSession, **kwargs: t.Any) -> None: + if not has_app_context(): + return + + if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + for targets, operation in ( + (session.new, "insert"), + (session.dirty, "update"), + (session.deleted, "delete"), + ): + for target in targets: + state = sa.inspect(target) + key = state.identity_key if state.has_identity else id(target) + session._model_changes[key] = (target, operation) + + +def _before_commit(session: SignallingSession) -> None: + if not has_app_context(): + return + + if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + if session._model_changes: + changes = list(session._model_changes.values()) + before_models_committed.send(current_app._get_current_object(), changes=changes) + + +def _after_commit(session: SignallingSession) -> None: + if not has_app_context(): + return + + if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + return + + if session._model_changes: + changes = list(session._model_changes.values()) + models_committed.send(current_app._get_current_object(), changes=changes) + session._model_changes.clear() + + +def _after_rollback(session: SignallingSession) -> None: + session._model_changes.clear() diff --git a/tests/test_signals.py b/tests/test_signals.py index bdd02c43..dde1cfc6 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -24,13 +24,16 @@ def before_committed(sender, changes): before_models_committed.connect(before_committed) todo = Todo("Awesome", "the text") - db.session.add(todo) - db.session.commit() + + with app.app_context(): + db.session.add(todo) + db.session.commit() + assert Namespace.is_received before_models_committed.disconnect(before_committed) -def test_model_signals(db, Todo): +def test_model_signals(app, db, Todo): recorded = [] def committed(sender, changes): @@ -38,22 +41,32 @@ def committed(sender, changes): recorded.extend(changes) models_committed.connect(committed) - todo = Todo("Awesome", "the text") - db.session.add(todo) - assert len(recorded) == 0 - db.session.commit() + + with app.app_context(): + todo = Todo("Awesome", "the text") + db.session.add(todo) + assert len(recorded) == 0 + db.session.commit() + assert len(recorded) == 1 assert recorded[0][0] == todo assert recorded[0][1] == "insert" del recorded[:] - todo.text = "aha" - db.session.commit() + + with app.app_context(): + db.session.add(todo) + todo.text = "aha" + db.session.commit() + assert len(recorded) == 1 assert recorded[0][0] == todo assert recorded[0][1] == "update" del recorded[:] - db.session.delete(todo) - db.session.commit() + + with app.app_context(): + db.session.delete(todo) + db.session.commit() + assert len(recorded) == 1 assert recorded[0][0] == todo assert recorded[0][1] == "delete" From 287710d110f7bb2280c6c6682244c30b7aec32c4 Mon Sep 17 00:00:00 2001 From: David Lord Date: Sat, 3 Sep 2022 06:46:26 -0700 Subject: [PATCH 14/27] use getattr for sqlalchemy aliases query class is applied to backref correctly --- CHANGES.rst | 3 + src/flask_sqlalchemy/extension.py | 122 +++++++++++++++++------------- 2 files changed, 74 insertions(+), 51 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 085313f7..89257e20 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -36,6 +36,9 @@ Unreleased config and functionality. - Recorded query info is a dataclass instead of a tuple. The ``context`` attribute is renamed to ``location``. Finding the location uses a more inclusive check. +- The ``SQLAlchemy`` extension object uses ``__getattr__`` to alias names from the + SQLAlchemy package, rather than copying them as attributes. +- The query class is applied to ``backref`` in ``relationship``. Version 2.5.1 diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 98455b54..7faead64 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -1,13 +1,14 @@ -import functools +from __future__ import annotations + import os +import typing as t import warnings from threading import Lock -import sqlalchemy +import sqlalchemy as sa +import sqlalchemy.event +import sqlalchemy.orm from flask import current_app -from sqlalchemy import event -from sqlalchemy import orm -from sqlalchemy.engine.url import make_url from .model import _QueryProperty from .model import DefaultMeta @@ -15,14 +16,6 @@ from .query import Query from .session import SignallingSession -try: - from sqlalchemy.orm import declarative_base - from sqlalchemy.orm import DeclarativeMeta -except ImportError: - # SQLAlchemy <= 1.3 - from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.ext.declarative import DeclarativeMeta - # Scope the session to the current greenlet if greenlet is available, # otherwise fall back to the current thread. try: @@ -63,38 +56,6 @@ def _make_table(*args, **kwargs): return _make_table -def _set_default_query_class(d, cls): - if "query_class" not in d: - d["query_class"] = cls - - -def _wrap_with_default_query_class(fn, cls): - @functools.wraps(fn) - def newfn(*args, **kwargs): - _set_default_query_class(kwargs, cls) - if "backref" in kwargs: - backref = kwargs["backref"] - if isinstance(backref, str): - backref = (backref, {}) - _set_default_query_class(backref[1], cls) - return fn(*args, **kwargs) - - return newfn - - -def _include_sqlalchemy(obj, cls): - for module in sqlalchemy, sqlalchemy.orm: - for key in module.__all__: - if not hasattr(obj, key): - setattr(obj, key, getattr(module, key)) - # Note: obj.Table does not attempt to be a SQLAlchemy Table class. - obj.Table = _make_table(obj) - obj.relationship = _wrap_with_default_query_class(obj.relationship, cls) - obj.relation = _wrap_with_default_query_class(obj.relation, cls) - obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls) - obj.event = event - - def _record_queries(app): if app.debug: return True @@ -129,7 +90,7 @@ def get_engine(self): if (uri, echo) == self._connected_for: return self._engine - sa_url = make_url(uri) + sa_url = sa.engine.make_url(uri) sa_url, options = self.get_options(sa_url, echo) self._engine = rv = self._sa.create_engine(sa_url, options) @@ -287,7 +248,7 @@ def __init__( self._engine_lock = Lock() self.app = app self._engine_options = engine_options or {} - _include_sqlalchemy(self, query_class) + self.Table = _make_table(self) if app is not None: self.init_app(app) @@ -317,7 +278,7 @@ def create_scoped_session(self, options=None): scopefunc = options.pop("scopefunc", _ident_func) options.setdefault("query_cls", self.Query) - return orm.scoped_session(self.create_session(options), scopefunc=scopefunc) + return sa.orm.scoped_session(self.create_session(options), scopefunc=scopefunc) def create_session(self, options): """Create the session factory used by :meth:`create_scoped_session`. @@ -334,7 +295,7 @@ class or a :class:`~sqlalchemy.orm.session.sessionmaker`. :param options: dict of keyword arguments passed to session class """ - return orm.sessionmaker(class_=SignallingSession, db=self, **options) + return sa.orm.sessionmaker(class_=SignallingSession, db=self, **options) def make_declarative_base(self, model, metadata=None): """Creates the declarative base that all models will inherit from. @@ -350,8 +311,8 @@ def make_declarative_base(self, model, metadata=None): ``model`` can be an existing declarative base in order to support complex customization such as changing the metaclass. """ - if not isinstance(model, DeclarativeMeta): - model = declarative_base( + if not isinstance(model, sa.orm.DeclarativeMeta): + model = sa.orm.declarative_base( cls=model, name="Model", metadata=metadata, metaclass=DefaultMeta ) @@ -610,3 +571,62 @@ def reflect(self, bind="__all__", app=None): def __repr__(self): url = self.engine.url if self.app or current_app else None return f"<{type(self).__name__} engine={url!r}>" + + def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None: + """Apply the extension's :attr:`Query` class as the default for relationships + and backrefs. + + :meta private: + """ + kwargs.setdefault("query_class", self.Query) + + if "backref" in kwargs: + backref = kwargs["backref"] + + if isinstance(backref, str): + backref = (backref, {}) + + backref[1].setdefault("query_class", self.Query) + + def relationship( + self, *args: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty: + """A SQLAlchemy :func:`~sqlalchemy.orm.relationship` that applies this + extension's :attr:`Query` class for dynamic relationships and backrefs. + """ + self._set_rel_query(kwargs) + return sa.orm.relationship(*args, **kwargs) + + def dynamic_loader( + self, argument: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty: + """A SQLAlchemy :func:`~sqlalchemy.orm.dynamic_loader` that applies this + extension's :attr:`Query` class for relationships and backrefs. + """ + self._set_rel_query(kwargs) + return sa.orm.dynamic_loader(argument, **kwargs) + + def _relation(self, *args: t.Any, **kwargs: t.Any) -> sa.orm.RelationshipProperty: + """A SQLAlchemy :func:`~sqlalchemy.orm.relationship` that applies this + extension's :attr:`Query` class for dynamic relationships and backrefs. + + SQLAlchemy 2.0 removes this name, use ``relationship`` instead. + + :meta private: + """ + # Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``. + self._set_rel_query(kwargs) + return sa.orm.relation(*args, **kwargs) + + def __getattr__(self, name: str) -> t.Any: + if name == "relation": + return self._relation + + if name == "event": + return sa.event + + for mod in (sa, sa.orm): + if name in mod.__all__: + return getattr(sa, name) + + raise AttributeError(name) From 73be43443f4ebb12d1775794fa5432b9988ed64c Mon Sep 17 00:00:00 2001 From: David Lord Date: Wed, 7 Sep 2022 15:57:55 -0700 Subject: [PATCH 15/27] draw the rest of the owl unique metadata per bind key engines are created immediately, no connectors extension is stored directly in app.extensions engine_options param is lower precedence than config make setup methods private rename SignallingSession to Session model repr distinguishes transient and pending Table is a subclass not a function session class can be customized sqlite does not use null pool mysql does not set pool size --- CHANGES.rst | 28 + src/flask_sqlalchemy/extension.py | 1052 +++++++++++++++++------------ src/flask_sqlalchemy/model.py | 188 ++++-- src/flask_sqlalchemy/session.py | 119 ++-- tests/test_basic_app.py | 6 +- tests/test_binds.py | 95 +-- tests/test_commit_on_teardown.py | 37 - tests/test_config.py | 67 +- tests/test_query_property.py | 3 +- 9 files changed, 914 insertions(+), 681 deletions(-) delete mode 100644 tests/test_commit_on_teardown.py diff --git a/CHANGES.rst b/CHANGES.rst index 89257e20..8518ba78 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -39,6 +39,34 @@ Unreleased - The ``SQLAlchemy`` extension object uses ``__getattr__`` to alias names from the SQLAlchemy package, rather than copying them as attributes. - The query class is applied to ``backref`` in ``relationship``. +- ``SignallingSession`` is renamed to ``Session``. +- ``Session.get_bind`` more closely matches the base implementation. +- ``Model`` ``repr`` distinguishes between transient and pending instances. +- Different bind keys use different SQLAlchemy ``MetaData`` registries, allowing + tables in different databases to have the same name. Bind keys are stored and looked + up on the resulting metadata rather than the model or table. +- The ``engine_options`` parameter is applied as defaults before per-engine + configuration. +- ``SQLALCHEMY_BINDS`` values can either be an engine URL, or a dict of engine options + including URL, for each bind. ``SQLALCHEMY_DATABASE_URI`` and + ``SQLALCHEMY_ENGINE_OPTIONS`` correspond to the ``None`` key and take precedence. +- Engines are created when calling ``init_app`` rather than the first time they are + accessed. +- The extension instance is stored directly as ``app.extensions["sqlalchemy"]``. +- All parameters except ``app`` are keyword-only. +- Setup methods that create the engines and session are renamed with a leading + underscore. They are considered internal interfaces which may change at any time. +- ``db.Table`` is a subclass instead of a function. +- The session class can be customized by passing the ``class_`` key in the + ``session_options`` parameter. +- SQLite engines do not use ``NullPool`` if ``pool_size`` is 0. +- MySQL engines do not set ``pool_size`` to 10. +- ``db.engines`` exposes the map of bind keys to engines for the current app. +- ``get_engine``, ``get_tables_for_bind``, and ``get_binds`` are deprecated. +- Renamed the ``bind`` parameter to ``bind_key`` and removed the ``app`` parameter + from various methods. +- ``SQLALCHEMY_RECORD_QUERIES`` configuration takes precedence over ``app.debug`` and + ``app.testing``, allowing it to be disabled in those modes. Version 2.5.1 diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 7faead64..6fc78c9a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -2,19 +2,21 @@ import os import typing as t -import warnings -from threading import Lock +from weakref import WeakKeyDictionary import sqlalchemy as sa import sqlalchemy.event +import sqlalchemy.exc import sqlalchemy.orm +import sqlalchemy.pool from flask import current_app +from flask import Flask +from flask import has_app_context -from .model import _QueryProperty from .model import DefaultMeta from .model import Model from .query import Query -from .session import SignallingSession +from .session import Session # Scope the session to the current greenlet if greenlet is available, # otherwise fall back to the current thread. @@ -24,553 +26,742 @@ from threading import get_ident as _ident_func -def _sa_url_set(url, **kwargs): - try: - url = url.set(**kwargs) - except AttributeError: - # SQLAlchemy <= 1.3 - for key, value in kwargs.items(): - setattr(url, key, value) +class SQLAlchemy: + """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, + associating tables and models with specific engines, and cleaning up connections and + sessions after each request. + + Only the engine configuration is specific to each application, other things like + the model, table, metadata, and session are shared for all applications using that + extension instance. Call :meth:`init_app` to configure the extension on an + application. + + After creating the extension, create model classes by subclassing :attr:`Model`, and + table classes with :attr:`Table`. These can be accessed before :meth:`init_app` is + called, making it possible to define the models separately from the application. + + Accessing :attr:`session` and :attr:`engine` requires an active Flask application + context. This includes methods like :meth:`create_all` which use the engine. + + This class also provides access to names in SQLAlchemy's :mod:`sqlalchemy` and + :mod:`sqlalchemy.orm` modules. For example, you can use ``db.Column`` and + ``db.relationship`` instead of importing ``sqlalchemy.Column`` and + ``sqlalchemy.orm.relationship``. This can be convenient when defining models. + + :param app: Call :meth:`init_app` on this Flask application now. + :param metadata: Use this as the default :class:sqlalchemy.MetaData`. Useful for + setting a naming convention. + :param session_options: Arguments used by :attr:`db.session` to create each session + instance. A ``scopefunc`` key will be passed to :attr:`db.session`, not the + session instance. See :class:`sqlalchemy.orm.sessionmaker` for a list of + arguments. + :param query_class: Use this as the default query class for models and dynamic + relationships. The query interface is considered legacy in SQLAlchemy 2.0. + :param model_class: Use this as the model base class when creating the declarative + model class :attr:`db.Model`. Can also be a fully created declarative model + class for further customization. + :param engine_options: Default arguments used when creating every engine. These are + lower precedence than application config. See :func:`sqlalchemy.create_engine` + for a list of arguments. - return url + .. versionchanged:: 3.0 + Separate ``metadata`` are used for each bind key. + .. versionchanged:: 3.0 + The ``engine_options`` parameter is applied as defaults before per-engine + configuration. -def _sa_url_query_setdefault(url, **kwargs): - query = dict(url.query) + .. versionchanged:: 3.0 + The session class can be customized in ``session_options``. - for key, value in kwargs.items(): - query.setdefault(key, value) + .. versionchanged:: 3.0 + Engines are created when calling ``init_app`` rather than the first time they + are accessed. - return _sa_url_set(url, query=query) + .. versionchanged:: 3.0 + All parameters except ``app`` are keyword-only. + .. versionchanged:: 3.0 + The extension instance is stored directly as ``app.extensions["sqlalchemy"]``. -def _make_table(db): - def _make_table(*args, **kwargs): - if len(args) > 1 and isinstance(args[1], db.Column): - args = (args[0], db.metadata) + args[1:] - info = kwargs.pop("info", None) or {} - info.setdefault("bind_key", None) - kwargs["info"] = info - return sqlalchemy.Table(*args, **kwargs) + .. versionchanged:: 3.0 + Setup methods are renamed with a leading underscore. They are considered + internal interfaces which may change at any time. - return _make_table + .. versionchanged:: 3.0 + Removed the ``use_native_unicode`` parameter and config. + .. versionchanged:: 3.0 + The ``COMMIT_ON_TEARDOWN`` configuration is deprecated and will + be removed in Flask-SQLAlchemy 3.1. Call ``db.session.commit()`` + directly instead. -def _record_queries(app): - if app.debug: - return True - rq = app.config["SQLALCHEMY_RECORD_QUERIES"] - if rq is not None: - return rq - return bool(app.config.get("TESTING")) + .. versionchanged:: 2.4 + Added the ``engine_options`` parameter. + .. versionchanged:: 2.1 + Added the ``metadata``, ``query_class``, and ``model_class`` parameters. -class _EngineConnector: - def __init__(self, sa, app, bind=None): - self._sa = sa - self._app = app - self._engine = None - self._connected_for = None - self._bind = bind - self._lock = Lock() + .. versionchanged:: 2.1 + Use the same query class across ``session``, ``Model.query`` and + ``Query``. - def get_uri(self): - if self._bind is None: - return self._app.config["SQLALCHEMY_DATABASE_URI"] - binds = self._app.config.get("SQLALCHEMY_BINDS") or () - assert ( - self._bind in binds - ), f"Bind {self._bind!r} is not configured in 'SQLALCHEMY_BINDS'." - return binds[self._bind] + .. versionchanged:: 0.16 + ``scopefunc`` is accepted in ``session_options``. - def get_engine(self): - with self._lock: - uri = self.get_uri() - echo = self._app.config["SQLALCHEMY_ECHO"] - if (uri, echo) == self._connected_for: - return self._engine + .. versionchanged:: 0.10 + Added the ``session_options`` parameter. + """ - sa_url = sa.engine.make_url(uri) - sa_url, options = self.get_options(sa_url, echo) - self._engine = rv = self._sa.create_engine(sa_url, options) + def __init__( + self, + app: Flask | None = None, + *, + metadata: sa.MetaData | None = None, + session_options: dict[str, t.Any] | None = None, + query_class: t.Type[Query] = Query, + model_class: t.Type[Model] | sa.orm.DeclarativeMeta = Model, + engine_options: dict[str, t.Any] | None = None, + ): + if session_options is None: + session_options = {} - if _record_queries(self._app): - from . import record_queries + self.Query = query_class + """The default query class used by ``Model.query`` and ``lazy="dynamic"`` + relationships. - record_queries._listen(self._engine) + .. warning:: + The query interface is considered legacy in SQLAlchemy 2.0. - self._connected_for = (uri, echo) + Customize this by passing the ``query_class`` parameter to the extension. + """ - return rv + self.session = self._make_scoped_session(session_options) + """A :class:`sqlalchemy.orm.scoped_session` that creates instances of + :class:`.Session` scoped to the current Flask application context. The session + will be removed, returning the engine connection to the pool, when the + application context exits. - def get_options(self, sa_url, echo): - options = {} - sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options) + Customize this by passing ``session_options`` to the extension. + """ - if echo: - options["echo"] = echo + self.metadatas: dict[str | None, sa.MetaData] = {} + """Map of bind keys to :class:`sqlalchemy.MetaData` instances. The ``None`` key + refers to the default metadata, and is available as :attr:`metadata`. - # Give the config options set by a developer explicitly priority - # over decisions FSA makes. - options.update(self._app.config["SQLALCHEMY_ENGINE_OPTIONS"]) - # Give options set in SQLAlchemy.__init__() ultimate priority - options.update(self._sa._engine_options) - return sa_url, options + Customize the default metadata by passing the ``metadata`` parameter to the + extension. This can be used to set a naming convention. When metadata for + another bind key is created, it copies the default's naming convention. + .. versionadded:: 3.0 + """ -def get_state(app): - """Gets the state for the application""" - assert "sqlalchemy" in app.extensions, ( - "The sqlalchemy extension was not registered to the current " - "application. Please make sure to call init_app() first." - ) - return app.extensions["sqlalchemy"] + if metadata is not None: + metadata.info["bind_key"] = None + self.metadatas[None] = metadata + self.Table = self._make_table_class() + """A :class:`sqlalchemy.Table` class that chooses a metadata automatically. -class _SQLAlchemyState: - """Remembers configuration for the (db, app) tuple.""" + Unlike the base ``Table``, the ``metadata`` argument is not required. If it is + not given, it is selected based on the ``bind_key`` argument. - def __init__(self, db): - self.db = db - self.connectors = {} + :param bind_key: Used to select a different metadata. + :param args: Arguments passed to the base class. These are typically the table's + name, columns, and constraints. + :param kwargs: Arguments passed to the base class. + .. versionchanged:: 3.0 + This is a subclass of SQLAlchemy's ``Table`` rather than a function. + """ -class SQLAlchemy: - """This class is used to control the SQLAlchemy integration to one - or more Flask applications. Depending on how you initialize the - object it is usable right away or will attach as needed to a - Flask application. + self.Model = self._make_declarative_base(model_class) + """A SQLAlchemy declarative model class. Subclass this to define database + models. - There are two usage modes which work very similarly. One is binding - the instance to a very specific Flask application:: + If a model does not set ``__tablename__``, it will be generated by converting + the class name from ``CamelCase`` to ``snake_case``. It will not be generated + if the model looks like it uses single-table inheritance. - app = Flask(__name__) - db = SQLAlchemy(app) + If a model or parent class sets ``__bind_key__``, it will use that metadata and + database engine. Otherwise, it will use the default :attr:`metadata` and + :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. - The second possibility is to create the object once and configure the - application later to support it:: + Customize this by subclassing :class:`.Model` and passing the ``model_class`` + parameter to the extension. A fully created declarative model class can be + passed as well, to use a custom metaclass. + """ - db = SQLAlchemy() + if engine_options is None: + engine_options = {} - def create_app(): - app = Flask(__name__) - db.init_app(app) - return app + self._engine_options = engine_options + self._app_engines: WeakKeyDictionary[Flask, dict[str | None, sa.engine.Engine]] + self._app_engines = WeakKeyDictionary() - The difference between the two is that in the first case methods like - :meth:`create_all` and :meth:`drop_all` will work all the time but in - the second case a :meth:`flask.Flask.app_context` has to exist. + self._app: Flask | None = app - By default Flask-SQLAlchemy will apply some backend-specific settings - to improve your experience with them. + if app is not None: + self.init_app(app) - This class also provides access to all the SQLAlchemy functions and classes - from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can - declare models like this:: + def __repr__(self): + if not has_app_context() and self._app is None: + return f"<{type(self).__name__}>" - class User(db.Model): - username = db.Column(db.String(80), unique=True) - pw_hash = db.Column(db.String(80)) + message = f"{type(self).__name__} {self.engine.url}" - You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but - note that Flask-SQLAlchemy customizations are available only through an - instance of this :class:`SQLAlchemy` class. Query classes default to - :class:`Query` for `db.Query`, `db.Model.query_class`, and the default - query_class for `db.relationship` and `db.backref`. If you use these - interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, - the default query class will be that of :mod:`sqlalchemy`. + if len(self.engines) > 1: + message = f"{message} +{len(self.engines) - 1}" - .. admonition:: Check types carefully + return f"<{message}>" - Don't perform type or `isinstance` checks against `db.Table`, which - emulates `Table` behavior but is not a class. `db.Table` exposes the - `Table` interface, but is a function which allows omission of metadata. + def init_app(self, app: Flask) -> None: + """Initialize a Flask application for use with this extension instance. This + must be called before accessing the database engine or session with the app. - The ``session_options`` parameter, if provided, is a dict of parameters - to be passed to the session constructor. See - :class:`~sqlalchemy.orm.session.Session` for the standard options. + This sets default configuration values, then configures the extension on the + application and creates the engines for each bind key. Therefore, this must be + called after the application has been configured. Changes to application config + after this call will not be reflected. - The ``engine_options`` parameter, if provided, is a dict of parameters - to be passed to create engine. See :func:`~sqlalchemy.create_engine` - for the standard options. The values given here will be merged with and - override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config - variable or othewise set by this library. + The following keys from ``app.config`` are used: - .. versionchanged:: 3.0 - Removed the ``use_native_unicode`` parameter and config. + - :data:`.SQLALCHEMY_DATABASE_URI` + - :data:`.SQLALCHEMY_ENGINE_OPTIONS` + - :data:`.SQLALCHEMY_ECHO` + - :data:`.SQLALCHEMY_BINDS` + - :data:`.SQLALCHEMY_RECORD_QUERIES` + - :data:`.SQLALCHEMY_TRACK_MODIFICATIONS` - .. versionchanged:: 3.0 - ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in - version 3.1. Call ``db.session.commit()`` directly instead. + :param app: The Flask application to initialize. + """ + # We intentionally don't set self._app, to support initializing multiple apps. + app.extensions["sqlalchemy"] = self - .. versionchanged:: 2.4 - Added the ``engine_options`` parameter. + if app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN", False): + import warnings - .. versionchanged:: 2.1 - Added the ``metadata`` parameter. This allows for setting custom - naming conventions among other, non-trivial things. + warnings.warn( + "'SQLALCHEMY_COMMIT_ON_TEARDOWN' is deprecated and will be removed in" + " Flask-SQAlchemy 3.1. Call 'db.session.commit()'` directly instead.", + DeprecationWarning, + ) + app.teardown_appcontext(self._teardown_commit) + else: + app.teardown_appcontext(self._teardown_session) - .. versionchanged:: 2.1 - Added the ``query_class`` parameter, to allow customisation - of the query class, in place of the default of - :class:`Query`. + basic_uri: str | sa.engine.URL | None = app.config.setdefault( + "SQLALCHEMY_DATABASE_URI", None + ) + basic_engine_options = self._engine_options.copy() + basic_engine_options.update( + app.config.setdefault("SQLALCHEMY_ENGINE_OPTIONS", {}) + ) + echo: bool = app.config.setdefault("SQLALCHEMY_ECHO", False) + config_binds: dict[ + str | None, str | sa.engine.URL | dict[str, t.Any] + ] = app.config.setdefault("SQLALCHEMY_BINDS", {}) + engine_options: dict[str | None, dict[str, t.Any]] = {} - .. versionchanged:: 2.1 - Added the ``model_class`` parameter, which allows a custom model - class to be used in place of :class:`Model`. + # Build the engine config for each bind key. + for key, value in config_binds.items(): + engine_options[key] = self._engine_options.copy() - .. versionchanged:: 2.1 - Use the same query class across ``session``, ``Model.query`` and - ``Query``. + if isinstance(value, (str, sa.engine.URL)): + engine_options[key]["url"] = value + else: + engine_options[key].update(value) - .. versionchanged:: 0.16 - ``scopefunc`` is now accepted on ``session_options``. It allows - specifying a custom function which will define the SQLAlchemy - session's scoping. + # Build the engine config for the default bind key. + if basic_uri is not None: + basic_engine_options["url"] = basic_uri - .. versionchanged:: 0.10 - Added the ``session_options`` parameter. - """ + if basic_engine_options: + engine_options.setdefault(None, {}).update(basic_engine_options) - def __init__( - self, - app=None, - session_options=None, - metadata=None, - query_class=Query, - model_class=Model, - engine_options=None, - ): + if not engine_options: + raise RuntimeError( + "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." + ) - self.Query = query_class - self.session = self.create_scoped_session(session_options) - self.Model = self.make_declarative_base(model_class, metadata) - self._engine_lock = Lock() - self.app = app - self._engine_options = engine_options or {} - self.Table = _make_table(self) + engines = self._app_engines.setdefault(app, {}) - if app is not None: - self.init_app(app) + # Dispose existing engines in case init_app is called again. + if engines: + for engine in engines.values(): + engine.dispose() - @property - def metadata(self): - """The metadata associated with ``db.Model``.""" + engines.clear() + + # Create the metadata and engine for each bind key. + for key, options in engine_options.items(): + self._make_metadata(key) + options.setdefault("echo", echo) + options.setdefault("echo_pool", echo) + self._apply_driver_defaults(options, app) + engines[key] = self._make_engine(key, options, app) + + record: bool | None = app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", None) + + if record is None: + record = app.debug or app.testing + + if record: + from . import record_queries + + for engine in engines.values(): + record_queries._listen(engine) + + if app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False): + from . import track_modifications - return self.Model.metadata + track_modifications._listen(self.session) - def create_scoped_session(self, options=None): - """Create a :class:`~sqlalchemy.orm.scoping.scoped_session` - on the factory from :meth:`create_session`. + def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_session: + """Create a :class:`sqlalchemy.orm.scoped_session` around the factory from + :meth:`_make_session_factory`. The result is available as :attr:`session`. - An extra key ``'scopefunc'`` can be set on the ``options`` dict to - specify a custom scope function. If it's not provided, Flask's app - context stack identity is used. This will ensure that sessions are - created and removed with the request/response cycle, and should be fine - in most cases. + The scope function can be customized using the ``scopefunc`` key in the + ``session_options`` parameter to the extension. By default it uses the current + thread or greenlet id. - :param options: dict of keyword arguments passed to session class in - ``create_session`` + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param options: The ``session_options`` parameter from ``__init__``. Keyword + arguments passed to the session factory. A ``scopefunc`` key is popped. + + .. versionchanged:: 3.0 + Renamed from ``create_scoped_session``, this method is internal. """ + scope = options.pop("scopefunc", _ident_func) + factory = self._make_session_factory(options) + return sa.orm.scoped_session(factory, scope) + + def _make_session_factory(self, options: dict[str, t.Any]) -> sa.orm.sessionmaker: + """Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by + :meth:`_make_scoped_session`. - if options is None: - options = {} + To customize, pass the ``session_options`` parameter to :class:`SQLAlchemy`. To + customize the session class, subclass :class:`.Session` and pass it as the + ``class_`` key. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param options: The ``session_options`` parameter from ``__init__``. Keyword + arguments passed to the session factory. + + .. versionchanged:: 3.0 + The session class can be customized. - scopefunc = options.pop("scopefunc", _ident_func) + .. versionchanged:: 3.0 + Renamed from ``create_session``, this method is internal. + """ + options.setdefault("class_", Session) options.setdefault("query_cls", self.Query) - return sa.orm.scoped_session(self.create_session(options), scopefunc=scopefunc) + return sa.orm.sessionmaker(db=self, **options) + + def _teardown_commit(self, exc: BaseException | None) -> None: + """Commit the session at the end of the request if there was not an unhandled + exception during the request. + + :meta private: + + .. deprecated:: 3.0 + Will be removed in 3.1. Use ``db.session.commit()`` directly instead. + """ + if exc is None: + self.session.commit() + + self.session.remove() + + def _teardown_session(self, exc: BaseException | None) -> None: + """Remove the current session at the end of the request. + + :meta private: - def create_session(self, options): - """Create the session factory used by :meth:`create_scoped_session`. + .. versionadded:: 3.0 + """ + self.session.remove() + + def _make_metadata(self, bind_key: str | None) -> sa.MetaData: + """Get or create a :class:`sqlalchemy.MetaData` for the given bind key. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: + + :param bind_key: The name of the metadata being created. + + .. versionadded:: 3.0 + """ + if bind_key in self.metadatas: + return self.metadatas[bind_key] + + if bind_key is not None: + # Copy the naming convention from the default metadata. + naming_convention = self._make_metadata(None).naming_convention + else: + naming_convention = None + + # Set the bind key in info to be used by session.get_bind. + metadata = sa.MetaData( + naming_convention=naming_convention, info={"bind_key": bind_key} + ) + self.metadatas[bind_key] = metadata + return metadata - The factory **must** return an object that SQLAlchemy recognizes as a session, - or registering session events may raise an exception. + def _make_table_class(self) -> t.Type[sa.Table]: + """Create a SQLAlchemy :class:`sqlalchemy.Table` class that chooses a metadata + automatically based on the ``bind_key``. The result is available as + :attr:`Table`. - Valid factories include a :class:`~sqlalchemy.orm.session.Session` - class or a :class:`~sqlalchemy.orm.session.sessionmaker`. + This method is used for internal setup. Its signature may change at any time. - The default implementation creates a ``sessionmaker`` for - :class:`SignallingSession`. + :meta private: - :param options: dict of keyword arguments passed to session class + .. versionadded:: 3.0 """ - return sa.orm.sessionmaker(class_=SignallingSession, db=self, **options) + class Table(sa.Table): + def __new__( + cls, *args: t.Any, bind_key: str | None = None, **kwargs: t.Any + ) -> Table: + # If a metadata arg is passed, go directly to the base Table. Also do + # this for no args so the correct error is shown. + if not args or (len(args) >= 2 and isinstance(args[1], sa.MetaData)): + return super().__new__(cls, *args, **kwargs) + + if ( + bind_key is None + and "info" in kwargs + and "bind_key" in kwargs["info"] + ): + import warnings + + warnings.warn( + "'table.info['bind_key'] is deprecated and will not be used in" + " Flask-SQLAlchemy 3.1. Pass the 'bind_key' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + bind_key = kwargs["info"].get("bind_key") + + metadata = self._make_metadata(bind_key) + return super().__new__(cls, args[0], metadata, *args[1:], **kwargs) + + return Table + + def _make_declarative_base( + self, model: t.Type[Model] | sa.orm.DeclarativeMeta + ) -> t.Type[t.Any]: + """Create a SQLAlchemy declarative model class. The result is available as + :attr:`Model`. + + To customize, subclass :class:`.Model` and pass it as ``model_class`` to + :class:`SQLAlchemy`. To customize at the metaclass level, pass an already + created declarative model class as ``model_class``. - def make_declarative_base(self, model, metadata=None): - """Creates the declarative base that all models will inherit from. + This method is used for internal setup. Its signature may change at any time. - :param model: base model class (or a tuple of base classes) to pass - to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class - returned from ``declarative_base``, in which case a new base class - is not created. - :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or - none to use SQLAlchemy's default. + :meta private: + + :param model: A model base class, or an already created declarative model class. + + .. versionchanged:: 3.0 + Renamed with a leading underscore, this method is internal. - .. versionchanged 2.3.0:: - ``model`` can be an existing declarative base in order to support - complex customization such as changing the metaclass. + .. versionchanged:: 2.3 + ``model`` can be an already created declarative model class. """ + metadata = self._make_metadata(None) + if not isinstance(model, sa.orm.DeclarativeMeta): model = sa.orm.declarative_base( - cls=model, name="Model", metadata=metadata, metaclass=DefaultMeta + metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta ) - # if user passed in a declarative base and a metaclass for some reason, - # make sure the base uses the metaclass - if metadata is not None and model.metadata is not metadata: - model.metadata = metadata + model.metadata = metadata # type: ignore[union-attr] + model.query_class = self.Query + model.__fsa__ = self + return model - if not getattr(model, "query_class", None): - model.query_class = self.Query + def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: + """Apply driver-specific configuration to an engine. - model.query = _QueryProperty(self) - return model + SQLite in-memory databases use ``StaticPool`` and disable ``check_same_thread``. + File paths are relative to the app's :func:`~flask.Flask.instance_path`, which + is created if it doesn't exist. - def init_app(self, app): - """This callback can be used to initialize an application for the - use with this database setup. Never use a database in the context - of an application not initialized that way or connections will - leak. - """ + MySQL sets ``charset="utf8mb4"``, and ``pool_timeout`` defaults to 2 hours. - # We intentionally don't set self.app = app, to support multiple - # applications. If the app is passed in the constructor, - # we set it and don't support multiple applications. - if not ( - app.config.get("SQLALCHEMY_DATABASE_URI") - or app.config.get("SQLALCHEMY_BINDS") - ): - raise RuntimeError( - "Either SQLALCHEMY_DATABASE_URI or SQLALCHEMY_BINDS needs to be set." - ) + This method is used for internal setup. Its signature may change at any time. + + :meta private: - app.config.setdefault("SQLALCHEMY_DATABASE_URI", None) - app.config.setdefault("SQLALCHEMY_BINDS", None) - app.config.setdefault("SQLALCHEMY_ECHO", False) - app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", None) - app.config.setdefault("SQLALCHEMY_COMMIT_ON_TEARDOWN", False) - app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False) - app.config.setdefault("SQLALCHEMY_ENGINE_OPTIONS", {}) - - app.extensions["sqlalchemy"] = _SQLAlchemyState(self) - - @app.teardown_appcontext - def shutdown_session(response_or_exc): - if app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"]: - warnings.warn( - "'COMMIT_ON_TEARDOWN' is deprecated and will be" - " removed in version 3.1. Call" - " 'db.session.commit()'` directly instead.", - DeprecationWarning, - ) - - if response_or_exc is None: - self.session.commit() - - self.session.remove() - return response_or_exc - - def apply_driver_hacks(self, app, sa_url, options): - """This method is called before engine creation and used to inject - driver specific hacks into the options. The `options` parameter is - a dictionary of keyword arguments that will then be used to call - the :func:`sqlalchemy.create_engine` function. - - The default implementation provides some defaults for things - like pool sizes for MySQL and SQLite. + :param options: Arguments passed to the engine. + :param app: The application that the engine configuration belongs to. .. versionchanged:: 3.0 - Change the default MySQL character set to "utf8mb4". + SQLite paths are relative to ``app.instance_path``. It does not use + ``NullPool`` if ``pool_size`` is 0. + + .. versionchanged:: 3.0 + MySQL sets ``charset="utf8mb4". It does not set ``pool_size`` to 10. + + .. versionchanged:: 3.0 + Renamed from ``apply_driver_hacks``, this method is internal. It does not + return anything. .. versionchanged:: 2.5 - Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL - immutable, so any changes to it must now be passed back up - to the original caller. + Returns ``(sa_url, options)``. """ - if sa_url.drivername.startswith("mysql"): - sa_url = _sa_url_query_setdefault(sa_url, charset="utf8mb4") - - if sa_url.drivername != "mysql+gaerdbms": - options.setdefault("pool_size", 10) - options.setdefault("pool_recycle", 7200) - elif sa_url.drivername == "sqlite": - pool_size = options.get("pool_size") - detected_in_memory = False - if sa_url.database in (None, "", ":memory:"): - detected_in_memory = True - from sqlalchemy.pool import StaticPool - - options["poolclass"] = StaticPool + url = sa.engine.make_url(options["url"]) + + if url.drivername in {"sqlite", "sqlite+pysqlite"}: + if url.database in {None, "", ":memory:"}: + options["poolclass"] = sa.pool.StaticPool + if "connect_args" not in options: options["connect_args"] = {} - options["connect_args"]["check_same_thread"] = False - # we go to memory and the pool size was explicitly set - # to 0 which is fail. Let the user know that - if pool_size == 0: - raise RuntimeError( - "SQLite in memory database with an " - "empty queue not possible due to data " - "loss." + options["connect_args"]["check_same_thread"] = False + else: + if not os.path.isabs(url.database): + os.makedirs(app.instance_path, exist_ok=True) + options["url"] = url.set( + database=os.path.join(app.instance_path, url.database) ) - # if pool size is None or explicitly set to 0 we assume the - # user did not want a queue for this sqlite connection and - # hook in the null pool. - elif not pool_size: - from sqlalchemy.pool import NullPool + elif url.drivername.startswith("mysql"): + options.setdefault("pool_recycle", 7200) + + if "charset" not in url.query: + options["url"] = url.update_query_dict({"charset": "utf8mb4"}) - options["poolclass"] = NullPool + def _make_engine( + self, bind_key: str | None, options: dict[str, t.Any], app: Flask + ) -> sa.engine.Engine: + """Create the :class:`sqlalchemy.engine.Engine` for the given bind key and app. - # If the database path is not absolute, it's relative to the - # app instance path, which might need to be created. - if not detected_in_memory and not os.path.isabs(sa_url.database): - os.makedirs(app.instance_path, exist_ok=True) - sa_url = _sa_url_set( - sa_url, database=os.path.join(app.root_path, sa_url.database) - ) + To customize, use :data:`.SQLALCHEMY_ENGINE_OPTIONS` or + :data:`.SQLALCHEMY_BINDS` config. Pass ``engine_options`` to :class:`SQLAlchemy` + to set defaults for all engines. + + This method is used for internal setup. Its signature may change at any time. + + :meta private: - return sa_url, options + :param bind_key: The name of the engine being created. + :param options: Arguments passed to the engine. + :param app: The application that the engine configuration belongs to. + + .. versionchanged:: 3.0 + Renamed from ``create_engine``, this method is internal. + """ + return sa.engine_from_config(options, prefix="") @property - def engine(self): - """Gives access to the engine. If the database configuration is bound - to a specific application (initialized with an application) this will - always return a database connection. If however the current application - is used this might raise a :exc:`RuntimeError` if no application is - active at the moment. + def metadata(self) -> sa.MetaData: + """The default metadata used by :attr:`Model` and :attr:`Table` if no bind key + is set. """ - return self.get_engine() + return self.metadatas[None] - def make_connector(self, app=None, bind=None): - """Creates the connector for a given state and bind.""" - return _EngineConnector(self, self.get_app(app), bind) + @property + def engines(self) -> t.Mapping[str | None, sa.engine.Engine]: + """Map of bind keys to :class:`sqlalchemy.engine.Engine` instances for current + application. The ``None`` key refers to the default engine, and is available as + :attr:`engine`. - def get_engine(self, app=None, bind=None): - """Returns a specific engine.""" + To customize, set the :data:`.SQLALCHEMY_BINDS` config, and set defaults by + passing the ``engine_options`` parameter to the extension. - app = self.get_app(app) - state = get_state(app) + This requires that a Flask application context is active. - with self._engine_lock: - connector = state.connectors.get(bind) + .. versionadded:: 3.0 + """ + if not has_app_context() and self._app is not None: + app = self._app + else: + app = current_app._get_current_object() - if connector is None: - connector = self.make_connector(app, bind) - state.connectors[bind] = connector + return self._app_engines[app] - return connector.get_engine() + @property + def engine(self) -> sa.engine.Engine: + """The default :class:`~sqlalchemy.engine.Engine` for the current application, + used by :attr:`session` if the :attr:`Model` or :attr:`Table` being queried does + not set a bind key. - def create_engine(self, sa_url, engine_opts): - """Override this method to have final say over how the - SQLAlchemy engine is created. + To customize, set the :data:`.SQLALCHEMY_ENGINE_OPTIONS` config, and set + defaults by passing the ``engine_options`` parameter to the extension. - In most cases, you will want to use - ``'SQLALCHEMY_ENGINE_OPTIONS'`` config variable or set - ``engine_options`` for :func:`SQLAlchemy`. + This requires that a Flask application context is active. """ - return sqlalchemy.create_engine(sa_url, **engine_opts) + return self.engines[None] + + def get_engine(self, bind_key: str | None = None) -> sa.engine.Engine: + """Get the engine for the given bind key for the current application. - def get_app(self, reference_app=None): - """Helper method that implements the logic to look up an - application.""" + This requires that a Flask application context is active. - if reference_app is not None: - return reference_app + :param bind_key: The name of the engine. - if current_app: - return current_app._get_current_object() + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. Use ``engines[key]`` instead. - if self.app is not None: - return self.app + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. + """ + import warnings - raise RuntimeError( - "No application found. Either work inside a view function or push" - " an application context. See" - " https://flask-sqlalchemy.palletsprojects.com/contexts/." + warnings.warn( + "'get_engine' is deprecated and will be removed in Flask-SQLAlchemy 3.1." + " Use 'engine' or 'engines[key]' instead.", + DeprecationWarning, + stacklevel=2, ) + return self.engines[bind_key] + + def get_tables_for_bind(self, bind_key: str | None = None) -> list[sa.Table]: + """Get all tables in the metadata for the given bind key. - def get_tables_for_bind(self, bind=None): - """Returns a list of all tables relevant for a bind.""" - result = [] - for table in self.Model.metadata.tables.values(): - if table.info.get("bind_key") == bind: - result.append(table) - return result + :param bind_key: The bind key to get. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. Use ``metadata.tables`` instead. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. + """ + import warnings + + warnings.warn( + "'get_tables_for_bind' is deprecated and will be removed in" + " Flask-SQLAlchemy 3.1. Use 'metadata.tables' instead.", + DeprecationWarning, + stacklevel=2, + ) + return list(self.metadatas[bind_key].tables.values()) + + def get_binds(self) -> dict[sa.Table, sa.engine.Engine]: + """Map all tables to their engine based on their bind key, which can be used to + create a session with ``Session(binds=db.get_binds(app))``. + + This requires that a Flask application context is active. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. ``db.session`` supports multiple + binds directly. + + .. versionchanged:: 3.0 + Removed the ``app`` parameter. + """ + import warnings + + warnings.warn( + "'get_binds' is deprecated and will be removed in Flask-SQLAlchemy 3.1." + " 'db.session' supports multiple binds directly.", + DeprecationWarning, + stacklevel=2, + ) + return { + table: engine + for bind_key, engine in self.engines.items() + for table in self.metadatas[bind_key].tables.values() + } + + def _call_for_binds( + self, bind_key: str | None | list[str | None], op_name: str + ) -> None: + """Call a method on each metadata. + + :meta private: - def get_binds(self, app=None): - """Returns a dictionary with a table->engine mapping. + :param bind_key: A bind key or list of keys. Defaults to all binds. + :param op_name: The name of the method to call. - This is suitable for use of sessionmaker(binds=db.get_binds(app)). + .. versionchanged:: 3.0 + Renamed from ``_execute_for_all_tables``. """ - app = self.get_app(app) - binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) - retval = {} - for bind in binds: - engine = self.get_engine(app, bind) - tables = self.get_tables_for_bind(bind) - retval.update({table: engine for table in tables}) - return retval - - def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): - app = self.get_app(app) - - if bind == "__all__": - binds = [None] + list(app.config.get("SQLALCHEMY_BINDS") or ()) - elif isinstance(bind, str) or bind is None: - binds = [bind] + if bind_key == "__all__": + keys: list[str | None] = list(self.metadatas) + elif bind_key is None or isinstance(bind_key, str): + keys = [bind_key] else: - binds = bind + keys = bind_key + + for key in keys: + try: + engine = self.engines[key] + except KeyError: + message = f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config." + + if key is None: + message = f"'SQLALCHEMY_DATABASE_URI' config is not set. {message}" + + raise sa.exc.UnboundExecutionError(message) from None + + metadata = self.metadatas[key] + getattr(metadata, op_name)(bind=engine) + + def create_all(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Create tables that do not exist in the database by calling + ``metadata.create_all()`` for all or some bind keys. This does not + update existing tables, use a migration library for that. - for bind in binds: - extra = {} - if not skip_tables: - tables = self.get_tables_for_bind(bind) - extra["tables"] = tables - op = getattr(self.Model.metadata, operation) - op(bind=self.get_engine(app, bind), **extra) + This requires that a Flask application context is active. - def create_all(self, bind="__all__", app=None): - """Create all tables that do not already exist in the database. - This does not update existing tables, use a migration library - for that. + :param bind_key: A bind key or list of keys to create the tables for. Defaults + to all binds. - :param bind: A bind key or list of keys to create the tables - for. Defaults to all binds. - :param app: Use this app instead of requiring an app context. + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. .. versionchanged:: 0.12 Added the ``bind`` and ``app`` parameters. """ - self._execute_for_all_tables(app, bind, "create_all") + self._call_for_binds(bind_key, "create_all") + + def drop_all(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Drop tables by calling ``metadata.drop_all()`` for all or some bind keys. + + This requires that a Flask application context is active. - def drop_all(self, bind="__all__", app=None): - """Drop all tables. + :param bind_key: A bind key or list of keys to drop the tables from. Defaults to + all binds. - :param bind: A bind key or list of keys to drop the tables for. - Defaults to all binds. - :param app: Use this app instead of requiring an app context. + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. .. versionchanged:: 0.12 Added the ``bind`` and ``app`` parameters. """ - self._execute_for_all_tables(app, bind, "drop_all") + self._call_for_binds(bind_key, "drop_all") + + def reflect(self, bind_key: str | None | list[str | None] = "__all__") -> None: + """Load table definitions from the database by calling ``metadata.reflect()`` + for all or some bind keys. - def reflect(self, bind="__all__", app=None): - """Reflects tables from the database. + This requires that a Flask application context is active. - :param bind: A bind key or list of keys to reflect the tables - from. Defaults to all binds. - :param app: Use this app instead of requiring an app context. + :param bind_key: A bind key or list of keys to reflect the tables from. Defaults + to all binds. + + .. versionchanged:: 3.0 + Renamed the ``bind`` parameter to ``bind_key``. Removed the ``app`` + parameter. .. versionchanged:: 0.12 Added the ``bind`` and ``app`` parameters. """ - self._execute_for_all_tables(app, bind, "reflect", skip_tables=True) - - def __repr__(self): - url = self.engine.url if self.app or current_app else None - return f"<{type(self).__name__} engine={url!r}>" + self._call_for_binds(bind_key, "reflect") def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None: """Apply the extension's :attr:`Query` class as the default for relationships @@ -591,8 +782,11 @@ def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None: def relationship( self, *args: t.Any, **kwargs: t.Any ) -> sa.orm.RelationshipProperty: - """A SQLAlchemy :func:`~sqlalchemy.orm.relationship` that applies this - extension's :attr:`Query` class for dynamic relationships and backrefs. + """A :func:`sqlalchemy.orm.relationship` that applies this extension's + :attr:`Query` class for dynamic relationships and backrefs. + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. """ self._set_rel_query(kwargs) return sa.orm.relationship(*args, **kwargs) @@ -600,25 +794,43 @@ def relationship( def dynamic_loader( self, argument: t.Any, **kwargs: t.Any ) -> sa.orm.RelationshipProperty: - """A SQLAlchemy :func:`~sqlalchemy.orm.dynamic_loader` that applies this - extension's :attr:`Query` class for relationships and backrefs. + """A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's + :attr:`Query` class for relationships and backrefs. + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. """ self._set_rel_query(kwargs) return sa.orm.dynamic_loader(argument, **kwargs) def _relation(self, *args: t.Any, **kwargs: t.Any) -> sa.orm.RelationshipProperty: - """A SQLAlchemy :func:`~sqlalchemy.orm.relationship` that applies this - extension's :attr:`Query` class for dynamic relationships and backrefs. + """A :func:`sqlalchemy.orm.relationship` that applies this extension's + :attr:`Query` class for dynamic relationships and backrefs. SQLAlchemy 2.0 removes this name, use ``relationship`` instead. :meta private: + + .. versionchanged:: 3.0 + The :attr:`Query` class is set on ``backref``. """ # Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``. self._set_rel_query(kwargs) return sa.orm.relation(*args, **kwargs) def __getattr__(self, name: str) -> t.Any: + if name == "db": + import warnings + + warnings.warn( + "The 'db' attribute is deprecated and will be removed in" + " Flask-SQLAlchemy 3.1. The extension is registered directly as" + " 'app.extensions[\"sqlalchemy\"]'.", + DeprecationWarning, + stacklevel=2, + ) + return self + if name == "relation": return self._relation @@ -627,6 +839,6 @@ def __getattr__(self, name: str) -> t.Any: for mod in (sa, sa.orm): if name in mod.__all__: - return getattr(sa, name) + return getattr(mod, name) raise AttributeError(name) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index ad8cbe90..fa83803a 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -1,74 +1,118 @@ +from __future__ import annotations + import re +import typing as t import sqlalchemy as sa -from sqlalchemy import inspect -from sqlalchemy import orm -from sqlalchemy.ext.declarative import DeclarativeMeta -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm.exc import UnmappedClassError -from sqlalchemy.schema import _get_table_key +import sqlalchemy.orm + +from .query import Query + +if t.TYPE_CHECKING: + from .extension import SQLAlchemy class _QueryProperty: - def __init__(self, sa): - self.sa = sa + """A class property that creates a query object for a model. + + :meta private: + """ + + @t.overload + def __get__(self, obj: None, cls: t.Type[Model]) -> Query: + ... + + @t.overload + def __get__(self, obj: Model, cls: t.Type[Model]) -> Query: + ... - def __get__(self, obj, type): - try: - mapper = orm.class_mapper(type) - if mapper: - return type.query_class(mapper, session=self.sa.session()) - except UnmappedClassError: - return None + def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query: + return cls.query_class(cls, session=cls.__fsa__.session()) class Model: - """Base class for SQLAlchemy declarative base model. + """The base class of the :class:`.SQLAlchemy.Model` declarative model class. - To define models, subclass :attr:`db.Model `, not this - class. To customize ``db.Model``, subclass this and pass it as - ``model_class`` to :class:`SQLAlchemy`. + To define models, subclass :attr:`db.Model <.SQLAlchemy.Model>`, not this. To + customize ``db.Model``, subclass this and pass it as ``model_class`` to + :class:`.SQLAlchemy`. To customize ``db.Model`` at the metaclass level, pass an + already created declarative model class as ``model_class``. """ - #: Query class used by :attr:`query`. Defaults to - # :class:`SQLAlchemy.Query`, which defaults to :class:`Query`. - query_class = None + __fsa__: t.ClassVar[SQLAlchemy] + """Internal reference to the extension object. - #: Convenience property to query the database for instances of this model - # using the current session. Equivalent to ``db.session.query(Model)`` - # unless :attr:`query_class` has been changed. - query = None + :meta private: + """ + + query_class: t.ClassVar[t.Type[Query]] = Query + """Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which + defaults to :class:`.Query`. + """ + + query: t.ClassVar[Query] = _QueryProperty() + """A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be + customized per-model by overriding :attr:`query_class`. + + .. warning:: + The ``Query`` interface is considered legacy in SQLAlchemy 2.0. Prefer using the + ``execute(select())`` pattern instead. + """ - def __repr__(self): - identity = inspect(self).identity + def __repr__(self) -> str: + state: sa.orm.InstanceState = sa.inspect(self) - if identity is None: + if state.transient: pk = f"(transient {id(self)})" + elif state.pending: + pk = f"(pending {id(self)})" else: - pk = ", ".join(str(value) for value in identity) + pk = ", ".join(str(value) for value in state.identity) return f"<{type(self).__name__} {pk}>" class BindMetaMixin(type): - def __init__(cls, name, bases, d): - bind_key = d.pop("__bind_key__", None) or getattr(cls, "__bind_key__", None) + """Metaclass mixin that sets a model's ``metadata`` based on its ``__bind_key__``. - super().__init__(name, bases, d) + If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is + ignored. If the ``metadata`` is the same as the parent model, it will not be set + directly on the child model. + """ - if bind_key is not None and getattr(cls, "__table__", None) is not None: - cls.__table__.info["bind_key"] = bind_key + __fsa__: SQLAlchemy + metadata: sa.MetaData + + def __init__(cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any]) -> None: + if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): + bind_key = getattr(cls, "__bind_key__", None) + parent_metadata = getattr(cls, "metadata", None) + metadata = cls.__fsa__._make_metadata(bind_key) + + if metadata is not parent_metadata: + cls.metadata = metadata + + super().__init__(name, bases, d) class NameMetaMixin(type): - def __init__(cls, name, bases, d): + """Metaclass mixin that sets a model's ``__tablename__`` by converting the + ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models + that do not otherwise define ``__tablename__``. If a model does not define a primary + key, it will not generate a name or ``__table__``, for single-table inheritance. + """ + + metadata: sa.MetaData + __tablename__: str + __table__: sa.Table + + def __init__(cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any]) -> None: if should_set_tablename(cls): cls.__tablename__ = camel_to_snake_case(cls.__name__) super().__init__(name, bases, d) - # __table_cls__ has run at this point - # if no table was created, use the parent table + # __table_cls__ has run. If no table was created, use the parent table. if ( "__tablename__" not in cls.__dict__ and "__table__" in cls.__dict__ @@ -76,57 +120,62 @@ def __init__(cls, name, bases, d): ): del cls.__table__ - def __table_cls__(cls, *args, **kwargs): - """This is called by SQLAlchemy during mapper setup. It determines the - final table object that the model will use. + def __table_cls__(cls, *args, **kwargs) -> sa.Table | None: + """This is called by SQLAlchemy during mapper setup. It determines the final + table object that the model will use. - If no primary key is found, that indicates single-table inheritance, - so no table will be created and ``__tablename__`` will be unset. + If no primary key is found, that indicates single-table inheritance, so no table + will be created and ``__tablename__`` will be unset. """ - # check if a table with this name already exists - # allows reflected tables to be applied to model by name - key = _get_table_key(args[0], kwargs.get("schema")) + schema = kwargs.get("schema") + + if schema is None: + key = args[0] + else: + key = f"{schema}.{args[0]}" + # Check if a table with this name already exists. Allows reflected tables to be + # applied to models by name. if key in cls.metadata.tables: return sa.Table(*args, **kwargs) - # if a primary key or constraint is found, create a table for - # joined-table inheritance + # If a primary key is found, create a table for joined-table inheritance. for arg in args: if (isinstance(arg, sa.Column) and arg.primary_key) or isinstance( arg, sa.PrimaryKeyConstraint ): return sa.Table(*args, **kwargs) - # if no base classes define a table, return one - # ensures the correct error shows up when missing a primary key + # If no base classes define a table, return one that's missing a primary key + # so SQLAlchemy shows the correct error. for base in cls.__mro__[1:-1]: if "__table__" in base.__dict__: break else: return sa.Table(*args, **kwargs) - # single-table inheritance, use the parent tablename + # Single-table inheritance, use the parent table name. __init__ will unset + # __table__ based on this. if "__tablename__" in cls.__dict__: del cls.__tablename__ + return None + -def should_set_tablename(cls): - """Determine whether ``__tablename__`` should be automatically generated - for a model. +def should_set_tablename(cls: type) -> bool: + """Determine whether ``__tablename__`` should be generated for a model. - * If no class in the MRO sets a name, one should be generated. - * If a declared attr is found, it should be used instead. - * If a name is found, it should be used if the class is a mixin, otherwise - one should be generated. - * Abstract models should not have one generated. + - If no class in the MRO sets a name, one should be generated. + - If a declared attr is found, it should be used instead. + - If a name is found, it should be used if the class is a mixin, otherwise one + should be generated. + - Abstract models should not have one generated. - Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the - model looks like single or joined-table inheritance. If no primary key is - found, the name will be unset. + Later, ``__table_cls__`` will determine if the model looks like single or + joined-table inheritance. If no primary key is found, the name will be unset. """ if cls.__dict__.get("__abstract__", False) or not any( - isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:] + isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:] ): return False @@ -134,22 +183,25 @@ def should_set_tablename(cls): if "__tablename__" not in base.__dict__: continue - if isinstance(base.__dict__["__tablename__"], declared_attr): + if isinstance(base.__dict__["__tablename__"], sa.orm.declared_attr): return False return not ( base is cls or base.__dict__.get("__abstract__", False) - or not isinstance(base, DeclarativeMeta) + or not isinstance(base, sa.orm.DeclarativeMeta) ) return True -def camel_to_snake_case(name): +def camel_to_snake_case(name: str) -> str: + """Convert a ``CamelCase`` name to ``snake_case``.""" name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) return name.lower().lstrip("_") -class DefaultMeta(NameMetaMixin, BindMetaMixin, DeclarativeMeta): - pass +class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta): + """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and + ``__tablename__`` support. + """ diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index cfb31352..dda1a736 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -1,63 +1,86 @@ -from sqlalchemy.orm import Session as SessionBase +from __future__ import annotations +import typing as t -class SignallingSession(SessionBase): - """The signalling session is the default session that Flask-SQLAlchemy - uses. It extends the default session system with bind selection and - modification tracking. +import sqlalchemy as sa +import sqlalchemy.exc +import sqlalchemy.orm - If you want to use a different session you can override the - :meth:`SQLAlchemy.create_session` function. +if t.TYPE_CHECKING: + from .extension import SQLAlchemy - .. versionadded:: 2.0 - .. versionadded:: 2.1 - The `binds` option was added, which allows a session to be joined - to an external transaction. +class Session(sa.orm.Session): + """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to + use based on the bind key associated with the metadata associated with the thing + being queried. + + To customize ``db.session``, subclass this and pass it as the ``class_`` key in the + ``session_options`` to :class:`.SQLAlchemy`. + + .. versionchanged:: 3.0 + Renamed from ``SignallingSession``. """ - def __init__(self, db, autocommit=False, autoflush=True, **options): - #: The application that this session belongs to. - self.app = app = db.get_app() + def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: + super().__init__(**kwargs) self._db = db - self._model_changes = {} - track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] - bind = options.pop("bind", None) or db.engine - binds = options.pop("binds", db.get_binds(app)) - - if track_modifications: - from . import track_modifications - - track_modifications._listen(self) - - SessionBase.__init__( - self, - autocommit=autocommit, - autoflush=autoflush, - bind=bind, - binds=binds, - **options, - ) + self._model_changes: dict[object, tuple[t.Any, str]] = {} + + def get_bind( + self, mapper=None, clause=None, bind=None, **kwargs: t.Any + ) -> sa.engine.Engine: + """Select an engine based on the ``bind_key`` of the metadata associated with + the model or table being queried. If no bind key is set, uses the default bind. - def get_bind(self, mapper=None, **kwargs): - """Return the engine or connection for a given model or - table, using the ``__bind_key__`` if it is set. + .. versionchanged:: 3.0 + The implementation more closely matches the base SQLAlchemy implementation. + + .. versionchanged:: 2.1 + Support joining an external transaction. """ - # mapper is None if someone tries to just get a connection + if bind is not None: + return bind + if mapper is not None: try: - # SA >= 1.3 - persist_selectable = mapper.persist_selectable - except AttributeError: - # SA < 1.3 - persist_selectable = mapper.mapped_table + mapper = sa.inspect(mapper) + except sa.exc.NoInspectionAvailable as e: + if isinstance(mapper, type): + raise sa.orm.exc.UnmappedClassError(mapper) from e + + raise + + clause = mapper.persist_selectable - info = getattr(persist_selectable, "info", {}) - bind_key = info.get("bind_key") - if bind_key is not None: - from .extension import get_state + engines = self._db.engines - state = get_state(self.app) - return state.db.get_engine(self.app, bind=bind_key) + if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info: + key = clause.metadata.info["bind_key"] + + if key not in engines: + raise sa.exc.UnboundExecutionError( + f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config." + ) + + return engines[key] + + if None in engines: + return engines[None] + + return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs) + + +def __getattr__(name: str) -> t.Any: + import warnings + + if name == "SignallingSession": + warnings.warn( + "'SignallingSession' has been renamed to 'Session'. The old name is" + " deprecated and will be removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + return Session - return super().get_bind(mapper, **kwargs) + raise AttributeError(name) diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index b56ae841..10c343e9 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -71,15 +71,15 @@ def test_sqlite_relative_path(app, tmp_path): app.instance_path = tmp_path / "instance" # tests default to memory, shouldn't create - SQLAlchemy(app).get_engine() + SQLAlchemy(app) assert not app.instance_path.exists() # absolute path, shouldn't create app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:////tmp/test.sqlite" - SQLAlchemy(app).get_engine() + SQLAlchemy(app) assert not app.instance_path.exists() # relative path, should create app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.sqlite" - SQLAlchemy(app).get_engine() + SQLAlchemy(app) assert app.instance_path.exists() diff --git a/tests/test_binds.py b/tests/test_binds.py index b0a9fbb9..90f56d79 100644 --- a/tests/test_binds.py +++ b/tests/test_binds.py @@ -1,9 +1,17 @@ +import sqlalchemy as sa + from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.extension import get_state -def test_basic_binds(app, db): +def test_basic_binds(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} + db = SQLAlchemy(app) + + assert str(db.engine.url) == app.config["SQLALCHEMY_DATABASE_URI"] + + for key in "foo", "bar": + engine = db.engines[key] + assert str(engine.url) == app.config["SQLALCHEMY_BINDS"][key] class Foo(db.Model): __bind_key__ = "foo" @@ -19,45 +27,31 @@ class Baz(db.Model): db.create_all() - # simple way to check if the engines are looked up properly - assert db.get_engine(app, None) == db.engine - for key in "foo", "bar": - engine = db.get_engine(app, key) - connector = app.extensions["sqlalchemy"].connectors[key] - assert engine == connector.get_engine() - assert str(engine.url) == app.config["SQLALCHEMY_BINDS"][key] - # do the models have the correct engines? - assert db.metadata.tables["foo"].info["bind_key"] == "foo" - assert db.metadata.tables["bar"].info["bind_key"] == "bar" - assert db.metadata.tables["baz"].info.get("bind_key") is None + assert "foo" in db.metadatas["foo"].tables + assert "bar" in db.metadatas["bar"].tables + assert "baz" in db.metadata.tables # see the tables created in an engine - metadata = db.MetaData() - metadata.reflect(bind=db.get_engine(app, "foo")) + metadata = sa.MetaData() + metadata.reflect(bind=db.engines["foo"]) assert len(metadata.tables) == 1 assert "foo" in metadata.tables - metadata = db.MetaData() - metadata.reflect(bind=db.get_engine(app, "bar")) + metadata = sa.MetaData() + metadata.reflect(bind=db.engines["bar"]) assert len(metadata.tables) == 1 assert "bar" in metadata.tables - metadata = db.MetaData() - metadata.reflect(bind=db.get_engine(app)) + metadata = sa.MetaData() + metadata.reflect(bind=db.engine) assert len(metadata.tables) == 1 assert "baz" in metadata.tables - # do the session have the right binds set? - assert db.get_binds(app) == { - Foo.__table__: db.get_engine(app, "foo"), - Bar.__table__: db.get_engine(app, "bar"), - Baz.__table__: db.get_engine(app, None), - } - -def test_abstract_binds(app, db): +def test_abstract_binds(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://"} + db = SQLAlchemy(app) class AbstractFooBoundModel(db.Model): __abstract__ = True @@ -68,59 +62,38 @@ class FooBoundModel(AbstractFooBoundModel): db.create_all() - # does the model have the correct engines? - assert db.metadata.tables["foo_bound_model"].info["bind_key"] == "foo" + # does the model have the correct engine? + assert "foo_bound_model" in db.metadatas["foo"].tables # see the tables created in an engine - metadata = db.MetaData() - metadata.reflect(bind=db.get_engine(app, "foo")) + metadata = sa.MetaData() + metadata.reflect(bind=db.engines["foo"]) assert len(metadata.tables) == 1 assert "foo_bound_model" in metadata.tables -def test_connector_cache(app): - db = SQLAlchemy() - db.init_app(app) - - with app.app_context(): - db.get_engine() - - connector = get_state(app).connectors[None] - assert connector._app is app - - -def test_polymorphic_bind(app, db): +def test_polymorphic_bind(app): bind_key = "polymorphic_bind_key" - - app.config["SQLALCHEMY_BINDS"] = { - bind_key: "sqlite:///:memory", - } + app.config["SQLALCHEMY_BINDS"] = {bind_key: "sqlite:///:memory"} + db = SQLAlchemy(app) class Base(db.Model): __bind_key__ = bind_key - __tablename__ = "base" - id = db.Column(db.Integer, primary_key=True) - p_type = db.Column(db.String(50)) - __mapper_args__ = {"polymorphic_identity": "base", "polymorphic_on": p_type} class Child1(Base): - child_1_data = db.Column(db.String(50)) - __mapper_args__ = { - "polymorphic_identity": "child_1", - } + __mapper_args__ = {"polymorphic_identity": "child_1"} - assert Base.__table__.info["bind_key"] == bind_key - assert Child1.__table__.info["bind_key"] == bind_key + assert Base.metadata.info["bind_key"] == bind_key + assert Child1.metadata.info["bind_key"] == bind_key -def test_execute_with_binds_arguments(app, db): +def test_execute_with_binds_arguments(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} + db = SQLAlchemy(app) db.create_all() - db.session.execute( - "SELECT true", bind_arguments={"bind": db.get_engine(app, "foo")} - ) + db.session.execute("SELECT true", bind_arguments={"bind": db.engines["foo"]}) diff --git a/tests/test_commit_on_teardown.py b/tests/test_commit_on_teardown.py deleted file mode 100644 index 7399395d..00000000 --- a/tests/test_commit_on_teardown.py +++ /dev/null @@ -1,37 +0,0 @@ -import flask -import pytest - - -@pytest.fixture -def client(app, db, Todo): - app.testing = False - app.config["SQLALCHEMY_COMMIT_ON_TEARDOWN"] = True - - @app.route("/") - def index(): - return "\n".join(x.title for x in Todo.query.all()) - - @app.route("/create", methods=["POST"]) - def create(): - db.session.add(Todo("Test one", "test")) - if flask.request.form.get("fail"): - raise RuntimeError("Failing as requested") - return "ok" - - return app.test_client() - - -def test_commit_on_success(client): - with pytest.warns(DeprecationWarning, match="COMMIT_ON_TEARDOWN"): - resp = client.post("/create") - - assert resp.status_code == 200 - assert client.get("/").data == b"Test one" - - -def test_roll_back_on_failure(client): - with pytest.warns(DeprecationWarning, match="COMMIT_ON_TEARDOWN"): - resp = client.post("/create", data={"fail": "on"}) - - assert resp.status_code == 500 - assert client.get("/").data == b"" diff --git a/tests/test_config.py b/tests/test_config.py index ede0a1fd..6839d1d9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,8 @@ import os -from unittest import mock import pytest -import sqlalchemy -from sqlalchemy.pool import NullPool +import sqlalchemy as sa +import sqlalchemy.pool from flask_sqlalchemy import SQLAlchemy @@ -34,7 +33,7 @@ def test_default_error_without_uri_or_binds(self, app, recwarn): with pytest.raises(RuntimeError) as exc_info: SQLAlchemy(app) - expected = "Either SQLALCHEMY_DATABASE_URI or SQLALCHEMY_BINDS needs to be set." + expected = "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." assert exc_info.value.args[0] == expected def test_defaults_with_uri(self, app, recwarn): @@ -50,7 +49,7 @@ def test_defaults_with_uri(self, app, recwarn): # Expecting no warnings for default config with URI assert len(recwarn) == 0 - assert app.config["SQLALCHEMY_BINDS"] is None + assert app.config["SQLALCHEMY_BINDS"] == {} assert app.config["SQLALCHEMY_ECHO"] is False assert app.config["SQLALCHEMY_RECORD_QUERIES"] is None assert app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] is False @@ -60,57 +59,39 @@ def test_engine_creation_ok(self, app): """create_engine() isn't called until needed. Make sure we can do that without errors or warnings. """ - assert SQLAlchemy(app).get_engine() + assert SQLAlchemy(app).engine -@mock.patch.object(sqlalchemy, "create_engine", autospec=True, spec_set=True) class TestCreateEngine: """Tests for _EngineConnector and SQLAlchemy methods involved in setting up the SQLAlchemy engine. """ - def test_engine_echo_default(self, m_create_engine, app_nr): - SQLAlchemy(app_nr).get_engine() + def test_engine_echo_default(self, app_nr): + db = SQLAlchemy(app_nr) + assert not db.engine.echo + assert not db.engine.pool.echo - args, options = m_create_engine.call_args - assert "echo" not in options - - def test_engine_echo_true(self, m_create_engine, app_nr): + def test_engine_echo_true(self, app_nr): app_nr.config["SQLALCHEMY_ECHO"] = True - SQLAlchemy(app_nr).get_engine() - - args, options = m_create_engine.call_args - assert options["echo"] is True - - def test_config_from_engine_options(self, m_create_engine, app_nr): - app_nr.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"foo": "bar"} - SQLAlchemy(app_nr).get_engine() - - args, options = m_create_engine.call_args - assert options["foo"] == "bar" - - def test_config_from_init(self, m_create_engine, app_nr): - SQLAlchemy(app_nr, engine_options={"bar": "baz"}).get_engine() - - args, options = m_create_engine.call_args - assert options["bar"] == "baz" - - def test_pool_class_default(self, m_create_engine, app_nr): - SQLAlchemy(app_nr).get_engine() + db = SQLAlchemy(app_nr) + assert db.engine.echo + assert db.engine.pool.echo - args, options = m_create_engine.call_args - assert options["poolclass"].__name__ == "StaticPool" + def test_config_from_engine_options(self, app_nr): + app_nr.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": True} + assert SQLAlchemy(app_nr).engine.echo - def test_pool_class_nullpool(self, m_create_engine, app_nr): - engine_options = {"poolclass": NullPool} - SQLAlchemy(app_nr, engine_options=engine_options).get_engine() + def test_config_from_init(self, app_nr): + db = SQLAlchemy(app_nr, engine_options={"echo": True}) + assert db.engine.echo - args, options = m_create_engine.call_args - assert options["poolclass"].__name__ == "NullPool" - assert "pool_size" not in options + def test_pool_class_default(self, app_nr): + db = SQLAlchemy(app_nr) + assert isinstance(db.engine.pool, sa.pool.StaticPool) -def test_sqlite_relative_to_app_root(app): +def test_sqlite_relative_to_instance_path(app): app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" db = SQLAlchemy(app) - assert db.engine.url.database == os.path.join(app.root_path, "test.db") + assert db.engine.url.database == os.path.join(app.instance_path, "test.db") diff --git a/tests/test_query_property.py b/tests/test_query_property.py index ab6d1781..b5740fb6 100644 --- a/tests/test_query_property.py +++ b/tests/test_query_property.py @@ -13,7 +13,8 @@ class Foo(db.Model): # If no app is bound to the SQLAlchemy instance, a # request context is required to access Model.query. - pytest.raises(RuntimeError, getattr, Foo, "query") + assert Foo.query + with app.test_request_context(): db.create_all() foo = Foo() From 335241ae0c07339731f811453a73de5bc4747486 Mon Sep 17 00:00:00 2001 From: David Lord Date: Wed, 7 Sep 2022 16:16:49 -0700 Subject: [PATCH 16/27] session is scoped to the current app context ensures the session is cleaned up after every request and command --- CHANGES.rst | 3 +++ src/flask_sqlalchemy/extension.py | 18 ++++++++++-------- src/flask_sqlalchemy/session.py | 6 ++++++ tests/conftest.py | 6 ++++++ tests/test_binds.py | 2 ++ tests/test_model_class.py | 1 + tests/test_pagination.py | 13 +++++++------ tests/test_query_class.py | 5 +++++ tests/test_query_property.py | 18 +++++------------- tests/test_regressions.py | 5 +++-- tests/test_table_name.py | 1 + 11 files changed, 49 insertions(+), 29 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 8518ba78..68ad41c3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -67,6 +67,9 @@ Unreleased from various methods. - ``SQLALCHEMY_RECORD_QUERIES`` configuration takes precedence over ``app.debug`` and ``app.testing``, allowing it to be disabled in those modes. +- The session is scoped to the current app context instead of the thread. This + requires that an app context is active. This ensures that the session is cleaned up + after every request. Version 2.5.1 diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 6fc78c9a..9ee4894a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -16,15 +16,9 @@ from .model import DefaultMeta from .model import Model from .query import Query +from .session import _app_ctx_id from .session import Session -# Scope the session to the current greenlet if greenlet is available, -# otherwise fall back to the current thread. -try: - from greenlet import getcurrent as _ident_func -except ImportError: - from threading import get_ident as _ident_func - class SQLAlchemy: """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, @@ -143,6 +137,11 @@ def __init__( application context exits. Customize this by passing ``session_options`` to the extension. + + This requires that a Flask application context is active. + + .. versionchanged:: 3.0 + The session is scoped to the current app context. """ self.metadatas: dict[str | None, sa.MetaData] = {} @@ -332,10 +331,13 @@ def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_sessi :param options: The ``session_options`` parameter from ``__init__``. Keyword arguments passed to the session factory. A ``scopefunc`` key is popped. + .. versionchanged:: 3.0 + The session is scoped to the current app context. + .. versionchanged:: 3.0 Renamed from ``create_scoped_session``, this method is internal. """ - scope = options.pop("scopefunc", _ident_func) + scope = options.pop("scopefunc", _app_ctx_id) factory = self._make_session_factory(options) return sa.orm.scoped_session(factory, scope) diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index dda1a736..1ac4c8b3 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -5,6 +5,7 @@ import sqlalchemy as sa import sqlalchemy.exc import sqlalchemy.orm +from flask.globals import app_ctx if t.TYPE_CHECKING: from .extension import SQLAlchemy @@ -71,6 +72,11 @@ def get_bind( return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs) +def _app_ctx_id() -> int: + """Get the id of the current Flask application context for the session scope.""" + return id(app_ctx._get_current_object()) + + def __getattr__(name: str) -> t.Any: import warnings diff --git a/tests/conftest.py b/tests/conftest.py index cccc9b2d..066476fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,12 @@ def app(request): return app +@pytest.fixture +def app_ctx(app): + with app.app_context() as ctx: + yield ctx + + @pytest.fixture def db(app): return SQLAlchemy(app) diff --git a/tests/test_binds.py b/tests/test_binds.py index 90f56d79..8cc1d903 100644 --- a/tests/test_binds.py +++ b/tests/test_binds.py @@ -1,3 +1,4 @@ +import pytest import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy @@ -92,6 +93,7 @@ class Child1(Base): assert Child1.metadata.info["bind_key"] == bind_key +@pytest.mark.usefixtures("app_ctx") def test_execute_with_binds_arguments(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} db = SQLAlchemy(app) diff --git a/tests/test_model_class.py b/tests/test_model_class.py index 728e7eb3..166d04a9 100644 --- a/tests/test_model_class.py +++ b/tests/test_model_class.py @@ -34,6 +34,7 @@ class User(db.Model): pass +@pytest.mark.usefixtures("app_ctx") def test_repr(db): class User(db.Model): name = db.Column(db.String, primary_key=True) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1d91e580..c962a4e6 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -51,18 +51,18 @@ def index(): assert p.total == 100 +@pytest.mark.usefixtures("app_ctx") def test_query_paginate_more_than_20(app, db, Todo): - with app.app_context(): - db.session.add_all(Todo("", "") for _ in range(20)) - db.session.commit() + db.session.add_all(Todo("", "") for _ in range(20)) + db.session.commit() assert len(Todo.query.paginate(max_per_page=10).items) == 10 +@pytest.mark.usefixtures("app_ctx") def test_paginate_min(app, db, Todo): - with app.app_context(): - db.session.add_all(Todo(str(x), "") for x in range(20)) - db.session.commit() + db.session.add_all(Todo(str(x), "") for x in range(20)) + db.session.commit() assert Todo.query.paginate(error_out=False, page=-1).items[0].title == "0" assert len(Todo.query.paginate(error_out=False, per_page=0).items) == 0 @@ -75,6 +75,7 @@ def test_paginate_min(app, db, Todo): Todo.query.paginate(per_page=-1) +@pytest.mark.usefixtures("app_ctx") def test_paginate_without_count(app, db, Todo): with app.app_context(): db.session.add_all(Todo("", "") for _ in range(20)) diff --git a/tests/test_query_class.py b/tests/test_query_class.py index f74fff64..c598ea5f 100644 --- a/tests/test_query_class.py +++ b/tests/test_query_class.py @@ -1,7 +1,10 @@ +import pytest + from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.query import Query +@pytest.mark.usefixtures("app_ctx") def test_default_query_class(db): class Parent(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -21,6 +24,7 @@ class Child(db.Model): assert isinstance(db.session.query(Parent), Query) +@pytest.mark.usefixtures("app_ctx") def test_custom_query_class(app): class CustomQueryClass(Query): pass @@ -47,6 +51,7 @@ class Child(db.Model): assert isinstance(db.session.query(Parent), CustomQueryClass) +@pytest.mark.usefixtures("app_ctx") def test_dont_override_model_default(app): class CustomQueryClass(Query): pass diff --git a/tests/test_query_property.py b/tests/test_query_property.py index b5740fb6..8c12ec87 100644 --- a/tests/test_query_property.py +++ b/tests/test_query_property.py @@ -4,16 +4,15 @@ from flask_sqlalchemy import SQLAlchemy -def test_no_app_bound(app): +def test_app_ctx_required(app): db = SQLAlchemy() db.init_app(app) class Foo(db.Model): id = db.Column(db.Integer, primary_key=True) - # If no app is bound to the SQLAlchemy instance, a - # request context is required to access Model.query. - assert Foo.query + with pytest.raises(RuntimeError): + assert Foo.query with app.test_request_context(): db.create_all() @@ -23,15 +22,7 @@ class Foo(db.Model): assert len(Foo.query.all()) == 1 -def test_app_bound(db, Todo): - # If an app was passed to the SQLAlchemy constructor, - # the query property is always available. - todo = Todo("Test", "test") - db.session.add(todo) - db.session.commit() - assert len(Todo.query.all()) == 1 - - +@pytest.mark.usefixtures("app_ctx") def test_get_or_404(Todo): with pytest.raises(NotFound): Todo.query.get_or_404(1) @@ -44,6 +35,7 @@ def test_get_or_404(Todo): assert e_info.value.description == expected +@pytest.mark.usefixtures("app_ctx") def test_first_or_404(Todo): with pytest.raises(NotFound): Todo.query.first_or_404() diff --git a/tests/test_regressions.py b/tests/test_regressions.py index b216282b..ec4c3e82 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -36,6 +36,7 @@ class SubBase(Base): db.create_all() +@pytest.mark.usefixtures("app_ctx") def test_joined_inheritance_relation(db): class Relation(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -61,9 +62,9 @@ class SubBase(Base): base.relations = [Relation(name="foo")] db.session.add(base) db.session.commit() - - base = base.query.one() + base.query.one() +@pytest.mark.usefixtures("app_ctx") def test_connection_binds(db): assert db.session.connection() diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 5a9ccee1..a970094f 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -158,6 +158,7 @@ class RubberDuck(IdMixin, Duck): assert RubberDuck.__tablename__ == "rubber_duck" +@pytest.mark.usefixtures("app_ctx") def test_manual_name(db): """Setting a manual name prevents generation for the immediate model. A name is generated for joined but not single-table inheritance. From af20c1d5175cc758b0520c467f81174cd6cfc7fb Mon Sep 17 00:00:00 2001 From: David Lord Date: Wed, 7 Sep 2022 16:46:33 -0700 Subject: [PATCH 17/27] allow Model.__init_subclass__ parameters --- CHANGES.rst | 2 ++ src/flask_sqlalchemy/model.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 68ad41c3..4018555a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -70,6 +70,8 @@ Unreleased - The session is scoped to the current app context instead of the thread. This requires that an app context is active. This ensures that the session is cleaned up after every request. +- A custom model class can implement ``__init_subclass__`` with class parameters. + :issue:`1002` Version 2.5.1 diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index fa83803a..83b6efd2 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -83,7 +83,9 @@ class BindMetaMixin(type): __fsa__: SQLAlchemy metadata: sa.MetaData - def __init__(cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any]) -> None: + def __init__( + cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any + ) -> None: if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): bind_key = getattr(cls, "__bind_key__", None) parent_metadata = getattr(cls, "metadata", None) @@ -92,7 +94,7 @@ def __init__(cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any]) -> No if metadata is not parent_metadata: cls.metadata = metadata - super().__init__(name, bases, d) + super().__init__(name, bases, d, **kwargs) class NameMetaMixin(type): @@ -106,11 +108,13 @@ class NameMetaMixin(type): __tablename__: str __table__: sa.Table - def __init__(cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any]) -> None: + def __init__( + cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any + ) -> None: if should_set_tablename(cls): cls.__tablename__ = camel_to_snake_case(cls.__name__) - super().__init__(name, bases, d) + super().__init__(name, bases, d, **kwargs) # __table_cls__ has run. If no table was created, use the parent table. if ( From 41bb661a42dafdb75de2c7a3049aa36bfc2b3eaa Mon Sep 17 00:00:00 2001 From: David Lord Date: Wed, 7 Sep 2022 18:32:37 -0700 Subject: [PATCH 18/27] address mypy findings --- src/flask_sqlalchemy/extension.py | 23 ++++++++++++------ src/flask_sqlalchemy/model.py | 12 +++++---- src/flask_sqlalchemy/query.py | 2 +- src/flask_sqlalchemy/record_queries.py | 14 +++++------ src/flask_sqlalchemy/session.py | 12 ++++++--- src/flask_sqlalchemy/track_modifications.py | 27 ++++++++++++--------- 6 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 9ee4894a..60bc6c6a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -203,7 +203,7 @@ def __init__( if app is not None: self.init_app(app) - def __repr__(self): + def __repr__(self) -> str: if not has_app_context() and self._app is None: return f"<{type(self).__name__}>" @@ -341,7 +341,9 @@ def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_sessi factory = self._make_session_factory(options) return sa.orm.scoped_session(factory, scope) - def _make_session_factory(self, options: dict[str, t.Any]) -> sa.orm.sessionmaker: + def _make_session_factory( + self, options: dict[str, t.Any] + ) -> sa.orm.sessionmaker[Session]: # type: ignore[type-var] """Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by :meth:`_make_scoped_session`. @@ -532,10 +534,13 @@ def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: options["connect_args"]["check_same_thread"] = False else: - if not os.path.isabs(url.database): + if not os.path.isabs(url.database): # type: ignore[arg-type] os.makedirs(app.instance_path, exist_ok=True) options["url"] = url.set( - database=os.path.join(app.instance_path, url.database) + database=os.path.join( + app.instance_path, + url.database, # type: ignore[arg-type] + ) ) elif url.drivername.startswith("mysql"): options.setdefault("pool_recycle", 7200) @@ -588,7 +593,7 @@ def engines(self) -> t.Mapping[str | None, sa.engine.Engine]: if not has_app_context() and self._app is not None: app = self._app else: - app = current_app._get_current_object() + app = current_app._get_current_object() # type: ignore[attr-defined] return self._app_engines[app] @@ -783,7 +788,7 @@ def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None: def relationship( self, *args: t.Any, **kwargs: t.Any - ) -> sa.orm.RelationshipProperty: + ) -> sa.orm.RelationshipProperty[t.Any]: """A :func:`sqlalchemy.orm.relationship` that applies this extension's :attr:`Query` class for dynamic relationships and backrefs. @@ -795,7 +800,7 @@ def relationship( def dynamic_loader( self, argument: t.Any, **kwargs: t.Any - ) -> sa.orm.RelationshipProperty: + ) -> sa.orm.RelationshipProperty[t.Any]: """A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's :attr:`Query` class for relationships and backrefs. @@ -805,7 +810,9 @@ def dynamic_loader( self._set_rel_query(kwargs) return sa.orm.dynamic_loader(argument, **kwargs) - def _relation(self, *args: t.Any, **kwargs: t.Any) -> sa.orm.RelationshipProperty: + def _relation( + self, *args: t.Any, **kwargs: t.Any + ) -> sa.orm.RelationshipProperty[t.Any]: """A :func:`sqlalchemy.orm.relationship` that applies this extension's :attr:`Query` class for dynamic relationships and backrefs. diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 83b6efd2..ba6124b6 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -27,7 +27,9 @@ def __get__(self, obj: Model, cls: t.Type[Model]) -> Query: ... def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query: - return cls.query_class(cls, session=cls.__fsa__.session()) + return cls.query_class( + cls, session=cls.__fsa__.session() # type: ignore[arg-type] + ) class Model: @@ -50,7 +52,7 @@ class Model: defaults to :class:`.Query`. """ - query: t.ClassVar[Query] = _QueryProperty() + query: t.ClassVar[Query] = _QueryProperty() # type: ignore[assignment] """A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be customized per-model by overriding :attr:`query_class`. @@ -60,14 +62,14 @@ class Model: """ def __repr__(self) -> str: - state: sa.orm.InstanceState = sa.inspect(self) + state = sa.inspect(self) if state.transient: pk = f"(transient {id(self)})" elif state.pending: pk = f"(pending {id(self)})" else: - pk = ", ".join(str(value) for value in state.identity) + pk = ", ".join(map(str, state.identity)) return f"<{type(self).__name__} {pk}>" @@ -124,7 +126,7 @@ def __init__( ): del cls.__table__ - def __table_cls__(cls, *args, **kwargs) -> sa.Table | None: + def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None: """This is called by SQLAlchemy during mapper setup. It determines the final table object that the model will use. diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py index a9efc35f..1ef75b57 100644 --- a/src/flask_sqlalchemy/query.py +++ b/src/flask_sqlalchemy/query.py @@ -10,7 +10,7 @@ from .pagination import Pagination -class Query(sa.orm.Query): +class Query(sa.orm.Query): # type: ignore[type-arg] """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods useful for querying in a web application. diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index dfe92403..bbad15b2 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -41,7 +41,7 @@ def get_recorded_queries() -> list[_QueryInfo]: .. versionchanged:: 3.0 The info object attribute ``context`` is renamed to ``location``. """ - return g.get("_sqlalchemy_queries", []) + return g.get("_sqlalchemy_queries", []) # type: ignore[no-any-return] @dataclasses.dataclass @@ -58,7 +58,7 @@ class _QueryInfo: ``context`` is renamed to ``location``. """ - statement: str + statement: str | None parameters: t.Any start_time: float end_time: float @@ -93,19 +93,19 @@ def __getitem__(self, key: int) -> object: return getattr(self, name) -def _listen(engine: sa.Engine) -> None: +def _listen(engine: sa.engine.Engine) -> None: sa.event.listen(engine, "before_cursor_execute", _record_start, named=True) sa.event.listen(engine, "after_cursor_execute", _record_end, named=True) -def _record_start(context: sa.ExecutionContext, **kwargs: t.Any) -> None: +def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: if not has_app_context(): return - context._fsa_start_time = perf_counter() + context._fsa_start_time = perf_counter() # type: ignore[attr-defined] -def _record_end(context: sa.ExecutionContext, **kwargs: t.Any) -> None: +def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: if not has_app_context(): return @@ -132,7 +132,7 @@ def _record_end(context: sa.ExecutionContext, **kwargs: t.Any) -> None: _QueryInfo( statement=context.statement, parameters=context.parameters, - start_time=context._fsa_start_time, + start_time=context._fsa_start_time, # type: ignore[attr-defined] end_time=perf_counter(), location=location, ) diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index 1ac4c8b3..40e7b20e 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -28,9 +28,13 @@ def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: self._db = db self._model_changes: dict[object, tuple[t.Any, str]] = {} - def get_bind( - self, mapper=None, clause=None, bind=None, **kwargs: t.Any - ) -> sa.engine.Engine: + def get_bind( # type: ignore[override] + self, + mapper: t.Any | None = None, + clause: t.Any | None = None, + bind: sa.engine.Engine | sa.engine.Connection | None = None, + **kwargs: t.Any, + ) -> sa.engine.Engine | sa.engine.Connection: """Select an engine based on the ``bind_key`` of the metadata associated with the model or table being queried. If no bind key is set, uses the default bind. @@ -74,7 +78,7 @@ def get_bind( def _app_ctx_id() -> int: """Get the id of the current Flask application context for the session scope.""" - return id(app_ctx._get_current_object()) + return id(app_ctx._get_current_object()) # type: ignore[attr-defined] def __getattr__(name: str) -> t.Any: diff --git a/src/flask_sqlalchemy/track_modifications.py b/src/flask_sqlalchemy/track_modifications.py index 4a518963..8c9f38f5 100644 --- a/src/flask_sqlalchemy/track_modifications.py +++ b/src/flask_sqlalchemy/track_modifications.py @@ -4,19 +4,20 @@ import sqlalchemy as sa import sqlalchemy.event +import sqlalchemy.orm from flask import current_app from flask import has_app_context -from flask.signals import Namespace +from flask.signals import Namespace # type: ignore[attr-defined] if t.TYPE_CHECKING: - from .session import SignallingSession + from .session import Session _signals = Namespace() models_committed = _signals.signal("models-committed") before_models_committed = _signals.signal("before-models-committed") -def _listen(session) -> None: +def _listen(session: sa.orm.scoped_session) -> None: sa.event.listen(session, "before_flush", _record_ops, named=True) sa.event.listen(session, "before_commit", _record_ops, named=True) sa.event.listen(session, "before_commit", _before_commit) @@ -24,7 +25,7 @@ def _listen(session) -> None: sa.event.listen(session, "after_rollback", _after_rollback) -def _record_ops(session: SignallingSession, **kwargs: t.Any) -> None: +def _record_ops(session: Session, **kwargs: t.Any) -> None: if not has_app_context(): return @@ -42,30 +43,34 @@ def _record_ops(session: SignallingSession, **kwargs: t.Any) -> None: session._model_changes[key] = (target, operation) -def _before_commit(session: SignallingSession) -> None: +def _before_commit(session: Session) -> None: if not has_app_context(): return - if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + app = current_app._get_current_object() # type: ignore[attr-defined] + + if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: return if session._model_changes: changes = list(session._model_changes.values()) - before_models_committed.send(current_app._get_current_object(), changes=changes) + before_models_committed.send(app, changes=changes) -def _after_commit(session: SignallingSession) -> None: +def _after_commit(session: Session) -> None: if not has_app_context(): return - if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: + app = current_app._get_current_object() # type: ignore[attr-defined] + + if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]: return if session._model_changes: changes = list(session._model_changes.values()) - models_committed.send(current_app._get_current_object(), changes=changes) + models_committed.send(app, changes=changes) session._model_changes.clear() -def _after_rollback(session: SignallingSession) -> None: +def _after_rollback(session: Session) -> None: session._model_changes.clear() From c285fc3994b2d125702be313dee50e04b493201d Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 8 Sep 2022 05:42:29 -0700 Subject: [PATCH 19/27] deprecate top-level imports --- src/flask_sqlalchemy/__init__.py | 39 ++++++++++++++++++++++++++ src/flask_sqlalchemy/query.py | 15 ---------- src/flask_sqlalchemy/record_queries.py | 15 ---------- src/flask_sqlalchemy/session.py | 15 ---------- 4 files changed, 39 insertions(+), 45 deletions(-) diff --git a/src/flask_sqlalchemy/__init__.py b/src/flask_sqlalchemy/__init__.py index d54cf4a7..7de06603 100644 --- a/src/flask_sqlalchemy/__init__.py +++ b/src/flask_sqlalchemy/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + from .extension import SQLAlchemy __version__ = "3.0.0.dev0" @@ -5,3 +9,38 @@ __all__ = [ "SQLAlchemy", ] + +_deprecated_map = { + "Model": ".model.Model", + "DefaultMeta": ".model.DefaultMeta", + "Pagination": ".pagination.Pagination", + "BaseQuery": ".query.Query", + "get_debug_queries": ".record_queries.get_recorded_queries", + "SignallingSession": ".session.Session", + "before_models_committed": ".track_modifications.before_models_committed", + "models_committed": ".track_modifications.models_committed", +} + + +def __getattr__(name: str) -> t.Any: + import importlib + import warnings + + if name in _deprecated_map: + path = _deprecated_map[name] + import_path, _, new_name = path.rpartition(".") + action = "moved and renamed" + + if new_name == name: + action = "moved" + + warnings.warn( + f"'{name}' has been {action} to '{path[1:]}'. The top-level import is" + " deprecated and will be removed in Flask-SQLAlchemy 3.1.", + DeprecationWarning, + stacklevel=2, + ) + mod = importlib.import_module(import_path, __name__) + return getattr(mod, new_name) + + raise AttributeError(name) diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py index 1ef75b57..a1632c60 100644 --- a/src/flask_sqlalchemy/query.py +++ b/src/flask_sqlalchemy/query.py @@ -102,18 +102,3 @@ def paginate( error_out=error_out, count=count, ) - - -def __getattr__(name: str) -> t.Any: - import warnings - - if name == "BaseQuery": - warnings.warn( - "'BaseQuery' is renamed to 'Query'. The old name is deprecated and will be" - " removed in Flask-SQLAlchemy 3.1.", - DeprecationWarning, - stacklevel=2, - ) - return Query - - raise AttributeError(name) diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index bbad15b2..d50492be 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -137,18 +137,3 @@ def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: location=location, ) ) - - -def __getattr__(name: str) -> t.Any: - import warnings - - if name == "get_debug_queries": - warnings.warn( - "'get_debug_queries' is renamed to 'get_recorded_queries'. The old name is" - " deprecated and will be removed in Flask-SQLAlchemy 3.1.", - DeprecationWarning, - stacklevel=2, - ) - return get_recorded_queries - - raise AttributeError(name) diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index 40e7b20e..2715969c 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -79,18 +79,3 @@ def get_bind( # type: ignore[override] def _app_ctx_id() -> int: """Get the id of the current Flask application context for the session scope.""" return id(app_ctx._get_current_object()) # type: ignore[attr-defined] - - -def __getattr__(name: str) -> t.Any: - import warnings - - if name == "SignallingSession": - warnings.warn( - "'SignallingSession' has been renamed to 'Session'. The old name is" - " deprecated and will be removed in Flask-SQLAlchemy 3.1.", - DeprecationWarning, - stacklevel=2, - ) - return Session - - raise AttributeError(name) From 5f439821d211d747f32899e3f59f883a9219bee8 Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 8 Sep 2022 06:39:03 -0700 Subject: [PATCH 20/27] require app context --- CHANGES.rst | 2 ++ src/flask_sqlalchemy/extension.py | 15 ++++++--------- tests/conftest.py | 10 +++++++--- tests/test_binds.py | 2 ++ tests/test_config.py | 9 +++++++++ tests/test_regressions.py | 6 ++++-- tests/test_sessions.py | 31 +++++++++++++++++++++++-------- 7 files changed, 53 insertions(+), 22 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 4018555a..101cac49 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -72,6 +72,8 @@ Unreleased after every request. - A custom model class can implement ``__init_subclass__`` with class parameters. :issue:`1002` +- An active Flask application context is always required to access ``session`` and + ``engine``, regardless of if an application was passed to the constructor. Version 2.5.1 diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 60bc6c6a..2e1a7f11 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -58,6 +58,10 @@ class for further customization. lower precedence than application config. See :func:`sqlalchemy.create_engine` for a list of arguments. + .. versionchanged:: 3.0 + An active Flask application context is always required to access ``session`` and + ``engine``. + .. versionchanged:: 3.0 Separate ``metadata`` are used for each bind key. @@ -198,13 +202,11 @@ def __init__( self._app_engines: WeakKeyDictionary[Flask, dict[str | None, sa.engine.Engine]] self._app_engines = WeakKeyDictionary() - self._app: Flask | None = app - if app is not None: self.init_app(app) def __repr__(self) -> str: - if not has_app_context() and self._app is None: + if not has_app_context(): return f"<{type(self).__name__}>" message = f"{type(self).__name__} {self.engine.url}" @@ -234,7 +236,6 @@ def init_app(self, app: Flask) -> None: :param app: The Flask application to initialize. """ - # We intentionally don't set self._app, to support initializing multiple apps. app.extensions["sqlalchemy"] = self if app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN", False): @@ -590,11 +591,7 @@ def engines(self) -> t.Mapping[str | None, sa.engine.Engine]: .. versionadded:: 3.0 """ - if not has_app_context() and self._app is not None: - app = self._app - else: - app = current_app._get_current_object() # type: ignore[attr-defined] - + app = current_app._get_current_object() # type: ignore[attr-defined] return self._app_engines[app] @property diff --git a/tests/conftest.py b/tests/conftest.py index 066476fe..e6bf8e6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ def db(app): @pytest.fixture -def Todo(db): +def Todo(app, db): class Todo(db.Model): __tablename__ = "todos" id = db.Column("todo_id", db.Integer, primary_key=True) @@ -41,6 +41,10 @@ def __init__(self, title, text): self.done = False self.pub_date = datetime.utcnow() - db.create_all() + with app.app_context(): + db.create_all() + yield Todo - db.drop_all() + + with app.app_context(): + db.drop_all() diff --git a/tests/test_binds.py b/tests/test_binds.py index 8cc1d903..308e3373 100644 --- a/tests/test_binds.py +++ b/tests/test_binds.py @@ -4,6 +4,7 @@ from flask_sqlalchemy import SQLAlchemy +@pytest.mark.usefixtures("app_ctx") def test_basic_binds(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} db = SQLAlchemy(app) @@ -50,6 +51,7 @@ class Baz(db.Model): assert "baz" in metadata.tables +@pytest.mark.usefixtures("app_ctx") def test_abstract_binds(app): app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://"} db = SQLAlchemy(app) diff --git a/tests/test_config.py b/tests/test_config.py index 6839d1d9..7b108bc0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,6 +16,12 @@ def app_nr(app): return app +@pytest.fixture +def nr_app_ctx(app_nr): + with app_nr.app_context() as ctx: + yield ctx + + class TestConfigKeys: def test_default_error_without_uri_or_binds(self, app, recwarn): """ @@ -55,6 +61,7 @@ def test_defaults_with_uri(self, app, recwarn): assert app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] is False assert app.config["SQLALCHEMY_ENGINE_OPTIONS"] == {} + @pytest.mark.usefixtures("app_ctx") def test_engine_creation_ok(self, app): """create_engine() isn't called until needed. Make sure we can do that without errors or warnings. @@ -62,6 +69,7 @@ def test_engine_creation_ok(self, app): assert SQLAlchemy(app).engine +@pytest.mark.usefixtures("nr_app_ctx") class TestCreateEngine: """Tests for _EngineConnector and SQLAlchemy methods involved in setting up the SQLAlchemy engine. @@ -91,6 +99,7 @@ def test_pool_class_default(self, app_nr): assert isinstance(db.engine.pool, sa.pool.StaticPool) +@pytest.mark.usefixtures("nr_app_ctx") def test_sqlite_relative_to_instance_path(app): app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" db = SQLAlchemy(app) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index ec4c3e82..14bc00c9 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -7,7 +7,8 @@ def db(app, db): return db -def test_joined_inheritance(db): +@pytest.mark.usefixtures("app_ctx") +def test_joined_inheritance(app, db): class Base(db.Model): id = db.Column(db.Integer, primary_key=True) type = db.Column(db.String(20)) @@ -22,7 +23,8 @@ class SubBase(Base): db.create_all() -def test_single_table_inheritance(db): +@pytest.mark.usefixtures("app_ctx") +def test_single_table_inheritance(app, db): class Base(db.Model): id = db.Column(db.Integer, primary_key=True) type = db.Column(db.String(20)) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 5cb1468d..e4183db2 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -1,5 +1,4 @@ -import random - +import pytest import sqlalchemy as sa from sqlalchemy.orm import sessionmaker @@ -10,7 +9,8 @@ def test_default_session_scoping(app, db): class FOOBar(db.Model): id = db.Column(db.Integer, primary_key=True) - db.create_all() + with app.app_context(): + db.create_all() with app.test_request_context(): fb = FOOBar() @@ -19,17 +19,32 @@ class FOOBar(db.Model): def test_session_scoping_changing(app): - db = SQLAlchemy(app, session_options={"scopefunc": random.random}) + count = 0 + + def scope(): + nonlocal count + count += 1 + return count + + db = SQLAlchemy(app, session_options={"scopefunc": scope}) class Example(db.Model): id = db.Column(db.Integer, primary_key=True) - db.create_all() - fb = Example() - db.session.add(fb) - assert fb not in db.session # because a new scope is generated on each call + with app.app_context(): + db.create_all() + fb = Example() + db.session.add(fb) + assert fb not in db.session # because a new scope is generated on each call + assert count == 2 + + for session in db.session.registry.registry.values(): + session.close() + + db.session.registry.registry.clear() +@pytest.mark.usefixtures("app_ctx") def test_insert_update_delete(db): # Ensure _SignalTrackingMapperExtension doesn't croak when # faced with a vanilla SQLAlchemy session. Verify that From 867aaea25b9e64762f20e5dd0fb35a92af3e7cac Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 12 Sep 2022 16:30:17 -0700 Subject: [PATCH 21/27] rename test files --- tests/{test_table_name.py => test_model_name.py} | 0 tests/{test_sessions.py => test_session.py} | 0 tests/{test_signals.py => test_track_modifications.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{test_table_name.py => test_model_name.py} (100%) rename tests/{test_sessions.py => test_session.py} (100%) rename tests/{test_signals.py => test_track_modifications.py} (100%) diff --git a/tests/test_table_name.py b/tests/test_model_name.py similarity index 100% rename from tests/test_table_name.py rename to tests/test_model_name.py diff --git a/tests/test_sessions.py b/tests/test_session.py similarity index 100% rename from tests/test_sessions.py rename to tests/test_session.py diff --git a/tests/test_signals.py b/tests/test_track_modifications.py similarity index 100% rename from tests/test_signals.py rename to tests/test_track_modifications.py From 749602d17f18ba0da73f4158aff9b5d403776a6a Mon Sep 17 00:00:00 2001 From: David Lord Date: Mon, 12 Sep 2022 16:40:51 -0700 Subject: [PATCH 22/27] rewrite tests --- tests/conftest.py | 37 +++--- tests/test_basic_app.py | 85 -------------- tests/test_binds.py | 103 ----------------- tests/test_config.py | 106 ----------------- tests/test_engine.py | 106 +++++++++++++++++ tests/test_meta_data.py | 51 --------- tests/test_metadata.py | 102 +++++++++++++++++ tests/test_model.py | 49 ++++++++ tests/test_model_bind.py | 90 +++++++++++++++ tests/test_model_class.py | 65 ----------- tests/test_model_name.py | 110 +++++++++--------- tests/test_pagination.py | 184 +++++++++++++++++++++--------- tests/test_query.py | 98 ++++++++++++++++ tests/test_query_class.py | 65 ----------- tests/test_query_property.py | 48 -------- tests/test_record_queries.py | 59 ++++++++++ tests/test_regressions.py | 72 ------------ tests/test_session.py | 92 +++++++-------- tests/test_sqlalchemy_includes.py | 14 --- tests/test_table_bind.py | 39 +++++++ tests/test_track_modifications.py | 97 ++++++++-------- 21 files changed, 832 insertions(+), 840 deletions(-) delete mode 100644 tests/test_basic_app.py delete mode 100644 tests/test_binds.py delete mode 100644 tests/test_config.py create mode 100644 tests/test_engine.py delete mode 100644 tests/test_meta_data.py create mode 100644 tests/test_metadata.py create mode 100644 tests/test_model.py create mode 100644 tests/test_model_bind.py delete mode 100644 tests/test_model_class.py create mode 100644 tests/test_query.py delete mode 100644 tests/test_query_class.py delete mode 100644 tests/test_query_property.py create mode 100644 tests/test_record_queries.py delete mode 100644 tests/test_regressions.py delete mode 100644 tests/test_sqlalchemy_includes.py create mode 100644 tests/test_table_bind.py diff --git a/tests/conftest.py b/tests/conftest.py index e6bf8e6d..ff09ab27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,45 +1,40 @@ -from datetime import datetime +from __future__ import annotations + +import typing as t +from pathlib import Path -import flask import pytest +import sqlalchemy as sa +from flask import Flask +from flask.ctx import AppContext from flask_sqlalchemy import SQLAlchemy @pytest.fixture -def app(request): - app = flask.Flask(request.module.__name__) - app.testing = True - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" +def app(request: pytest.FixtureRequest, tmp_path: Path) -> Flask: + app = Flask(request.module.__name__, instance_path=str(tmp_path / "instance")) + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" + app.config["SQLALCHEMY_RECORD_QUERIES"] = False return app @pytest.fixture -def app_ctx(app): +def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: with app.app_context() as ctx: yield ctx @pytest.fixture -def db(app): +def db(app: Flask) -> SQLAlchemy: return SQLAlchemy(app) @pytest.fixture -def Todo(app, db): +def Todo(app: Flask, db: SQLAlchemy) -> t.Any: class Todo(db.Model): - __tablename__ = "todos" - id = db.Column("todo_id", db.Integer, primary_key=True) - title = db.Column(db.String(60)) - text = db.Column(db.String) - done = db.Column(db.Boolean) - pub_date = db.Column(db.DateTime) - - def __init__(self, title, text): - self.title = title - self.text = text - self.done = False - self.pub_date = datetime.utcnow() + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String) with app.app_context(): db.create_all() diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py deleted file mode 100644 index 10c343e9..00000000 --- a/tests/test_basic_app.py +++ /dev/null @@ -1,85 +0,0 @@ -import flask - -from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.record_queries import get_recorded_queries - - -def test_basic_insert(app, db, Todo): - @app.route("/") - def index(): - return "\n".join(x.title for x in Todo.query.all()) - - @app.route("/add", methods=["POST"]) - def add(): - form = flask.request.form - todo = Todo(form["title"], form["text"]) - db.session.add(todo) - db.session.commit() - return "added" - - c = app.test_client() - c.post("/add", data=dict(title="First Item", text="The text")) - c.post("/add", data=dict(title="2nd Item", text="The text")) - rv = c.get("/") - assert rv.data == b"First Item\n2nd Item" - - -def test_query_recording(app, db, Todo): - with app.test_request_context(): - todo = Todo("Test 1", "test") - db.session.add(todo) - db.session.flush() - todo.done = True - db.session.commit() - - queries = get_recorded_queries() - assert len(queries) == 2 - - query = queries[0] - assert "insert into" in query.statement.lower() - assert query.parameters[0][0] == "Test 1" - assert query.parameters[0][1] == "test" - assert "test_basic_app.py" in query.location - assert "test_query_recording" in query.location - - query = queries[1] - assert "update" in query.statement.lower() - assert query.parameters[0][0] == 1 - assert query.parameters[0][1] == 1 - - -def test_helper_api(db): - assert db.metadata == db.Model.metadata - - -def test_persist_selectable(app, db, Todo, recwarn): - """In SA 1.3, mapper.mapped_table should be replaced with - mapper.persist_selectable. - """ - with app.test_request_context(): - todo = Todo("Test 1", "test") - db.session.add(todo) - db.session.commit() - - assert len(recwarn) == 0 - - -def test_sqlite_relative_path(app, tmp_path): - """If a SQLite URI has a relative path, it should be relative to the - instance path, and that directory should be created. - """ - app.instance_path = tmp_path / "instance" - - # tests default to memory, shouldn't create - SQLAlchemy(app) - assert not app.instance_path.exists() - - # absolute path, shouldn't create - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:////tmp/test.sqlite" - SQLAlchemy(app) - assert not app.instance_path.exists() - - # relative path, should create - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.sqlite" - SQLAlchemy(app) - assert app.instance_path.exists() diff --git a/tests/test_binds.py b/tests/test_binds.py deleted file mode 100644 index 308e3373..00000000 --- a/tests/test_binds.py +++ /dev/null @@ -1,103 +0,0 @@ -import pytest -import sqlalchemy as sa - -from flask_sqlalchemy import SQLAlchemy - - -@pytest.mark.usefixtures("app_ctx") -def test_basic_binds(app): - app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} - db = SQLAlchemy(app) - - assert str(db.engine.url) == app.config["SQLALCHEMY_DATABASE_URI"] - - for key in "foo", "bar": - engine = db.engines[key] - assert str(engine.url) == app.config["SQLALCHEMY_BINDS"][key] - - class Foo(db.Model): - __bind_key__ = "foo" - __table_args__ = {"info": {"bind_key": "foo"}} - id = db.Column(db.Integer, primary_key=True) - - class Bar(db.Model): - __bind_key__ = "bar" - id = db.Column(db.Integer, primary_key=True) - - class Baz(db.Model): - id = db.Column(db.Integer, primary_key=True) - - db.create_all() - - # do the models have the correct engines? - assert "foo" in db.metadatas["foo"].tables - assert "bar" in db.metadatas["bar"].tables - assert "baz" in db.metadata.tables - - # see the tables created in an engine - metadata = sa.MetaData() - metadata.reflect(bind=db.engines["foo"]) - assert len(metadata.tables) == 1 - assert "foo" in metadata.tables - - metadata = sa.MetaData() - metadata.reflect(bind=db.engines["bar"]) - assert len(metadata.tables) == 1 - assert "bar" in metadata.tables - - metadata = sa.MetaData() - metadata.reflect(bind=db.engine) - assert len(metadata.tables) == 1 - assert "baz" in metadata.tables - - -@pytest.mark.usefixtures("app_ctx") -def test_abstract_binds(app): - app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://"} - db = SQLAlchemy(app) - - class AbstractFooBoundModel(db.Model): - __abstract__ = True - __bind_key__ = "foo" - - class FooBoundModel(AbstractFooBoundModel): - id = db.Column(db.Integer, primary_key=True) - - db.create_all() - - # does the model have the correct engine? - assert "foo_bound_model" in db.metadatas["foo"].tables - - # see the tables created in an engine - metadata = sa.MetaData() - metadata.reflect(bind=db.engines["foo"]) - assert len(metadata.tables) == 1 - assert "foo_bound_model" in metadata.tables - - -def test_polymorphic_bind(app): - bind_key = "polymorphic_bind_key" - app.config["SQLALCHEMY_BINDS"] = {bind_key: "sqlite:///:memory"} - db = SQLAlchemy(app) - - class Base(db.Model): - __bind_key__ = bind_key - __tablename__ = "base" - id = db.Column(db.Integer, primary_key=True) - p_type = db.Column(db.String(50)) - __mapper_args__ = {"polymorphic_identity": "base", "polymorphic_on": p_type} - - class Child1(Base): - child_1_data = db.Column(db.String(50)) - __mapper_args__ = {"polymorphic_identity": "child_1"} - - assert Base.metadata.info["bind_key"] == bind_key - assert Child1.metadata.info["bind_key"] == bind_key - - -@pytest.mark.usefixtures("app_ctx") -def test_execute_with_binds_arguments(app): - app.config["SQLALCHEMY_BINDS"] = {"foo": "sqlite://", "bar": "sqlite://"} - db = SQLAlchemy(app) - db.create_all() - db.session.execute("SELECT true", bind_arguments={"bind": db.engines["foo"]}) diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 7b108bc0..00000000 --- a/tests/test_config.py +++ /dev/null @@ -1,106 +0,0 @@ -import os - -import pytest -import sqlalchemy as sa -import sqlalchemy.pool - -from flask_sqlalchemy import SQLAlchemy - - -@pytest.fixture -def app_nr(app): - """Signal/event registration with record queries breaks when - sqlalchemy.create_engine() is mocked out. - """ - app.config["SQLALCHEMY_RECORD_QUERIES"] = False - return app - - -@pytest.fixture -def nr_app_ctx(app_nr): - with app_nr.app_context() as ctx: - yield ctx - - -class TestConfigKeys: - def test_default_error_without_uri_or_binds(self, app, recwarn): - """ - Test that default configuration throws an error because - SQLALCHEMY_DATABASE_URI and SQLALCHEMY_BINDS are unset - """ - - SQLAlchemy(app) - - # Our pytest fixture for creating the app sets - # SQLALCHEMY_DATABASE_URI, so undo that here so that we - # can inspect what FSA does below: - del app.config["SQLALCHEMY_DATABASE_URI"] - - with pytest.raises(RuntimeError) as exc_info: - SQLAlchemy(app) - - expected = "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." - assert exc_info.value.args[0] == expected - - def test_defaults_with_uri(self, app, recwarn): - """ - Test default config values when URI is provided, in the order they - appear in the documentation: https://flask-sqlalchemy.palletsprojects.com/config - - Our pytest fixture for creating the app sets SQLALCHEMY_DATABASE_URI - """ - - SQLAlchemy(app) - - # Expecting no warnings for default config with URI - assert len(recwarn) == 0 - - assert app.config["SQLALCHEMY_BINDS"] == {} - assert app.config["SQLALCHEMY_ECHO"] is False - assert app.config["SQLALCHEMY_RECORD_QUERIES"] is None - assert app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] is False - assert app.config["SQLALCHEMY_ENGINE_OPTIONS"] == {} - - @pytest.mark.usefixtures("app_ctx") - def test_engine_creation_ok(self, app): - """create_engine() isn't called until needed. Make sure we can - do that without errors or warnings. - """ - assert SQLAlchemy(app).engine - - -@pytest.mark.usefixtures("nr_app_ctx") -class TestCreateEngine: - """Tests for _EngineConnector and SQLAlchemy methods involved in - setting up the SQLAlchemy engine. - """ - - def test_engine_echo_default(self, app_nr): - db = SQLAlchemy(app_nr) - assert not db.engine.echo - assert not db.engine.pool.echo - - def test_engine_echo_true(self, app_nr): - app_nr.config["SQLALCHEMY_ECHO"] = True - db = SQLAlchemy(app_nr) - assert db.engine.echo - assert db.engine.pool.echo - - def test_config_from_engine_options(self, app_nr): - app_nr.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": True} - assert SQLAlchemy(app_nr).engine.echo - - def test_config_from_init(self, app_nr): - db = SQLAlchemy(app_nr, engine_options={"echo": True}) - assert db.engine.echo - - def test_pool_class_default(self, app_nr): - db = SQLAlchemy(app_nr) - assert isinstance(db.engine.pool, sa.pool.StaticPool) - - -@pytest.mark.usefixtures("nr_app_ctx") -def test_sqlite_relative_to_instance_path(app): - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" - db = SQLAlchemy(app) - assert db.engine.url.database == os.path.join(app.instance_path, "test.db") diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 00000000..a01b8ae9 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import os.path +import unittest.mock + +import pytest +import sqlalchemy as sa +import sqlalchemy.pool +from flask import Flask + +from flask_sqlalchemy import SQLAlchemy + + +def test_default_engine(app: Flask, db: SQLAlchemy) -> None: + with app.app_context(): + assert db.engine is db.engines[None] + + with pytest.raises(RuntimeError): + assert db.engine + + +@pytest.mark.usefixtures("app_ctx") +def test_engine_per_bind(app: Flask) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app) + assert db.engines["a"] is not db.engine + + +@pytest.mark.usefixtures("app_ctx") +def test_config_engine_options(app: Flask) -> None: + app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": True} + db = SQLAlchemy(app) + assert db.engine.echo + + +@pytest.mark.usefixtures("app_ctx") +def test_init_engine_options(app: Flask) -> None: + app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": False} + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app, engine_options={"echo": True}) + # init is default + assert db.engines["a"].echo + # config overrides init + assert not db.engine.echo + + +@pytest.mark.usefixtures("app_ctx") +def test_config_echo(app: Flask) -> None: + app.config["SQLALCHEMY_ECHO"] = True + db = SQLAlchemy(app) + assert db.engine.echo + assert db.engine.pool.echo + + +@pytest.mark.usefixtures("app_ctx") +@pytest.mark.parametrize( + "value", + [ + "sqlite://", + sa.engine.URL.create("sqlite"), + {"url": "sqlite://"}, + {"url": sa.engine.URL.create("sqlite")}, + ], +) +def test_url_type(app: Flask, value: str | sa.engine.URL) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": value} + db = SQLAlchemy(app) + assert str(db.engines["a"].url) == "sqlite://" + + +def test_no_default_url(app: Flask) -> None: + del app.config["SQLALCHEMY_DATABASE_URI"] + + with pytest.raises(RuntimeError) as info: + SQLAlchemy(app) + + e = "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." + assert str(info.value) == e + + +@pytest.mark.usefixtures("app_ctx") +def test_sqlite_relative_path(app: Flask) -> None: + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" + db = SQLAlchemy(app) + db.create_all() + assert isinstance(db.engine.pool, sa.pool.NullPool) + db_path = db.engine.url.database + assert db_path.startswith(app.instance_path) # type: ignore[union-attr] + assert os.path.exists(db_path) # type: ignore[arg-type] + + +@unittest.mock.patch.object(SQLAlchemy, "_make_engine", autospec=True) +def test_sqlite_memory_defaults(make_engine: unittest.mock.Mock, app: Flask) -> None: + SQLAlchemy(app) + options = make_engine.call_args[0][2] + assert options["poolclass"] is sa.pool.StaticPool + assert options["connect_args"]["check_same_thread"] is False + + +@unittest.mock.patch.object(SQLAlchemy, "_make_engine", autospec=True) +def test_mysql_defaults(make_engine: unittest.mock.Mock, app: Flask) -> None: + app.config["SQLALCHEMY_DATABASE_URI"] = "mysql:///test" + SQLAlchemy(app) + options = make_engine.call_args[0][2] + assert options["pool_recycle"] == 7200 + assert options["url"].query["charset"] == "utf8mb4" diff --git a/tests/test_meta_data.py b/tests/test_meta_data.py deleted file mode 100644 index a32b3bbc..00000000 --- a/tests/test_meta_data.py +++ /dev/null @@ -1,51 +0,0 @@ -import sqlalchemy as sa - -from flask_sqlalchemy import SQLAlchemy - - -def test_default_metadata(app): - db = SQLAlchemy(app, metadata=None) - - class One(db.Model): - id = db.Column(db.Integer, primary_key=True) - myindex = db.Column(db.Integer, index=True) - - class Two(db.Model): - id = db.Column(db.Integer, primary_key=True) - one_id = db.Column(db.Integer, db.ForeignKey(One.id)) - myunique = db.Column(db.Integer, unique=True) - - assert One.metadata.__class__ is sa.MetaData - assert Two.metadata.__class__ is sa.MetaData - - assert One.__table__.schema is None - assert Two.__table__.schema is None - - -def test_custom_metadata(app): - class CustomMetaData(sa.MetaData): - pass - - custom_metadata = CustomMetaData(schema="test_schema") - db = SQLAlchemy(app, metadata=custom_metadata) - - class One(db.Model): - id = db.Column(db.Integer, primary_key=True) - myindex = db.Column(db.Integer, index=True) - - class Two(db.Model): - id = db.Column(db.Integer, primary_key=True) - one_id = db.Column(db.Integer, db.ForeignKey(One.id)) - myunique = db.Column(db.Integer, unique=True) - - assert One.metadata is custom_metadata - assert Two.metadata is custom_metadata - - assert One.metadata.__class__ is not sa.MetaData - assert One.metadata.__class__ is CustomMetaData - - assert Two.metadata.__class__ is not sa.MetaData - assert Two.metadata.__class__ is CustomMetaData - - assert One.__table__.schema == "test_schema" - assert Two.__table__.schema == "test_schema" diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 00000000..1bb1e78d --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import pytest +import sqlalchemy as sa +import sqlalchemy.exc +from flask import Flask + +from flask_sqlalchemy import SQLAlchemy + + +def test_default_metadata(db: SQLAlchemy) -> None: + assert db.metadata is db.metadatas[None] + assert db.metadata.info["bind_key"] is None + + +def test_custom_metadata(app: Flask) -> None: + metadata = sa.MetaData() + db = SQLAlchemy(app, metadata=metadata) + assert db.metadata is metadata + assert db.metadata.info["bind_key"] is None + + +def test_metadata_per_bind(app: Flask) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app) + assert db.metadatas["a"] is not db.metadata + assert db.metadatas["a"].info["bind_key"] == "a" + + +def test_copy_naming_convention(app: Flask) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy( + app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) + ) + assert db.metadata.naming_convention["pk"] == "spk_%(table_name)s" + assert db.metadatas["a"].naming_convention == db.metadata.naming_convention + + +@pytest.mark.usefixtures("app_ctx") +def test_create_drop_all(app: Flask) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app) + + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + class Post(db.Model): + __bind_key__ = "a" + id = sa.Column(sa.Integer, primary_key=True) + + with pytest.raises(sa.exc.OperationalError): + User.query.all() + + with pytest.raises(sa.exc.OperationalError): + Post.query.all() + + db.create_all() + User.query.all() + Post.query.all() + db.drop_all() + + with pytest.raises(sa.exc.OperationalError): + User.query.all() + + with pytest.raises(sa.exc.OperationalError): + Post.query.all() + + +@pytest.mark.usefixtures("app_ctx") +@pytest.mark.parametrize("bind_key", ["a", ["a"]]) +def test_create_key_spec(app: Flask, bind_key: str | list[str | None]) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app) + + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + class Post(db.Model): + __bind_key__ = "a" + id = sa.Column(sa.Integer, primary_key=True) + + db.create_all(bind_key=bind_key) + Post.query.all() + + with pytest.raises(sa.exc.OperationalError): + User.query.all() + + +@pytest.mark.usefixtures("app_ctx") +def test_reflect(app: Flask) -> None: + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///user.db" + app.config["SQLALCHEMY_BINDS"] = {"post": "sqlite:///post.db"} + db = SQLAlchemy(app) + db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) + db.Table("post", sa.Column("id", sa.Integer, primary_key=True), bind_key="post") + db.create_all() + + db = SQLAlchemy(app) + assert not db.metadata.tables + db.reflect() + assert "user" in db.metadata.tables + assert "post" in db.metadatas["post"].tables diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 00000000..fe76df8c --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import pytest +import sqlalchemy as sa +import sqlalchemy.orm +from flask import Flask + +from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import DefaultMeta +from flask_sqlalchemy.model import Model + + +def test_default_model_class(db: SQLAlchemy) -> None: + assert db.Model.query_class is db.Query + assert db.Model.metadata is db.metadata + assert issubclass(db.Model, Model) + assert isinstance(db.Model, DefaultMeta) + + +def test_custom_model_class(app: Flask) -> None: + class CustomModel(Model): + pass + + db = SQLAlchemy(app, model_class=CustomModel) + assert issubclass(db.Model, CustomModel) + assert isinstance(db.Model, DefaultMeta) + + +def test_custom_declarative_class(app: Flask) -> None: + class CustomMeta(DefaultMeta): + pass + + CustomModel = sa.orm.declarative_base(cls=Model, name="Model", metaclass=CustomMeta) + db = SQLAlchemy(app, model_class=CustomModel) + assert db.Model is CustomModel + + +@pytest.mark.usefixtures("app_ctx") +def test_model_repr(db: SQLAlchemy) -> None: + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + db.create_all() + user = User() + assert repr(user) == f"" + db.session.add(user) + assert repr(user) == f"" + db.session.flush() + assert repr(user) == f"" diff --git a/tests/test_model_bind.py b/tests/test_model_bind.py new file mode 100644 index 00000000..7c633c83 --- /dev/null +++ b/tests/test_model_bind.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import sqlalchemy as sa + +from flask_sqlalchemy import SQLAlchemy + + +def test_bind_key_default(db: SQLAlchemy) -> None: + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + assert User.metadata is db.metadata + + +def test_metadata_per_bind(db: SQLAlchemy) -> None: + class User(db.Model): + __bind_key__ = "other" + id = sa.Column(sa.Integer, primary_key=True) + + assert User.metadata is db.metadatas["other"] + + +def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: + class UserA(db.Model): + __tablename__ = "user" + id = sa.Column(sa.Integer, primary_key=True) + + class UserB(db.Model): + __bind_key__ = "other" + __tablename__ = "user" + id = sa.Column(sa.Integer, primary_key=True) + + assert UserA.metadata is db.metadata + assert UserB.metadata is db.metadatas["other"] + assert UserA.__table__.metadata is not UserB.__table__.metadata + + +def test_inherit_parent(db: SQLAlchemy) -> None: + class User(db.Model): + __bind_key__ = "auth" + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.String) + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} + + class Admin(User): + id = sa.Column(sa.Integer, sa.ForeignKey(User.id), primary_key=True) + __mapper_args__ = {"polymorphic_identity": "admin"} + + assert "admin" in db.metadatas["auth"].tables + # inherits metadata, doesn't set it directly + assert "metadata" not in Admin.__dict__ + + +def test_inherit_abstract_parent(db: SQLAlchemy) -> None: + class AbstractUser(db.Model): + __abstract__ = True + __bind_key__ = "auth" + + class User(AbstractUser): + id = sa.Column(sa.Integer, primary_key=True) + + assert "user" in db.metadatas["auth"].tables + assert "metadata" not in User.__dict__ + + +def test_explicit_metadata(db: SQLAlchemy) -> None: + other_metadata = sa.MetaData() + + class User(db.Model): + __bind_key__ = "other" + metadata = other_metadata + id = sa.Column(sa.Integer, primary_key=True) + + assert User.__table__.metadata is other_metadata + assert "other" not in db.metadatas + + +def test_explicit_table(db: SQLAlchemy) -> None: + user_table = db.Table( + "user", + sa.Column("id", sa.Integer, primary_key=True), + bind_key="auth", + ) + + class User(db.Model): + __bind_key__ = "other" + __table__ = user_table + + assert User.__table__.metadata is db.metadatas["auth"] + assert "other" not in db.metadatas diff --git a/tests/test_model_class.py b/tests/test_model_class.py deleted file mode 100644 index 166d04a9..00000000 --- a/tests/test_model_class.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -from sqlalchemy.exc import InvalidRequestError -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.ext.declarative import DeclarativeMeta - -from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.model import BindMetaMixin -from flask_sqlalchemy.model import Model - - -def test_custom_model_class(): - class CustomModelClass(Model): - pass - - db = SQLAlchemy(model_class=CustomModelClass) - - class SomeModel(db.Model): - id = db.Column(db.Integer, primary_key=True) - - assert isinstance(SomeModel(), CustomModelClass) - - -def test_no_table_name(): - class NoNameMeta(BindMetaMixin, DeclarativeMeta): - pass - - db = SQLAlchemy( - model_class=declarative_base(cls=Model, metaclass=NoNameMeta, name="Model") - ) - - with pytest.raises(InvalidRequestError): - - class User(db.Model): - pass - - -@pytest.mark.usefixtures("app_ctx") -def test_repr(db): - class User(db.Model): - name = db.Column(db.String, primary_key=True) - - class Report(db.Model): - id = db.Column(db.Integer, primary_key=True, autoincrement=False) - user_name = db.Column(db.ForeignKey(User.name), primary_key=True) - - db.create_all() - - u = User(name="test") - assert repr(u).startswith("" - assert repr(u) == str(u) - - u2 = User(name="🐍") - db.session.add(u2) - db.session.flush() - assert repr(u2) == "" - assert repr(u2) == str(u2) - - r = Report(id=2, user_name=u.name) - db.session.add(r) - db.session.flush() - assert repr(r) == "" - assert repr(u) == str(u) diff --git a/tests/test_model_name.py b/tests/test_model_name.py index a970094f..8030a6e3 100644 --- a/tests/test_model_name.py +++ b/tests/test_model_name.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import inspect +import typing as t import pytest -from sqlalchemy.exc import ArgumentError -from sqlalchemy.ext.declarative import declared_attr +import sqlalchemy as sa +import sqlalchemy.exc +import sqlalchemy.orm +from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import camel_to_snake_case @@ -39,31 +44,31 @@ # ("__test__Method", "test___method"), ], ) -def test_camel_to_snake_case(name, expect): +def test_camel_to_snake_case(name: str, expect: str) -> None: assert camel_to_snake_case(name) == expect -def test_name(db): +def test_name(db: SQLAlchemy) -> None: class FOOBar(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class BazBar(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Ham(db.Model): __tablename__ = "spam" - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) assert FOOBar.__tablename__ == "foo_bar" assert BazBar.__tablename__ == "baz_bar" assert Ham.__tablename__ == "spam" -def test_single_name(db): +def test_single_name(db: SQLAlchemy) -> None: """Single table inheritance should not set a new name.""" class Duck(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Mallard(Duck): pass @@ -72,25 +77,25 @@ class Mallard(Duck): assert Mallard.__tablename__ == "duck" -def test_joined_name(db): +def test_joined_name(db: SQLAlchemy) -> None: """Model has a separate primary key; it should set a new name.""" class Duck(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Donald(Duck): - id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + id = sa.Column(sa.Integer, sa.ForeignKey(Duck.id), primary_key=True) assert Donald.__tablename__ == "donald" -def test_mixin_id(db): +def test_mixin_id(db: SQLAlchemy) -> None: """Primary key provided by mixin should still allow model to set tablename. """ class Base: - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Duck(Base, db.Model): pass @@ -99,38 +104,38 @@ class Duck(Base, db.Model): assert Duck.__tablename__ == "duck" -def test_mixin_attr(db): +def test_mixin_attr(db: SQLAlchemy) -> None: """A declared attr tablename will be used down multiple levels of inheritance. """ class Mixin: - @declared_attr - def __tablename__(cls): # noqa: B902 - return cls.__name__.upper() + @sa.orm.declared_attr + def __tablename__(cls) -> str: # noqa: B902 + return cls.__name__.upper() # type: ignore[attr-defined,no-any-return] class Bird(Mixin, db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Duck(Bird): # object reference - id = db.Column(db.ForeignKey(Bird.id), primary_key=True) + id = sa.Column(sa.Integer, sa.ForeignKey(Bird.id), primary_key=True) class Mallard(Duck): # string reference - id = db.Column(db.ForeignKey("DUCK.id"), primary_key=True) + id = sa.Column(sa.Integer, sa.ForeignKey("DUCK.id"), primary_key=True) assert Bird.__tablename__ == "BIRD" assert Duck.__tablename__ == "DUCK" assert Mallard.__tablename__ == "MALLARD" -def test_abstract_name(db): +def test_abstract_name(db: SQLAlchemy) -> None: """Abstract model should not set a name. Subclass should set a name.""" class Base(db.Model): __abstract__ = True - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Duck(Base): pass @@ -139,88 +144,85 @@ class Duck(Base): assert Duck.__tablename__ == "duck" -def test_complex_inheritance(db): +def test_complex_inheritance(db: SQLAlchemy) -> None: """Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class. """ class Duck(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class IdMixin: - @declared_attr - def id(cls): # noqa: B902 - return db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + @sa.orm.declared_attr + def id(cls) -> sa.Column[sa.Integer]: # noqa: B902 + return sa.Column(sa.Integer, sa.ForeignKey(Duck.id), primary_key=True) - class RubberDuck(IdMixin, Duck): + class RubberDuck(IdMixin, Duck): # type: ignore[misc] pass assert RubberDuck.__tablename__ == "rubber_duck" -@pytest.mark.usefixtures("app_ctx") -def test_manual_name(db): +def test_manual_name(db: SQLAlchemy) -> None: """Setting a manual name prevents generation for the immediate model. A name is generated for joined but not single-table inheritance. """ class Duck(db.Model): __tablename__ = "DUCK" - id = db.Column(db.Integer, primary_key=True) - type = db.Column(db.String) + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.String) __mapper_args__ = {"polymorphic_on": type} class Daffy(Duck): - id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + id = sa.Column(sa.Integer, sa.ForeignKey(Duck.id), primary_key=True) - __mapper_args__ = {"polymorphic_identity": "Warner"} + __mapper_args__ = {"polymorphic_identity": "Tower"} # type: ignore[dict-item] class Donald(Duck): - __mapper_args__ = {"polymorphic_identity": "Disney"} + __mapper_args__ = {"polymorphic_identity": "Mouse"} # type: ignore[dict-item] assert Duck.__tablename__ == "DUCK" assert Daffy.__tablename__ == "daffy" assert "__tablename__" not in Donald.__dict__ assert Donald.__tablename__ == "DUCK" - # polymorphic condition for single-table query - assert 'WHERE "DUCK".type' in str(Donald.query) -def test_primary_constraint(db): +def test_primary_constraint(db: SQLAlchemy) -> None: """Primary key will be picked up from table args.""" class Duck(db.Model): - id = db.Column(db.Integer) + id = sa.Column(sa.Integer) - __table_args__ = (db.PrimaryKeyConstraint(id),) + __table_args__ = (sa.PrimaryKeyConstraint(id),) assert Duck.__table__ is not None assert Duck.__tablename__ == "duck" -def test_no_access_to_class_property(db): +def test_no_access_to_class_property(db: SQLAlchemy) -> None: """Ensure the implementation doesn't access class properties or declared attrs while inspecting the unmapped model. """ class class_property: - def __init__(self, f): + def __init__(self, f: t.Callable[..., t.Any]) -> None: self.f = f - def __get__(self, instance, owner): + def __get__(self, instance: t.Any, owner: t.Type[t.Any]) -> t.Any: return self.f(owner) class Duck(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class ns: is_duck = False floats = False class Witch(Duck): - @declared_attr - def is_duck(self): + @sa.orm.declared_attr + def is_duck(self) -> None: # declared attrs will be accessed during mapper configuration, # but make sure they're not accessed before that info = inspect.getouterframes(inspect.currentframe())[2] @@ -228,15 +230,15 @@ def is_duck(self): ns.is_duck = True @class_property - def floats(self): + def floats(self) -> None: ns.floats = True assert ns.is_duck assert not ns.floats -def test_metadata_has_table(db): - user = db.Table("user", db.Column("id", db.Integer, primary_key=True)) +def test_metadata_has_table(db: SQLAlchemy) -> None: + user = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) class User(db.Model): pass @@ -244,8 +246,8 @@ class User(db.Model): assert User.__table__ is user -def test_correct_error_for_no_primary_key(db): - with pytest.raises(ArgumentError) as info: +def test_correct_error_for_no_primary_key(db: SQLAlchemy) -> None: + with pytest.raises(sa.exc.ArgumentError) as info: class User(db.Model): pass @@ -253,9 +255,9 @@ class User(db.Model): assert "could not assemble any primary key" in str(info.value) -def test_single_has_parent_table(db): +def test_single_has_parent_table(db: SQLAlchemy) -> None: class Duck(db.Model): - id = db.Column(db.Integer, primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) class Call(Duck): pass diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c962a4e6..f425eccb 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,84 +1,162 @@ +from __future__ import annotations + +import typing as t + import pytest +from flask import Flask from werkzeug.exceptions import NotFound +from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.pagination import Pagination -def test_basic_pagination(): - p = Pagination(query=None, page=1, per_page=20, total=500, items=[]) +def _make_page( + *, page: int = 1, per_page: int = 10, total: int | None = 150 +) -> Pagination: + return Pagination(query=None, page=page, per_page=per_page, total=total, items=[]) + + +def test_first_page() -> None: + p = _make_page() assert p.page == 1 + assert p.per_page == 10 + assert p.total == 150 + assert p.pages == 15 assert not p.has_prev + assert p.prev_num is None assert p.has_next - assert p.total == 500 - assert p.pages == 25 assert p.next_num == 2 - assert list(p.iter_pages()) == [1, 2, 3, 4, 5, None, 24, 25] - p.page = 10 - assert list(p.iter_pages()) == [1, 2, None, 8, 9, 10, 11, 12, 13, 14, None, 24, 25] - - -def test_pagination_pages_when_0_items_per_page(): - p = Pagination(query=None, page=1, per_page=0, total=500, items=[]) - assert p.pages == 0 -def test_pagination_pages_when_total_is_none(): - p = Pagination(query=None, page=1, per_page=20, total=None, items=[]) +def test_last_page() -> None: + p = _make_page(page=15) + assert p.page == 15 + assert p.has_prev + assert p.prev_num == 14 + assert not p.has_next + assert p.next_num is None + + +@pytest.mark.parametrize( + ("per_page", "total"), + [ + (0, 150), + (10, 0), + (10, None), + ], +) +def test_0_pages(per_page: int, total: int | None) -> None: + p = _make_page(per_page=per_page, total=total) assert p.pages == 0 - - -def test_query_paginate(app, db, Todo): + assert not p.has_prev + assert not p.has_next + + +@pytest.mark.parametrize( + ("page", "expect"), + [ + (1, [1, 2, 3, 4, 5, None, 14, 15]), + (2, [1, 2, 3, 4, 5, 6, None, 14, 15]), + (3, [1, 2, 3, 4, 5, 6, 7, None, 14, 15]), + (4, [1, 2, 3, 4, 5, 6, 7, 8, None, 14, 15]), + (5, [1, 2, 3, 4, 5, 6, 7, 8, 9, None, 14, 15]), + (6, [1, 2, None, 4, 5, 6, 7, 8, 9, 10, None, 14, 15]), + (7, [1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 14, 15]), + (8, [1, 2, None, 6, 7, 8, 9, 10, 11, 12, None, 14, 15]), + (9, [1, 2, None, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + (10, [1, 2, None, 8, 9, 10, 11, 12, 13, 14, 15]), + (11, [1, 2, None, 9, 10, 11, 12, 13, 14, 15]), + (12, [1, 2, None, 10, 11, 12, 13, 14, 15]), + (13, [1, 2, None, 11, 12, 13, 14, 15]), + (14, [1, 2, None, 12, 13, 14, 15]), + (15, [1, 2, None, 13, 14, 15]), + ], +) +def test_iter_pages(page: int, expect: list[int | None]) -> None: + p = _make_page(page=page) + assert list(p.iter_pages()) == expect + + +def test_iter_0_pages() -> None: + p = _make_page(total=0) + assert list(p.iter_pages()) == [] + + +@pytest.mark.parametrize("page", [1, 2, 3, 4]) +def test_iter_pages_short(page: int) -> None: + p = _make_page(page=page, total=40) + assert list(p.iter_pages()) == [1, 2, 3, 4] + + +class _PaginateCallable: + def __init__(self, app: Flask, Todo: t.Any) -> None: + self.app = app + self.Todo = Todo + + def __call__( + self, + page: int | None = None, + per_page: int | None = None, + max_per_page: int | None = None, + error_out: bool = True, + count: bool = True, + ) -> Pagination: + with self.app.test_request_context( + query_string={"page": page, "per_page": per_page} + ): + return self.Todo.query.paginate( # type: ignore[no-any-return] + max_per_page=max_per_page, error_out=error_out, count=count + ) + + +@pytest.fixture +def paginate(app: Flask, db: SQLAlchemy, Todo: t.Any) -> _PaginateCallable: with app.app_context(): - db.session.add_all([Todo("", "") for _ in range(100)]) + for i in range(1, 101): + db.session.add(Todo(title=f"task {i}")) + db.session.commit() - @app.route("/") - def index(): - p = Todo.query.paginate() - return f"{len(p.items)} items retrieved" + return _PaginateCallable(app, Todo) - c = app.test_client() - # request default - r = c.get("/") - assert r.status_code == 200 - # request args - r = c.get("/?per_page=10") - assert r.data.decode("utf8") == "10 items retrieved" - with app.app_context(): - # query default - p = Todo.query.paginate() - assert p.total == 100 +def test_paginate(paginate: _PaginateCallable) -> None: + p = paginate() + assert p.page == 1 + assert p.per_page == 20 + assert len(p.items) == 20 + assert p.total == 100 + assert p.pages == 5 -@pytest.mark.usefixtures("app_ctx") -def test_query_paginate_more_than_20(app, db, Todo): - db.session.add_all(Todo("", "") for _ in range(20)) - db.session.commit() +def test_paginate_qs(paginate: _PaginateCallable) -> None: + p = paginate(page=2, per_page=10) + assert p.page == 2 + assert p.per_page == 10 - assert len(Todo.query.paginate(max_per_page=10).items) == 10 +def test_paginate_max(paginate: _PaginateCallable) -> None: + p = paginate(per_page=100, max_per_page=50) + assert p.per_page == 50 -@pytest.mark.usefixtures("app_ctx") -def test_paginate_min(app, db, Todo): - db.session.add_all(Todo(str(x), "") for x in range(20)) - db.session.commit() - assert Todo.query.paginate(error_out=False, page=-1).items[0].title == "0" - assert len(Todo.query.paginate(error_out=False, per_page=0).items) == 0 - assert len(Todo.query.paginate(error_out=False, per_page=-1).items) == 20 +def test_no_count(paginate: _PaginateCallable) -> None: + p = paginate(count=False) + assert p.total is None - with pytest.raises(NotFound): - Todo.query.paginate(page=0) +@pytest.mark.parametrize( + ("page", "per_page"), [("abc", None), (None, "abc"), (0, None), (None, -1)] +) +def test_error_out(paginate: _PaginateCallable, page: t.Any, per_page: t.Any) -> None: with pytest.raises(NotFound): - Todo.query.paginate(per_page=-1) + paginate(page=page, per_page=per_page) @pytest.mark.usefixtures("app_ctx") -def test_paginate_without_count(app, db, Todo): - with app.app_context(): - db.session.add_all(Todo("", "") for _ in range(20)) - db.session.commit() +def test_no_items_404(Todo: t.Any) -> None: + p = Todo.query.paginate() + assert len(p.items) == 0 - assert len(Todo.query.paginate(count=False, page=1, per_page=10).items) == 10 + with pytest.raises(NotFound): + Todo.query.paginate(page=2) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..cfc79a8f --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import typing as t + +import pytest +import sqlalchemy as sa +from flask import Flask +from werkzeug.exceptions import NotFound + +from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.query import Query + + +@pytest.mark.usefixtures("app_ctx") +def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: + item = Todo() + db.session.add(item) + db.session.commit() + assert Todo.query.get_or_404(1) is item + + with pytest.raises(NotFound): + Todo.query.get_or_404(2) + + +@pytest.mark.usefixtures("app_ctx") +def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: + db.session.add(Todo(title="a")) + db.session.commit() + assert Todo.query.filter_by(title="a").first_or_404().title == "a" + + with pytest.raises(NotFound): + Todo.query.filter_by(title="b").first_or_404() + + +@pytest.mark.usefixtures("app_ctx") +def test_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: + db.session.add(Todo(title="a")) + db.session.add(Todo(title="b")) + db.session.add(Todo(title="b")) + db.session.commit() + assert Todo.query.filter_by(title="a").one_or_404().title == "a" + + with pytest.raises(NotFound): + # MultipleResultsFound + Todo.query.filter_by(title="b").one_or_404() + + with pytest.raises(NotFound): + # NoResultFound + Todo.query.filter_by(title="c").one_or_404() + + +@pytest.mark.usefixtures("app_ctx") +def test_default_query_class(db: SQLAlchemy) -> None: + class Parent(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + children1 = db.relationship("Child", backref="parent1", lazy="dynamic") + + class Child(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + parent_id = sa.Column(sa.ForeignKey(Parent.id)) + parent2 = db.relationship( + Parent, + backref=db.backref("children2", lazy="dynamic", viewonly=True), + viewonly=True, + ) + + p = Parent() + assert type(Parent.query) is Query + assert isinstance(p.children1, Query) + assert isinstance(p.children2, Query) + assert isinstance(db.session.query(Child), Query) + + +@pytest.mark.usefixtures("app_ctx") +def test_custom_query_class(app: Flask) -> None: + class CustomQuery(Query): + pass + + db = SQLAlchemy(app, query_class=CustomQuery) + + class Parent(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + children1 = db.relationship("Child", backref="parent1", lazy="dynamic") + + class Child(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + parent_id = sa.Column(sa.ForeignKey(Parent.id)) + parent2 = db.relationship( + Parent, + backref=db.backref("children2", lazy="dynamic", viewonly=True), + viewonly=True, + ) + + p = Parent() + assert type(Parent.query) is CustomQuery + assert isinstance(p.children1, CustomQuery) + assert isinstance(p.children2, CustomQuery) + assert isinstance(db.session.query(Child), CustomQuery) diff --git a/tests/test_query_class.py b/tests/test_query_class.py deleted file mode 100644 index c598ea5f..00000000 --- a/tests/test_query_class.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - -from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.query import Query - - -@pytest.mark.usefixtures("app_ctx") -def test_default_query_class(db): - class Parent(db.Model): - id = db.Column(db.Integer, primary_key=True) - children = db.relationship("Child", backref="parent", lazy="dynamic") - - class Child(db.Model): - id = db.Column(db.Integer, primary_key=True) - parent_id = db.Column(db.Integer, db.ForeignKey("parent.id")) - - p = Parent() - c = Child() - c.parent = p - - assert type(Parent.query) == Query - assert type(Child.query) == Query - assert isinstance(p.children, Query) - assert isinstance(db.session.query(Parent), Query) - - -@pytest.mark.usefixtures("app_ctx") -def test_custom_query_class(app): - class CustomQueryClass(Query): - pass - - db = SQLAlchemy(app, query_class=CustomQueryClass) - - class Parent(db.Model): - id = db.Column(db.Integer, primary_key=True) - children = db.relationship("Child", backref="parent", lazy="dynamic") - - class Child(db.Model): - id = db.Column(db.Integer, primary_key=True) - parent_id = db.Column(db.Integer, db.ForeignKey("parent.id")) - - p = Parent() - c = Child() - c.parent = p - - assert type(Parent.query) == CustomQueryClass - assert type(Child.query) == CustomQueryClass - assert isinstance(p.children, CustomQueryClass) - assert db.Query == CustomQueryClass - assert db.Model.query_class == CustomQueryClass - assert isinstance(db.session.query(Parent), CustomQueryClass) - - -@pytest.mark.usefixtures("app_ctx") -def test_dont_override_model_default(app): - class CustomQueryClass(Query): - pass - - db = SQLAlchemy(app, query_class=CustomQueryClass) - - class SomeModel(db.Model): - id = db.Column(db.Integer, primary_key=True) - query_class = Query - - assert type(SomeModel.query) == Query diff --git a/tests/test_query_property.py b/tests/test_query_property.py deleted file mode 100644 index 8c12ec87..00000000 --- a/tests/test_query_property.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from werkzeug.exceptions import NotFound - -from flask_sqlalchemy import SQLAlchemy - - -def test_app_ctx_required(app): - db = SQLAlchemy() - db.init_app(app) - - class Foo(db.Model): - id = db.Column(db.Integer, primary_key=True) - - with pytest.raises(RuntimeError): - assert Foo.query - - with app.test_request_context(): - db.create_all() - foo = Foo() - db.session.add(foo) - db.session.commit() - assert len(Foo.query.all()) == 1 - - -@pytest.mark.usefixtures("app_ctx") -def test_get_or_404(Todo): - with pytest.raises(NotFound): - Todo.query.get_or_404(1) - - expected = "Expected message" - - with pytest.raises(NotFound) as e_info: - Todo.query.get_or_404(1, description=expected) - - assert e_info.value.description == expected - - -@pytest.mark.usefixtures("app_ctx") -def test_first_or_404(Todo): - with pytest.raises(NotFound): - Todo.query.first_or_404() - - expected = "Expected message" - - with pytest.raises(NotFound) as e_info: - Todo.query.first_or_404(description=expected) - - assert e_info.value.description == expected diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py new file mode 100644 index 00000000..e1e48abb --- /dev/null +++ b/tests/test_record_queries.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import pytest +import sqlalchemy as sa +from flask import Flask + +from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.record_queries import get_recorded_queries + + +@pytest.mark.usefixtures("app_ctx") +@pytest.mark.parametrize( + ("record", "debug", "testing", "expect"), + [ + (None, False, False, False), + (False, True, True, False), + (None, True, False, True), + (None, False, True, True), + (True, False, False, True), + ], +) +def test_record_enabled( + app: Flask, + record: bool | None, + debug: bool, + testing: bool, + expect: bool, +) -> None: + app.config["SQLALCHEMY_RECORD_QUERIES"] = record + app.debug = debug + app.testing = testing + db = SQLAlchemy(app) + + class Example(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + db.create_all() + Example.query.all() + assert bool(get_recorded_queries()) is expect + + +@pytest.mark.usefixtures("app_ctx") +def test_query_info(app: Flask) -> None: + app.config["SQLALCHEMY_RECORD_QUERIES"] = True + db = SQLAlchemy(app) + + class Example(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + db.create_all() + Example.query.filter(Example.id < 5).all() + info = get_recorded_queries()[-1] + assert info.statement is not None + assert "SELECT" in info.statement + assert "FROM example" in info.statement + assert info.parameters[0][0] == 5 + assert info.duration == info.end_time - info.start_time + assert "tests/test_record_queries.py:" in info.location + assert "(test_query_info)" in info.location diff --git a/tests/test_regressions.py b/tests/test_regressions.py deleted file mode 100644 index 14bc00c9..00000000 --- a/tests/test_regressions.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest - - -@pytest.fixture -def db(app, db): - app.testing = False - return db - - -@pytest.mark.usefixtures("app_ctx") -def test_joined_inheritance(app, db): - class Base(db.Model): - id = db.Column(db.Integer, primary_key=True) - type = db.Column(db.String(20)) - __mapper_args__ = {"polymorphic_on": type} - - class SubBase(Base): - id = db.Column(db.Integer, db.ForeignKey("base.id"), primary_key=True) - __mapper_args__ = {"polymorphic_identity": "sub"} - - assert Base.__tablename__ == "base" - assert SubBase.__tablename__ == "sub_base" - db.create_all() - - -@pytest.mark.usefixtures("app_ctx") -def test_single_table_inheritance(app, db): - class Base(db.Model): - id = db.Column(db.Integer, primary_key=True) - type = db.Column(db.String(20)) - __mapper_args__ = {"polymorphic_on": type} - - class SubBase(Base): - __mapper_args__ = {"polymorphic_identity": "sub"} - - assert Base.__tablename__ == "base" - assert SubBase.__tablename__ == "base" - db.create_all() - - -@pytest.mark.usefixtures("app_ctx") -def test_joined_inheritance_relation(db): - class Relation(db.Model): - id = db.Column(db.Integer, primary_key=True) - base_id = db.Column(db.Integer, db.ForeignKey("base.id")) - name = db.Column(db.String(20)) - - def __init__(self, name): - self.name = name - - class Base(db.Model): - id = db.Column(db.Integer, primary_key=True) - type = db.Column(db.String(20)) - __mapper_args__ = {"polymorphic_on": type} - - class SubBase(Base): - id = db.Column(db.Integer, db.ForeignKey("base.id"), primary_key=True) - __mapper_args__ = {"polymorphic_identity": "sub"} - relations = db.relationship(Relation) - - db.create_all() - - base = SubBase() - base.relations = [Relation(name="foo")] - db.session.add(base) - db.session.commit() - base.query.one() - - -@pytest.mark.usefixtures("app_ctx") -def test_connection_binds(db): - assert db.session.connection() diff --git a/tests/test_session.py b/tests/test_session.py index e4183db2..ae517add 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,76 +1,66 @@ +from __future__ import annotations + import pytest import sqlalchemy as sa -from sqlalchemy.orm import sessionmaker +from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.session import Session -def test_default_session_scoping(app, db): - class FOOBar(db.Model): - id = db.Column(db.Integer, primary_key=True) +def test_scope(app: Flask, db: SQLAlchemy) -> None: + with pytest.raises(RuntimeError): + db.session() with app.app_context(): - db.create_all() + first = db.session() + second = db.session() + assert first is second + assert isinstance(first, Session) - with app.test_request_context(): - fb = FOOBar() - db.session.add(fb) - assert fb in db.session + with app.app_context(): + third = db.session() + assert first is not third -def test_session_scoping_changing(app): +def test_custom_scope(app: Flask) -> None: count = 0 - def scope(): + def scope() -> int: nonlocal count count += 1 return count db = SQLAlchemy(app, session_options={"scopefunc": scope}) - class Example(db.Model): - id = db.Column(db.Integer, primary_key=True) - with app.app_context(): - db.create_all() - fb = Example() - db.session.add(fb) - assert fb not in db.session # because a new scope is generated on each call - assert count == 2 + first = db.session() + second = db.session() + assert first is not second # a new scope is generated on each call + first.close() + second.close() + - for session in db.session.registry.registry.values(): - session.close() +@pytest.mark.usefixtures("app_ctx") +def test_session_class(app: Flask) -> None: + class CustomSession(Session): + pass - db.session.registry.registry.clear() + db = SQLAlchemy(app, session_options={"class_": CustomSession}) + assert isinstance(db.session(), CustomSession) @pytest.mark.usefixtures("app_ctx") -def test_insert_update_delete(db): - # Ensure _SignalTrackingMapperExtension doesn't croak when - # faced with a vanilla SQLAlchemy session. Verify that - # "AttributeError: 'SessionMaker' object has no attribute - # '_model_changes'" is not thrown. - Session = sessionmaker(bind=db.engine) - - class QazWsx(db.Model): - id = db.Column(db.Integer, primary_key=True) - x = db.Column(db.String, default="") - - db.create_all() - session = Session() - session.add(QazWsx()) - session.flush() # issues an INSERT. - session.expunge_all() - qaz_wsx = session.query(QazWsx).first() - assert qaz_wsx.x == "" - qaz_wsx.x = "test" - session.flush() # issues an UPDATE. - session.expunge_all() - qaz_wsx = session.query(QazWsx).first() - assert qaz_wsx.x == "test" - session.delete(qaz_wsx) # issues a DELETE. - assert session.query(QazWsx).first() is None - - -def test_listen_to_session_event(db): - sa.event.listen(db.session, "after_commit", lambda session: None) +def test_session_uses_bind_key(app: Flask) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app) + + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + class Post(db.Model): + __bind_key__ = "a" + id = sa.Column(sa.Integer, primary_key=True) + + assert db.session.get_bind(mapper=User) is db.engine + assert db.session.get_bind(mapper=Post) is db.engines["a"] diff --git a/tests/test_sqlalchemy_includes.py b/tests/test_sqlalchemy_includes.py deleted file mode 100644 index aa84a7f5..00000000 --- a/tests/test_sqlalchemy_includes.py +++ /dev/null @@ -1,14 +0,0 @@ -import sqlalchemy as sa - -from flask_sqlalchemy import SQLAlchemy -from flask_sqlalchemy.query import Query - - -def test_sqlalchemy_includes(): - """Various SQLAlchemy objects are exposed as attributes.""" - db = SQLAlchemy() - - assert db.Column == sa.Column - - # The Query object we expose is actually our own subclass. - assert db.Query == Query diff --git a/tests/test_table_bind.py b/tests/test_table_bind.py new file mode 100644 index 00000000..fd83d1a9 --- /dev/null +++ b/tests/test_table_bind.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import sqlalchemy as sa + +from flask_sqlalchemy import SQLAlchemy + + +def test_bind_key_default(db: SQLAlchemy) -> None: + user_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) + assert user_table.metadata is db.metadata + + +def test_metadata_per_bind(db: SQLAlchemy) -> None: + user_table = db.Table( + "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" + ) + assert user_table.metadata is db.metadatas["other"] + + +def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: + user1_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) + user2_table = db.Table( + "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" + ) + + assert user1_table.metadata is db.metadata + assert user2_table.metadata is db.metadatas["other"] + + +def test_explicit_metadata(db: SQLAlchemy) -> None: + other_metadata = sa.MetaData() + user_table = db.Table( + "user", + other_metadata, + sa.Column("id", sa.Integer, primary_key=True), + bind_key="other", + ) + assert user_table.metadata is other_metadata + assert "other" not in db.metadatas diff --git a/tests/test_track_modifications.py b/tests/test_track_modifications.py index dde1cfc6..e48053f1 100644 --- a/tests/test_track_modifications.py +++ b/tests/test_track_modifications.py @@ -1,73 +1,66 @@ -import flask +from __future__ import annotations + +import typing as t + import pytest +import sqlalchemy as sa +from flask import Flask +from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.track_modifications import before_models_committed from flask_sqlalchemy.track_modifications import models_committed -pytestmark = pytest.mark.skipif( - not flask.signals_available, reason="Signals require the blinker library." -) +pytest.importorskip("blinker") -@pytest.fixture() -def app(app): +@pytest.mark.usefixtures("app_ctx") +def test_track_modifications(app: Flask) -> None: app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = True - return app - + db = SQLAlchemy(app) -def test_before_committed(app, db, Todo): - class Namespace: - is_received = False + class Example(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + data = sa.Column(sa.String) - def before_committed(sender, changes): - Namespace.is_received = True + db.create_all() + before: list[tuple[t.Any, str]] = [] + after: list[tuple[t.Any, str]] = [] - before_models_committed.connect(before_committed) - todo = Todo("Awesome", "the text") - - with app.app_context(): - db.session.add(todo) - db.session.commit() + def before_commit(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: + nonlocal before + before = changes - assert Namespace.is_received - before_models_committed.disconnect(before_committed) + def after_commit(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: + nonlocal after + after = changes + connect_before = before_models_committed.connected_to(before_commit, app) + connect_after = models_committed.connected_to(after_commit, app) -def test_model_signals(app, db, Todo): - recorded = [] + with connect_before, connect_after: + item = Example() - def committed(sender, changes): - assert isinstance(changes, list) - recorded.extend(changes) + db.session.add(item) + assert not before + assert not after - models_committed.connect(committed) - - with app.app_context(): - todo = Todo("Awesome", "the text") - db.session.add(todo) - assert len(recorded) == 0 db.session.commit() + assert len(before) == 1 + assert before[0] == (item, "insert") + assert before == after - assert len(recorded) == 1 - assert recorded[0][0] == todo - assert recorded[0][1] == "insert" - del recorded[:] - - with app.app_context(): - db.session.add(todo) - todo.text = "aha" + db.session.remove() + item = Example.query.get(1) + item.data = "test" # type: ignore[assignment] db.session.commit() + assert len(before) == 1 + assert before[0] == (item, "update") + assert before == after - assert len(recorded) == 1 - assert recorded[0][0] == todo - assert recorded[0][1] == "update" - del recorded[:] - - with app.app_context(): - db.session.delete(todo) + db.session.remove() + item = Example.query.get(1) + db.session.delete(item) db.session.commit() - - assert len(recorded) == 1 - assert recorded[0][0] == todo - assert recorded[0][1] == "delete" - models_committed.disconnect(committed) + assert len(before) == 1 + assert before[0] == (item, "delete") + assert before == after From 55d532e0cdd1f7e8d8989bb88dd5e613678e74b6 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 13 Sep 2022 17:35:51 -0700 Subject: [PATCH 23/27] update requirements for pip-compile-multi --- requirements/dev.in | 2 +- requirements/dev.txt | 20 ++++++++++++-------- requirements/docs.txt | 4 ++-- requirements/tests.in | 1 - requirements/tests.txt | 4 +--- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/requirements/dev.in b/requirements/dev.in index c854000e..20148ab6 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -1,5 +1,5 @@ -r docs.in -r tests.in -pip-tools +pip-compile-multi pre-commit tox diff --git a/requirements/dev.txt b/requirements/dev.txt index bb06214d..31ea6f27 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,4 +1,4 @@ -# SHA1:9df2a4dd582fac9b474679829c35ad897ecf5e4b +# SHA1:54196885a2acdc154945dacc9470e2a9900fd8c1 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -12,7 +12,9 @@ build==0.8.0 cfgv==3.3.1 # via pre-commit click==8.1.3 - # via pip-tools + # via + # pip-compile-multi + # pip-tools distlib==0.3.6 # via virtualenv filelock==3.8.0 @@ -25,8 +27,10 @@ nodeenv==1.7.0 # via pre-commit pep517==0.13.0 # via build -pip-tools==6.8.0 +pip-compile-multi==2.4.6 # via -r requirements/dev.in +pip-tools==6.8.0 + # via pip-compile-multi platformdirs==2.5.2 # via virtualenv pre-commit==2.20.0 @@ -36,12 +40,12 @@ pyyaml==6.0 six==1.16.0 # via tox toml==0.10.2 - # via - # pre-commit - # tox -tox==3.25.1 + # via pre-commit +toposort==1.7 + # via pip-compile-multi +tox==3.26.0 # via -r requirements/dev.in -virtualenv==20.16.4 +virtualenv==20.16.5 # via # pre-commit # tox diff --git a/requirements/docs.txt b/requirements/docs.txt index 14b89ed2..96322c42 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -9,13 +9,13 @@ alabaster==0.7.12 # via sphinx babel==2.10.3 # via sphinx -certifi==2022.6.15 +certifi==2022.6.15.2 # via requests charset-normalizer==2.1.1 # via requests docutils==0.19 # via sphinx -idna==3.3 +idna==3.4 # via requests imagesize==1.4.1 # via sphinx diff --git a/requirements/tests.in b/requirements/tests.in index 743f173c..528c35eb 100644 --- a/requirements/tests.in +++ b/requirements/tests.in @@ -1,3 +1,2 @@ pytest blinker -mock diff --git a/requirements/tests.txt b/requirements/tests.txt index e77f940b..3866317e 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ -# SHA1:12cbb27fb6b9e5b10590bcbd02ce029861185abe +# SHA1:9d3a5f2ea12fad5bb7b944df8244cb9209535c8c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -11,8 +11,6 @@ blinker==1.5 # via -r requirements/tests.in iniconfig==1.1.1 # via pytest -mock==4.0.3 - # via -r requirements/tests.in packaging==21.3 # via pytest pluggy==1.0.0 From 8bb2b1611f3835dfe0847fbd7867ad8c37554d3c Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 13 Sep 2022 17:54:53 -0700 Subject: [PATCH 24/27] bump minimum versions --- .pre-commit-config.yaml | 5 +++-- CHANGES.rst | 4 ++-- setup.cfg | 2 +- setup.py | 2 +- tox.ini | 4 ++-- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af4b97c7..274bcaee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,17 +5,18 @@ repos: rev: v2.37.3 hooks: - id: pyupgrade - args: ["--py36-plus"] + args: ["--py37-plus"] - repo: https://github.com/asottile/reorder_python_imports rev: v3.8.2 hooks: - id: reorder-python-imports files: "^(?!examples/)" - args: ["--application-directories", "src"] + args: ["--py37-plus", "--application-directories", "src"] - repo: https://github.com/psf/black rev: 22.8.0 hooks: - id: black + args: ["--target-version", "py37"] - repo: https://github.com/PyCQA/flake8 rev: 5.0.4 hooks: diff --git a/CHANGES.rst b/CHANGES.rst index 101cac49..8c5cb83c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,8 +4,8 @@ Version 3.0.0 Unreleased - Drop support for Python 2, 3.4, 3.5, and 3.6. -- Bump minimum version of Flask to 1.0.4. -- Bump minimum version of SQLAlchemy to 1.2. +- Bump minimum version of Flask to 2.2. +- Bump minimum version of SQLAlchemy to 1.4.18. - Remove previously deprecated code. - The CamelCase to snake_case table name converter handles more patterns correctly. If such a name was already created in the diff --git a/setup.cfg b/setup.cfg index b17bf105..929191b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ classifiers = packages = find: package_dir = = src include_package_data = true -python_requires = >= 3.6 +python_requires = >= 3.7 # Dependencies are in setup.py for GitHub's dependency graph. [options.packages.find] diff --git a/setup.py b/setup.py index 88ef48b5..b79abc1a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ from setuptools import setup # Metadata goes in setup.cfg. These are here for GitHub's dependency graph. -setup(name="Flask-SQLAlchemy", install_requires=["Flask>=1.0.4", "SQLAlchemy>=1.2"]) +setup(name="Flask-SQLAlchemy", install_requires=["Flask>=2.0", "SQLAlchemy>=1.4"]) diff --git a/tox.ini b/tox.ini index e59472c4..e198e9be 100644 --- a/tox.ini +++ b/tox.ini @@ -10,8 +10,8 @@ skip_missing_interpreters = true [testenv] deps = -r requirements/tests.txt - lowest: flask==1.0.4 - lowest: sqlalchemy==1.2 + lowest: flask==2.2 + lowest: sqlalchemy==1.4.18 commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} [testenv:style] From 9dbd2738d93490385ae99196cf0ce6a9d0ba669c Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 13 Sep 2022 18:03:00 -0700 Subject: [PATCH 25/27] update urls --- CONTRIBUTING.rst | 10 +++++----- README.rst | 6 +++--- docs/conf.py | 4 ++-- setup.cfg | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 9e228836..142368b5 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -91,7 +91,7 @@ First time setup .. code-block:: text - $ git clone https://github.com/pallets/flask-sqlalchemy + $ git clone https://github.com/pallets-eco/flask-sqlalchemy $ cd flask-sqlalchemy - Add your fork as a remote to push your work to. Replace @@ -132,7 +132,7 @@ First time setup .. _username: https://docs.github.com/en/github/using-git/setting-your-username-in-git .. _email: https://docs.github.com/en/github/setting-up-and-managing-your-github-user-account/setting-your-commit-email-address .. _GitHub account: https://github.com/join -.. _Fork: https://github.com/pallets/jinja/fork +.. _Fork: https://github.com/pallets-eco/flask-sqlalchemy/fork .. _Clone: https://docs.github.com/en/github/getting-started-with-github/fork-a-repo#step-2-create-a-local-clone-of-your-fork @@ -146,15 +146,15 @@ Start coding .. code-block:: text $ git fetch origin - $ git checkout -b your-branch-name origin/2.x + $ git checkout -b your-branch-name origin/3.0.x If you're submitting a feature addition or change, branch off of the - "master" branch. + "main" branch. .. code-block:: text $ git fetch origin - $ git checkout -b your-branch-name origin/master + $ git checkout -b your-branch-name origin/main - Using your favorite editor, make your changes, `committing as you go`_. diff --git a/README.rst b/README.rst index c2e6d1b2..1d65ea21 100644 --- a/README.rst +++ b/README.rst @@ -53,7 +53,7 @@ Contributing For guidance on setting up a development environment and how to make a contribution to Flask-SQLAlchemy, see the `contributing guidelines`_. -.. _contributing guidelines: https://github.com/pallets/flask-sqlalchemy/blob/master/CONTRIBUTING.rst +.. _contributing guidelines: https://github.com/pallets-eco/flask-sqlalchemy/blob/main/CONTRIBUTING.rst Donate @@ -73,8 +73,8 @@ Links - Documentation: https://flask-sqlalchemy.palletsprojects.com/ - Changes: https://flask-sqlalchemy.palletsprojects.com/changes/ - PyPI Releases: https://pypi.org/project/Flask-SQLAlchemy/ -- Source Code: https://github.com/pallets/flask-sqlalchemy/ -- Issue Tracker: https://github.com/pallets/flask-sqlalchemy/issues/ +- Source Code: https://github.com/pallets-eco/flask-sqlalchemy/ +- Issue Tracker: https://github.com/pallets-eco/flask-sqlalchemy/issues/ - Website: https://palletsprojects.com/ - Twitter: https://twitter.com/PalletsTeam - Chat: https://discord.gg/pallets diff --git a/docs/conf.py b/docs/conf.py index 3198a593..3993d7bc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,9 +31,9 @@ "project_links": [ ProjectLink("Donate", "https://palletsprojects.com/donate"), ProjectLink("PyPI Releases", "https://pypi.org/project/Flask-SQLAlchemy/"), - ProjectLink("Source Code", "https://github.com/pallets/flask-sqlalchemy/"), + ProjectLink("Source Code", "https://github.com/pallets-eco/flask-sqlalchemy/"), ProjectLink( - "Issue Tracker", "https://github.com/pallets/flask-sqlalchemy/issues/" + "Issue Tracker", "https://github.com/pallets-eco/flask-sqlalchemy/issues/" ), ProjectLink("Website", "https://palletsprojects.com/"), ProjectLink("Twitter", "https://twitter.com/PalletsTeam"), diff --git a/setup.cfg b/setup.cfg index 929191b2..db489fff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,8 +6,8 @@ project_urls = Donate = https://palletsprojects.com/donate Documentation = https://flask-sqlalchemy.palletsprojects.com/ Changes = https://flask-sqlalchemy.palletsprojects.com/changes/ - Source Code = https://github.com/pallets/flask-sqlalchemy/ - Issue Tracker = https://github.com/pallets/flask-sqlalchemy/issues/ + Source Code = https://github.com/pallets-eco/flask-sqlalchemy/ + Issue Tracker = https://github.com/pallets-eco/flask-sqlalchemy/issues/ Twitter = https://twitter.com/PalletsTeam Chat = https://discord.gg/pallets license = BSD-3-Clause From bee0c29d27dcdb965456c26920b89e91bf4ce0e1 Mon Sep 17 00:00:00 2001 From: David Lord Date: Tue, 13 Sep 2022 18:33:04 -0700 Subject: [PATCH 26/27] add issue links for v3 changes --- CHANGES.rst | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 8c5cb83c..67f7d7c8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,28 +7,25 @@ Unreleased - Bump minimum version of Flask to 2.2. - Bump minimum version of SQLAlchemy to 1.4.18. - Remove previously deprecated code. -- The CamelCase to snake_case table name converter handles more - patterns correctly. If such a name was already created in the - database, either use Alembic to rename the table, or set - ``__tablename__`` to keep the old name. :issue:`406` -- Set ``SQLALCHEMY_TRACK_MODIFICATIONS`` to ``False`` by default. - :pr:`727` -- Remove default ``'sqlite:///:memory:'`` setting for - ``SQLALCHEMY_DATABASE_URI``, raise error when both it and - ``SQLALCHEMY_BINDS`` are unset. :pr:`731` -- Configuring SQLite with a relative path is relative to - ``app.instance_path`` instead of ``app.root_path``. The instance - folder is created if necessary. :issue:`462` -- Deprecate ``SQLALCHEMY_COMMIT_ON_TEARDOWN`` as it can cause various - design issues that are difficult to debug. Call - ``db.session.commit()`` directly instead. :issue:`216` +- The ``CamelCase`` to ``snake_case`` table name converter handles more patterns + correctly. If such a that was was already created in the database changed, either + use Alembic to rename the table, or set ``__tablename__`` to keep the old name. + :issue:`406` +- Set ``SQLALCHEMY_TRACK_MODIFICATIONS`` to ``False`` by default. :pr:`727` +- Remove default ``'sqlite:///:memory:'`` setting for ``SQLALCHEMY_DATABASE_URI``, + raise error when both it and ``SQLALCHEMY_BINDS`` are unset. :pr:`731` +- Configuring SQLite with a relative path is relative to ``app.instance_path`` instead + of ``app.root_path``. The instance folder is created if necessary. :issue:`462` +- Deprecate ``SQLALCHEMY_COMMIT_ON_TEARDOWN`` as it can cause various design issues + that are difficult to debug. Call ``db.session.commit()`` directly instead. + :issue:`216` - Change the default MySQL character set to "utf8mb4". :issue:`875` - ``Pagination``, ``Pagination.iter_pages``, and ``Query.paginate`` parameters are keyword-only. -- ``Pagination`` is iterable, iterating over its items. +- ``Pagination`` is iterable, iterating over its items. :issue:`70` - ``Pagination.apply_to_query`` can be used instead of ``query.paginate``. - ``Query.paginate`` ``count`` is more efficient. -- ``Pagination.iter_pages`` is more efficient. +- ``Pagination.iter_pages`` is more efficient. :issue:`622` - ``Pagination.iter_pages`` ``right_current`` parameter is inclusive. - ``Query`` is renamed from ``BaseQuery``. - ``Query.one_or_404`` is added. @@ -38,10 +35,11 @@ Unreleased renamed to ``location``. Finding the location uses a more inclusive check. - The ``SQLAlchemy`` extension object uses ``__getattr__`` to alias names from the SQLAlchemy package, rather than copying them as attributes. -- The query class is applied to ``backref`` in ``relationship``. +- The query class is applied to ``backref`` in ``relationship``. :issue:`417` - ``SignallingSession`` is renamed to ``Session``. - ``Session.get_bind`` more closely matches the base implementation. - ``Model`` ``repr`` distinguishes between transient and pending instances. + :issue:`967` - Different bind keys use different SQLAlchemy ``MetaData`` registries, allowing tables in different databases to have the same name. Bind keys are stored and looked up on the resulting metadata rather than the model or table. @@ -50,15 +48,17 @@ Unreleased - ``SQLALCHEMY_BINDS`` values can either be an engine URL, or a dict of engine options including URL, for each bind. ``SQLALCHEMY_DATABASE_URI`` and ``SQLALCHEMY_ENGINE_OPTIONS`` correspond to the ``None`` key and take precedence. + :issue:`783` - Engines are created when calling ``init_app`` rather than the first time they are - accessed. + accessed. :issue:`698` - The extension instance is stored directly as ``app.extensions["sqlalchemy"]``. + :issue:`698` - All parameters except ``app`` are keyword-only. - Setup methods that create the engines and session are renamed with a leading underscore. They are considered internal interfaces which may change at any time. - ``db.Table`` is a subclass instead of a function. - The session class can be customized by passing the ``class_`` key in the - ``session_options`` parameter. + ``session_options`` parameter. :issue:`327` - SQLite engines do not use ``NullPool`` if ``pool_size`` is 0. - MySQL engines do not set ``pool_size`` to 10. - ``db.engines`` exposes the map of bind keys to engines for the current app. @@ -74,6 +74,7 @@ Unreleased :issue:`1002` - An active Flask application context is always required to access ``session`` and ``engine``, regardless of if an application was passed to the constructor. + :issue:`508, 944` Version 2.5.1 From 53dfc8eaf66e41aa21ba8d6d323de70fd2fbbe8e Mon Sep 17 00:00:00 2001 From: David Lord Date: Sat, 17 Sep 2022 13:59:50 -0700 Subject: [PATCH 27/27] rewrite documentation --- docs/Makefile | 7 +- docs/api.rst | 79 ++++-- docs/binds.rst | 126 +++++---- docs/conf.py | 14 +- docs/config.rst | 250 ++++++++++-------- docs/contexts.rst | 109 +++++--- docs/customizing.rst | 218 +++++++++------- docs/index.rst | 52 ++-- docs/make.bat | 10 +- docs/models.rst | 241 ++++++----------- docs/pagination.rst | 74 ++++++ docs/queries.rst | 160 +++++------- docs/quickstart.rst | 271 ++++++++++---------- docs/record-queries.rst | 27 ++ docs/signals.rst | 27 -- docs/track-modifications.rst | 25 ++ src/flask_sqlalchemy/extension.py | 40 +-- src/flask_sqlalchemy/model.py | 2 +- src/flask_sqlalchemy/pagination.py | 24 +- src/flask_sqlalchemy/query.py | 12 +- src/flask_sqlalchemy/record_queries.py | 4 +- src/flask_sqlalchemy/track_modifications.py | 12 + 22 files changed, 975 insertions(+), 809 deletions(-) create mode 100644 docs/pagination.rst create mode 100644 docs/record-queries.rst delete mode 100644 docs/signals.rst create mode 100644 docs/track-modifications.rst diff --git a/docs/Makefile b/docs/Makefile index 51285967..d4bb2cbb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,9 +1,10 @@ # Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build diff --git a/docs/api.rst b/docs/api.rst index 2fd4c0bf..80f56626 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,45 +1,86 @@ API ---- +=== -.. module:: flask_sqlalchemy -Configuration -````````````` +Extension +--------- + +.. module:: flask_sqlalchemy .. autoclass:: SQLAlchemy - :members: + :members: + + +Model +----- -Models -`````` +.. module:: flask_sqlalchemy.model .. autoclass:: Model :members: .. attribute:: __bind_key__ - Optionally declares the bind to use. ``None`` refers to the default - bind. For more information see :ref:`binds`. + Use this bind key to select a metadata and engine to associate with this model's + table. Ignored if ``metadata`` or ``__table__`` is set. If not given, uses the + default key, ``None``. .. attribute:: __tablename__ - The name of the table in the database. This is required by SQLAlchemy; - however, Flask-SQLAlchemy will set it automatically if a model has a - primary key defined. If the ``__table__`` or ``__tablename__`` is set - explicitly, that will be used instead. + The name of the table in the database. This is required by SQLAlchemy; however, + Flask-SQLAlchemy will set it automatically if a model has a primary key defined. + If the ``__table__`` or ``__tablename__`` is set explicitly, that will be used + instead. + +.. autoclass:: DefaultMeta + +.. autoclass:: BindMetaMixin + +.. autoclass:: NameMetaMixin + + +Query +----- + +.. module:: flask_sqlalchemy.query .. autoclass:: Query :members: -Sessions -```````` -.. autoclass:: SignallingSession +Session +------- + +.. module:: flask_sqlalchemy.session + +.. autoclass:: Session :members: -Utilities -````````` + +Pagination +---------- + +.. module:: flask_sqlalchemy.pagination .. autoclass:: Pagination :members: -.. autofunction:: get_debug_queries + +Record Queries +-------------- + +.. module:: flask_sqlalchemy.record_queries + +.. autofunction:: get_recorded_queries + + +Track Modifications +------------------- + +.. module:: flask_sqlalchemy.track_modifications + +.. autodata:: models_committed + :no-value: + +.. autodata:: before_models_committed + :no-value: diff --git a/docs/binds.rst b/docs/binds.rst index 579d93b6..92ffd4d2 100644 --- a/docs/binds.rst +++ b/docs/binds.rst @@ -1,73 +1,97 @@ -.. _binds: - -.. currentmodule:: flask_sqlalchemy - Multiple Databases with Binds ============================= -Starting with 0.12 Flask-SQLAlchemy can easily connect to multiple -databases. To achieve that it preconfigures SQLAlchemy to support -multiple “binds”. +SQLAlchemy can connect to more than one database at a time. It refers to different +engines as "binds". Flask-SQLAlchemy simplifies how binds work by associating each +engine with a short string, a "bind key", and then associating each model and table with +a bind key. The session will choose what engine to use for a query based on the bind key +of the thing being queried. If no bind key is given, the default engine is used. -What are binds? In SQLAlchemy speak a bind is something that can execute -SQL statements and is usually a connection or engine. In Flask-SQLAlchemy -binds are always engines that are created for you automatically behind the -scenes. Each of these engines is then associated with a short key (the -bind key). This key is then used at model declaration time to assocate a -model with a specific engine. -If no bind key is specified for a model the default connection is used -instead (as configured by ``SQLALCHEMY_DATABASE_URI``). +Configuring Binds +----------------- -Example Configuration ---------------------- +The default bind is still configured by setting :data:`.SQLALCHEMY_DATABASE_URI`, and +:data:`.SQLALCHEMY_ENGINE_OPTIONS` for any engine options. Additional binds are given in +:data:`.SQLALCHEMY_BINDS`, a dict mapping bind keys to engine URLs. To specify engine +options for a bind, the value can be a dict of engine options with the ``"url"`` key, +instead of only a URL string. -The following configuration declares three database connections. The -special default one as well as two others named `users` (for the users) -and `appmeta` (which connects to a sqlite database for read only access to -some data the application provides internally):: +.. code-block:: python - SQLALCHEMY_DATABASE_URI = 'postgres://localhost/main' + SQLALCHEMY_DATABASE_URI = "postgresql:///main" SQLALCHEMY_BINDS = { - 'users': 'mysqldb://localhost/users', - 'appmeta': 'sqlite:////path/to/appmeta.db' + "meta": "sqlite:////path/to/meta.db", + "auth": { + "url": "mysql://localhost/users", + "pool_recycle": 3600, + }, } -Creating and Dropping Tables ----------------------------- -The :meth:`~SQLAlchemy.create_all` and :meth:`~SQLAlchemy.drop_all` methods -by default operate on all declared binds, including the default one. This -behavior can be customized by providing the `bind` parameter. It takes -either a single bind name, ``'__all__'`` to refer to all binds or a list -of binds. The default bind (``SQLALCHEMY_DATABASE_URI``) is named `None`: +Defining Models and Tables with Binds +------------------------------------- ->>> db.create_all() ->>> db.create_all(bind=['users']) ->>> db.create_all(bind='appmeta') ->>> db.drop_all(bind=None) +Flask-SQLAlchemy will create a metadata and engine for each configured bind. Models and +tables with a bind key will be registered with the corresponding metadata, and the +session will query them using the corresponding engine. -Referring to Binds ------------------- +To set the bind for a model, set the ``__bind_key__`` class attribute. Not setting a +bind key is equivalent to setting it to ``None``, the default key. -If you declare a model you can specify the bind to use with the -:attr:`~Model.__bind_key__` attribute:: +.. code-block:: python class User(db.Model): - __bind_key__ = 'users' + __bind_key__ = "auth" id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String(80), unique=True) -Internally the bind key is stored in the table's `info` dictionary as -``'bind_key'``. This is important to know because when you want to create -a table object directly you will have to put it in there:: +Models that inherit from this model will share the same bind key, or can override it. + +To set the bind for a table, pass the ``bind_key`` keyword argument. - user_favorites = db.Table('user_favorites', - db.Column('user_id', db.Integer, db.ForeignKey('user.id')), - db.Column('message_id', db.Integer, db.ForeignKey('message.id')), - info={'bind_key': 'users'} +.. code-block:: python + + user_table = db.Table( + "user", + db.Column("id", db.Integer, primary_key=True), + bind_key="auth", ) -If you specified the `__bind_key__` on your models you can use them exactly the -way you are used to. The model connects to the specified database connection -itself. +Ultimately, the session looks up the bind key on the metadata associated with the model +or table. That association happens during creation. Therefore, changing the bind key +after creating a model or table will have no effect. + + +Accessing Metadata and Engines +------------------------------ + +You may need to inspect the metadata or engine for a bind. Note that you should execute +queries through the session, not directly on the engine. + +The default engine is :attr:`.SQLAlchemy.engine`, and the default metadata is +:attr:`.SQLAlchemy.metadata`. :attr:`.SQLAlchemy.engines` and +:attr:`.SQLAlchemy.metadatas` are dicts mapping all bind keys. + + +Creating and Dropping Tables +---------------------------- + +The :meth:`~.SQLAlchemy.create_all` and :meth:`~.SQLAlchemy.drop_all` methods operate on +all binds by default. The ``bind_key`` argument to these methods can be a string or +``None`` to operate on a single bind, or a list of strings or ``None`` to operate on a +subset of binds. Because these methods access the engines, they must be called inside an +application context. + +.. code-block:: python + + # create tables for all binds + db.create_all() + + # create tables for the default and "auth" binds + db.create_all(bind=[None, "auth"]) + + # create tables for the "meta" bind + db.create_all(bind="meta") + + # drop tables for the default bind + db.drop_all(bind=None) diff --git a/docs/conf.py b/docs/conf.py index 3993d7bc..dd88d508 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,23 +6,23 @@ project = "Flask-SQLAlchemy" copyright = "2010 Pallets" author = "Pallets" -release, version = get_version("Flask-SQLAlchemy", version_length=1) +release, version = get_version("Flask-SQLAlchemy") # General -------------------------------------------------------------- -master_doc = "index" extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "pallets_sphinx_themes", "sphinx_issues", ] +autodoc_typehints = "description" intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "flask": ("https://flask.palletsprojects.com/", None), - "sqlalchemy": ("https://docs.sqlalchemy.org/en/latest/", None), + "sqlalchemy": ("https://docs.sqlalchemy.org/", None), } -issues_github_path = "pallets/flask-sqlalchemy" +issues_github_path = "pallets-eco/flask-sqlalchemy" # HTML ----------------------------------------------------------------- @@ -50,9 +50,3 @@ html_logo = "_static/flask-sqlalchemy-logo.png" html_title = f"Flask-SQLAlchemy Documentation ({version})" html_show_sourcelink = False - -# LaTeX ---------------------------------------------------------------- - -latex_documents = [ - (master_doc, f"{project}-{version}.tex", html_title, author, "manual") -] diff --git a/docs/config.rst b/docs/config.rst index f89912c5..0ba9b692 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -1,163 +1,199 @@ -.. currentmodule:: flask_sqlalchemy - Configuration ============= -The following configuration values exist for Flask-SQLAlchemy. -Flask-SQLAlchemy loads these values from your main Flask config which can -be populated in various ways. Note that some of those cannot be modified -after the engine was created so make sure to configure as early as -possible and to not modify them at runtime. Configuration Keys ------------------ -A list of configuration keys currently understood by the extension: - -.. tabularcolumns:: |p{6.5cm}|p{8.5cm}| - -================================== ========================================= -``SQLALCHEMY_DATABASE_URI`` The database URI that should be used for - the connection. Examples: - - - ``sqlite:////tmp/test.db`` - - ``mysql://username:password@server/db`` -``SQLALCHEMY_BINDS`` A dictionary that maps bind keys to - SQLAlchemy connection URIs. For more - information about binds see :ref:`binds`. -``SQLALCHEMY_ECHO`` If set to `True` SQLAlchemy will log all - the statements issued to stderr which can - be useful for debugging. -``SQLALCHEMY_RECORD_QUERIES`` Can be used to explicitly disable or - enable query recording. Query recording - automatically happens in debug or testing - mode. See :func:`get_debug_queries` for - more information. -``SQLALCHEMY_ENGINE_OPTIONS`` A dictionary of keyword args to send to - :func:`~sqlalchemy.create_engine`. See - also ``engine_options`` to :class:`SQLAlchemy`. -================================== ========================================= +Configuration is loaded from the Flask ``app.config`` when :meth:`.SQLAlchemy.init_app` +is called. The configuration is not read again after that. Therefore, all configuration +must happen before initializing the application. -.. versionchanged:: 3.0 - ``SQLALCHEMY_TRACK_MODIFICATIONS`` defaults to ``False``. +.. module:: flask_sqlalchemy.config -.. versionchanged:: 3.0 - ``SQLALCHEMY_DATABASE_URI`` no longer defaults to - ``'sqlite:///:memory:'`` +.. data:: SQLALCHEMY_DATABASE_URI + + The database connection URI used for the default engine. It can be either a string + or a SQLAlchemy ``URL`` instance. See below and :external:doc:`core/engines` for + examples. + + At least one of this and :data:`SQLALCHEMY_BINDS` must be set. + + .. versionchanged:: 3.0 + No longer defaults to an in-memory SQLite database if not set. + +.. data:: SQLALCHEMY_ENGINE_OPTIONS + + A dict of arguments to pass to :func:`sqlalchemy.create_engine` for the default + engine. + + This takes precedence over the ``engine_options`` argument to :class:`.SQLAlchemy`, + which can be used to set default options for all engines. + + .. versionchanged:: 3.0 + Only applies to the default bind. + + .. versionadded:: 2.4 + +.. data:: SQLALCHEMY_BINDS + + A dict mapping bind keys to engine options. The value can be a string or a + SQLAlchemy ``URL`` instance. Or it can be a dict of arguments, including the ``url`` + key, that will be passed to :func:`sqlalchemy.create_engine`. The ``None`` key can + be used to configure the default bind, but :data:`SQLALCHEMY_ENGINE_OPTIONS` and + :data:`SQLALCHEMY_DATABASE_URI` take precedence. + + At least one of this and :data:`SQLALCHEMY_DATABASE_URI` must be set. + + .. versionadded:: 0.12 + +.. data:: SQLALCHEMY_ECHO + + The default value for ``echo`` and ``echo_pool`` for every engine. This is useful to + quickly debug the connections and queries issued from SQLAlchemy. + + .. versionchanged:: 3.0 + Sets ``echo_pool`` in addition to ``echo``. + +.. data:: SQLALCHEMY_RECORD_QUERIES + + If enabled, information about each query during a request will be recorded. Use + :func:`.get_recorded_queries` to get a list of queries that were issued during the + request. + + If not set, this is enabled if ``app.debug`` or ``app.testing`` are enabled. + +.. data:: SQLALCHEMY_TRACK_MODIFICATIONS + + If enabled, all ``insert``, ``update``, and ``delete`` operations on models are + recorded, then sent in :data:`.models_committed` and + :data:`.before_models_committed` signals when ``session.commit()`` is called. + + This adds a significant amount of overhead to every session. Prefer using + SQLAlchemy's :external:doc:`orm/events` directly for the exact information you need. + + .. versionchanged:: 3.0 + Disabled by default. + + .. versionadded:: 2.0 + +.. data:: SQLALCHEMY_COMMIT_ON_TEARDOWN + + Call ``db.session.commit()`` automatically if the request finishes without an + unhandled exception. + + .. deprecated:: 3.0 + Will be removed in Flask-SQLAlchemy 3.1. .. versionchanged:: 3.0 Removed ``SQLALCHEMY_NATIVE_UNICODE``, ``SQLALCHEMY_POOL_SIZE``, ``SQLALCHEMY_POOL_TIMEOUT``, ``SQLALCHEMY_POOL_RECYCLE``, and ``SQLALCHEMY_MAX_OVERFLOW``. -.. versionchanged:: 3.0 - Deprecated ``SQLALCHEMY_COMMIT_ON_TEARDOWN``. -.. versionadded:: 2.4 - Added ``SQLALCHEMY_ENGINE_OPTIONS``. +Connection URL Format +--------------------- -.. versionchanged:: 2.4 - Deprecated ``SQLALCHEMY_NATIVE_UNICODE``, ``SQLALCHEMY_POOL_SIZE``, - ``SQLALCHEMY_POOL_TIMEOUT``, ``SQLALCHEMY_POOL_RECYCLE``, and - ``SQLALCHEMY_MAX_OVERFLOW``. +See SQLAlchemy's documentation on :external:doc:`core/engines` for a complete +description of syntax, dialects, and options. -.. versionadded:: 2.0 - Added ``SQLALCHEMY_TRACK_MODIFICATIONS``. +A basic database connection URL uses the following format. Username, password, host, and +port are optional depending on the database type and configuration. -.. versionadded:: 0.17 - Added ``SQLALCHEMY_MAX_OVERFLOW``. +.. code-block:: text -.. versionadded:: 0.12 - Added ``SQLALCHEMY_BINDS``. + dialect://username:password@host:port/database -.. versionadded:: 0.8 - Added ``SQLALCHEMY_NATIVE_UNICODE``, ``SQLALCHEMY_POOL_SIZE``, - ``SQLALCHEMY_POOL_TIMEOUT`` and ``SQLALCHEMY_POOL_RECYCLE``. +Here are some example connection strings: +.. code-block:: text -Connection URI Format ---------------------- + # SQLite, relative to Flask instance path + sqlite:///project.db + + # PostgreSQL + postgresql://scott:tiger@localhost/project -For a complete list of connection URIs head over to the SQLAlchemy -documentation under (`Supported Databases -`_). This here shows -some common connection strings. + # MySQL / MariaDB + mysql://scott:tiger@localhost/project -SQLAlchemy indicates the source of an Engine as a URI combined with -optional keyword arguments to specify options for the Engine. The form of -the URI is:: +SQLite does not use a user or host, so its URLs always start with _three_ slashes +instead of two. The ``dbname`` value is a file path. Absolute paths start with a +_fourth_ slash (on Linux or Mac). Relative paths are relative to the Flask application's +:attr:`~flask.Flask.instance_path`. - dialect+driver://username:password@host:port/database -Many of the parts in the string are optional. If no driver is specified -the default one is selected (make sure to *not* include the ``+`` in that -case). +Default Driver Options +---------------------- -Postgres:: +Some default options are set for SQLite and MySQL engines to make them more usable by +default in web applications. - postgresql://scott:tiger@localhost/mydatabase +SQLite relative file paths are relative to the Flask instance path instead of the +current working directory. In-memory databases use a static pool and +``check_same_thread`` to work across requests. -MySQL:: +MySQL (and MariaDB) servers are configured to drop connections that have been idle for +8 hours, which can result in an error like ``2013: Lost connection to MySQL server +during query``. A default ``pool_recycle`` value of 2 hours (7200 seconds) is used to +recreate connections before that timeout. - mysql://scott:tiger@localhost/mydatabase -Oracle:: +Engine Configuration Precedence +------------------------------- - oracle://scott:tiger@127.0.0.1:1521/sidname +Because Flask-SQLAlchemy has support for multiple engines, there are rules for which +config overrides other config. Most applications will only have a single database and +only need to use :data:`SQLALCHEMY_DATABASE_URI` and :data:`SQLALCHEMY_ENGINE_OPTIONS`. -SQLite (note that platform path conventions apply):: +- If the ``engine_options`` argument is given to :class:`.SQLAlchemy`, it sets default + options for *all* engines. :data:`SQLALCHEMY_ECHO` sets the default value for both + ``echo`` and ``echo_pool`` for all engines. +- The options for each engine in :data:`.SQLALCHEMY_BINDS` override those defaults. +- :data:`.SQLALCHEMY_ENGINE_OPTIONS` overrides the ``None`` key in + ``SQLALCHEMY_BINDS``, and :data:`.SQLALCHEMY_DATABASE_URI` overrides the ``url`` key + in that engine's options. - #Unix/Mac (note the four leading slashes) - sqlite:////absolute/path/to/foo.db - #Windows (note 3 leading forward slashes and backslash escapes) - sqlite:///C:\\absolute\\path\\to\\foo.db - #Windows (alternative using raw string) - r'sqlite:///C:\absolute\path\to\foo.db' Using custom MetaData and naming conventions -------------------------------------------- -You can optionally construct the :class:`SQLAlchemy` object with a custom -:class:`~sqlalchemy.schema.MetaData` object. -This allows you to, among other things, -specify a `custom constraint naming convention -`_ -in conjunction with SQLAlchemy 0.9.2 or higher. -Doing so is important for dealing with database migrations (for instance using -`alembic `_ as stated -`here `_. Here's an -example, as suggested by the SQLAlchemy docs:: +You can optionally construct the :class:`.SQLAlchemy` object with a custom +:class:`~sqlalchemy.schema.MetaData` object. This allows you to specify a custom +constraint `naming convention`_. This makes constraint names consistent and predictable, +useful when using migrations, as described by `Alembic`_. + +.. code-block:: python from sqlalchemy import MetaData - from flask import Flask from flask_sqlalchemy import SQLAlchemy - convention = { + db = SQLAlchemy(metadata=MetaData(naming_convention={ "ix": 'ix_%(column_0_label)s', "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s" - } - - metadata = MetaData(naming_convention=convention) - db = SQLAlchemy(app, metadata=metadata) + })) -For more info about :class:`~sqlalchemy.schema.MetaData`, -`check out the official docs on it -`_. +.. _naming convention: https://docs.sqlalchemy.org/core/constraints.html#constraint-naming-conventions +.. _Alembic: https://alembic.sqlalchemy.org/en/latest/naming.html -.. _timeouts: Timeouts -------- -Certain database backends may impose different inactive connection timeouts, -which interferes with Flask-SQLAlchemy's connection pooling. +Certain databases may be configured to close inactive connections after a period of +time. MySQL and MariaDB are configured for this by default, but database services may +also configure this type of limit. This can result in an error like +``2013: Lost connection to MySQL server during query``. + +If you encounter this error, try setting ``pool_recycle`` in the engine options to +a value less than the database's timeout. + +Alternatively, you can try setting ``pool_pre_ping`` if you expect the database to close +connections often, such as if it's running in a container that may restart. -By default, MariaDB is configured to have a 600 second timeout. This often -surfaces hard to debug, production environment only exceptions like ``2013: Lost connection to MySQL server during query``. +See SQAlchemy's docs on `dealing with disconnects`_ for more information. -If you are using a backend (or a pre-configured database-as-a-service) with a -lower connection timeout, it is recommended that you set -`SQLALCHEMY_POOL_RECYCLE` to a value less than your backend's timeout. +.. _dealing with disconnects: https://docs.sqlalchemy.org/core/pooling.html#dealing-with-disconnects diff --git a/docs/contexts.rst b/docs/contexts.rst index 16d3beac..ccf68efa 100644 --- a/docs/contexts.rst +++ b/docs/contexts.rst @@ -1,65 +1,90 @@ -.. _contexts: +Flask Application Context +========================= -.. currentmodule:: flask_sqlalchemy +An active Flask application context is required to make queries and to access +``db.engine`` and ``db.session``. This is because the session is scoped to the context +so that it is cleaned up properly after every request or CLI command. -Introduction into Contexts -========================== +Regardless of how an application is initialized with the extension, it is not stored for +later use. Instead, the extension uses Flask's ``current_app`` proxy to get the active +application, which requires an active application context. -If you are planning on using only one application you can largely skip -this chapter. Just pass your application to the :class:`SQLAlchemy` -constructor and you're usually set. However if you want to use more than -one application or create the application dynamically in a function you -want to read on. -If you define your application in a function, but the :class:`SQLAlchemy` -object globally, how does the latter learn about the former? The answer -is the :meth:`~SQLAlchemy.init_app` function:: +Automatic Context +----------------- - from flask import Flask - from flask_sqlalchemy import SQLAlchemy +When Flask is handling a request or a CLI command, an application context will +automatically be pushed. Therefore you don't need to do anything special to use the +database during requests or CLI commands. - db = SQLAlchemy() + +Manual Context +-------------- + +If you try to use the database when an application context is not active, you will see +the following error. + +.. code-block:: text + + RuntimeError: Working outside of application context. + + This typically means that you attempted to use functionality that needed + the current application. To solve this, set up an application context + with app.app_context(). See the documentation for more information. + +If you find yourself in a situation where you need the database and don't have a +context, you can push one with ``app_context``. This is common when calling +``db.create_all`` to creat the tables, for example. + +.. code-block:: python def create_app(): app = Flask(__name__) - db.init_app(app) - return app + app.config.from_object("project.config") + import project.models -What it does is prepare the application to work with -:class:`SQLAlchemy`. However that does not now bind the -:class:`SQLAlchemy` object to your application. Why doesn't it do that? -Because there might be more than one application created. + with app.app_context(): + db.create_all() + + return app -So how does :class:`SQLAlchemy` come to know about your application? -You will have to setup an application context. If you are working inside -a Flask view function or a CLI command, that automatically happens. However, -if you are working inside the interactive shell, you will have to do that -yourself (see `Creating an Application Context -`_). -If you try to perform database operations outside an application context, you -will see the following error: +Tests +----- - No application found. Either work inside a view function or push an - application context. +If you test your application using the Flask test client to make requests to your +endpoints, the context will be available as part of the request. If you need to test +something about your database or models directly, rather than going through a request, +you need to push a context manually. -In a nutshell, do something like this: +Only push a context exactly where and for how long it's needed for each test. Do not +push an application context globally for every test, as that can interfere with how the +session is cleaned up. ->>> from yourapp import create_app ->>> app = create_app() ->>> app.app_context().push() +.. code-block:: python -Alternatively, use the with-statement to take care of setup and teardown:: + def test_user_model(app): + user = User() - def my_function(): with app.app_context(): - user = db.User(...) db.session.add(user) db.session.commit() -Some functions inside Flask-SQLAlchemy also accept optionally the -application to operate on: +If you find yourself writing many tests like that, you can use a pytest fixture to push +a context for a specific test. + +.. code-block:: python + + import pytest + + @pytest.mark.fixture + def app_ctx(app): + with app.app_context(): + yield ->>> from yourapp import db, create_app ->>> db.create_all(app=create_app()) + @pytest.mark.usefixtures("app_ctx") + def test_user_model(app): + user = User() + db.session.add(user) + db.session.commit() diff --git a/docs/customizing.rst b/docs/customizing.rst index 8eb2290a..a0c42613 100644 --- a/docs/customizing.rst +++ b/docs/customizing.rst @@ -1,45 +1,35 @@ -.. _customizing: +Advanced Customization +====================== -.. currentmodule:: flask_sqlalchemy - -Customizing -=========== - -Flask-SQLAlchemy defines sensible defaults. However, sometimes customization is -needed. There are various ways to customize how the models are defined and -interacted with. - -These customizations are applied at the creation of the :class:`SQLAlchemy` -object and extend to all models derived from its ``Model`` class. +The various objects managed by the extension can be customized by passing arguments to +the :class:`.SQLAlchemy` constructor. Model Class ----------- -SQLAlchemy models all inherit from a declarative base class. This is exposed -as ``db.Model`` in Flask-SQLAlchemy, which all models extend. This can be -customized by subclassing the default and passing the custom class to -``model_class``. +SQLAlchemy models all inherit from a declarative base class. This is exposed as +``db.Model`` in Flask-SQLAlchemy, which all models extend. This can be customized by +subclassing the default and passing the custom class to ``model_class``. -The following example gives every model an integer primary key, or a foreign -key for joined-table inheritance. +The following example gives every model an integer primary key, or a foreign key for +joined-table inheritance. .. note:: + Integer primary keys for everything is not necessarily the best database design + (that's up to your project's requirements), this is only an example. - Integer primary keys for everything is not necessarily the best database - design (that's up to your project's requirements), this is only an example. +.. code-block:: python -:: - - from flask_sqlalchemy import Model, SQLAlchemy + from flask_sqlalchemy.model import model import sqlalchemy as sa - from sqlalchemy.ext.declarative import declared_attr + import sqlalchemy.orm class IdModel(Model): - @declared_attr + @sa.orm.declared_attr def id(cls): for base in cls.__mro__[1:-1]: - if getattr(base, '__table__', None) is not None: + if getattr(base, "__table__", None) is not None: type = sa.ForeignKey(base.id) break else: @@ -56,101 +46,147 @@ key for joined-table inheritance. title = db.Column(db.String) -Model Mixins ------------- +Abstract Models and Mixins +-------------------------- + +If behavior is only needed on some models rather than all models, use an abstract model +base class to customize only those models. For example, if some models should track when +they are created or updated. -If behavior is only needed on some models rather than all models, use mixin -classes to customize only those models. For example, if some models should -track when they are created or updated:: +.. code-block:: python from datetime import datetime - class TimestampMixin(object): - created = db.Column( - db.DateTime, nullable=False, default=datetime.utcnow) + class TimestampModel(db.Model): + __abstract__ = True + created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) updated = db.Column(db.DateTime, onupdate=datetime.utcnow) class Author(db.Model): ... + class Post(TimestampModel): + ... + +This can also be done with a mixin class, inheriting from ``db.Model`` separately. + +.. code-block:: python + + class TimestampMixin: + created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) + updated = db.Column(db.DateTime, onupdate=datetime.utcnow) + class Post(TimestampMixin, db.Model): ... +Session Class +------------- + +Flask-SQLAlchemy's :class:`.Session` class chooses which engine to query based on the +bind key associated with the model or table. However, there are other strategies such as +horizontal sharding that can be implemented with a different session class. The +``class_`` key to the ``session_options`` argument to the extension to change the +session class. + +Flask-SQLAlchemy will always pass the extension instance as the ``db`` argument to the +session, so it must accept that to continue working. That can be used to get access to +``db.engines``. + +.. code-block:: python + + from sqlalchemy.ext.horizontal_shard import ShardedSession + from flask_sqlalchemy.session import Session + + class CustomSession(ShardedSession, Session): + ... + + db = SQLAlchemy(session_options={"class_": CustomSession}) + + Query Class ----------- -It is also possible to customize what is available for use on the -special ``query`` property of models. For example, providing a -``get_or`` method:: +.. warning:: + The query interface is considered legacy in SQLAlchemy 2.0. This includes + ``session.query``, ``Model.query``, ``db.Query``, and ``lazy="dynamic"`` + relationships. Prefer using selects instead of the query class. + +It is possible to customize the query interface used by the session, models, and +relationships. This can be used to add extra query methods. For example, you could add +a ``get_or`` method that gets a row or returns a default. + +.. code-block:: python - from flask_sqlalchemy import Query, SQLAlchemy + from flask_sqlalchemy.query import Query class GetOrQuery(Query): def get_or(self, ident, default=None): - return self.get(ident) or default + out = self.get(ident) + + if out is None: + return default + + return out db = SQLAlchemy(query_class=GetOrQuery) - # get a user by id, or return an anonymous user instance user = User.query.get_or(user_id, anonymous_user) -And now all queries executed from the special ``query`` property -on Flask-SQLAlchemy models can use the ``get_or`` method as part -of their queries. All relationships defined with -``db.relationship`` (but not :func:`sqlalchemy.orm.relationship`) -will also be provided with this functionality. +Passing the ``query_class`` argument will customize ``db.Query``, ``db.session.query``, +``Model.query``, and ``db.relationship(lazy="dynamic")`` relationships. It's also +possible to customize these on a per-object basis. -It also possible to define a custom query class for individual -relationships as well, by providing the ``query_class`` keyword -in the definition. This works with both ``db.relationship`` -and ``sqlalchemy.relationship``:: +To customize a specific model's ``query`` property, set the ``query_class`` attribute on +the model class. - class MyModel(db.Model): - cousin = db.relationship('OtherModel', query_class=GetOrQuery) +.. code-block:: python -.. note:: + class User(db.Model): + query_class = GetOrQuery - If a query class is defined on a relationship, it will take precedence over - the query class attached to its corresponding model. +To customize a specific dynamic relationship, pass the ``query_class`` argument to the +relationship. -It is also possible to define a specific query class for individual models -by overriding the ``query_class`` class attribute on the model:: +.. code-block:: python - class MyModel(db.Model): - query_class = GetOrQuery + db.relationship(User, lazy="dynamic", query_class=GetOrQuery) + +To customize only ``session.query``, pass the ``query_cls`` key to the +``session_options`` argument to the constructor. + +.. code-block:: python -In this case, the ``get_or`` method will be only availble on queries -orginating from ``MyModel.query``. + db = SQLAlchemy(session_options={"query_cls": GetOrQuery}) Model Metaclass --------------- .. warning:: + Metaclasses are an advanced topic, and you probably don't need to customize them to + achieve what you want. It is mainly documented here to show how to disable table + name generation. - Metaclasses are an advanced topic, and you probably don't need to customize - them to achieve what you want. It is mainly documented here to show how to - disable table name generation. +The model metaclass is responsible for setting up the SQLAlchemy internals when defining +model subclasses. Flask-SQLAlchemy adds some extra behaviors through mixins; its default +metaclass, :class:`~.DefaultMeta`, inherits them all. -The model metaclass is responsible for setting up the SQLAlchemy internals when -defining model subclasses. Flask-SQLAlchemy adds some extra behaviors through -mixins; its default metaclass, :class:`~model.DefaultMeta`, inherits them all. - -* :class:`~model.BindMetaMixin`: ``__bind_key__`` is extracted from the class - and applied to the table. See :ref:`binds`. -* :class:`~model.NameMetaMixin`: If the model does not specify a - ``__tablename__`` but does specify a primary key, a name is automatically - generated. +- :class:`.BindMetaMixin`: ``__bind_key__`` sets the bind to use for the model. +- :class:`.NameMetaMixin`: If the model does not specify a ``__tablename__`` but does + specify a primary key, a name is automatically generated. You can add your own behaviors by defining your own metaclass and creating the -declarative base yourself. Be sure to still inherit from the mixins you want -(or just inherit from the default metaclass). +declarative base yourself. Be sure to still inherit from the mixins you want (or just +inherit from the default metaclass). + +Passing a declarative base class instead of a simple model base class to ``model_class`` +will cause Flask-SQLAlchemy to use this base instead of constructing one with the +default metaclass. -Passing a declarative base class instead of a simple model base class, as shown -above, to ``base_class`` will cause Flask-SQLAlchemy to use this base instead -of constructing one with the default metaclass. :: +.. code-block:: python + from sqlalchemy.orm import declarative_base from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import DefaultMeta, Model @@ -163,28 +199,32 @@ of constructing one with the default metaclass. :: # custom class-only methods could go here - db = SQLAlchemy(model_class=declarative_base( - cls=Model, metaclass=CustomMeta, name='Model')) + CustomModel = declarative_base(cls=Model, metaclass=CustomMeta, name="Model") + db = SQLAlchemy(model_class=CustomModel) You can also pass whatever other arguments you want to -:func:`~sqlalchemy.ext.declarative.declarative_base` to customize the base -class as needed. +:func:`~sqlalchemy.orm.declarative_base` to customize the base class. + Disabling Table Name Generation ``````````````````````````````` -Some projects prefer to set each model's ``__tablename__`` manually rather than -relying on Flask-SQLAlchemy's detection and generation. The table name -generation can be disabled by defining a custom metaclass. :: +Some projects prefer to set each model's ``__tablename__`` manually rather than relying +on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to +set each ``__tablename__`` and not modify the base class. However, the table name +generation can be disabled by defining a custom metaclass with only the +``BindMetaMixin`` and not the ``NameMetaMixin``. + +.. code-block:: python + from sqlalchemy.orm import DeclarativeMeta, declarative_base from flask_sqlalchemy.model import BindMetaMixin, Model - from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base class NoNameMeta(BindMetaMixin, DeclarativeMeta): pass - db = SQLAlchemy(model_class=declarative_base( - cls=Model, metaclass=NoNameMeta, name='Model')) + CustomModel = declarative_base(cls=Model, metaclass=NoNameMeta, name="Model") + db = SQLAlchemy(model_class=CustomModel) -This creates a base that still supports the ``__bind_key__`` feature but does -not generate table names. +This creates a base that still supports the ``__bind_key__`` feature but does not +generate table names. diff --git a/docs/index.rst b/docs/index.rst index eb5d55e6..bbbace9a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,48 +6,52 @@ Flask-SQLAlchemy .. image:: _static/flask-sqlalchemy-title.png :align: center -Flask-SQLAlchemy is an extension for `Flask`_ that adds support for `SQLAlchemy`_ to your -application. It aims to simplify using SQLAlchemy with Flask by providing useful defaults and extra -helpers that make it easier to accomplish common tasks. +Flask-SQLAlchemy is an extension for `Flask`_ that adds support for `SQLAlchemy`_ to +your application. It simplifies using SQLAlchemy with Flask by setting up common objects +and patterns for using those objects, such as a session tied to each web request, models, +and engines. -.. _SQLAlchemy: https://www.sqlalchemy.org/ -.. _Flask: https://palletsprojects.com/p/flask/ - -See `the SQLAlchemy documentation`_ to learn how to work with the ORM in depth. The following -documentation is a brief overview of the most common tasks, as well as the features specific to -Flask-SQLAlchemy. +Flask-SQLAlchemy does not change how SQLAlchemy works or is used. See the +`SQLAlchemy documentation`_ to learn how to work with the ORM in depth. The +documentation here will only cover setting up the extension, not how to use SQLAlchemy. -.. _the SQLAlchemy documentation: https://docs.sqlalchemy.org/en/latest/ +.. _SQLAlchemy: https://www.sqlalchemy.org/ +.. _Flask: https://flask.palletsprojects.com/ +.. _SQLAlchemy documentation: https://docs.sqlalchemy.org/ User Guide ---------- .. toctree:: - :maxdepth: 2 + :maxdepth: 2 + + quickstart + config + models + queries + pagination + contexts + binds + record-queries + track-modifications + customizing - quickstart - contexts - config - models - queries - binds - signals - customizing API Reference ------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 2 + + api - api Additional Information ---------------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 2 - license - changes + license + changes diff --git a/docs/make.bat b/docs/make.bat index 7893348a..954237b9 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -10,8 +10,6 @@ if "%SPHINXBUILD%" == "" ( set SOURCEDIR=. set BUILDDIR=_build -if "%1" == "" goto help - %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. @@ -21,15 +19,17 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd diff --git a/docs/models.rst b/docs/models.rst index c355e4bc..71a1a28b 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -1,169 +1,94 @@ -.. _models: +Models and Tables +================= -.. currentmodule:: flask_sqlalchemy +Use the ``db.Model`` class to define models, or the ``db.Table`` class to create tables. +Both handle Flask-SQLAlchemy's bind keys to associate with a specific engine. -Declaring Models -================ -Generally Flask-SQLAlchemy behaves like a properly configured declarative -base from the :mod:`~sqlalchemy.ext.declarative` extension. As such we -recommend reading the SQLAlchemy docs for a full reference. However the -most common use cases are also documented here. +Defining Models +--------------- -Things to keep in mind: +See SQLAlchemy's `declarative documentation`_ for full information about defining model +classes declaratively. -- The baseclass for all your models is called ``db.Model``. It's stored - on the SQLAlchemy instance you have to create. See :ref:`quickstart` - for more details. -- Some parts that are required in SQLAlchemy are optional in - Flask-SQLAlchemy. For instance the table name is automatically set - for you unless overridden. It's derived from the class name converted - to lowercase and with “CamelCase” converted to “camel_case”. To override - the table name, set the ``__tablename__`` class attribute. +.. _declarative documentation: https://docs.sqlalchemy.org/orm/declarative_tables.html -Simple Example --------------- +Subclass ``db.Model`` to create a model class. This is a SQLAlchemy declarative base +class, it will take ``Column`` attributes and create a table. Unlike plain SQLAlchemy, +Flask-SQLAlchemy's model will automatically generate a table name if ``__tablename__`` +is not set and a primary key column is defined. -A very simple example:: +.. code-block:: python + + import sqlalchemy as sa class User(db.Model): - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String(80), unique=True, nullable=False) - email = db.Column(db.String(120), unique=True, nullable=False) - - def __repr__(self): - return f"" - -Use :class:`~sqlalchemy.schema.Column` to define a column. The name of the -column is the name you assign it to. If you want to use a different name -in the table you can provide an optional first argument which is a string -with the desired column name. Primary keys are marked with -``primary_key=True``. Multiple keys can be marked as primary keys in -which case they become a compound primary key. - -The types of the column are the first argument to -:class:`~sqlalchemy.schema.Column`. You can either provide them directly -or call them to further specify them (like providing a length). The -following types are the most common: - -================================================ ===================================== -:class:`~sqlalchemy.types.Integer` an integer -:class:`String(size) ` a string, size is optional in some - databases, including SQLite and - PostgreSQL -:class:`~sqlalchemy.types.Text` some longer text -:class:`~sqlalchemy.types.DateTime` date and time expressed as Python - :class:`~datetime.datetime` object. -:class:`~sqlalchemy.types.Float` stores floating point values -:class:`~sqlalchemy.types.Boolean` stores a boolean value -:class:`~sqlalchemy.types.PickleType` stores a pickled Python object -:class:`~sqlalchemy.types.LargeBinary` stores large arbitrary binary data -================================================ ===================================== - -One-to-Many Relationships -------------------------- - -The most common relationships are one-to-many relationships. Because -relationships are declared before they are established you can use strings -to refer to classes that are not created yet (for instance if ``Person`` -defines a relationship to ``Address`` which is declared later in the file). - -Relationships are expressed with the :func:`~sqlalchemy.orm.relationship` -function. However the foreign key has to be separately declared with the -:class:`~sqlalchemy.schema.ForeignKey` class:: - - class Person(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(50), nullable=False) - addresses = db.relationship('Address', backref='person', lazy=True) - - class Address(db.Model): - id = db.Column(db.Integer, primary_key=True) - email = db.Column(db.String(120), nullable=False) - person_id = db.Column(db.Integer, db.ForeignKey('person.id'), - nullable=False) - -What does :func:`db.relationship() ` do? -That function returns a new property that can do multiple things. -In this case we told it to point to the ``Address`` class and load -multiple of those. How does it know that this will return more than -one address? Because SQLAlchemy guesses a useful default from your -declaration. If you would want to have a one-to-one relationship you -can pass ``uselist=False`` to :func:`~sqlalchemy.orm.relationship`. - -Since a person with no name or an email address with no address associated -makes no sense, ``nullable=False`` tells SQLAlchemy to create the column -as ``NOT NULL``. This is implied for primary key columns, but it's a good -idea to specify it for all other columns to make it clear to other people -working on your code that you did actually want a nullable column and did -not just forget to add it. - -So what do ``backref`` and ``lazy`` mean? ``backref`` is a simple way to also -declare a new property on the ``Address`` class. You can then also use -``my_address.person`` to get to the person at that address. ``lazy`` defines -when SQLAlchemy will load the data from the database: - -- ``'select'`` / ``True`` (which is the default, but explicit is better - than implicit) means that SQLAlchemy will load the data as necessary - in one go using a standard ``SELECT`` statement. -- ``'joined'`` / ``False`` tells SQLAlchemy to load the relationship in - the same query as the parent using a ``JOIN`` statement. -- ``'subquery'`` works like ``'joined'`` but instead SQLAlchemy will - use a subquery. -- ``'dynamic'`` is special and can be useful if you have many items - and always want to apply additional SQL filters to them. - Instead of loading the items SQLAlchemy will return another query - object which you can further refine before loading the items. - Note that this cannot be turned into a different loading strategy - when querying so it's often a good idea to avoid using this in - favor of ``lazy=True``. A query object equivalent to a dynamic - ``user.addresses`` relationship can be created using - :meth:`Address.query.with_parent(user) ` - while still being able to use - lazy or eager loading on the relationship itself as necessary. - -How do you define the lazy status for backrefs? By using the -:func:`~sqlalchemy.orm.backref` function:: - - class Person(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(50), nullable=False) - addresses = db.relationship('Address', lazy='select', - backref=db.backref('person', lazy='joined')) - -Many-to-Many Relationships --------------------------- - -If you want to use many-to-many relationships you will need to define a -helper table that is used for the relationship. For this helper table it -is strongly recommended to *not* use a model but an actual table:: - - tags = db.Table('tags', - db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True), - db.Column('page_id', db.Integer, db.ForeignKey('page.id'), primary_key=True) + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.String) + +For convenience, the extension object provides access to names in the ``sqlalchemy`` and +``sqlalchemy.orm`` modules. So you can use ``db.Column`` instead of importing and using +``sqlalchemy.Column``, although the two are equivalent. + +Defining a model does not create it in the database. Use :meth:`~.SQLAlchemy.create_all` +to create the models and tables after defining them. If you define models in submodules, +you must import them so that SQLAlchemy knows about them before calling ``create_all``. + +.. code-block:: python + + with app.app_context(): + db.create_all() + + +Defining Tables +--------------- + +See SQLAlchemy's `table documentation`_ for full information about defining table +objects. + +.. _table documentation: https://docs.sqlalchemy.org/core/metadata.html + +Create instances of ``db.Table`` to define tables. The class takes a table name, then +any columns and other table parts such as columns and constraints. Unlike plain +SQLAlchemy, the ``metadata`` argument is not required. A metadata will be chosen based +on the ``bind_key`` argument, or the default will be used. + +A common reason to create a table directly is when defining many to many relationships. +The association table doesn't need its own model class, as it will be accessed through +the relevant relationship attributes on the related models. + +.. code-block:: python + + import sqlalchemy as sa + + user_book_m2m = db.Table( + "user_book", + sa.Column("user_id", sa.ForeignKey(User.id), primary_key=True), + sa.Column("book_id", sa.ForeignKey(Book.id), primary_key=True), ) - class Page(db.Model): - id = db.Column(db.Integer, primary_key=True) - tags = db.relationship('Tag', secondary=tags, lazy='subquery', - backref=db.backref('pages', lazy=True)) - - class Tag(db.Model): - id = db.Column(db.Integer, primary_key=True) - -Here we configured ``Page.tags`` to be loaded immediately after loading -a Page, but using a separate query. This always results in two -queries when retrieving a Page, but when querying for multiple pages -you will not get additional queries. - -The list of pages for a tag on the other hand is something that's -rarely needed. For example, you won't need that list when retrieving -the tags for a specific page. Therefore, the backref is set to be -lazy-loaded so that accessing it for the first time will trigger a -query to get the list of pages for that tag. If you need to apply -further query options on that list, you could either switch to the -``'dynamic'`` strategy - with the drawbacks mentioned above - or get -a query object using -:meth:`Page.query.with_parent(some_tag) ` -and then use it exactly as you would with the query object from a dynamic -relationship. + +Reflecting Tables +----------------- + +If you are connecting to a database that already has tables, SQLAlchemy can detect that +schema and create tables with columns automatically. This is called reflection. Those +tables can also be assigned to model classes with the ``__table__`` attribute instead of +defining the full model. + +Call the :meth:`~.SQLAlchemy.reflect` method on the extension. It will reflect all the +tables for each bind key. Each metadata's ``table`` attribute will contain the detected +table objects. + +.. code-block:: python + + with app.app_context(): + db.reflect() + + class User: + __table__ = db.metadata["user"] + +In most cases, it will be more maintainable to define the model classes yourself. You +only need to define the models and columns you will actually use, even if you're +connecting to a broader schema. IDEs will know the available attributes, and migration +tools like Alembic can detect changes and generate schema migrations. diff --git a/docs/pagination.rst b/docs/pagination.rst new file mode 100644 index 00000000..b985803e --- /dev/null +++ b/docs/pagination.rst @@ -0,0 +1,74 @@ +Paging Query Results +==================== + +If you have a lot of results, you may only want to show a certain number at a time, +allowing the user to click next and previous links to see pages of data. This is +sometimes called *pagination*, and uses the verb *paginate*. + +Pagination is currently available through the ``Model.query`` and ``session.query`` +interfaces by calling the :meth:`.Query.paginate` method. This returns a +:class:`.Pagination` object. + +During a request, this will take ``page`` and ``per_page`` arguments from the query +string ``request.args``. Pass ``max_per_page`` to prevent users from requesting too many +results on a single page. If not given, the default values will be page 1 with 20 items +per page. + +.. code-block:: python + + page = User.query.order_by(User.join_date).paginate() + return render_template("user/list.html", page=page) + + +Showing the Items +----------------- + +The :class:`.Pagination` object's :attr:`.Pagination.items` attribute is the list of +items for the current page. The object can also be iterated over directly. + +.. code-block:: jinja + +
    + {% for user in page %} +
  • {{ user.username }} + {% endfor %} +
+ + +Page Selection Widget +--------------------- + +The :class:`.Pagination` object has attributes that can be used to create a page +selection widget by iterating over page numbers and checking the current page. +:meth:`~.Pagination.iter_pages` will produce up to three groups of numbers, separated by +``None``. It defaults to showing 2 page numbers at either edge, 2 numbers before the +current, the current, and 4 numbers after the current. For example, if there are 20 +pages and the current page is 7, the following values are yielded. + +.. code-block:: python + + users.iter_pages() + [1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20] + +The following Jinja macro renders a simple pagination widget. + +.. code-block:: jinja + + {% macro render_pagination(pagination, endpoint) %} + + {% endmacro %} + +You might also use the :attr:`~.Pagination.total` attribute to show the total number of +results. diff --git a/docs/queries.rst b/docs/queries.rst index 05e69c67..de2b887f 100644 --- a/docs/queries.rst +++ b/docs/queries.rst @@ -1,133 +1,103 @@ -.. currentmodule:: flask_sqlalchemy +Modifying and Querying Data +=========================== -Select, Insert, Delete -====================== -Now that you have :ref:`declared models ` it's time to query the -data from the database. We will be using the model definitions from the -:ref:`quickstart` chapter. +Insert, Update, Delete +---------------------- -Inserting Records ------------------ +See SQLAlchemy's `ORM tutorial`_ and other SQLAlchemy documentation for more information +about modifying data with the ORM. -Before we can query something we will have to insert some data. All your -models should have a constructor, so make sure to add one if you forgot. -Constructors are only used by you, not by SQLAlchemy internally so it's -entirely up to you how you define them. +.. _ORM tutorial: https://docs.sqlalchemy.org/tutorial/orm_data_manipulation.html -Inserting data into the database is a three step process: +To insert data, pass the model object to ``db.session.add()``: -1. Create the Python object -2. Add it to the session -3. Commit the session +.. code-block:: python -The session here is not the Flask session, but the Flask-SQLAlchemy one. -It is essentially a beefed up version of a database transaction. This is -how it works: + user = User() + db.session.add(user) + db.session.commit() ->>> from yourapp import User ->>> me = User(username='admin', email='admin@example.com') ->>> db.session.add(me) ->>> db.session.commit() +To update data, modify attributes on the model objects: -Alright, that was not hard. What happens at what point? Before you add -the object to the session, SQLAlchemy basically does not plan on adding it -to the transaction. That is good because you can still discard the -changes. For example think about creating the post at a page but you only -want to pass the post to the template for preview rendering instead of -storing it in the database. +.. code-block:: python -The :func:`~sqlalchemy.orm.session.Session.add` function call then adds -the object. It will issue an `INSERT` statement for the database but -because the transaction is still not committed you won't get an ID back -immediately. If you do the commit, your user will have an ID: + user.verified = True + db.session.commit() ->>> me.id -1 +To delete data, pass the model object to ``db.session.delete()``: -Deleting Records ----------------- +.. code-block:: python -Deleting records is very similar, instead of -:func:`~sqlalchemy.orm.session.Session.add` use -:func:`~sqlalchemy.orm.session.Session.delete`: + db.session.delete(user) + db.session.commit() ->>> db.session.delete(me) ->>> db.session.commit() +After modifying data, you must call ``db.session.commit()`` to commit the changes to +the database. Otherwise, they will be discarded at the end of the request. -Querying Records ----------------- -So how do we get data back out of our database? For this purpose -Flask-SQLAlchemy provides a :attr:`~Model.query` attribute on your -:class:`Model` class. When you access it you will get back a new query -object over all records. You can then use methods like -:func:`~sqlalchemy.orm.query.Query.filter` to filter the records before -you fire the select with :func:`~sqlalchemy.orm.query.Query.all` or -:func:`~sqlalchemy.orm.query.Query.first`. If you want to go by -primary key you can also use :func:`~sqlalchemy.orm.query.Query.get`. +Select +------ -The following queries assume following entries in the database: +See SQLAlchemy's `Querying Guide`_ and other SQLAlchemy documentation for more +information about querying data with the ORM. -=========== =========== ===================== -`id` `username` `email` -1 admin admin@example.com -2 peter peter@example.org -3 guest guest@example.com -=========== =========== ===================== +.. _Querying Guide: https://docs.sqlalchemy.org/orm/queryguide.html -Retrieve a user by username: +Queries are executed through ``db.session.execute()``. They can be constructed +using :func:`~sqlalchemy.sql.expression.select`. Executing a select returns a +:class:`~sqlalchemy.engine.Result` object that has many methods for working with the +returned rows. ->>> peter = User.query.filter_by(username='peter').first() ->>> peter.id -2 ->>> peter.email -'peter@example.org' +.. code-block:: python -Same as above but for a non existing username gives `None`: + user = db.session.execute(db.select(User).filter_by(username=username)).one() ->>> missing = User.query.filter_by(username='missing').first() ->>> missing is None -True -Selecting a bunch of users by a more complex expression: +Legacy Query Interface +---------------------- ->>> User.query.filter(User.email.endswith('@example.com')).all() -[, ] +.. warning:: + SQLAlchemy 2.0 has designated the ``Query`` interface as "legacy". It will no + longer be updated and may be deprecated in the future. Prefer using + ``db.session.execute(db.select(...))`` instead. -Ordering users by something: +Flask-SQLAlchemy adds a ``query`` object to each model. This can be used to query +instances of a given model. ``User.query`` is a shortcut for ``db.session.query(User)``. ->>> User.query.order_by(User.username).all() -[, , ] +.. code-block:: python -Limiting users: + # get the user with id 5 + user = User.query.get(5) ->>> User.query.limit(1).all() -[] + # get a user by username + user = User.query.filter_by(username=username).one() -Getting user by primary key: ->>> User.query.get(1) - +Queries for Views +````````````````` +If you write a Flask view function it's often useful to return a ``404 Not Found`` error +for missing entries. Flask-SQLAlchemy provides some extra query methods. -Queries in Views ----------------- +- :meth:`.Query.get_or_404` will raise a 404 if the row with the given id doesn't + exist, otherwise it will return the instance. +- :meth:`.Query.first_or_404` will raise a 404 if the query does not return any + results, otherwise it will return the first result. +- :meth:`.Query.one_or_404` will raise a 404 if the query does not return exactly one + result, otherwise it will return the result. -If you write a Flask view function it's often very handy to return a 404 -error for missing entries. Because this is a very common idiom, -Flask-SQLAlchemy provides a helper for this exact purpose. Instead of -:meth:`~sqlalchemy.orm.query.Query.get` one can use -:meth:`~Query.get_or_404` and instead of -:meth:`~sqlalchemy.orm.query.Query.first` :meth:`~Query.first_or_404`. -This will raise 404 errors instead of returning `None`:: +.. code-block:: python - @app.route('/user/') + @app.route("/user/") def show_user(username): - user = User.query.filter_by(username=username).first_or_404() - return render_template('show_user.html', user=user) + user = User.query.filter_by(username=username).one_or_404() + return render_template("show_user.html", user=user) +You can add a custom message to the 404 error: -Also, if you want to add a description with abort(), you can use it as argument as well. + .. code-block:: python ->>> User.query.filter_by(username=username).first_or_404(description=f"There is no data with {username}") + user = User.query.filter_by(username=username).one_or_404( + description=f"No user named '{username}'." + ) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index acbeaffd..1dcb0871 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -1,187 +1,198 @@ .. _quickstart: -Quickstart -========== +Quick Start +=========== .. currentmodule:: flask_sqlalchemy -Flask-SQLAlchemy is fun to use, incredibly easy for basic applications, and -readily extends for larger applications. For the complete guide, checkout -the API documentation on the :class:`SQLAlchemy` class. +Flask-SQLAlchemy simplifies using SQLAlchemy by automatically handling creating, using, +and cleaning up the SQLAlchemy objects you'd normally work with. While it adds a few +useful features, it still works like SQLAlchemy. + +This page will walk you through the basic use of Flask-SQLAlchemy. For full capabilities +and customization, see the rest of these docs, including the API docs for the +:class:`SQLAlchemy` object. + + +Check the SQLAlchemy Documentation +---------------------------------- + +Flask-SQLAlchemy is a wrapper around SQLAlchemy. You should follow the +`SQLAlchemy Tutorial`_ to learn about how to use it, and consult its documentation +for detailed information about its features. These docs show how to set up +Flask-SQLAlchemy itself, not how to use SQLAlchemy. Flask-SQLAlchemy sets up the +engine, declarative model class, and scoped session automatically, so you can skip those +parts of the SQLAlchemy tutorial. + +.. _SQLAlchemy Tutorial: https://docs.sqlalchemy.org/tutorial/index.html + Installation ------------ -Install and update using `pip `_:: +Flask-SQLAlchemy is available on `PyPI`_ and can be installed with various Python tools. +For example, to install or update the latest version using pip: + +.. code-block:: text $ pip install -U Flask-SQLAlchemy -A Minimal Application ---------------------- +.. _PyPI: https://pypi.org/project/Flask-SQLAlchemy/ + + +Configure the Extension +----------------------- -For the common case of having one Flask application all you have to do is -to create your Flask application, load the configuration of choice and -then create the :class:`SQLAlchemy` object by passing it the application. +The only required Flask app config is the :data:`.SQLALCHEMY_DATABASE_URI` key. That +is a connection string that tells SQLAlchemy what database to connect to. -Once created, that object then contains all the functions and helpers -from both :mod:`sqlalchemy` and :mod:`sqlalchemy.orm`. Furthermore it -provides a class called ``Model`` that is a declarative base which can be -used to declare models:: +Create your Flask application object, load any config, and then initialize the +:class:`SQLAlchemy` extension class with the application by calling +:meth:`db.init_app <.SQLAlchemy.init_app>`. This example connects to a SQLite database, +which is stored in the app's instance folder. + +.. code-block:: python from flask import Flask from flask_sqlalchemy import SQLAlchemy + # create the extension + db = SQLAlchemy() + # create the app app = Flask(__name__) - app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:////tmp/test.db' - db = SQLAlchemy(app) - - - class User(db.Model): - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String(80), unique=True, nullable=False) - email = db.Column(db.String(120), unique=True, nullable=False) + # configure the SQLite database, relative to the app instance folder + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///project.db" + # initialize the app with the extension + db.init_app(app) - def __repr__(self): - return f"" +The ``db`` object gives you access to the :attr:`db.Model <.SQLAlchemy.Model>` class to +define models, and the :attr:`db.session <.SQLAlchemy.session>` to execute queries. -To create the initial database schema, just import the ``db`` object from an -interactive Python shell and run the -:meth:`SQLAlchemy.create_all` method to create the -tables (note that SQLite will create the database as well, this is not true in general):: +See :doc:`config` for an explanation of connections strings and what other configuration +keys are used. The :class:`SQLAlchemy` object also takes some arguments to customize the +objects it manages. - >>> from yourapplication import db - >>> db.create_all() -Boom, and there is your database. Now to create some users:: +Define Models +------------- - >>> from yourapplication import User - >>> admin = User(username='admin', email='admin@example.com') - >>> guest = User(username='guest', email='guest@example.com') +Subclass ``db.Model`` to define a model class. The ``db`` object makes the names in +``sqlalchemy`` and ``sqlalchemy.orm`` available for convenience, such as ``db.Column``. +The model will generate a table name by converting the ``CamelCase`` class name to +``snake_case``. -But they are not yet in the database, so let's make sure they are:: +.. code-block:: python - >>> db.session.add(admin) - >>> db.session.add(guest) - >>> db.session.commit() + class User(db.Model): + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String, unique=True, nullable=False) + email = db.Column(db.String) -Accessing the data in database is easy as a pie:: +The table name ``"user"`` will automatically be assigned to the model's table. - >>> User.query.all() - [, ] - >>> User.query.filter_by(username='admin').first() - +See :doc:`models` for more information about defining and creating models and tables. -Note how we never defined a ``__init__`` method on the ``User`` class? -That's because SQLAlchemy adds an implicit constructor to all model -classes which accepts keyword arguments for all its columns and -relationships. If you decide to override the constructor for any -reason, make sure to keep accepting ``**kwargs`` and call the super -constructor with those ``**kwargs`` to preserve this behavior:: - class Foo(db.Model): - # ... - def __init__(self, **kwargs): - super(Foo, self).__init__(**kwargs) - # do custom stuff +Create the Tables +----------------- -Simple Relationships --------------------- +After all models and tables are defined, call :meth:`.SQLAlchemy.create_all` to create +the table schema in the database. This requires an application context. Since you're not +in a request at this point, create one manually. -SQLAlchemy connects to relational databases and what relational databases -are really good at are relations. As such, we shall have an example of an -application that uses two tables that have a relationship to each other:: +.. code-block:: python - from datetime import datetime + with app.app_context(): + db.create_all() +If you define models in other modules, you must import them before calling +``create_all``, otherwise SQLAlchemy will not know about them. - class Post(db.Model): - id = db.Column(db.Integer, primary_key=True) - title = db.Column(db.String(80), nullable=False) - body = db.Column(db.Text, nullable=False) - pub_date = db.Column(db.DateTime, nullable=False, - default=datetime.utcnow) +``create_all`` does not update tables if they are already in the database. If you change +a model's columns, use a migration library like `Alembic`_ with `Flask-Alembic`_ or +`Flask-Migrate`_ to generate migrations that update the database schema. - category_id = db.Column(db.Integer, db.ForeignKey('category.id'), - nullable=False) - category = db.relationship('Category', - backref=db.backref('posts', lazy=True)) +.. _Alembic: https://alembic.sqlalchemy.org/ +.. _Flask-Alembic: https://flask-alembic.readthedocs.io/ +.. _Flask-Migrate: https://flask-migrate.readthedocs.io/ - def __repr__(self): - return f"" +Query the Data +-------------- - class Category(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(50), nullable=False) +Within a Flask view or CLI command, you can use ``db.session`` to execute queries and +modify model data. - def __repr__(self): - return f"" +SQLAlchemy automatically defines an ``__init__`` method for each model that assigns any +keyword arguments to corresponding database columns and other attributes. -First let's create some objects:: +``db.session.add(obj)`` adds an object to the session, to be inserted. Modifying an +object's attributes updates the object. ``db.session.delete(obj)`` deletes an object. +Remember to call ``db.session.commit()`` after modifying, adding, or deleting any data. - >>> py = Category(name='Python') - >>> Post(title='Hello Python!', body='Python is pretty cool', category=py) - >>> p = Post(title='Snakes', body='Ssssssss') - >>> py.posts.append(p) - >>> db.session.add(py) +``db.session.execute(db.select(...))`` constructs a query to select data from the +database. Building queries is the main feature of SQLAlchemy, so you'll want to read its +`tutorial on select`_ to learn all about it. -As you can see, there is no need to add the ``Post`` objects to the -session. Since the ``Category`` is part of the session all objects -associated with it through relationships will be added too. It does -not matter whether :meth:`db.session.add() ` -is called before or after creating these objects. The association can -also be done on either side of the relationship - so a post can be -created with a category or it can be added to the list of posts of -the category. +.. _tutorial on select: https://docs.sqlalchemy.org/tutorial/data_select.html -Let's look at the posts. Accessing them will load them from the database -since the relationship is lazy-loaded, but you will probably not notice -the difference - loading a list is quite fast:: +.. code-block:: python - >>> py.posts - [, ] + @app.route("/users") + def user_list(): + users = db.session.execute(db.select(User).order_by(User.username)).all() + return render_template("user/list.html", users=users) -While lazy-loading a relationship is fast, it can easily become a major -bottleneck when you end up triggering extra queries in a loop for more -than a few objects. For this case, SQLAlchemy lets you override the -loading strategy on the query level. If you wanted a single query to -load all categories and their posts, you could do it like this:: + @app.route("/users/create", methods=["GET", "POST"]) + def user_create(): + if request.method == "POST": + user = User( + username=request.form["username"], + email=request.form["email"], + ) + db.session.add(user) + db.session.commit() + return redirect(url_for("user_detail", id=user.id)) - >>> from sqlalchemy.orm import joinedload - >>> query = Category.query.options(joinedload('posts')) - >>> for category in query: - ... print(category, category.posts) - [, ] + return render_template("user/create.html") + @app.route("/user/") + def user_detail(id): + user = User.query.get_or_404(id) + return render_template("user/detail.html", user=user) -If you want to get a query object for that relationship, you can do so -using :meth:`~sqlalchemy.orm.query.Query.with_parent`. Let's exclude -that post about Snakes for example:: + @app.route("/user//delete", methods=["GET", "POST"]) + def user_delete(id): + user = User.query.get_or_404(id) - >>> Post.query.with_parent(py).filter(Post.title != 'Snakes').all() - [] + if request.method == "POST": + db.session.delete(user) + db.session.commit + return redirect(url_for("user_list")) + return render_template("user/delete.html", user=user) -Road to Enlightenment ---------------------- +You may see uses of ``Model.query`` to build queries. This is an older interface for +queries that is considered legacy in SQLAlchemy 2.0. Prefer using +``db.session.execute(db.select(...))`` instead. -The only things you need to know compared to plain SQLAlchemy are: +See :doc:`queries` for more information about queries. -1. The :class:`SQLAlchemy` extension instance gives you access to the - following things: - - All the functions and classes from the :mod:`sqlalchemy` and - :mod:`sqlalchemy.orm` modules. - - a preconfigured scoped session called ``session`` - - the :attr:`~SQLAlchemy.metadata` - - the :attr:`~SQLAlchemy.engine` - - a :meth:`SQLAlchemy.create_all` and :meth:`SQLAlchemy.drop_all` - methods to create and drop tables according to the models. - - a :class:`Model` baseclass that is a configured declarative base. +What to Remember +---------------- -2. The :class:`Model` declarative base class behaves like a regular - Python class but has a ``query`` attribute attached that can be used to - query the model. (:class:`Model` and :class:`Query`) +For the most part, you should use SQLAlchemy as usual. The :class:`SQLAlchemy` extension +instance creates, configures, and gives access to the following things: -3. You have to commit the session, but you don't have to remove it at - the end of the request, Flask-SQLAlchemy does that for you. +- :attr:`.SQLAlchemy.Model` declarative model base class. It sets the table + name automatically instead of needing ``__tablename__``. +- :attr:`.SQLAlchemy.session` is a session that is scoped to the current + Flask application context. It is cleaned up after every request. +- :attr:`.SQLAlchemy.metadata` and :attr:`.SQLAlchemy.metadatas` gives access to each + metadata defined in the config. +- :attr:`.SQLAlchemy.engine` and :attr:`.SQLAlchemy.engines` gives access to each + engine defined in the config. +- :meth:`.SQLAlchemy.create_all` creates all tables. +- You must be in an active Flask application context to execute queries and to access + the session and engine. diff --git a/docs/record-queries.rst b/docs/record-queries.rst new file mode 100644 index 00000000..afc43481 --- /dev/null +++ b/docs/record-queries.rst @@ -0,0 +1,27 @@ +Recording Query Information +=========================== + +.. warning:: + This feature is intended for debugging only. + +Flask-SQLAlchemy can record some information about every query that executes during a +request. This information can then be retrieved to aid in debugging performance. For +example, it can reveal that a relationship performed too many individual selects, or +reveal a query that took a long time. + +To enable this feature, set :data:`.SQLALCHEMY_RECORD_QUERIES` to ``True`` in the Flask +app config. Use :func:`.get_recorded_queries` to get a list of query info objects. Each +object has the following attributes: + +``statement`` + The string of SQL generated by SQLAlchemy with parameter placeholders. +``parameters`` + The parameters sent with the SQL statement. +``start_time`` / ``end_time`` + Timing info about when the query started execution and when the results where + returned. Accuracy and value depends on the operating system. +``duration`` + The time the query took in seconds. +``location`` + A string description of where in your application code the query was executed. This + may be unknown in certain cases. diff --git a/docs/signals.rst b/docs/signals.rst deleted file mode 100644 index 4aafac02..00000000 --- a/docs/signals.rst +++ /dev/null @@ -1,27 +0,0 @@ -Signalling Support -================== - -Connect to the following signals to get notified before and after -changes are committed to the database. Tracking changes adds significant -overhead, so it is only enabled if ``SQLALCHEMY_TRACK_MODIFICATIONS`` is -enabled in the config. In most cases, you'll probably be better served -by using `SQLAlchemy events`_ directly. - -.. _SQLAlchemy events: https://docs.sqlalchemy.org/core/event.html - -.. data:: models_committed - - This signal is sent when changed models are committed to the - database. - - The sender is the application that emitted the changes. The receiver - is passed the ``changes`` parameter with a list of tuples in the - form ``(model instance, operation)``. - - The operation is one of ``'insert'``, ``'update'``, and - ``'delete'``. - -.. data:: before_models_committed - - This signal works exactly like :data:`models_committed` but is - emitted before the commit takes place. diff --git a/docs/track-modifications.rst b/docs/track-modifications.rst new file mode 100644 index 00000000..88be7e2e --- /dev/null +++ b/docs/track-modifications.rst @@ -0,0 +1,25 @@ +Tracking Modifications +====================== + +.. warning:: + Tracking changes adds significant overhead. In most cases, you'll be better served by + using `SQLAlchemy events`_ directly. + +.. _SQLAlchemy events: https://docs.sqlalchemy.org/core/event.html + +Flask-SQLAlchemy can set up its session to track inserts, updates, and deletes for +models, then send a Blinker signal with a list of these changes either before or during +calls to ``session.flush()`` and ``session.commit()``. + +To enable this feature, set :data:`.SQLALCHEMY_TRACK_MODIFICATIONS` in the Flask app +config. Then add a listener to :data:`.models_committed` (emitted after the commit) or +:data:`.before_models_committed` (emitted before the commit). + +.. code-block:: python + + from flask_sqlalchemy.track_modifications import models_committed + + def get_modifications(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: + ... + + models_committed.connect(get_modifications) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 2e1a7f11..6f2d92a2 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -37,23 +37,23 @@ class SQLAlchemy: Accessing :attr:`session` and :attr:`engine` requires an active Flask application context. This includes methods like :meth:`create_all` which use the engine. - This class also provides access to names in SQLAlchemy's :mod:`sqlalchemy` and - :mod:`sqlalchemy.orm` modules. For example, you can use ``db.Column`` and + This class also provides access to names in SQLAlchemy's ``sqlalchemy`` and + ``sqlalchemy.orm`` modules. For example, you can use ``db.Column`` and ``db.relationship`` instead of importing ``sqlalchemy.Column`` and ``sqlalchemy.orm.relationship``. This can be convenient when defining models. :param app: Call :meth:`init_app` on this Flask application now. - :param metadata: Use this as the default :class:sqlalchemy.MetaData`. Useful for - setting a naming convention. - :param session_options: Arguments used by :attr:`db.session` to create each session - instance. A ``scopefunc`` key will be passed to :attr:`db.session`, not the + :param metadata: Use this as the default :class:`sqlalchemy.schema.MetaData`. Useful + for setting a naming convention. + :param session_options: Arguments used by :attr:`session` to create each session + instance. A ``scopefunc`` key will be passed to the scoped session, not the session instance. See :class:`sqlalchemy.orm.sessionmaker` for a list of arguments. :param query_class: Use this as the default query class for models and dynamic relationships. The query interface is considered legacy in SQLAlchemy 2.0. :param model_class: Use this as the model base class when creating the declarative - model class :attr:`db.Model`. Can also be a fully created declarative model - class for further customization. + model class :attr:`Model`. Can also be a fully created declarative model class + for further customization. :param engine_options: Default arguments used when creating every engine. These are lower precedence than application config. See :func:`sqlalchemy.create_engine` for a list of arguments. @@ -135,7 +135,7 @@ def __init__( """ self.session = self._make_scoped_session(session_options) - """A :class:`sqlalchemy.orm.scoped_session` that creates instances of + """A :class:`sqlalchemy.orm.scoping.scoped_session` that creates instances of :class:`.Session` scoped to the current Flask application context. The session will be removed, returning the engine connection to the pool, when the application context exits. @@ -149,8 +149,9 @@ def __init__( """ self.metadatas: dict[str | None, sa.MetaData] = {} - """Map of bind keys to :class:`sqlalchemy.MetaData` instances. The ``None`` key - refers to the default metadata, and is available as :attr:`metadata`. + """Map of bind keys to :class:`sqlalchemy.schema.MetaData` instances. The + ``None`` key refers to the default metadata, and is available as + :attr:`metadata`. Customize the default metadata by passing the ``metadata`` parameter to the extension. This can be used to set a naming convention. When metadata for @@ -164,7 +165,8 @@ def __init__( self.metadatas[None] = metadata self.Table = self._make_table_class() - """A :class:`sqlalchemy.Table` class that chooses a metadata automatically. + """A :class:`sqlalchemy.schema.Table` class that chooses a metadata + automatically. Unlike the base ``Table``, the ``metadata`` argument is not required. If it is not given, it is selected based on the ``bind_key`` argument. @@ -318,8 +320,8 @@ def init_app(self, app: Flask) -> None: track_modifications._listen(self.session) def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_session: - """Create a :class:`sqlalchemy.orm.scoped_session` around the factory from - :meth:`_make_session_factory`. The result is available as :attr:`session`. + """Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory + from :meth:`_make_session_factory`. The result is available as :attr:`session`. The scope function can be customized using the ``scopefunc`` key in the ``session_options`` parameter to the extension. By default it uses the current @@ -393,7 +395,7 @@ def _teardown_session(self, exc: BaseException | None) -> None: self.session.remove() def _make_metadata(self, bind_key: str | None) -> sa.MetaData: - """Get or create a :class:`sqlalchemy.MetaData` for the given bind key. + """Get or create a :class:`sqlalchemy.schema.MetaData` for the given bind key. This method is used for internal setup. Its signature may change at any time. @@ -420,8 +422,8 @@ def _make_metadata(self, bind_key: str | None) -> sa.MetaData: return metadata def _make_table_class(self) -> t.Type[sa.Table]: - """Create a SQLAlchemy :class:`sqlalchemy.Table` class that chooses a metadata - automatically based on the ``bind_key``. The result is available as + """Create a SQLAlchemy :class:`sqlalchemy.schema.Table` class that chooses a + metadata automatically based on the ``bind_key``. The result is available as :attr:`Table`. This method is used for internal setup. Its signature may change at any time. @@ -498,8 +500,8 @@ def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: """Apply driver-specific configuration to an engine. SQLite in-memory databases use ``StaticPool`` and disable ``check_same_thread``. - File paths are relative to the app's :func:`~flask.Flask.instance_path`, which - is created if it doesn't exist. + File paths are relative to the app's :attr:`~flask.Flask.instance_path`, + which is created if it doesn't exist. MySQL sets ``charset="utf8mb4"``, and ``pool_timeout`` defaults to 2 hours. diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index ba6124b6..cd4bf864 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -33,7 +33,7 @@ def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query: class Model: - """The base class of the :class:`.SQLAlchemy.Model` declarative model class. + """The base class of the :attr:`.SQLAlchemy.Model` declarative model class. To define models, subclass :attr:`db.Model <.SQLAlchemy.Model>`, not this. To customize ``db.Model``, subclass this and pass it as ``model_class`` to diff --git a/src/flask_sqlalchemy/pagination.py b/src/flask_sqlalchemy/pagination.py index 1c51ab7d..3d51cade 100644 --- a/src/flask_sqlalchemy/pagination.py +++ b/src/flask_sqlalchemy/pagination.py @@ -35,7 +35,9 @@ def __init__( items: list[t.Any], ) -> None: self.query = query - """The original query that was paginated.""" + """The original query that was paginated. This is used to produce :meth:`next` + and :meth:`prev` pages. + """ self.page = page """The current page.""" @@ -231,26 +233,6 @@ def iter_pages( 1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20 - The following Jinja macro renders a simple pagination widget. - - .. code-block:: jinja - - {% macro render_pagination(pagination, endpoint) %} - - {% endmacro %} - :param left_edge: How many pages to show from the first page. :param left_current: How many pages to show left of the current page. :param right_current: How many pages to show right of the current page. diff --git a/src/flask_sqlalchemy/query.py b/src/flask_sqlalchemy/query.py index a1632c60..6daa6bf5 100644 --- a/src/flask_sqlalchemy/query.py +++ b/src/flask_sqlalchemy/query.py @@ -21,8 +21,8 @@ class Query(sa.orm.Query): # type: ignore[type-arg] """ def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any: - """Like :meth:`get` but aborts with a ``404 Not Found`` error instead of - returning ``None``. + """Like :meth:`~sqlalchemy.orm.Query.get` but aborts with a ``404 Not Found`` + error instead of returning ``None``. :param ident: The primary key to query. :param description: A custom message to show on the error page. @@ -35,8 +35,8 @@ def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any: return rv def first_or_404(self, description: str | None = None) -> t.Any: - """Like :meth:`first` but aborts with a ``404 Not Found`` error instead of - returning ``None``. + """Like :meth:`~sqlalchemy.orm.Query.first` but aborts with a ``404 Not Found`` + error instead of returning ``None``. :param description: A custom message to show on the error page. """ @@ -48,8 +48,8 @@ def first_or_404(self, description: str | None = None) -> t.Any: return rv def one_or_404(self, description: str | None = None) -> t.Any: - """Like :meth:`one` but aborts with a ``404 Not Found`` error instead of raising - ``NoResultFound`` or ``MultipleResultsFound``. + """Like :meth:`~sqlalchemy.orm.Query.one` but aborts with a ``404 Not Found`` + error instead of raising ``NoResultFound`` or ``MultipleResultsFound``. :param description: A custom message to show on the error page. diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index d50492be..6404c179 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -15,14 +15,14 @@ def get_recorded_queries() -> list[_QueryInfo]: """Get the list of recorded query information for the current session. Queries are recorded if the app is in debug or testing mode, or if the config - :data:`SQLALCHEMY_RECORD_QUERIES` is enabled. + :data:`.SQLALCHEMY_RECORD_QUERIES` is enabled. Each query info object has the following attributes: ``statement`` The string of SQL generated by SQLAlchemy with parameter placeholders. ``parameters`` - The parameters sent with the SQL statment. + The parameters sent with the SQL statement. ``start_time`` / ``end_time`` Timing info about when the query started execution and when the results where returned. Accuracy and value depends on the operating system. diff --git a/src/flask_sqlalchemy/track_modifications.py b/src/flask_sqlalchemy/track_modifications.py index 8c9f38f5..fac5e411 100644 --- a/src/flask_sqlalchemy/track_modifications.py +++ b/src/flask_sqlalchemy/track_modifications.py @@ -13,8 +13,20 @@ from .session import Session _signals = Namespace() + models_committed = _signals.signal("models-committed") +"""This Blinker signal is sent after the session is committed if there were changed +models in the session. + +The sender is the application that emitted the changes. The receiver is passed the +``changes`` argument with a list of tuples in the form ``(instance, operation)``. +The operations are ``"insert"``, ``"update"``, and ``"delete"``. +""" + before_models_committed = _signals.signal("before-models-committed") +"""This signal works exactly like :data:`models_committed` but is emitted before the +commit takes place. +""" def _listen(session: sa.orm.scoped_session) -> None: