diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index 8f1cdfae..b98af996 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -301,7 +301,7 @@ def get_debug_queries(): class Pagination(object): - """Internal helper class returned by :meth:`BaseQuery.paginate`. You + """Internal helper class returned by :meth:`paginate`. You can also construct it from any other SQLAlchemy query object if you are working with other libraries. Additionally it is possible to pass `None` as query object in which case the :meth:`prev` and :meth:`next` will @@ -334,7 +334,7 @@ def prev(self, error_out=False): """Returns a :class:`Pagination` object for the previous page.""" assert self.query is not None, 'a query object is required ' \ 'for this method to work' - return self.query.paginate(self.page - 1, self.per_page, error_out) + return paginate(self.query, self.page - 1, self.per_page, error_out) @property def prev_num(self): @@ -350,7 +350,7 @@ def next(self, error_out=False): """Returns a :class:`Pagination` object for the next page.""" assert self.query is not None, 'a query object is required ' \ 'for this method to work' - return self.query.paginate(self.page + 1, self.per_page, error_out) + return paginate(self.query, self.page + 1, self.per_page, error_out) @property def has_next(self): @@ -427,26 +427,33 @@ def first_or_404(self): return rv def paginate(self, page, per_page=20, error_out=True): - """Returns `per_page` items from page `page`. By default it will - abort with 404 if no items were found and the page was larger than - 1. This behavor can be disabled by setting `error_out` to `False`. + """:class:`BaseQuery` wrapper for :func:`sqlalchemy.paginate` Returns an :class:`Pagination` object. """ - if error_out and page < 1: - abort(404) - items = self.limit(per_page).offset((page - 1) * per_page).all() - if not items and page != 1 and error_out: - abort(404) + return paginate(self, page, per_page, error_out) - # No need to count if we're on the first page and there are fewer - # items than we expected. - if page == 1 and len(items) < per_page: - total = len(items) - else: - total = self.order_by(None).count() +def paginate(query, page, per_page=20, error_out=True): + """Returns `per_page` items from page `page`. By default it will + abort with 404 if no items were found and the page was larger than + 1. This behavor can be disabled by setting `error_out` to `False`. - return Pagination(self, page, per_page, total, items) + Returns an :class:`Pagination` object. + """ + if error_out and page < 1: + abort(404) + items = query.limit(per_page).offset((page - 1) * per_page).all() + if not items and page != 1 and error_out: + abort(404) + + # No need to count if we're on the first page and there are fewer + # items than we expected. + if page == 1 and len(items) < per_page: + total = len(items) + else: + total = query.order_by(None).count() + + return Pagination(query, page, per_page, total, items) class _QueryProperty(object): diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index b545b7eb..0163df04 100644 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -295,6 +295,63 @@ def test_pagination_pages_when_0_items_per_page(self): self.assertEqual(p.pages, 0) +class MockPaginationTestCase(unittest.TestCase): + + def setUp(self): + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + db = sqlalchemy.SQLAlchemy(app) + + self.book_name = "To Kill a Mockingbird" + + class Book(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(10)) + chapter = db.Column(db.Integer) + + def __init__(self, name, chapter): + self.name = name + self.chapter = chapter + + db.create_all() + + for index in range(1, 10): + db.session.add(Book(self.book_name, index)) + db.session.commit() + + self.db = db + self.Book = Book + + def tearDown(self): + self.db.drop_all() + + def test_base_query_paginate(self): + book = self.Book.query \ + .filter(self.Book.name==self.book_name) \ + .paginate(5, 1) + + self.assertEqual(book.page, 5) + self.assertTrue(book.has_prev) + self.assertTrue(book.has_next) + self.assertEqual(book.total, 9) + self.assertEqual(book.pages, 9) + self.assertEqual(book.next_num, 6) + + def test_sqlalchemy_query_paginate(self): + book = self.db \ + .session \ + .query(self.Book) \ + .filter(self.Book.name==self.book_name) + p = sqlalchemy.paginate(book, 5, 1) + + self.assertEqual(p.page, 5) + self.assertTrue(p.has_prev) + self.assertTrue(p.has_next) + self.assertEqual(p.total, 9) + self.assertEqual(p.pages, 9) + self.assertEqual(p.next_num, 6) + class BindsTestCase(unittest.TestCase): def test_basic_binds(self): @@ -596,6 +653,7 @@ def suite(): suite.addTest(unittest.makeSuite(TestQueryProperty)) suite.addTest(unittest.makeSuite(TablenameTestCase)) suite.addTest(unittest.makeSuite(PaginationTestCase)) + suite.addTest(unittest.makeSuite(MockPaginationTestCase)) suite.addTest(unittest.makeSuite(BindsTestCase)) suite.addTest(unittest.makeSuite(DefaultQueryClassTestCase)) suite.addTest(unittest.makeSuite(SQLAlchemyIncludesTestCase))