diff --git a/courses/factories.py b/courses/factories.py index 65a3eef885..05af36dd9b 100644 --- a/courses/factories.py +++ b/courses/factories.py @@ -59,3 +59,12 @@ class CourseRunFactory(DjangoModelFactory): class Meta: # pylint: disable=missing-docstring model = CourseRun + + @classmethod + def create(cls, **kwargs): + # If 'program' is provided, build a Course from it first + if 'program' in kwargs: + kwargs['course'] = CourseFactory.create(program=kwargs.pop('program')) + # Create the CourseRun as normal + attrs = cls.attributes(create=True, extra=kwargs) + return cls._generate(True, attrs) diff --git a/dashboard/factories.py b/dashboard/factories.py index 7edf9d3610..1a9482866d 100644 --- a/dashboard/factories.py +++ b/dashboard/factories.py @@ -4,9 +4,7 @@ from datetime import datetime, timedelta import pytz -from factory import ( - SubFactory, -) +from factory import SubFactory from factory.django import DjangoModelFactory from factory.fuzzy import ( FuzzyAttribute, @@ -17,8 +15,13 @@ from dashboard.models import ( CachedCertificate, CachedEnrollment, + ProgramEnrollment +) +from courses.factories import ( + CourseRunFactory, + CourseFactory, + ProgramFactory, ) -from courses.factories import CourseRunFactory from profiles.factories import UserFactory @@ -47,3 +50,32 @@ class CachedEnrollmentFactory(DjangoModelFactory): class Meta: # pylint: disable=missing-docstring,no-init,too-few-public-methods,old-style-class model = CachedEnrollment + + +class ProgramEnrollmentFactory(DjangoModelFactory): + """Factory for ProgramEnrollment""" + user = SubFactory(UserFactory) + program = SubFactory(ProgramFactory) + + class Meta: # pylint: disable=missing-docstring,no-init,too-few-public-methods,old-style-class + model = ProgramEnrollment + + @classmethod + def create(cls, **kwargs): + """ + Overrides default ProgramEnrollment object creation for the factory. + + ProgramEnrollments should only exist if there is a CachedEnrollment associated with + the given User and Program. Instead of creating a new record with the factory, we + will create the necessary objects to trigger its creation. + """ + user = kwargs.get('user', UserFactory.create()) + program = kwargs.get('program', ProgramFactory.create()) + course = CourseFactory.create(program=program) + course_run = CourseRunFactory.create(course=course) + CachedEnrollmentFactory.create(user=user, course_run=course_run) + # CachedCertificate isn't strictly necessary to create a ProgramEnrollment. This is here for test convenience. + CachedCertificateFactory.create(user=user, course_run=course_run) + # Signal from the creation of a CachedEnrollment should have created a ProgramEnrollment + program_enrollment = ProgramEnrollment.objects.get(user=user, program=program) + return program_enrollment diff --git a/dashboard/models.py b/dashboard/models.py index 172ca8a0f2..03041898cf 100644 --- a/dashboard/models.py +++ b/dashboard/models.py @@ -25,6 +25,16 @@ class CachedEnrollment(Model): class Meta: unique_together = (('user', 'course_run'), ) + @classmethod + def active_count(cls, user, program): + """ + Returns the number of active CachedEnrollments for a User/Program pair + """ + return cls.objects.filter( + user=user, + course_run__course__program=program + ).exclude(data__isnull=True).count() + def __str__(self): """ String representation of the model object diff --git a/dashboard/signals.py b/dashboard/signals.py index b1cfa9cedc..dc25fcc619 100644 --- a/dashboard/signals.py +++ b/dashboard/signals.py @@ -1,10 +1,27 @@ """ Signals for user profiles """ -from django.db.models.signals import pre_save, post_save +from django.db.models.signals import pre_save, post_save, post_delete from django.dispatch import receiver -from dashboard.models import CachedEnrollment, ProgramEnrollment +from dashboard.models import CachedEnrollment, CachedCertificate, ProgramEnrollment +from search.tasks import index_program_enrolled_users, index_users, remove_program_enrolled_user + + +@receiver(post_save, sender=ProgramEnrollment, dispatch_uid="programenrollment_post_save") +def handle_create_programenrollment(sender, instance, created, **kwargs): # pylint: disable=unused-argument + """ + When a ProgramEnrollment model is created/updated, update index. + """ + index_program_enrolled_users.delay([instance]) + + +@receiver(post_delete, sender=ProgramEnrollment, dispatch_uid="programenrollment_post_delete") +def handle_delete_programenrollment(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + When a ProgramEnrollment model is deleted, update index. + """ + remove_program_enrolled_user.delay(instance) @receiver(pre_save, sender=CachedEnrollment, dispatch_uid="preupdate_programenrollment") @@ -25,27 +42,65 @@ def precreate_programenrollment(sender, instance, **kwargs): # pylint: disable= instance_in_db = CachedEnrollment.objects.filter(id=instance.id).exclude(data__isnull=True).count() # if the count is 1, it means the student unenrolled from the course run if instance_in_db == 1: - active_enrollment_count = CachedEnrollment.objects.filter( - user=user, - course_run__course__program=program - ).exclude(data__isnull=True).count() # if there is only one enrollment with data non None, it means that it is the # current instance is the only one for the program, so the program enrollment # needs to be deleted - if active_enrollment_count <= 1: # theoretically this cannot be <1, but just in case + if CachedEnrollment.active_count(user, program) <= 1: # theoretically this cannot be <1, but just in case ProgramEnrollment.objects.filter( user=user, program=program ).delete() -@receiver(post_save, sender=CachedEnrollment, dispatch_uid="update_programenrollment") -def create_programenrollment(sender, instance, **kwargs): # pylint: disable=unused-argument +@receiver(post_save, sender=CachedEnrollment, dispatch_uid="cachedenrollment_post_save") +def handle_update_enrollment(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + Create ProgramEnrollment when a CachedEnrollment is created/updated, and update the index. + """ + if instance.data is not None: + program_enrollment, _ = ProgramEnrollment.objects.get_or_create( + user=instance.user, + program=instance.course_run.course.program + ) + index_program_enrolled_users.delay([program_enrollment]) + + +@receiver(post_save, sender=CachedCertificate, dispatch_uid="cachedcertificate_post_save") +def handle_update_certificate(sender, instance, **kwargs): # pylint: disable=unused-argument """ - Signal handler to create Program enrollment when the CachedEnrollment table is updated + When a CachedCertificate model is updated, update index. """ if instance.data is not None: - ProgramEnrollment.objects.get_or_create( + program_enrollment, _ = ProgramEnrollment.objects.get_or_create( user=instance.user, program=instance.course_run.course.program ) + index_program_enrolled_users.delay([program_enrollment]) + + +@receiver(post_delete, sender=CachedEnrollment, dispatch_uid="cachedenrollment_post_delete") +def handle_delete_enrollment(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + Update index when CachedEnrollment model instance is deleted. + """ + user = instance.user + program = instance.course_run.course.program + program_enrollment = ProgramEnrollment.objects.filter(user=user, program=program).first() + if program_enrollment is not None: + if CachedEnrollment.active_count(user, program) == 0: + program_enrollment.delete() + index_users.delay([user]) + else: + index_program_enrolled_users.delay([program_enrollment]) + + +@receiver(post_delete, sender=CachedCertificate, dispatch_uid="cachedcertificate_post_delete") +def handle_delete_certificate(sender, instance, **kwargs): # pylint: disable=unused-argument + """ + Update index when CachedCertificate model instance is deleted. + """ + user = instance.user + program = instance.course_run.course.program + program_enrollment = ProgramEnrollment.objects.filter(user=user, program=program).first() + if program_enrollment is not None: + index_program_enrolled_users.delay([program_enrollment]) diff --git a/micromasters/utils.py b/micromasters/utils.py index f945f4d4a2..a290cc0f9e 100644 --- a/micromasters/utils.py +++ b/micromasters/utils.py @@ -19,6 +19,15 @@ def webpack_dev_server_url(request): return 'http://{}:{}'.format(webpack_dev_server_host(request), settings.WEBPACK_DEV_SERVER_PORT) +def dict_without_key(dictionary, key): + """ + Helper method to remove a key from a dict and return the dict. This can be used in cases like a list comprehension + where the actual dictionary is needed once the key is deleted ('del' does not return anything) + """ + del dictionary[key] + return dictionary + + def load_json_from_file(project_rel_filepath): """ Loads JSON data from a file diff --git a/search/api.py b/search/api.py index b5dd1df8f2..75a00a37f7 100644 --- a/search/api.py +++ b/search/api.py @@ -5,7 +5,6 @@ import logging from django.conf import settings -from django.contrib.auth.models import User from elasticsearch.helpers import bulk from elasticsearch.exceptions import NotFoundError from elasticsearch_dsl import Mapping @@ -13,11 +12,12 @@ from profiles.models import Profile from profiles.serializers import ProfileSerializer +from dashboard.models import ProgramEnrollment from search.exceptions import ReindexException log = logging.getLogger(__name__) -USER_DOC_TYPE = 'user' +USER_DOC_TYPE = 'program_user' DOC_TYPES = (USER_DOC_TYPE, ) _CONN = None # When we create the connection, check to make sure all appropriate mappings exist @@ -75,13 +75,13 @@ def get_conn(verify=True): return _CONN -def _index_users_chunk(users): +def _index_program_enrolled_users_chunk(program_enrollments): """ - Add/update a small number of user records in Elasticsearch + Add/update a list of ProgramEnrollment records in Elasticsearch Args: - users (list of User): - List of users + program_enrollments (list of ProgramEnrollments): + List of ProgramEnrollments to serialize and index Returns: int: Number of items inserted into Elasticsearch @@ -90,27 +90,25 @@ def _index_users_chunk(users): conn = get_conn() insert_count, errors = bulk( conn, - (serialize_user(user) for user in users), + (serialize_program_enrolled_user(program_enrollment) for program_enrollment in program_enrollments), index=settings.ELASTICSEARCH_INDEX, doc_type=USER_DOC_TYPE, ) - if len(errors) > 0: raise ReindexException("Error during bulk insert: {errors}".format( errors=errors )) refresh_index() - return insert_count -def index_users(users, chunk_size=100): +def index_program_enrolled_users(program_enrollments, chunk_size=100): """ - Add/update profile records in Elasticsearch. + Add/update ProgramEnrollment records in Elasticsearch. Args: - users (iterable of User): - Iterable of users + program_enrollments (iterable of ProgramEnrollments): + Iterable of ProgramEnrollments to serialize and index chunk_size (int): How many users to index at once. @@ -118,55 +116,80 @@ def index_users(users, chunk_size=100): int: Number of indexed items """ # Use an iterator so we can keep track of what's been indexed already - users = iter(users) - + program_enrollments = iter(program_enrollments) count = 0 - chunk = list(islice(users, chunk_size)) + chunk = list(islice(program_enrollments, chunk_size)) while len(chunk) > 0: - count += _index_users_chunk(chunk) - chunk = list(islice(users, chunk_size)) - + count += _index_program_enrolled_users_chunk(chunk) + chunk = list(islice(program_enrollments, chunk_size)) refresh_index() - return count -def remove_user(user): +def index_users(users, chunk_size=100): """ - Remove a user from Elasticsearch. + Indexes a list of users via their ProgramEnrollments + """ + program_enrollments = ProgramEnrollment.objects.filter(user__in=users).select_related('user', 'program').all() + return index_program_enrolled_users(program_enrollments, chunk_size) + + +def remove_program_enrolled_user(program_enrollment): + """ + Remove a program-enrolled user from Elasticsearch. """ conn = get_conn() try: - conn.delete(index=settings.ELASTICSEARCH_INDEX, doc_type=USER_DOC_TYPE, id=user.id) + conn.delete(index=settings.ELASTICSEARCH_INDEX, doc_type=USER_DOC_TYPE, id=program_enrollment.id) except NotFoundError: # Item is already gone pass -def serialize_user(user): +def remove_user(user): + """ + Remove a user from Elasticsearch. + """ + program_enrollments = ProgramEnrollment.objects.filter(user=user).select_related('user', 'program').all() + for program_enrollment in program_enrollments: + remove_program_enrolled_user(program_enrollment) + + +def serialize_program_enrolled_user(program_enrollment): """ - Serializes user for use with Elasticsearch. + Serializes a program-enrolled user for use with Elasticsearch. Args: - user (User): A user to serialize + program_enrollment (ProgramEnrollment): A program_enrollment to serialize Returns: dict: The data to be sent to Elasticsearch """ + user = program_enrollment.user + program = program_enrollment.program serialized = { - 'id': user.id, - '_id': user.id, + 'id': program_enrollment.id, + '_id': program_enrollment.id, + 'user_id': user.id } try: serialized['profile'] = ProfileSerializer().to_representation(user.profile) except Profile.DoesNotExist: # Just in case pass - serialized['certificates'] = [ - certificate.data for certificate in user.cachedcertificate_set.all().exclude(data__isnull=True) - ] - serialized['enrollments'] = [ - enrollment.data for enrollment in user.cachedenrollment_set.all().exclude(data__isnull=True) - ] + + serialized['program'] = { + 'id': program.id, + 'certificates': [ + certificate.data for certificate in user.cachedcertificate_set.filter( + course_run__course__program=program + ).exclude(data__isnull=True) + ], + 'enrollments': [ + enrollment.data for enrollment in user.cachedenrollment_set.filter( + course_run__course__program=program + ).exclude(data__isnull=True) + ] + } return serialized @@ -177,19 +200,13 @@ def refresh_index(): get_conn().indices.refresh(index=settings.ELASTICSEARCH_INDEX) -def create_user_mapping(): +def program_enrolled_user_mapping(): """ - Create a mapping for profiles. If one already exists, delete it first. + Builds the raw mapping data for the program-enrolled user doc type """ - conn = get_conn(verify=False) - - index_name = settings.ELASTICSEARCH_INDEX - if conn.indices.exists_type(index=index_name, doc_type=USER_DOC_TYPE): - conn.indices.delete_mapping(index=index_name, doc_type=USER_DOC_TYPE) - mapping = Mapping(USER_DOC_TYPE) - mapping.field("id", "long") + mapping.field("user_id", "long") mapping.field("profile", "nested", properties={ 'account_privacy': NOT_ANALYZED_STRING_TYPE, 'agreed_to_terms_of_service': BOOL_TYPE, @@ -235,18 +252,21 @@ def create_user_mapping(): 'state_or_territory': NOT_ANALYZED_STRING_TYPE, }}, }) - mapping.field('enrollments', 'nested', properties={ - 'course_details': { - 'type': 'object', - 'properties': { - 'course_modes': { - 'type': 'nested', + mapping.field("program", "nested", properties={ + 'id': LONG_TYPE, + 'grade_average': LONG_TYPE, + 'enrollments': {'type': 'nested', 'properties': { + 'course_details': { + 'type': 'object', + 'properties': { + 'course_modes': { + 'type': 'nested', + } } } - } + }}, + 'certificates': {'type': 'nested'} }) - mapping.field('certificates', 'nested') - # Make strings not_analyzed by default mapping.meta('dynamic_templates', [{ "notanalyzed": { @@ -255,7 +275,18 @@ def create_user_mapping(): "mapping": NOT_ANALYZED_STRING_TYPE } }]) + return mapping + +def create_program_enrolled_user_mapping(): + """ + Save the mapping for a program user. If one already exists, delete it first. + """ + conn = get_conn(verify=False) + index_name = settings.ELASTICSEARCH_INDEX + if conn.indices.exists_type(index=index_name, doc_type=USER_DOC_TYPE): + conn.indices.delete_mapping(index=index_name, doc_type=USER_DOC_TYPE) + mapping = program_enrolled_user_mapping() mapping.save(index_name) @@ -263,7 +294,7 @@ def create_mappings(): """ Create all mappings, deleting existing mappings if they exist. """ - create_user_mapping() + create_program_enrolled_user_mapping() def clear_index(): @@ -284,4 +315,4 @@ def recreate_index(): Wipe and recreate index and mapping, and index all items. """ clear_index() - index_users(User.objects.iterator()) + index_program_enrolled_users(ProgramEnrollment.objects.iterator()) diff --git a/search/api_test.py b/search/api_test.py index 08e1438b0e..12ba081f34 100644 --- a/search/api_test.py +++ b/search/api_test.py @@ -2,25 +2,29 @@ Tests for search API functions. """ -from urllib.parse import urljoin - from django.conf import settings -from django.contrib.auth.models import User from django.db.models.signals import post_save from factory.django import mute_signals from rest_framework.fields import DateTimeField from requests import get +from mock import patch from dashboard.factories import ( CachedCertificateFactory, CachedEnrollmentFactory, + ProgramEnrollmentFactory +) +from dashboard.models import ProgramEnrollment +from courses.factories import ( + ProgramFactory, + CourseFactory, + CourseRunFactory, ) from profiles.api import get_social_username from profiles.factories import ( EducationFactory, EmploymentFactory, ProfileFactory, - UserFactory, ) from profiles.serializers import ( EducationSerializer, @@ -32,65 +36,56 @@ ) from search.api import ( get_conn, - index_users, + recreate_index, refresh_index, - remove_user, - serialize_user, + index_program_enrolled_users, + remove_program_enrolled_user, + serialize_program_enrolled_user, + USER_DOC_TYPE ) from search.base import ESTestCase from search.exceptions import ReindexException from search.util import traverse_mapping +from micromasters.utils import dict_without_key -def remove_key(dictionary, key): - """Helper method to remove a key from a dict and return the dict""" - del dictionary[key] - return dictionary - - -def search(): +class ESTestActions: """ - Execute a search and get results + Provides helper functions for tests to communicate with ES """ - # Refresh the index so we can read current data - refresh_index() + def __init__(self): + self.index = settings.ELASTICSEARCH_INDEX + self.url = "{}/{}".format(settings.ELASTICSEARCH_URL, self.index) + if not self.url.startswith("http"): + self.url = "http://{}".format(self.url) + self.search_url = "{}/{}".format(self.url, "_search") + self.mapping_url = "{}/{}".format(self.url, "_mapping") - elasticsearch_url = settings.ELASTICSEARCH_URL - if not elasticsearch_url.startswith("http"): - elasticsearch_url = "http://{}".format(elasticsearch_url) - url = urljoin( - elasticsearch_url, - "{}/{}".format(settings.ELASTICSEARCH_INDEX, "_search") - ) - return get(url).json()['hits'] + def search(self): + """Gets full index data from the _search endpoint""" + refresh_index() + return get(self.search_url).json()['hits'] + def get_mappings(self): + """Gets mapping data""" + refresh_index() + return get(self.mapping_url).json()[self.index]['mappings'] -def get_es_mappings(): - """ - Retrieve the current mapping - """ - # Refresh the index so we can read current data - refresh_index() - elasticsearch_url = settings.ELASTICSEARCH_URL - elasticsearch_index = settings.ELASTICSEARCH_INDEX - if not elasticsearch_url.startswith("http"): - elasticsearch_url = "http://{}".format(elasticsearch_url) - url = urljoin( - elasticsearch_url, - "{}/{}".format(elasticsearch_index, "_mapping") - ) - return get(url).json()[elasticsearch_index]['mappings'] +es = ESTestActions() -def assert_search(results, users): +def assert_search(results, program_enrollments): """ - Assert that search results match users + Assert that search results match program-enrolled users """ - assert results['total'] == len(users) + assert results['total'] == len(program_enrollments) sources = sorted([hit['_source'] for hit in results['hits']], key=lambda hit: hit['id']) - sorted_users = sorted(users, key=lambda user: user.id) - serialized = [remove_key(serialize_user(user), "_id") for user in sorted_users] + sorted_program_enrollments = sorted(program_enrollments, key=lambda program_enrollment: program_enrollment.id) + serialized = [ + dict_without_key(serialize_program_enrolled_user(program_enrollment), "_id") + for program_enrollment in sorted_program_enrollments + ] assert serialized == sources @@ -100,192 +95,172 @@ class IndexTests(ESTestCase): Tests for indexing """ - def test_user_add(self): + def test_program_enrollment_add(self): """ - Test that a newly created User is indexed properly + Test that a newly created ProgramEnrollment is indexed properly """ - assert search()['total'] == 0 - user = UserFactory.create() - assert_search(search(), [user]) + assert es.search()['total'] == 0 + program_enrollment = ProgramEnrollmentFactory.create() + assert_search(es.search(), [program_enrollment]) - def test_user_update(self): + def test_program_enrollment_delete(self): """ - Test that User is reindexed after being updated + Test that ProgramEnrollment is removed from index after the user is removed """ - user = UserFactory.create() - assert search()['total'] == 1 - profile = user.profile + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + program_enrollment.user.delete() + assert es.search()['total'] == 0 + + def test_profile_update(self): + """ + Test that ProgramEnrollment is reindexed after the User's Profile has been updated + """ + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + profile = program_enrollment.user.profile profile.first_name = 'updated' profile.save() - assert_search(search(), [user]) + assert_search(es.search(), [program_enrollment]) - def test_user_delete(self): + def test_program_enrollment_clear_upon_profile_deletion(self): """ - Test that User is removed from index after being updated + Test that all ProgramEnrollments are cleared from the index after the User's Profile has been deleted """ - user = UserFactory.create() - assert search()['total'] == 1 - user.profile.delete() - assert search()['total'] == 0 + with mute_signals(post_save): + profile = ProfileFactory.create() + ProgramEnrollmentFactory.create(user=profile.user) + ProgramEnrollmentFactory.create(user=profile.user) + assert es.search()['total'] == 2 + profile.delete() + assert es.search()['total'] == 0 def test_education_add(self): """ Test that Education is indexed after being added """ - user = UserFactory.create() - assert search()['total'] == 1 - EducationFactory.create(profile=user.profile) - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + EducationFactory.create(profile=program_enrollment.user.profile) + assert_search(es.search(), [program_enrollment]) def test_education_update(self): """ Test that Education is reindexed after being updated """ - user = UserFactory.create() - assert search()['total'] == 1 - education = EducationFactory.create(profile=user.profile) + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + education = EducationFactory.create(profile=program_enrollment.user.profile) education.school_city = 'city' education.save() - assert_search(search(), [user]) + assert_search(es.search(), [program_enrollment]) def test_education_delete(self): """ Test that Education is removed from index after being deleted """ - user = UserFactory.create() - education = EducationFactory.create(profile=user.profile) - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + education = EducationFactory.create(profile=program_enrollment.user.profile) + assert_search(es.search(), [program_enrollment]) education.delete() - assert_search(search(), [user]) + assert_search(es.search(), [program_enrollment]) def test_employment_add(self): """ Test that Employment is indexed after being added """ - user = UserFactory.create() - assert search()['total'] == 1 - EmploymentFactory.create(profile=user.profile) - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + EmploymentFactory.create(profile=program_enrollment.user.profile) + assert_search(es.search(), [program_enrollment]) def test_employment_update(self): """ Test that Employment is reindexed after being updated """ - user = UserFactory.create() - assert search()['total'] == 1 - employment = EmploymentFactory.create(profile=user.profile) + program_enrollment = ProgramEnrollmentFactory.create() + assert es.search()['total'] == 1 + employment = EmploymentFactory.create(profile=program_enrollment.user.profile) employment.city = 'city' employment.save() - assert_search(search(), [user]) + assert_search(es.search(), [program_enrollment]) def test_employment_delete(self): """ Test that Employment is removed from index after being deleted """ - user = UserFactory.create() - employment = EmploymentFactory.create(profile=user.profile) - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + employment = EmploymentFactory.create(profile=program_enrollment.user.profile) + assert_search(es.search(), [program_enrollment]) employment.delete() - assert_search(search(), [user]) + assert_search(es.search(), [program_enrollment]) - def test_remove_profile(self): + def test_remove_program_enrolled_user(self): """ - Test that remove_profile removes the profile from the index + Test that remove_program_enrolled_user removes the user from the index for that program """ - user = UserFactory.create() - assert_search(search(), [user]) - remove_user(user) - assert_search(search(), []) + program_enrollment = ProgramEnrollmentFactory.create() + assert_search(es.search(), [program_enrollment]) + remove_program_enrolled_user(program_enrollment) + assert_search(es.search(), []) - def test_index_users(self): + def test_index_program_enrolled_users(self): """ - Test that index_users indexes an iterable of users + Test that index_program_enrolled_users indexes an iterable of program-enrolled users """ - for _ in range(10): - with mute_signals(post_save): - # using ProfileFactory instead of UserFactory here since UserFactory will not fill in any - # fields on Profile - profile = ProfileFactory.create() - # Not strictly necessary, the muted post_save will prevent indexing - remove_user(profile.user) - - # Confirm nothing in index - assert_search(search(), []) - index_users(User.objects.iterator(), chunk_size=4) - assert_search(search(), list(User.objects.all())) - - def test_add_certificate(self): - """ - Test that Certificate is indexed after being added - """ - user = UserFactory.create() - assert search()['total'] == 1 - CachedCertificateFactory.create(user=user) - assert_search(search(), [user]) - - def test_update_certificate(self): - """ - Test that Certificate is reindexed after being updated - """ - user = UserFactory.create() - assert search()['total'] == 1 - certificate = CachedCertificateFactory.create(user=user) - certificate.data = {'new': 'data'} - certificate.save() - assert_search(search(), [user]) - - def test_delete_certificate(self): - """ - Test that Certificate is removed from index after being deleted - """ - user = UserFactory.create() - certificate = CachedCertificateFactory.create(user=user) - assert_search(search(), [user]) - certificate.delete() - assert_search(search(), [user]) + with mute_signals(post_save): + program_enrollments = [ProgramEnrollmentFactory.build() for _ in range(10)] + with patch('search.api._index_program_enrolled_users_chunk', autospec=True, return_value=0) as index_chunk: + index_program_enrolled_users(program_enrollments, chunk_size=4) + assert index_chunk.call_count == 3 + index_chunk.assert_any_call(program_enrollments[0:4]) - def test_add_enrollment(self): + def test_add_edx_record(self): """ - Test that Enrollment is indexed after being added + Test that cached edX records are indexed after being added """ - user = UserFactory.create() - CachedCertificateFactory.create(user=user) - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + for edx_cached_model_factory in [CachedCertificateFactory, CachedEnrollmentFactory]: + assert es.search()['total'] == 1 + course_run = CourseRunFactory.create(program=program_enrollment.program) + edx_cached_model_factory.create(user=program_enrollment.user, course_run=course_run) + assert_search(es.search(), [program_enrollment]) - def test_update_enrollment(self): + def test_update_edx_record(self): """ - Test that Enrollment is reindexed after being updated + Test that a cached edX record is reindexed after being updated """ - user = UserFactory.create() - assert search()['total'] == 1 - enrollment = CachedEnrollmentFactory.create(user=user) - enrollment.data = {'new': 'data'} - enrollment.save() - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + for edx_cached_model_factory in [CachedCertificateFactory, CachedEnrollmentFactory]: + assert es.search()['total'] == 1 + course_run = CourseRunFactory.create(program=program_enrollment.program) + edx_record = edx_cached_model_factory.create(user=program_enrollment.user, course_run=course_run) + edx_record.data = {'new': 'data'} + edx_record.save() + assert_search(es.search(), [program_enrollment]) - def test_delete_enrollment(self): + def test_delete_edx_record(self): """ - Test that Enrollment is removed from index after being deleted + Test that a cached edX record is removed from index after being deleted """ - user = UserFactory.create() - enrollment = CachedEnrollmentFactory.create(user=user) - assert_search(search(), [user]) - enrollment.delete() - assert_search(search(), [user]) + program_enrollment = ProgramEnrollmentFactory.create() + for edx_cached_model_factory in [CachedCertificateFactory, CachedEnrollmentFactory]: + course_run = CourseRunFactory.create(program=program_enrollment.program) + edx_record = edx_cached_model_factory.create(user=program_enrollment.user, course_run=course_run) + assert_search(es.search(), [program_enrollment]) + edx_record.delete() + assert_search(es.search(), [program_enrollment]) def test_not_analyzed(self): """ At the moment no string fields in the mapping should be 'analyzed' since there's no field supporting full text search. """ - with mute_signals(post_save): - profile = ProfileFactory.create() - EducationFactory.create(profile=profile) - EmploymentFactory.create(profile=profile) - CachedCertificateFactory.create(user=profile.user) - CachedEnrollmentFactory.create(user=profile.user) + program_enrollment = ProgramEnrollmentFactory.create() + EducationFactory.create(profile=program_enrollment.user.profile) + EmploymentFactory.create(profile=program_enrollment.user.profile) - mapping = get_es_mappings() + mapping = es.get_mappings() nodes = list(traverse_mapping(mapping)) for node in nodes: if node.get('type') == 'string': @@ -294,23 +269,28 @@ def test_not_analyzed(self): class SerializerTests(ESTestCase): """ - Tests for profile serializer + Tests for document serializers """ - def test_profile_serializer(self): # pylint: disable=no-self-use + def test_program_enrolled_user_serializer(self): # pylint: disable=no-self-use """ - Asserts the output of the profile serializer + Asserts the output of the serializer for program-enrolled users (ProgramEnrollments) """ with mute_signals(post_save): profile = ProfileFactory.create() EducationFactory.create(profile=profile) EmploymentFactory.create(profile=profile) - certificate = CachedCertificateFactory.create(user=profile.user) - enrollment = CachedEnrollmentFactory.create(user=profile.user) - - assert serialize_user(profile.user) == { - '_id': profile.user.id, - 'id': profile.user.id, + program = ProgramFactory.create() + course = CourseFactory.create(program=program) + course_run = CourseRunFactory.create(course=course) + certificate = CachedCertificateFactory.create(user=profile.user, course_run=course_run) + enrollment = CachedEnrollmentFactory.create(user=profile.user, course_run=course_run) + program_enrollment = ProgramEnrollment.objects.get(user=profile.user, program=program) + + assert serialize_program_enrolled_user(program_enrollment) == { + '_id': program_enrollment.id, + 'id': program_enrollment.id, + 'user_id': profile.user.id, 'profile': { 'username': get_social_username(profile.user), 'first_name': profile.first_name, @@ -344,8 +324,11 @@ def test_profile_serializer(self): # pylint: disable=no-self-use profile.work_history.all() ] }, - 'certificates': [certificate.data], - 'enrollments': [enrollment.data], + 'program': { + 'id': program.id, + 'certificates': [certificate.data], + 'enrollments': [enrollment.data] + } } @@ -385,4 +368,39 @@ def test_no_mapping(self): with self.assertRaises(ReindexException) as ex: get_conn() - assert str(ex.exception) == "Mapping user not found" + assert str(ex.exception) == "Mapping {} not found".format(USER_DOC_TYPE) + + +class RecreateIndexTests(ESTestCase): + """ + Tests for management commands + """ + def setUp(self): + """ + Start without any index + """ + super(RecreateIndexTests, self).setUp() + conn = get_conn(verify=False) + index_name = settings.ELASTICSEARCH_INDEX + if conn.indices.exists(index_name): + conn.indices.delete(index_name) + + def test_create_index(self): # pylint: disable=no-self-use + """ + Test that recreate_index will create an index and let search successfully + """ + recreate_index() + assert es.search()['total'] == 0 + + def test_update_index(self): # pylint: disable=no-self-use + """ + Test that recreate_index will clear old data and index all profiles + """ + recreate_index() + program_enrollment = ProgramEnrollmentFactory.create() + assert_search(es.search(), [program_enrollment]) + remove_program_enrolled_user(program_enrollment) + assert_search(es.search(), []) + # recreate_index should index the program-enrolled user + recreate_index() + assert_search(es.search(), [program_enrollment]) diff --git a/search/commands_test.py b/search/commands_test.py deleted file mode 100644 index bba5ce7428..0000000000 --- a/search/commands_test.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Test search management commands -""" - -from django.conf import settings - -from profiles.factories import UserFactory -from search.api import ( - get_conn, - recreate_index, - remove_user, -) -from search.api_test import ( - assert_search, - search, -) -from search.base import ESTestCase - - -class RecreateIndexTests(ESTestCase): - """ - Tests for management commands - """ - def setUp(self): - """ - Start without any index - """ - super(RecreateIndexTests, self).setUp() - conn = get_conn(verify=False) - index_name = settings.ELASTICSEARCH_INDEX - if conn.indices.exists(index_name): - conn.indices.delete(index_name) - - def test_create_index(self): # pylint: disable=no-self-use - """ - Test that recreate_index will create an index and let search successfully - """ - recreate_index() - assert search()['total'] == 0 - - def test_update_index(self): # pylint: disable=no-self-use - """ - Test that recreate_index will clear old data and index all profiles - """ - recreate_index() - user = UserFactory.create() - assert_search(search(), [user]) - remove_user(user) - # No profiles in Elasticsearch - assert_search(search(), []) - - # recreate_index will index the profile - recreate_index() - assert_search(search(), [user]) diff --git a/search/management/commands/clear_index.py b/search/management/commands/clear_index.py new file mode 100644 index 0000000000..3ed74912d7 --- /dev/null +++ b/search/management/commands/clear_index.py @@ -0,0 +1,22 @@ +""" +Management command to clear the Elasticsearch index +""" + +from django.core.management.base import BaseCommand + +from search.api import ( + clear_index, +) + + +class Command(BaseCommand): + """ + Command for clear_index + """ + help = "Clears existing Elasticsearch indices." + + def handle(self, *args, **kwargs): # pylint: disable=unused-argument + """ + Clear the index + """ + clear_index() diff --git a/search/signals.py b/search/signals.py index 088c2fbf6f..b1a7908c84 100644 --- a/search/signals.py +++ b/search/signals.py @@ -12,10 +12,6 @@ Employment, Profile, ) -from dashboard.models import ( - CachedCertificate, - CachedEnrollment, -) from search.tasks import index_users, remove_user log = logging.getLogger(__name__) @@ -26,17 +22,6 @@ # because each signal handler needs to be hooked to a single sender # otherwise it would run for any `post_save`/`post_delete` coming from any model -@receiver(post_save, sender=CachedCertificate, dispatch_uid="cachedcertificate_post_save_index") -def handle_update_certificate(sender, instance, **kwargs): # pylint: disable=unused-argument - """Update index when a CachedCertificate model is updated.""" - index_users.delay([instance.user]) - - -@receiver(post_save, sender=CachedEnrollment, dispatch_uid="cachedenrollment_post_save_index") -def handle_update_enrollment(sender, instance, **kwargs): # pylint: disable=unused-argument - """Update index when CachedEnrollment model is updated.""" - index_users.delay([instance.user]) - @receiver(post_save, sender=Profile, dispatch_uid="profile_post_save_index") def handle_update_profile(sender, instance, **kwargs): # pylint: disable=unused-argument @@ -72,15 +57,3 @@ def handle_delete_education(sender, instance, **kwargs): # pylint: disable=unus def handle_delete_employment(sender, instance, **kwargs): # pylint: disable=unused-argument """Update index when Employment model instance is deleted.""" index_users.delay([instance.profile.user]) - - -@receiver(post_delete, sender=CachedCertificate, dispatch_uid="cachedcertificate_post_delete_index") -def handle_delete_certificate(sender, instance, **kwargs): # pylint: disable=unused-argument - """Update index when CachedCertificate model instance is deleted.""" - index_users.delay([instance.user]) - - -@receiver(post_delete, sender=CachedEnrollment, dispatch_uid="cachedenrollment_post_delete_index") -def handle_delete_enrollment(sender, instance, **kwargs): # pylint: disable=unused-argument - """Update index when CachedEnrollment model instance is deleted.""" - index_users.delay([instance.user]) diff --git a/search/tasks.py b/search/tasks.py index 08c4cc793b..942e9893a9 100644 --- a/search/tasks.py +++ b/search/tasks.py @@ -4,19 +4,42 @@ from micromasters.celery import async from search.api import ( + index_program_enrolled_users as _index_program_enrolled_users, + remove_program_enrolled_user as _remove_program_enrolled_user, index_users as _index_users, remove_user as _remove_user ) @async.task -def index_users(users): +def remove_program_enrolled_user(user): + """ + Remove program-enrolled user from index + + Args: + user (User): A program-enrolled user to remove from index + """ + _remove_program_enrolled_user(user) + + +@async.task +def index_program_enrolled_users(program_enrollments): """ Index profiles Args: - users (iterable of User): - Iterable of Users + program_enrollments (iterable of ProgramEnrollments): Program-enrolled users to remove from index + """ + _index_program_enrolled_users(program_enrollments) + + +@async.task +def index_users(users): + """ + Index users + + Args: + users (iterable of Users): Users to remove from index """ _index_users(users) @@ -24,10 +47,9 @@ def index_users(users): @async.task def remove_user(user): """ - Remove profile from index + Remove user from index Args: - user (User): - A user to remove from index + user (User): A user to remove from index """ _remove_user(user)