Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_debug_queries():


class Pagination(object):
"""Internal helper class returned by :meth:`BaseQuery.paginate`. You
"""Internal helper class returned by :meth:`paginate`. You
can also construct it from any other SQLAlchemy query object if you are
working with other libraries. Additionally it is possible to pass `None`
as query object in which case the :meth:`prev` and :meth:`next` will
Expand Down Expand Up @@ -334,7 +334,7 @@ def prev(self, error_out=False):
"""Returns a :class:`Pagination` object for the previous page."""
assert self.query is not None, 'a query object is required ' \
'for this method to work'
return self.query.paginate(self.page - 1, self.per_page, error_out)
return paginate(self.query, self.page - 1, self.per_page, error_out)

@property
def prev_num(self):
Expand All @@ -350,7 +350,7 @@ def next(self, error_out=False):
"""Returns a :class:`Pagination` object for the next page."""
assert self.query is not None, 'a query object is required ' \
'for this method to work'
return self.query.paginate(self.page + 1, self.per_page, error_out)
return paginate(self.query, self.page + 1, self.per_page, error_out)

@property
def has_next(self):
Expand Down Expand Up @@ -427,26 +427,33 @@ def first_or_404(self):
return rv

def paginate(self, page, per_page=20, error_out=True):
"""Returns `per_page` items from page `page`. By default it will
abort with 404 if no items were found and the page was larger than
1. This behavor can be disabled by setting `error_out` to `False`.
""":class:`BaseQuery` wrapper for :func:`sqlalchemy.paginate`

Returns an :class:`Pagination` object.
"""
if error_out and page < 1:
abort(404)
items = self.limit(per_page).offset((page - 1) * per_page).all()
if not items and page != 1 and error_out:
abort(404)
return paginate(self, page, per_page, error_out)

# No need to count if we're on the first page and there are fewer
# items than we expected.
if page == 1 and len(items) < per_page:
total = len(items)
else:
total = self.order_by(None).count()
def paginate(query, page, per_page=20, error_out=True):
"""Returns `per_page` items from page `page`. By default it will
abort with 404 if no items were found and the page was larger than
1. This behavor can be disabled by setting `error_out` to `False`.

return Pagination(self, page, per_page, total, items)
Returns an :class:`Pagination` object.
"""
if error_out and page < 1:
abort(404)
items = query.limit(per_page).offset((page - 1) * per_page).all()
if not items and page != 1 and error_out:
abort(404)

# No need to count if we're on the first page and there are fewer
# items than we expected.
if page == 1 and len(items) < per_page:
total = len(items)
else:
total = query.order_by(None).count()

return Pagination(query, page, per_page, total, items)


class _QueryProperty(object):
Expand Down
58 changes: 58 additions & 0 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,63 @@ def test_pagination_pages_when_0_items_per_page(self):
self.assertEqual(p.pages, 0)


class MockPaginationTestCase(unittest.TestCase):

def setUp(self):
app = flask.Flask(__name__)
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
app.config['TESTING'] = True
db = sqlalchemy.SQLAlchemy(app)

self.book_name = "To Kill a Mockingbird"

class Book(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(10))
chapter = db.Column(db.Integer)

def __init__(self, name, chapter):
self.name = name
self.chapter = chapter

db.create_all()

for index in range(1, 10):
db.session.add(Book(self.book_name, index))
db.session.commit()

self.db = db
self.Book = Book

def tearDown(self):
self.db.drop_all()

def test_base_query_paginate(self):
book = self.Book.query \
.filter(self.Book.name==self.book_name) \
.paginate(5, 1)

self.assertEqual(book.page, 5)
self.assertTrue(book.has_prev)
self.assertTrue(book.has_next)
self.assertEqual(book.total, 9)
self.assertEqual(book.pages, 9)
self.assertEqual(book.next_num, 6)

def test_sqlalchemy_query_paginate(self):
book = self.db \
.session \
.query(self.Book) \
.filter(self.Book.name==self.book_name)
p = sqlalchemy.paginate(book, 5, 1)

self.assertEqual(p.page, 5)
self.assertTrue(p.has_prev)
self.assertTrue(p.has_next)
self.assertEqual(p.total, 9)
self.assertEqual(p.pages, 9)
self.assertEqual(p.next_num, 6)

class BindsTestCase(unittest.TestCase):

def test_basic_binds(self):
Expand Down Expand Up @@ -596,6 +653,7 @@ def suite():
suite.addTest(unittest.makeSuite(TestQueryProperty))
suite.addTest(unittest.makeSuite(TablenameTestCase))
suite.addTest(unittest.makeSuite(PaginationTestCase))
suite.addTest(unittest.makeSuite(MockPaginationTestCase))
suite.addTest(unittest.makeSuite(BindsTestCase))
suite.addTest(unittest.makeSuite(DefaultQueryClassTestCase))
suite.addTest(unittest.makeSuite(SQLAlchemyIncludesTestCase))
Expand Down