From 28dd33d3c9e87b80c6c9e185034f0a758604845f Mon Sep 17 00:00:00 2001 From: Seb Date: Fri, 2 Dec 2022 18:29:49 +1100 Subject: [PATCH 1/5] Fixes #227 run all middleware passed through decorators --- example/home/factories.py | 6 +++++ .../home/migrations/0006_middleware_model.py | 25 +++++++++++++++++++ example/home/models.py | 18 +++++++++++++ example/home/test/test_general.py | 21 ++++++++++++++++ grapple/middleware.py | 5 ++-- 5 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 example/home/migrations/0006_middleware_model.py diff --git a/example/home/factories.py b/example/home/factories.py index 6c3ee044..98c0b737 100644 --- a/example/home/factories.py +++ b/example/home/factories.py @@ -15,6 +15,7 @@ AuthorPage, BlogPage, BlogPageRelatedLink, + MiddlewareModel, Person, SimpleModel, ) @@ -163,3 +164,8 @@ class Meta: class SimpleModelFactory(factory.django.DjangoModelFactory): class Meta: model = SimpleModel + + +class MiddlewareModelFactory(factory.django.DjangoModelFactory): + class Meta: + model = MiddlewareModel diff --git a/example/home/migrations/0006_middleware_model.py b/example/home/migrations/0006_middleware_model.py new file mode 100644 index 00000000..1470d803 --- /dev/null +++ b/example/home/migrations/0006_middleware_model.py @@ -0,0 +1,25 @@ +# Generated by Django 3.2.16 on 2022-12-02 05:10 + +from django.db import migrations, models +import wagtail.blocks +import wagtail.embeds.blocks +import wagtail.fields +import wagtail.images.blocks +import wagtail.snippets.blocks + + +class Migration(migrations.Migration): + + dependencies = [ + ('home', '0005_auto_20220909_0959'), + ] + + operations = [ + migrations.CreateModel( + name='MiddlewareModel', + fields=[ + ('id', models.AutoField(auto_created=True, + primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + ] diff --git a/example/home/models.py b/example/home/models.py index 80f7d19b..df13f9ce 100644 --- a/example/home/models.py +++ b/example/home/models.py @@ -50,6 +50,24 @@ class SimpleModel(models.Model): pass +def custom_middleware_one(next, root, info, **args): + info.context.custom_middleware_one = True + return next(root, info, **args) + + +def custom_middleware_two(next, root, info, **args): + if not info.context.custom_middleware_one: + raise Exception("Middleware one should have been called") + if args['id'] == 2: + return None + return next(root, info, **args) + + +@register_query_field('middlewareModel', middleware=[custom_middleware_one, custom_middleware_two]) +class MiddlewareModel(models.Model): + pass + + class HomePage(Page): pass diff --git a/example/home/test/test_general.py b/example/home/test/test_general.py index 8dd9f143..f5a9328c 100644 --- a/example/home/test/test_general.py +++ b/example/home/test/test_general.py @@ -2,7 +2,11 @@ from django.contrib.auth.models import AnonymousUser from django.test import RequestFactory, override_settings +<<<<<<< HEAD from home.factories import AdvertFactory, BlogPageFactory, SimpleModelFactory +======= +from home.factories import BlogPageFactory, SimpleModelFactory, MiddlewareModelFactory +>>>>>>> e8bcc5a (Fixes #227 run all middleware passed through decorators) from example.tests.test_grapple import BaseGrappleTest @@ -80,6 +84,7 @@ def setUp(self): self.blog_post = BlogPageFactory(parent=self.home, slug="post-one") self.another_post = BlogPageFactory(parent=self.home, slug="post-two") self.child_post = BlogPageFactory(parent=self.another_post, slug="post-one") + self.middleware_instance = MiddlewareModelFactory() def test_query_field_plural(self): query = """ @@ -141,6 +146,22 @@ def test_query_field(self): data = results["data"]["post"] self.assertEqual(int(data["id"]), self.another_post.id) + def test_multiple_middleware(self): + query = """ + query ($id: Int) { + middlewareModel(id: $id) { + id + } + } + """ + results = self.client.execute( + query, variables={'id': 1}, context_value=self.request) + # Check that both middleware ran ok, value returned means the assert passed in middleware_2 + self.assertEqual(int(results["data"]["middlewareModel"]["id"]), 1) + results = self.client.execute( + query, variables={'id': 2}, context_value=self.request) + # Check that the second middleware failed when id = 2 + self.assertEqual(results["data"]["middlewareModel"], None) class TestRegisterPaginatedQueryField(BaseGrappleTest): def setUp(self): diff --git a/grapple/middleware.py b/grapple/middleware.py index 203f4679..bf138580 100644 --- a/grapple/middleware.py +++ b/grapple/middleware.py @@ -38,6 +38,7 @@ def resolve(self, next, root, info: ResolveInfo, **kwargs): parent_name = info.parent_type.name if field_name in self.field_middlewares and parent_name in ROOT_TYPES: for middleware in self.field_middlewares[field_name]: - return middleware(next, root, info, **kwargs) - + response = middleware(next, root, info, **args) + if not response: + return None return next(root, info, **kwargs) From d73550df9fc1c8c15dfb5653a8b3d48f40889608 Mon Sep 17 00:00:00 2001 From: Seb Date: Tue, 6 Dec 2022 16:43:27 +1100 Subject: [PATCH 2/5] use partials for the middleware, isort fix --- example/home/test/test_general.py | 20 +++++++++++++------- grapple/middleware.py | 12 +++++++----- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/example/home/test/test_general.py b/example/home/test/test_general.py index f5a9328c..d68b8fd9 100644 --- a/example/home/test/test_general.py +++ b/example/home/test/test_general.py @@ -2,11 +2,14 @@ from django.contrib.auth.models import AnonymousUser from django.test import RequestFactory, override_settings -<<<<<<< HEAD -from home.factories import AdvertFactory, BlogPageFactory, SimpleModelFactory -======= -from home.factories import BlogPageFactory, SimpleModelFactory, MiddlewareModelFactory ->>>>>>> e8bcc5a (Fixes #227 run all middleware passed through decorators) +from home.factories import ( + AdvertFactory, + BlogPageFactory, + MiddlewareModelFactory, + SimpleModelFactory, +) + +from home.factories import BlogPageFactory, MiddlewareModelFactory, SimpleModelFactory from example.tests.test_grapple import BaseGrappleTest @@ -155,14 +158,17 @@ def test_multiple_middleware(self): } """ results = self.client.execute( - query, variables={'id': 1}, context_value=self.request) + query, variables={"id": 1}, context_value=self.request + ) # Check that both middleware ran ok, value returned means the assert passed in middleware_2 self.assertEqual(int(results["data"]["middlewareModel"]["id"]), 1) results = self.client.execute( - query, variables={'id': 2}, context_value=self.request) + query, variables={"id": 2}, context_value=self.request + ) # Check that the second middleware failed when id = 2 self.assertEqual(results["data"]["middlewareModel"], None) + class TestRegisterPaginatedQueryField(BaseGrappleTest): def setUp(self): super().setUp() diff --git a/grapple/middleware.py b/grapple/middleware.py index bf138580..19a8c050 100644 --- a/grapple/middleware.py +++ b/grapple/middleware.py @@ -1,3 +1,5 @@ +from functools import partial + from graphene import ResolveInfo from graphql.execution.middleware import get_middleware_resolvers @@ -37,8 +39,8 @@ def resolve(self, next, root, info: ResolveInfo, **kwargs): field_name = info.field_name parent_name = info.parent_type.name if field_name in self.field_middlewares and parent_name in ROOT_TYPES: - for middleware in self.field_middlewares[field_name]: - response = middleware(next, root, info, **args) - if not response: - return None - return next(root, info, **kwargs) + middlewares = self.field_middlewares[field_name].copy() + while middlewares: + next = partial(middlewares.pop(), next) + + return next(root, info, **args) From 39d2eb55bb39624c299829086f519d78bde6401c Mon Sep 17 00:00:00 2001 From: Seb Date: Tue, 6 Dec 2022 16:44:09 +1100 Subject: [PATCH 3/5] flake8 fix --- example/home/migrations/0006_middleware_model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/example/home/migrations/0006_middleware_model.py b/example/home/migrations/0006_middleware_model.py index 1470d803..37b7d3eb 100644 --- a/example/home/migrations/0006_middleware_model.py +++ b/example/home/migrations/0006_middleware_model.py @@ -1,11 +1,6 @@ # Generated by Django 3.2.16 on 2022-12-02 05:10 from django.db import migrations, models -import wagtail.blocks -import wagtail.embeds.blocks -import wagtail.fields -import wagtail.images.blocks -import wagtail.snippets.blocks class Migration(migrations.Migration): From f88b31e90e72644ec54334ed9305667890bcb394 Mon Sep 17 00:00:00 2001 From: Seb Date: Wed, 7 Dec 2022 11:58:28 +1100 Subject: [PATCH 4/5] aaand black --- example/home/migrations/0006_middleware_model.py | 15 +++++++++++---- example/home/models.py | 6 ++++-- example/home/test/test_general.py | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/example/home/migrations/0006_middleware_model.py b/example/home/migrations/0006_middleware_model.py index 37b7d3eb..c4926874 100644 --- a/example/home/migrations/0006_middleware_model.py +++ b/example/home/migrations/0006_middleware_model.py @@ -6,15 +6,22 @@ class Migration(migrations.Migration): dependencies = [ - ('home', '0005_auto_20220909_0959'), + ("home", "0005_auto_20220909_0959"), ] operations = [ migrations.CreateModel( - name='MiddlewareModel', + name="MiddlewareModel", fields=[ - ('id', models.AutoField(auto_created=True, - primary_key=True, serialize=False, verbose_name='ID')), + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), ], ), ] diff --git a/example/home/models.py b/example/home/models.py index df13f9ce..313cfc85 100644 --- a/example/home/models.py +++ b/example/home/models.py @@ -58,12 +58,14 @@ def custom_middleware_one(next, root, info, **args): def custom_middleware_two(next, root, info, **args): if not info.context.custom_middleware_one: raise Exception("Middleware one should have been called") - if args['id'] == 2: + if args["id"] == 2: return None return next(root, info, **args) -@register_query_field('middlewareModel', middleware=[custom_middleware_one, custom_middleware_two]) +@register_query_field( + "middlewareModel", middleware=[custom_middleware_one, custom_middleware_two] +) class MiddlewareModel(models.Model): pass diff --git a/example/home/test/test_general.py b/example/home/test/test_general.py index d68b8fd9..38eea5ca 100644 --- a/example/home/test/test_general.py +++ b/example/home/test/test_general.py @@ -160,7 +160,7 @@ def test_multiple_middleware(self): results = self.client.execute( query, variables={"id": 1}, context_value=self.request ) - # Check that both middleware ran ok, value returned means the assert passed in middleware_2 + # Check that both middleware ran ok, value returned means the check for middleware_1 passed in middleware_2 self.assertEqual(int(results["data"]["middlewareModel"]["id"]), 1) results = self.client.execute( query, variables={"id": 2}, context_value=self.request From 43f90b690d05dc4bce785c4f9dfb30d3fd7cd207 Mon Sep 17 00:00:00 2001 From: Seb Date: Wed, 18 Jan 2023 09:47:58 +1100 Subject: [PATCH 5/5] pre-commit fixes --- example/home/migrations/0006_middleware_model.py | 1 - example/home/test/test_general.py | 2 -- grapple/middleware.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/example/home/migrations/0006_middleware_model.py b/example/home/migrations/0006_middleware_model.py index c4926874..08b0a307 100644 --- a/example/home/migrations/0006_middleware_model.py +++ b/example/home/migrations/0006_middleware_model.py @@ -4,7 +4,6 @@ class Migration(migrations.Migration): - dependencies = [ ("home", "0005_auto_20220909_0959"), ] diff --git a/example/home/test/test_general.py b/example/home/test/test_general.py index 38eea5ca..710b4e46 100644 --- a/example/home/test/test_general.py +++ b/example/home/test/test_general.py @@ -9,8 +9,6 @@ SimpleModelFactory, ) -from home.factories import BlogPageFactory, MiddlewareModelFactory, SimpleModelFactory - from example.tests.test_grapple import BaseGrappleTest diff --git a/grapple/middleware.py b/grapple/middleware.py index 19a8c050..33dc7990 100644 --- a/grapple/middleware.py +++ b/grapple/middleware.py @@ -43,4 +43,4 @@ def resolve(self, next, root, info: ResolveInfo, **kwargs): while middlewares: next = partial(middlewares.pop(), next) - return next(root, info, **args) + return next(root, info, **kwargs)