Skip to content
Closed
26 changes: 26 additions & 0 deletions docs/custom_json_encoder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Custom Json Encoder

flask-mongoengine have option to add custom encoder for flask
By this way you can handle encoding special object

Examples:

```python
from flask_mongoengine.json import MongoEngineJSONProvider
class CustomJSONEncoder(MongoEngineJSONProvider):
@staticmethod
def default(obj):
if isinstance(obj, set):
return list(obj)
if isinstance(obj, Decimal128):
return str(obj)
return MongoEngineJSONProvider.default(obj)


# Tell your flask app to use your customised JSON encoder


app.json_provider_class = CustomJSONEncoder
app.json = app.json_provider_class(app)

```
3 changes: 2 additions & 1 deletion docs/custom_queryset.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ flask-mongoengine attaches the following methods to Mongoengine's default QueryS
Optional arguments: *message* - custom message to display.
* **first_or_404**: same as above, except for .first().
Optional arguments: *message* - custom message to display.
* **paginate**: paginates the QuerySet. Takes two arguments, *page* and *per_page*.
* **paginate**: paginates the QuerySet. Takes two required arguments, *page* and *per_page*.
And one optional arguments *max_depth*.
* **paginate_field**: paginates a field from one document in the QuerySet.
Arguments: *field_name*, *doc_id*, *page*, *per_page*.

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ You can also use `WTForms <http://wtforms.simplecodes.com/>`_ as model forms for
forms
migration_to_v2
custom_queryset
custom_json_encoder
wtf_forms
session_interface
debug_toolbar
Expand Down
2 changes: 1 addition & 1 deletion flask_mongoengine/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def default(obj):
(BaseDocument, QuerySet, CommandCursor, DBRef, ObjectId),
):
return _convert_mongo_objects(obj)
return super().default(obj)
return superclass.default(obj)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What? Where it from?


return MongoEngineJSONProvider

Expand Down
18 changes: 13 additions & 5 deletions flask_mongoengine/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@


class Pagination(object):
def __init__(self, iterable, page, per_page):
def __init__(self, iterable, page: int, per_page: int, max_depth: int = None):
"""
:param iterable: iterable object .
:param page: Required page number start from 1.
:param per_page: Required number of documents per page.
:param max_depth: Option for limit number of dereference documents.


"""

if page < 1:
abort(404)
Expand All @@ -19,11 +27,11 @@ def __init__(self, iterable, page, per_page):

if isinstance(self.iterable, QuerySet):
self.total = iterable.count()
self.items = (
self.iterable.skip(self.per_page * (self.page - 1))
.limit(self.per_page)
.select_related()
self.items = self.iterable.skip(self.per_page * (self.page - 1)).limit(
self.per_page
)
if max_depth is not None:
self.items = self.items.select_related(max_depth)
else:
start_index = (page - 1) * per_page
end_index = page * per_page
Expand Down
6 changes: 3 additions & 3 deletions flask_mongoengine/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_expiration_time(self, app, session) -> timedelta:
return timedelta(**app.config.get("SESSION_TTL", {"days": 1}))

def open_session(self, app, request):
sid = request.cookies.get(app.session_cookie_name)
sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"])
if sid:
stored_session = self.cls.objects(sid=sid).first()

Expand All @@ -81,7 +81,7 @@ def save_session(self, app, session, response):
# If the session is empty, return without setting the cookie.
if not session:
if session.modified:
response.delete_cookie(app.session_cookie_name, domain=domain)
response.delete_cookie(app.config["SESSION_COOKIE_NAME"], domain=domain)
return

expiration = datetime.utcnow().replace(tzinfo=utc) + self.get_expiration_time(
Expand All @@ -92,7 +92,7 @@ def save_session(self, app, session, response):
self.cls(sid=session.sid, data=session, expiration=expiration).save()

response.set_cookie(
app.session_cookie_name,
app.config["SESSION_COOKIE_NAME"],
session.sid,
expires=expiration,
httponly=httponly,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extended_db(app):
test_db.connection["default"].drop_database(db_name)


class DummyEncoder(flask.json.JSONEncoder):
class DummyEncoder(flask.json._json.JSONEncoder):
"""
An example encoder which a user may create and override
the apps json_encoder with.
Expand Down
41 changes: 39 additions & 2 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
import flask
import pytest
from werkzeug.exceptions import NotFound

from flask_mongoengine import ListFieldPagination, Pagination


def test_queryset_paginator(app, todo):
@pytest.fixture(autouse=True)
def setup_endpoints(app, todo):
Todo = todo
for i in range(42):
Todo(title=f"post: {i}").save()

@app.route("/")
def index():
page = int(flask.request.form.get("page"))
per_page = int(flask.request.form.get("per_page"))
query_set = Todo.objects().paginate(page=page, per_page=per_page)
return {'data': [_ for _ in query_set.items],
'total': query_set.total,
'has_next': query_set.has_next,
}


def test_queryset_paginator(app, todo):
Todo = todo

with pytest.raises(NotFound):
Pagination(iterable=Todo.objects, page=0, per_page=10)

Expand Down Expand Up @@ -90,3 +105,25 @@ def _test_paginator(paginator):
# Paginate to the next page
if i < 5:
paginator = paginator.next()


def test_flask_pagination(app, todo):
client = app.test_client()
response = client.get(f"/", data={"page": 0, "per_page": 10})
print(response.status_code)
assert response.status_code == 404

response = client.get(f"/", data={"page": 6, "per_page": 10})
print(response.status_code)
assert response.status_code == 404


def test_flask_pagination_next(app, todo):
client = app.test_client()
has_next = True
page = 1
while has_next:
response = client.get(f"/", data={"page": page, "per_page": 10})
assert response.status_code == 200
has_next = response.json['has_next']
page += 1