diff --git a/pinax/stripe/admin.py b/pinax/stripe/admin.py index 15bb1c15e..798f71070 100644 --- a/pinax/stripe/admin.py +++ b/pinax/stripe/admin.py @@ -1,4 +1,5 @@ from django.contrib import admin +from django.contrib.admin.views.main import ChangeList from django.contrib.auth import get_user_model from django.db.models import Count @@ -103,8 +104,28 @@ def queryset(self, request, queryset): return queryset.all() +class PrefetchingChangeList(ChangeList): + """A custom changelist to prefetch related fields.""" + def get_queryset(self, request): + qs = super(PrefetchingChangeList, self).get_queryset(request) + + if subscription_status in self.list_display: + qs = qs.prefetch_related("subscription_set") + if "customer" in self.list_display: + qs = qs.prefetch_related("customer") + if "user" in self.list_display: + qs = qs.prefetch_related("user") + return qs + + +class ModelAdmin(admin.ModelAdmin): + def get_changelist(self, request, **kwargs): + return PrefetchingChangeList + + admin.site.register( Charge, + admin_class=ModelAdmin, list_display=[ "stripe_id", "customer", @@ -200,6 +221,7 @@ def subscription_status(obj): admin.site.register( Customer, + admin_class=ModelAdmin, raw_id_fields=["user"], list_display=[ "stripe_id", diff --git a/pinax/stripe/tests/test_admin.py b/pinax/stripe/tests/test_admin.py index 538f2c42c..3e8241729 100644 --- a/pinax/stripe/tests/test_admin.py +++ b/pinax/stripe/tests/test_admin.py @@ -1,5 +1,6 @@ import datetime +import django from django.contrib.auth import get_user_model from django.test import Client, TestCase from django.utils import timezone @@ -99,8 +100,16 @@ def setUp(self): def test_customer_admin(self): """Make sure we get good responses for all filter options""" url = reverse("admin:pinax_stripe_customer_changelist") - response = self.client.get(url) - self.assertEqual(response.status_code, 200) + + # Django 1.10 has the following query twice: + # SELECT COUNT(*) AS "__count" FROM "pinax_stripe_customer" + # (since https://github.com/django/django/commit/5fa7b592b3f) + # We might want to test for "num < 10" here instead, and/or compare the + # number to be equal with X and X+1 customers + num = 8 if django.VERSION >= (1, 10) else 7 + with self.assertNumQueries(num): + response = self.client.get(url) + self.assertEqual(response.status_code, 200) response = self.client.get(url + "?sub_status=active") self.assertEqual(response.status_code, 200)