Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(migrations): automatically add FULLTEXT index to Problem model when running MYSQL-like db #2353

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: build
on: [push, pull_request]
on: [push, pull_request, workflow_dispatch]
jobs:
lint:
runs-on: ubuntu-latest
Expand Down
48 changes: 48 additions & 0 deletions judge/migrations/0148_judge_add_fulltext_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Generated by Django 3.2.25 on 2024-10-01 06:08

from django.db import migrations


def execute_mysql_command(apps, schema_editor, sql, error_msg):
if schema_editor.connection.vendor != 'mysql':
return

Problem = apps.get_model('judge', 'Problem')
formatted_sql = sql.format(Problem._meta.db_table)

with schema_editor.connection.cursor() as cursor:
try:
cursor.execute(formatted_sql)
except Exception as e:
if error_msg in str(e):
print(f'Info: {error_msg}')
else:
raise


def add_fulltext_index(apps, schema_editor):
execute_mysql_command(
apps,
schema_editor,
'ALTER TABLE {} ADD FULLTEXT(code, name, description)',
'Duplicate key name',
)


def remove_fulltext_index(apps, schema_editor):
execute_mysql_command(
apps,
schema_editor,
'ALTER TABLE {} DROP INDEX code',
'check that column/key exists',
)


class Migration(migrations.Migration):
dependencies = [
('judge', '0147_judge_add_tiers'),
]

operations = [
migrations.RunPython(add_fulltext_index, remove_fulltext_index),
]
159 changes: 131 additions & 28 deletions judge/models/tests/test_problem.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,43 @@
from unittest import skipIf

from django.core.exceptions import ValidationError
from django.db import connection
from django.db.models import F
from django.test import SimpleTestCase, TestCase
from django.utils import timezone

from judge.models import Language, LanguageLimit, Problem, Submission
from judge.models.problem import VotePermission, disallowed_characters_validator
from judge.models.tests.util import CommonDataMixin, create_contest, create_contest_participation, \
create_organization, create_problem, create_problem_type, create_solution, create_user
from judge.models.tests.util import (
CommonDataMixin,
create_contest,
create_contest_participation,
create_organization,
create_problem,
create_problem_type,
create_solution,
create_user,
)


class ProblemTestCase(CommonDataMixin, TestCase):
@classmethod
def setUpTestData(self):
def setUpTestData(cls):
super().setUpTestData()

self.users.update({
'staff_problem_edit_only_all': create_user(
username='staff_problem_edit_only_all',
is_staff=True,
user_permissions=('edit_all_problem',),
),
})
cls.users.update(
{
'staff_problem_edit_only_all': create_user(
username='staff_problem_edit_only_all',
is_staff=True,
user_permissions=('edit_all_problem',),
),
},
)

create_problem_type(name='type')

self.basic_problem = create_problem(
cls.basic_problem = create_problem(
code='basic',
allowed_languages=Language.objects.values_list('key', flat=True),
types=('type',),
Expand All @@ -35,32 +49,32 @@ def setUpTestData(self):
for lang in Language.objects.filter(common_name=Language.get_python3().common_name):
limits.append(
LanguageLimit(
problem=self.basic_problem,
problem=cls.basic_problem,
language=lang,
time_limit=100,
memory_limit=131072,
),
)
LanguageLimit.objects.bulk_create(limits)

self.organization_private_problem = create_problem(
cls.organization_private_problem = create_problem(
code='organization_private',
time_limit=2,
is_public=True,
is_organization_private=True,
curators=('staff_problem_edit_own', 'staff_problem_edit_own_no_staff'),
)

self.problem_organization = create_organization(
cls.problem_organization = create_organization(
name='problem organization',
admins=('normal', 'staff_problem_edit_public'),
)
self.organization_admin_private_problem = create_problem(
cls.organization_admin_private_problem = create_problem(
code='org_admin_private',
is_organization_private=True,
organizations=('problem organization',),
)
self.organization_admin_problem = create_problem(
cls.organization_admin_problem = create_problem(
code='organization_admin',
organizations=('problem organization',),
)
Expand All @@ -79,7 +93,10 @@ def test_basic_problem(self):

self.assertListEqual(list(self.basic_problem.author_ids), [self.users['normal'].profile.id])
self.assertListEqual(list(self.basic_problem.editor_ids), [self.users['normal'].profile.id])
self.assertListEqual(list(self.basic_problem.tester_ids), [self.users['staff_problem_edit_public'].profile.id])
self.assertListEqual(
list(self.basic_problem.tester_ids),
[self.users['staff_problem_edit_public'].profile.id],
)
self.assertListEqual(list(self.basic_problem.usable_languages), [])
self.assertListEqual(self.basic_problem.types_list, ['type'])
self.assertSetEqual(self.basic_problem.usable_common_names, set())
Expand Down Expand Up @@ -255,7 +272,10 @@ def give_basic_problem_ac(self, user, points=None):
)

def test_problem_voting_permissions(self):
self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['anonymous']), VotePermission.NONE)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['anonymous']),
VotePermission.NONE,
)

now = timezone.now()
basic_contest = create_contest(
Expand All @@ -281,17 +301,29 @@ def test_problem_voting_permissions(self):
banned_from_voting = create_user(username='banned_from_voting')
banned_from_voting.profile.is_banned_from_problem_voting = True
self.give_basic_problem_ac(banned_from_voting)
self.assertEqual(self.basic_problem.vote_permission_for_user(banned_from_voting), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(banned_from_voting),
VotePermission.VIEW,
)

banned_from_problem = create_user(username='banned_from_problem')
self.basic_problem.banned_users.add(banned_from_problem.profile)
self.give_basic_problem_ac(banned_from_problem)
self.assertEqual(self.basic_problem.vote_permission_for_user(banned_from_problem), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(banned_from_problem),
VotePermission.VIEW,
)

self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['normal']), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['normal']),
VotePermission.VIEW,
)

self.give_basic_problem_ac(self.users['normal'])
self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['normal']), VotePermission.VOTE)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['normal']),
VotePermission.VOTE,
)

partial_ac = create_user(username='partial_ac')
self.give_basic_problem_ac(partial_ac, 0.5) # ensure this value is not equal to its point value
Expand Down Expand Up @@ -330,12 +362,14 @@ class SolutionTestCase(CommonDataMixin, TestCase):
@classmethod
def setUpTestData(self):
super().setUpTestData()
self.users.update({
'staff_solution_see_all': create_user(
username='staff_solution_see_all',
user_permissions=('see_private_solution',),
),
})
self.users.update(
{
'staff_solution_see_all': create_user(
username='staff_solution_see_all',
user_permissions=('see_private_solution',),
),
},
)

now = timezone.now()

Expand Down Expand Up @@ -448,3 +482,72 @@ def test_invalid(self):
disallowed_characters_validator('“')
with self.assertRaisesRegex(ValidationError, 'Disallowed characters: (?=.*‘)(?=.*’)'):
disallowed_characters_validator('‘’')


@skipIf(connection.vendor != 'mysql', 'FULLTEXT search is only supported on MySQL')
class FullTextSearchTestCase(CommonDataMixin, TestCase):
def setUpTestData(self):
super().setUpTestData()

languages = [
('P1', 'Django Test', 'A test problem for Django'),
('P2', 'Python Challenge', 'A challenging Python problem'),
('P3', 'Database Query', 'A problem about SQL and databases'),
]

for code, name, description in languages:
create_problem_type(
name=name,
code=code,
description=description,
allowed_languages=Language.objects.values_list('key', flat=True),
types=('type',),
authors=('normal',),
testers=('staff_problem_edit_public',),
)


def test_fulltext_search_name(self):
results = Problem.objects.filter(name__search='Python')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P2')


def test_fulltext_search_description(self):
results = Problem.objects.filter(description__search='database')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P3')


def test_fulltext_search_multiple_columns(self):
results = Problem.objects.filter(name__search='test') | Problem.objects.filter(description__search='test')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P1')


def test_fulltext_search_ranking(self):
Problem.objects.create(code='P4', name='Advanced Python', description='Python for advanced users')
Problem.objects.create(code='P5', name='Python Basics', description='Introduction to Python programming')

results = Problem.objects.filter(name__search='Python') | Problem.objects.filter(description__search='Python')
results = results.annotate(relevance=F('name__search') + F('description__search')).order_by('-relevance')

self.assertTrue(len(results) > 1)
self.assertEqual(results[0].code, 'P2')


def test_fulltext_search_boolean_mode(self):
results = Problem.objects.filter(description__search='+SQL -Python')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P3')


def test_fulltext_search_no_results(self):
results = Problem.objects.filter(name__search='NonexistentTerm')
self.assertEqual(results.count(), 0)


@classmethod
def tearDownClass(cls):
Problem.objects.all().delete()
super().tearDownClass()