diff --git a/kitsune/notifications/README.rst b/kitsune/notifications/README.rst index 86a816c23a8..9d390db3dbe 100644 --- a/kitsune/notifications/README.rst +++ b/kitsune/notifications/README.rst @@ -105,3 +105,23 @@ Notifications have only a few properties: Notifications also don't have any direct alerting properties, but they are used by other systems to alert users in some way. + +Realtime Notifications +====================== + +Realtime notifications are generally similar to the Notifications above in +function. They have different semantics however. Realtime notification +registrations are meant to be short lived, about one session. They also are not +intended to be a general notifications, and will not be shown to clients beside +the requesting client. Most importantly, they are unaffected by a user's follow +preferences. + +A realtime registration links a particular object to a simple push end point. +When any Action (as above) is created, it checks for any matching Realtime +registrations and then sends a SimplePush message. Clients are then expected to +check what actions they are being notified about. + +In the view clients check for actions in, much more information is provided +compared to when a user checks for notifications. This is because the action list +is intended to be used to populate a UI in realtime, as opposed to sending short +notifications to users. diff --git a/kitsune/notifications/api.py b/kitsune/notifications/api.py index a8d76191c68..9ddb372915b 100644 --- a/kitsune/notifications/api.py +++ b/kitsune/notifications/api.py @@ -1,9 +1,13 @@ +from django.db.models import Q + import django_filters +from actstream.models import Action from rest_framework import serializers, viewsets, permissions, mixins, status from rest_framework.decorators import action from rest_framework.response import Response -from kitsune.notifications.models import PushNotificationRegistration, Notification +from kitsune.notifications.models import ( + PushNotificationRegistration, Notification, RealtimeRegistration) from kitsune.sumo.api import OnlyCreatorEdits, DateTimeUTCField, GenericRelatedField @@ -131,3 +135,77 @@ class PushNotificationRegistrationViewSet(mixins.CreateModelMixin, permissions.IsAuthenticated, OnlyCreatorEdits, ] + + +class RealtimeRegistrationSerializer(serializers.ModelSerializer): + endpoint = serializers.CharField(write_only=True) + creator = serializers.SlugRelatedField(slug_field='username', required=False) + content_type = serializers.SlugRelatedField(slug_field='name') + + class Meta: + model = RealtimeRegistration + fields = [ + 'id', + 'creator', + 'created', + 'endpoint', + 'content_type', + 'object_id', + ] + + def validate_creator(self, attrs, source): + authed_user = getattr(self.context.get('request'), 'user') + creator = attrs.get('creator') + + if creator is None: + attrs['creator'] = authed_user + elif creator != authed_user: + raise serializers.ValidationError( + "Can't register push notifications for another user.") + + return attrs + + +class RealtimeActionSerializer(serializers.ModelSerializer): + action_object = GenericRelatedField(serializer_type='full') + actor = GenericRelatedField(serializer_type='full') + target = GenericRelatedField(serializer_type='full') + verb = serializers.CharField() + timestamp = DateTimeUTCField() + + class Meta: + model = PushNotificationRegistration + fields = ( + 'action_object', + 'actor', + 'id', + 'target', + 'timestamp', + 'verb', + ) + + +class RealtimeRegistrationViewSet(mixins.CreateModelMixin, + mixins.DestroyModelMixin, + viewsets.GenericViewSet): + model = RealtimeRegistration + serializer_class = RealtimeRegistrationSerializer + permission_classes = [ + permissions.IsAuthenticated, + OnlyCreatorEdits, + ] + + @action(methods=['GET']) + def updates(self, request, pk=None): + """Get all the actions that correspond to this registration.""" + reg = self.get_object() + + query = Q(actor_content_type=reg.content_type, actor_object_id=reg.object_id) + query |= Q(target_content_type=reg.content_type, target_object_id=reg.object_id) + query |= Q(action_object_content_type=reg.content_type, + action_object_object_id=reg.object_id) + + actions = Action.objects.filter(query) + serializer = RealtimeActionSerializer(actions, many=True) + + return Response(serializer.data) diff --git a/kitsune/notifications/models.py b/kitsune/notifications/models.py index e0478b1830b..40d431ff76d 100644 --- a/kitsune/notifications/models.py +++ b/kitsune/notifications/models.py @@ -1,21 +1,16 @@ -import logging from datetime import datetime from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes import generic from django.db import models from django.db.models.signals import post_save from django.dispatch import receiver import actstream.registry from actstream.models import Action -import requests -from requests.exceptions import RequestException from kitsune.sumo.models import ModelBase -from kitsune.notifications.decorators import notification_handler - - -logger = logging.getLogger('k.notifications') class Notification(ModelBase): @@ -62,21 +57,19 @@ def send_notification(sender, instance, created, **kwargs): tasks.send_notification.delay(instance.id) -@notification_handler -def simple_push(notification): - """ - Send simple push notifications to users that have opted in to them. - - This will be called as a part of a celery task. - """ - registrations = PushNotificationRegistration.objects.filter(creator=notification.owner) - for reg in registrations: - try: - r = requests.put(reg.push_url, 'version={}'.format(notification.id)) - # If something does wrong, the SimplePush server will give back - # json encoded error messages. - if r.status_code != 200: - logger.error('SimplePush error: %s %s', r.status_code, r.json()) - except RequestException as e: - # This will go to Sentry. - logger.error('SimplePush PUT failed: %s', e) +class RealtimeRegistration(ModelBase): + creator = models.ForeignKey(User) + created = models.DateTimeField(default=datetime.now) + endpoint = models.CharField(max_length=256) + + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + target = generic.GenericForeignKey('content_type', 'object_id') + + +@receiver(post_save, sender=Action, dispatch_uid='action_send_realtimes') +def send_realtimes_for_action(sender, instance, created, **kwargs): + if not created: + return + from kitsune.notifications import tasks # avoid circular import + tasks.send_realtimes_for_action.delay(instance.id) diff --git a/kitsune/notifications/south_migrations/0003_auto__add_realtimeregistration__chg_field_notification_read_at.py b/kitsune/notifications/south_migrations/0003_auto__add_realtimeregistration__chg_field_notification_read_at.py new file mode 100644 index 00000000000..e1279ee8f53 --- /dev/null +++ b/kitsune/notifications/south_migrations/0003_auto__add_realtimeregistration__chg_field_notification_read_at.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +from south.utils import datetime_utils as datetime +from south.db import db +from south.v2 import SchemaMigration +from django.db import models + + +class Migration(SchemaMigration): + + def forwards(self, orm): + # Adding model 'RealtimeRegistration' + db.create_table(u'notifications_realtimeregistration', ( + (u'id', self.gf('django.db.models.fields.AutoField')(primary_key=True)), + ('creator', self.gf('django.db.models.fields.related.ForeignKey')(to=orm['auth.User'])), + ('created', self.gf('django.db.models.fields.DateTimeField')(default=datetime.datetime.now)), + ('endpoint', self.gf('django.db.models.fields.CharField')(max_length=256)), + ('content_type', self.gf('django.db.models.fields.related.ForeignKey')(to=orm['contenttypes.ContentType'])), + ('object_id', self.gf('django.db.models.fields.PositiveIntegerField')()), + )) + db.send_create_signal(u'notifications', ['RealtimeRegistration']) + + def backwards(self, orm): + # Deleting model 'RealtimeRegistration' + db.delete_table(u'notifications_realtimeregistration') + + models = { + u'actstream.action': { + 'Meta': {'ordering': "('-timestamp',)", 'object_name': 'Action'}, + 'action_object_content_type': ('django.db.models.fields.related.ForeignKey', [], {'blank': 'True', 'related_name': "'action_object'", 'null': 'True', 'to': u"orm['contenttypes.ContentType']"}), + 'action_object_object_id': ('django.db.models.fields.CharField', [], {'max_length': '255', 'null': 'True', 'blank': 'True'}), + 'actor_content_type': ('django.db.models.fields.related.ForeignKey', [], {'related_name': "'actor'", 'to': u"orm['contenttypes.ContentType']"}), + 'actor_object_id': ('django.db.models.fields.CharField', [], {'max_length': '255'}), + 'data': ('jsonfield.fields.JSONField', [], {'null': 'True', 'blank': 'True'}), + 'description': ('django.db.models.fields.TextField', [], {'null': 'True', 'blank': 'True'}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'public': ('django.db.models.fields.BooleanField', [], {'default': 'True'}), + 'target_content_type': ('django.db.models.fields.related.ForeignKey', [], {'blank': 'True', 'related_name': "'target'", 'null': 'True', 'to': u"orm['contenttypes.ContentType']"}), + 'target_object_id': ('django.db.models.fields.CharField', [], {'max_length': '255', 'null': 'True', 'blank': 'True'}), + 'timestamp': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'verb': ('django.db.models.fields.CharField', [], {'max_length': '255'}) + }, + u'auth.group': { + 'Meta': {'object_name': 'Group'}, + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}), + 'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': u"orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}) + }, + u'auth.permission': { + 'Meta': {'ordering': "(u'content_type__app_label', u'content_type__model', u'codename')", 'unique_together': "((u'content_type', u'codename'),)", 'object_name': 'Permission'}, + 'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['contenttypes.ContentType']"}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) + }, + u'auth.user': { + 'Meta': {'object_name': 'User'}, + 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}), + 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), + 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'symmetrical': 'False', 'related_name': "u'user_set'", 'blank': 'True', 'to': u"orm['auth.Group']"}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}), + 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), + 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), + 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), + 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}), + 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'symmetrical': 'False', 'related_name': "u'user_set'", 'blank': 'True', 'to': u"orm['auth.Permission']"}), + 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'}) + }, + u'contenttypes.contenttype': { + 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, + 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), + 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) + }, + u'notifications.notification': { + 'Meta': {'object_name': 'Notification'}, + 'action': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['actstream.Action']"}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'owner': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['auth.User']"}), + 'read_at': ('django.db.models.fields.DateTimeField', [], {'null': 'True', 'blank': 'True'}) + }, + u'notifications.pushnotificationregistration': { + 'Meta': {'object_name': 'PushNotificationRegistration'}, + 'created': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'creator': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['auth.User']"}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'push_url': ('django.db.models.fields.CharField', [], {'max_length': '256'}) + }, + u'notifications.realtimeregistration': { + 'Meta': {'object_name': 'RealtimeRegistration'}, + 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['contenttypes.ContentType']"}), + 'created': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), + 'creator': ('django.db.models.fields.related.ForeignKey', [], {'to': u"orm['auth.User']"}), + 'endpoint': ('django.db.models.fields.CharField', [], {'max_length': '256'}), + u'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), + 'object_id': ('django.db.models.fields.PositiveIntegerField', [], {}) + } + } + + complete_apps = ['notifications'] diff --git a/kitsune/notifications/tasks.py b/kitsune/notifications/tasks.py index 27bd0f0534d..f346d352577 100644 --- a/kitsune/notifications/tasks.py +++ b/kitsune/notifications/tasks.py @@ -1,43 +1,68 @@ +import logging + from django.contrib.contenttypes.models import ContentType from django.db.models import Q import actstream.registry +import requests from actstream.models import Action, Follow from celery import task +from requests.exceptions import RequestException -from kitsune.notifications.models import Notification -from kitsune.notifications.decorators import notification_handlers +from kitsune.notifications.models import ( + Notification, RealtimeRegistration, PushNotificationRegistration) +from kitsune.notifications.decorators import notification_handler, notification_handlers -@task(ignore_result=True) -def add_notification_for_action(action_id): - action = Action.objects.get(id=action_id) +logger = logging.getLogger('k.notifications.tasks') - # For each attribute of the action, check that the attribute is valid, and - # build a query that finds all Follow objects that match it. - actstream.registry.check(action.actor) - query = Q( - content_type=ContentType.objects.get_for_model(action.actor).pk, - object_id=action.actor.pk) +def _ct_query(object, actor_only=None, **kwargs): + ct = ContentType.objects.get_for_model(object) + if actor_only is not None: + kwargs['actor_only'] = actor_only + return Q(content_type=ct.pk, object_id=object.pk, **kwargs) + + +def _full_ct_query(action, actor_only=None): + """Build a query that matches objects with a content type that matches an action.""" + actstream.registry.check(action.actor) + query = _ct_query(action.actor) if action.target is not None: actstream.registry.check(action.target) - query = query | Q( - content_type=ContentType.objects.get_for_model(action.target).pk, - object_id=action.target.pk, - actor_only=False) - + query |= _ct_query(action.target, actor_only) if action.action_object is not None: actstream.registry.check(action.action_object) - query = query | Q( - content_type=ContentType.objects.get_for_model(action.action_object).pk, - object_id=action.action_object.pk, - actor_only=False) + query |= _ct_query(action.action_object, actor_only) + return query + + +def _send_simple_push(endpoint, version): + """ + Hit a simple push endpoint to send a notification to a user. + + Handles and record any HTTP errors. + """ + try: + r = requests.put(endpoint, 'version={}'.format(version)) + # If something does wrong, the SimplePush server will give back + # json encoded error messages. + if r.status_code != 200: + logger.error('SimplePush error: %s %s', r.status_code, r.json()) + except RequestException as e: + # This will go to Sentry. + logger.error('SimplePush PUT failed: %s', e) - query = query & ~Q(user=action.actor) + +@task(ignore_result=True) +def add_notification_for_action(action_id): + action = Action.objects.get(id=action_id) + query = _full_ct_query(action, actor_only=False) + # Don't send notifications to a user about actions they take. + query &= ~Q(user=action.actor) # execute the above query, iterate through the results, get every user - # assocated with those Follow objects, and fire off a notification to + # assocated with those Follow objects, and fire off a notification to # every one of them. Use a set to only notify each user once. users_to_notify = set(f.user for f in Follow.objects.filter(query)) # Don't use bulk save since that wouldn't trigger signal handlers @@ -45,9 +70,33 @@ def add_notification_for_action(action_id): Notification.objects.create(owner=u, action=action) +@task(ignore_result=True) +def send_realtimes_for_action(action_id): + action = Action.objects.get(id=action_id) + query = _full_ct_query(action) + # Don't send notifications to a user about actions they take. + query &= ~Q(creator=action.actor) + + registrations = RealtimeRegistration.objects.filter(query) + for reg in registrations: + _send_simple_push(reg.endpoint, action.id) + + @task(ignore_result=True) def send_notification(notification_id): """Call every notification handler for a notification.""" notification = Notification.objects.get(id=notification_id) for handler in notification_handlers: handler(notification) + + +@notification_handler +def simple_push(notification): + """ + Send simple push notifications to users that have opted in to them. + + This will be called as a part of a celery task. + """ + registrations = PushNotificationRegistration.objects.filter(creator=notification.owner) + for reg in registrations: + _send_simple_push(reg.push_url, notification.id) diff --git a/kitsune/notifications/tests/test_api.py b/kitsune/notifications/tests/test_api.py index b01ca038b6f..85509408307 100644 --- a/kitsune/notifications/tests/test_api.py +++ b/kitsune/notifications/tests/test_api.py @@ -1,17 +1,20 @@ from datetime import datetime +from django.contrib.contenttypes.models import ContentType + from actstream.actions import follow from actstream.signals import action from actstream.models import Action, Follow -import mock +from mock import Mock, patch from nose.tools import eq_, ok_ from rest_framework.test import APIClient from kitsune.notifications import api -from kitsune.notifications.models import Notification +from kitsune.notifications import tasks as notification_tasks +from kitsune.notifications.models import Notification, RealtimeRegistration from kitsune.sumo.tests import TestCase from kitsune.sumo.urlresolvers import reverse -from kitsune.questions.tests import question +from kitsune.questions.tests import question, answer from kitsune.users.tests import profile, user from kitsune.users.helpers import profile_avatar @@ -21,7 +24,7 @@ class TestPushNotificationRegistrationSerializer(TestCase): def setUp(self): self.profile = profile() self.user = self.profile.user - self.request = mock.Mock() + self.request = Mock() self.request.user = self.user self.context = { 'request': self.request, @@ -81,6 +84,7 @@ def test_correct_fields(self): class TestNotificationViewSet(TestCase): + def setUp(self): self.client = APIClient() @@ -104,16 +108,16 @@ def _makeNotification(self, is_read=False): def test_mark_read(self): n = self._makeNotification() self.client.force_authenticate(user=self.follower) - req = self.client.post(reverse('notification-mark-read', args=[n.id])) - eq_(req.status_code, 204) + res = self.client.post(reverse('notification-mark-read', args=[n.id])) + eq_(res.status_code, 204) n = Notification.objects.get(id=n.id) eq_(n.is_read, True) def test_mark_unread(self): n = self._makeNotification(is_read=True) self.client.force_authenticate(user=self.follower) - req = self.client.post(reverse('notification-mark-unread', args=[n.id])) - eq_(req.status_code, 204) + res = self.client.post(reverse('notification-mark-unread', args=[n.id])) + eq_(res.status_code, 204) n = Notification.objects.get(id=n.id) eq_(n.is_read, False) @@ -121,14 +125,45 @@ def test_filter_is_read_false(self): n = self._makeNotification(is_read=False) self._makeNotification(is_read=True) self.client.force_authenticate(user=self.follower) - req = self.client.get(reverse('notification-list') + '?is_read=0') - eq_(req.status_code, 200) - eq_([d['id'] for d in req.data], [n.id]) + res = self.client.get(reverse('notification-list') + '?is_read=0') + eq_(res.status_code, 200) + eq_([d['id'] for d in res.data], [n.id]) def test_filter_is_read_true(self): self._makeNotification(is_read=False) n = self._makeNotification(is_read=True) self.client.force_authenticate(user=self.follower) - req = self.client.get(reverse('notification-list') + '?is_read=1') - eq_(req.status_code, 200) - eq_([d['id'] for d in req.data], [n.id]) + res = self.client.get(reverse('notification-list') + '?is_read=1') + eq_(res.status_code, 200) + eq_([d['id'] for d in res.data], [n.id]) + + +@patch.object(notification_tasks, 'requests') +class RealtimeViewSet(TestCase): + + def setUp(self): + self.client = APIClient() + + def test_updates_subview(self, requests): + requests.put.return_value.status_code = 200 + + u = profile().user + q = question(content='asdf', save=True) + ct = ContentType.objects.get_for_model(q) + rt = RealtimeRegistration.objects.create( + creator=u, content_type=ct, object_id=q.id, endpoint='http://example.com/') + # Some of the above may have created actions, which we don't care about. + Action.objects.all().delete() + # This shuld create an action that will trigger the above. + a = answer(question=q, content='asdf', save=True) + + self.client.force_authenticate(user=u) + url = reverse('realtimeregistration-updates', args=[rt.id]) + res = self.client.get(url) + eq_(res.status_code, 200) + + eq_(len(res.data), 1) + act = res.data[0] + eq_(act['actor']['username'], a.creator.username) + eq_(act['target']['content'], q.content_parsed) + eq_(act['action_object']['content'], a.content_parsed) diff --git a/kitsune/notifications/tests/test_signals.py b/kitsune/notifications/tests/test_signals.py index 36285c0bfaa..d7c78465539 100644 --- a/kitsune/notifications/tests/test_signals.py +++ b/kitsune/notifications/tests/test_signals.py @@ -1,11 +1,14 @@ +from django.contrib.contenttypes.models import ContentType + from actstream.actions import follow from actstream.signals import action from actstream.models import Action, Follow from mock import patch from nose.tools import eq_ -from kitsune.notifications import models as notification_models -from kitsune.notifications.models import Notification, PushNotificationRegistration +from kitsune.notifications import tasks as notification_tasks +from kitsune.notifications.models import ( + Notification, PushNotificationRegistration, RealtimeRegistration) from kitsune.notifications.tests import notification from kitsune.questions.tests import answer, question from kitsune.sumo.tests import TestCase @@ -75,7 +78,7 @@ def test_no_action_for_self(self): eq_(Notification.objects.filter(action=act).count(), 0) -@patch.object(notification_models, 'requests') +@patch.object(notification_tasks, 'requests') class TestSimplePushNotifier(TestCase): def test_simple_push_send(self, requests): @@ -106,3 +109,21 @@ def test_from_action_to_simple_push(self, requests): n = Notification.objects.get(owner=u) # Assert that they got notified. requests.put.assert_called_once_with(url, 'version={}'.format(n.id)) + + def test_from_action_to_realtime_notification(self, requests): + """ + Test that when an action is created, it results in a realtime notification being sent. + """ + # Create a user + u = profile().user + # Register realtime notifications for that user on a question + q = question(save=True) + url = 'http://example.com/simple_push/asdf' + ct = ContentType.objects.get_for_model(q) + RealtimeRegistration.objects.create( + creator=u, endpoint=url, content_type=ct, object_id=q.id) + # Create an action involving that question + action.send(profile().user, verb='looked at funny', action_object=q) + a = Action.objects.order_by('-id')[0] + # Assert that they got notified. + requests.put.assert_called_once_with(url, 'version={}'.format(a.id)) diff --git a/kitsune/notifications/urls_api.py b/kitsune/notifications/urls_api.py index cb991a5fdc1..f255c09346b 100644 --- a/kitsune/notifications/urls_api.py +++ b/kitsune/notifications/urls_api.py @@ -5,4 +5,5 @@ router = routers.SimpleRouter() router.register(r'pushnotification', api.PushNotificationRegistrationViewSet) router.register(r'notification', api.NotificationViewSet) +router.register(r'realtime', api.RealtimeRegistrationViewSet) urlpatterns = router.urls diff --git a/kitsune/questions/models.py b/kitsune/questions/models.py index d046ca95366..04780435ecf 100755 --- a/kitsune/questions/models.py +++ b/kitsune/questions/models.py @@ -383,10 +383,15 @@ def get_mapping_type(cls): return QuestionMappingType @classmethod - def get_generic_fk_serializer(cls): + def get_serializer(cls, serializer_type='full'): # Avoid circular import - from kitsune.questions.api import QuestionFKSerializer - return QuestionFKSerializer + from kitsune.questions import api + if serializer_type == 'full': + return api.QuestionSerializer + elif serializer_type == 'fk': + return api.QuestionFKSerializer + else: + raise ValueError('Unknown serializer type "{}".'.format(serializer_type)) @classmethod def recent_asked_count(cls, extra_filter=None): @@ -1122,10 +1127,15 @@ def get_mapping_type(cls): return AnswerMetricsMappingType @classmethod - def get_generic_fk_serializer(cls): + def get_serializer(cls, serializer_type='full'): # Avoid circular import - from kitsune.questions.api import AnswerFKSerializer - return AnswerFKSerializer + from kitsune.questions import api + if serializer_type == 'full': + return api.AnswerSerializer + elif serializer_type == 'fk': + return api.AnswerFKSerializer + else: + raise ValueError('Unknown serializer type "{}".'.format(serializer_type)) def mark_as_spam(self, by_user): """Mark the answer as spam by the specified user.""" diff --git a/kitsune/sumo/api.py b/kitsune/sumo/api.py index 151ace77648..65ed6ad4fa0 100644 --- a/kitsune/sumo/api.py +++ b/kitsune/sumo/api.py @@ -206,9 +206,13 @@ class Meta: class GenericRelatedField(relations.RelatedField): """ - Serializes GenericForeignKey relations. + Serializes GenericForeignKey relations using specified type of serializer. """ + def __init__(self, serializer_type='fk', **kwargs): + self.serializer_type = serializer_type + super(GenericRelatedField, self).__init__(**kwargs) + def to_native(self, value): content_type = ContentType.objects.get_for_model(value) data = {'type': content_type.model} @@ -216,8 +220,8 @@ def to_native(self, value): if isinstance(value, User): value = value.get_profile() - if hasattr(value, 'get_generic_fk_serializer'): - SerializerClass = value.get_generic_fk_serializer() + if hasattr(value, 'get_serializer'): + SerializerClass = value.get_serializer(self.serializer_type) else: SerializerClass = _IDSerializer data.update(SerializerClass(instance=value).data) diff --git a/kitsune/users/models.py b/kitsune/users/models.py index 68b2115e2c6..2d523985e28 100644 --- a/kitsune/users/models.py +++ b/kitsune/users/models.py @@ -133,10 +133,15 @@ def get_mapping_type(cls): return UserMappingType @classmethod - def get_generic_fk_serializer(cls): + def get_serializer(cls, serializer_type='full'): # Avoid circular import - from kitsune.users.api import ProfileFKSerializer - return ProfileFKSerializer + from kitsune.users import api + if serializer_type == 'full': + return api.ProfileSerializer + elif serializer_type == 'fk': + return api.ProfileFKSerializer + else: + raise ValueError('Unknown serializer type "{}".'.format(serializer_type)) @property def last_contribution_date(self):