Skip to content

Commit

Permalink
Added support for meta indexes and constraints in sqldiff. (#1726)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamkush authored Sep 9, 2022
1 parent 25b9fd5 commit 9f55f81
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
34 changes: 26 additions & 8 deletions django_extensions/management/commands/sqldiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from django.core.management.base import OutputWrapper
from django.core.management.color import no_style
from django.db import connection, transaction, models
from django.db.models import UniqueConstraint
from django.db.models.fields import AutoField, IntegerField
from django.db.models.options import normalize_together

Expand Down Expand Up @@ -381,6 +382,23 @@ def strip_parameters(self, field_type):
return field_type.split(" ")[0].split("(")[0].lower()
return field_type

def get_index_together(self, meta):
indexes_normalized = list(normalize_together(meta.index_together))

for idx in meta.indexes:
indexes_normalized.append(idx.fields)

return self.expand_together(indexes_normalized, meta)

def get_unique_together(self, meta):
unique_normalized = list(normalize_together(meta.unique_together))

for constraint in meta.constraints:
if isinstance(constraint, UniqueConstraint):
unique_normalized.append(constraint.fields)

return self.expand_together(unique_normalized, meta)

def expand_together(self, together, meta):
new_together = []
for fields in normalize_together(together):
Expand Down Expand Up @@ -411,7 +429,7 @@ def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, tabl
if db_type.startswith('text'):
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')

unique_together = self.expand_together(meta.unique_together, meta)
unique_together = self.get_unique_together(meta)
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique'] and not v['index']])

for unique_columns in unique_together:
Expand All @@ -427,7 +445,7 @@ def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, tabl

def find_unique_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
fields = dict([(field.column, field) for field in all_local_fields(meta)])
unique_together = self.expand_together(meta.unique_together, meta)
unique_together = self.get_unique_together(meta)

for constraint_name, constraint in table_constraints.items():
if not constraint['unique']:
Expand Down Expand Up @@ -463,7 +481,7 @@ def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table
if db_type.startswith('text'):
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')

index_together = self.expand_together(meta.index_together, meta)
index_together = self.get_index_together(meta)
db_index_together = normalize_together([v['columns'] for v in table_constraints.values() if v['index'] and not v['unique']])
for columns in index_together:
if columns in db_index_together:
Expand All @@ -478,7 +496,7 @@ def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table
def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
fields = dict([(field.column, field) for field in all_local_fields(meta)])
meta_index_names = [idx.name for idx in meta.indexes]
index_together = self.expand_together(meta.index_together, meta)
index_together = self.get_index_together(meta)

for constraint_name, constraint in table_constraints.items():
if constraint_name in meta_index_names:
Expand Down Expand Up @@ -838,8 +856,8 @@ def get_field_db_type(self, description, field=None, table_name=None):
def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
fields = dict([(field.column, field) for field in all_local_fields(meta)])
meta_index_names = [idx.name for idx in meta.indexes]
index_together = self.expand_together(meta.index_together, meta)
unique_together = self.expand_together(meta.unique_together, meta)
index_together = self.get_index_together(meta)
unique_together = self.get_unique_together(meta)

for constraint_name, constraint in table_constraints.items():
if constraint_name in meta_index_names:
Expand Down Expand Up @@ -904,7 +922,7 @@ def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, tabl
if db_type.startswith('text'):
self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')

unique_together = self.expand_together(meta.unique_together, meta)
unique_together = self.get_unique_together(meta)

# This comparison changed from superclass - otherwise function is the same
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique']])
Expand Down Expand Up @@ -953,7 +971,7 @@ def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, tabl
if column in unique_columns and (constraint['unique'] or constraint['primary_key']):
skip_list.append(column)

unique_together = self.expand_together(meta.unique_together, meta)
unique_together = self.get_unique_together(meta)
db_unique_columns = normalize_together([v['columns'] for v in table_constraints.values() if v['unique']])

for unique_columns in unique_together:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_sqldiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

# from django.core.management import call_command
from django_extensions.management.commands.sqldiff import SqliteSQLDiff, Command, MySQLDiff, PostgresqlSQLDiff
from tests.testapp.models import PostWithUniqField, SluggedWithUniqueTogetherTestModel, \
RandomCharTestModelUniqueTogether, SqlDiffUniqueTogether, SqlDiff, SqlDiffIndexes


class SqlDiffTests(TestCase):
Expand Down Expand Up @@ -53,6 +55,29 @@ def test_format_field_names(self):
expected_field_name = ['name', 'email', 'address']
self.assertEqual(instance.format_field_names(['Name', 'EMAIL', 'aDDress']), expected_field_name)

def test_get_index_together(self):
instance = MySQLDiff(
apps.get_models(include_auto_created=True),
vars(self.options),
stdout=self.tmp_out,
stderr=self.tmp_err,
)
self.assertEqual(instance.get_index_together(SqlDiff._meta), [('number', 'creator')])
self.assertEqual(instance.get_index_together(SqlDiffIndexes._meta), [('first', 'second')])

def test_get_unique_together(self):
instance = MySQLDiff(
apps.get_models(include_auto_created=True),
vars(self.options),
stdout=self.tmp_out,
stderr=self.tmp_err,
)
self.assertEqual(instance.get_unique_together(SluggedWithUniqueTogetherTestModel._meta), [('slug', 'category')])
self.assertEqual(instance.get_unique_together(RandomCharTestModelUniqueTogether._meta),
[('random_char_field', 'common_field')])
self.assertEqual(instance.get_unique_together(SqlDiffUniqueTogether._meta), [('aaa', 'bbb')])
self.assertEqual(instance.get_unique_together(PostWithUniqField._meta), [('common_field', 'uniq_field')])

@pytest.mark.skipif(settings.DATABASES['default']['ENGINE'] != 'django.db.backends.mysql', reason="Test can only run on mysql")
def test_mysql_to_dict(self):
mysql_instance = MySQLDiff(
Expand Down
13 changes: 13 additions & 0 deletions tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,19 @@ class SqlDiff(models.Model):
number = models.CharField(max_length=40, null=True, verbose_name='Chargennummer')
creator = models.CharField(max_length=20, null=True, blank=True)

class Meta:
index_together = ['number', 'creator']


class SqlDiffIndexes(models.Model):
first = models.CharField(max_length=40, null=True, verbose_name='Chargennummer')
second = models.CharField(max_length=20, null=True, blank=True)

class Meta:
indexes = [
models.Index(fields=['first', 'second']),
]


class SqlDiffUniqueTogether(models.Model):
aaa = models.CharField(max_length=20)
Expand Down

0 comments on commit 9f55f81

Please sign in to comment.