Skip to content
26 changes: 26 additions & 0 deletions docs/custom_json_encoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Custom Json Encoder

flask-mongoengine have option to add custom encoder for flask
By this way you can handel encoding special object

Examples:

```python
from flask_mongoengine.json import MongoEngineJSONProvider
class CustomJSONEncoder(MongoEngineJSONProvider):
@staticmethod
def default(obj):
if isinstance(obj, set):
return list(obj)
if isinstance(obj, Decimal128):
return str(obj)
return MongoEngineJSONProvider.default(obj)


# Tell your flask app to use your customised JSON encoder


app.json_provider_class = CustomJSONEncoder
app.json = app.json_provider_class(app)

```
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ You can also use `WTForms <http://wtforms.simplecodes.com/>`_ as model forms for
forms
migration_to_v2
custom_queryset
custom_json_encoder
wtf_forms
session_interface
debug_toolbar
Expand Down
6 changes: 6 additions & 0 deletions flask_mongoengine/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from werkzeug.exceptions import HTTPException


class InvalidPage(HTTPException):
code = 404
description = "Invalid page number."
2 changes: 1 addition & 1 deletion flask_mongoengine/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def default(obj):
(BaseDocument, QuerySet, CommandCursor, DBRef, ObjectId),
):
return _convert_mongo_objects(obj)
return super().default(obj)
return superclass.default(obj)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What? Where it from?


return MongoEngineJSONProvider

Expand Down
21 changes: 11 additions & 10 deletions flask_mongoengine/pagination.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
"""Module responsible for custom pagination."""
import math

from flask import abort
from mongoengine.queryset import QuerySet

from flask_mongoengine.exceptions import InvalidPage

__all__ = ("Pagination", "ListFieldPagination")


class Pagination(object):
def __init__(self, iterable, page, per_page):
def __init__(self, iterable, page, per_page, max_depth=None):

if page < 1:
abort(404)
raise InvalidPage

self.iterable = iterable
self.page = page
self.per_page = per_page

if isinstance(self.iterable, QuerySet):
self.total = iterable.count()
self.items = (
self.iterable.skip(self.per_page * (self.page - 1))
.limit(self.per_page)
.select_related()
self.items = self.iterable.skip(self.per_page * (self.page - 1)).limit(
self.per_page
)
if max_depth is not None:
self.items = self.items.select_related(max_depth)
else:
start_index = (page - 1) * per_page
end_index = page * per_page

self.total = len(iterable)
self.items = iterable[start_index:end_index]
if not self.items and page != 1:
abort(404)
raise InvalidPage

@property
def pages(self):
Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(self, queryset, doc_id, field_name, page, per_page, total=None):
elsewhere, but we still use array.length as a fallback.
"""
if page < 1:
abort(404)
raise InvalidPage

self.page = page
self.per_page = per_page
Expand All @@ -153,7 +154,7 @@ def __init__(self, queryset, doc_id, field_name, page, per_page, total=None):
)

if not self.items and page != 1:
abort(404)
raise InvalidPage

def prev(self, error_out=False):
"""Returns a :class:`Pagination` object for the previous page."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extended_db(app):
test_db.connection["default"].drop_database(db_name)


class DummyEncoder(flask.json.JSONEncoder):
class DummyEncoder(flask.json._json.JSONEncoder):
"""
An example encoder which a user may create and override
the apps json_encoder with.
Expand Down
41 changes: 33 additions & 8 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
import flask
import pytest
from werkzeug.exceptions import NotFound

from flask_mongoengine import ListFieldPagination, Pagination
from flask_mongoengine.exceptions import InvalidPage


def test_queryset_paginator(app, todo):
@pytest.fixture(autouse=True)
def setup_endpoints(app, todo):
Todo = todo
for i in range(42):
Todo(title=f"post: {i}").save()

with pytest.raises(NotFound):
@app.route("/")
def index():
page = int(flask.request.form.get("page"))
per_page = int(flask.request.form.get("per_page"))
return Pagination(iterable=Todo.objects, page=page, per_page=per_page)


def test_queryset_paginator(app, todo):
Todo = todo

with pytest.raises(InvalidPage):
Pagination(iterable=Todo.objects, page=0, per_page=10)

with pytest.raises(NotFound):
with pytest.raises(InvalidPage):
Pagination(iterable=Todo.objects, page=6, per_page=10)

paginator = Pagination(Todo.objects, 1, 10)
Expand All @@ -26,10 +38,10 @@ def test_queryset_paginator(app, todo):


def test_paginate_plain_list():
with pytest.raises(NotFound):
with pytest.raises(InvalidPage):
Pagination(iterable=range(1, 42), page=0, per_page=10)

with pytest.raises(NotFound):
with pytest.raises(InvalidPage):
Pagination(iterable=range(1, 42), page=6, per_page=10)

paginator = Pagination(range(1, 42), 1, 10)
Expand Down Expand Up @@ -68,14 +80,14 @@ def _test_paginator(paginator):

if i == 1:
assert not paginator.has_prev
with pytest.raises(NotFound):
with pytest.raises(InvalidPage):
paginator.prev()
else:
assert paginator.has_prev

if i == 5:
assert not paginator.has_next
with pytest.raises(NotFound):
with pytest.raises(InvalidPage):
paginator.next()
else:
assert paginator.has_next
Expand All @@ -90,3 +102,16 @@ def _test_paginator(paginator):
# Paginate to the next page
if i < 5:
paginator = paginator.next()


def test_flask_pagination(app, todo):
client = app.test_client()
response = client.get(f"/", data={"page": 0, "per_page": 10})
print(response.status_code)
assert response.status_code == 404
assert "Invalid page number" in response.text

response = client.get(f"/", data={"page": 6, "per_page": 10})
print(response.status_code)
assert response.status_code == 404
assert "Invalid page number" in response.text