Skip to content

Commit

Permalink
Add project state back to check interface
Browse files Browse the repository at this point in the history
Need this to detect changes to a field in AlterField statements
  • Loading branch information
ljodal committed Feb 10, 2023
1 parent cbf4adf commit f8f42f8
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 19 deletions.
55 changes: 41 additions & 14 deletions migration_checker/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from django.db.migrations.operations.fields import FieldOperation
from django.db.migrations.operations.models import ModelOperation
from django.db.migrations.state import ProjectState

from .warnings import (
ADD_INDEX_IN_SEPARATE_MIGRATION,
Expand All @@ -35,11 +36,13 @@


class Check(Protocol):
def __call__(self, *, migration: Migration) -> Iterable[Warning]:
def __call__(
self, *, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
...


def check_add_index(*, migration: Migration) -> Iterable[Warning]:
def check_add_index(*, migration: Migration, state: ProjectState) -> Iterable[Warning]:
if any(
isinstance(operation, AddIndex)
and not isinstance(operation, AddIndexConcurrently)
Expand All @@ -52,7 +55,9 @@ def check_add_index(*, migration: Migration) -> Iterable[Warning]:
yield ADD_INDEX_IN_SEPARATE_MIGRATION


def check_add_non_nullable_field(*, migration: Migration) -> Iterable[Warning]:
def check_add_non_nullable_field(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(
isinstance(operation, AddField) and not operation.field.null
for operation in migration.operations
Expand All @@ -61,7 +66,9 @@ def check_add_non_nullable_field(*, migration: Migration) -> Iterable[Warning]:
yield ADDING_NON_NULLABLE_FIELD


def check_alter_multiple_tables(*, migration: Migration) -> Iterable[Warning]:
def check_alter_multiple_tables(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
altered_models = set()

if migration.atomic and not migration.initial:
Expand All @@ -76,14 +83,18 @@ def check_alter_multiple_tables(*, migration: Migration) -> Iterable[Warning]:
yield ALTERING_MULTIPLE_MODELS


def check_atomic_run_python(*, migration: Migration) -> Iterable[Warning]:
def check_atomic_run_python(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if migration.atomic and any(
isinstance(operation, RunPython) for operation in migration.operations
):
yield ATOMIC_DATA_MIGRATION


def check_data_and_schema_changes(*, migration: Migration) -> Iterable[Warning]:
def check_data_and_schema_changes(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
data_migration, schema_migration = False, False
for operation in migration.operations:
if isinstance(operation, (RunPython, RunSQL)):
Expand All @@ -95,22 +106,30 @@ def check_data_and_schema_changes(*, migration: Migration) -> Iterable[Warning]:
yield SCHEMA_AND_DATA_CHANGES


def check_rename_model(*, migration: Migration) -> Iterable[Warning]:
def check_rename_model(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(isinstance(operation, RenameModel) for operation in migration.operations):
yield RENAMING_MODEL


def check_rename_field(*, migration: Migration) -> Iterable[Warning]:
def check_rename_field(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(isinstance(operation, RenameField) for operation in migration.operations):
yield RENAMING_FIELD


def check_remove_field(*, migration: Migration) -> Iterable[Warning]:
def check_remove_field(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(isinstance(operation, RemoveField) for operation in migration.operations):
yield REMOVING_FIELD


def check_field_with_check_constraint(*, migration: Migration) -> Iterable[Warning]:
def check_field_with_check_constraint(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(
connection.data_type_check_constraints.get(operation.field.get_internal_type())
is not None
Expand All @@ -120,12 +139,16 @@ def check_field_with_check_constraint(*, migration: Migration) -> Iterable[Warni
yield ADDING_FIELD_WITH_CHECK


def check_add_constraint(*, migration: Migration) -> Iterable[Warning]:
def check_add_constraint(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
if any(isinstance(operation, AddConstraint) for operation in migration.operations):
yield ADDING_CONSTRAINT


def check_validate_constraint(*, migration: Migration) -> Iterable[Warning]:
def check_validate_constraint(
*, migration: Migration, state: ProjectState
) -> Iterable[Warning]:
# This feature is only available in Django >= 4.0
if django.VERSION < (4, 0):
return
Expand Down Expand Up @@ -158,5 +181,9 @@ def check_validate_constraint(*, migration: Migration) -> Iterable[Warning]:
]


def run_checks(migration: Migration) -> list[Warning]:
return [warning for check in ALL_CHECKS for warning in check(migration=migration)]
def run_checks(migration: Migration, state: ProjectState) -> list[Warning]:
return [
warning
for check in ALL_CHECKS
for warning in check(migration=migration, state=state)
]
2 changes: 1 addition & 1 deletion migration_checker/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(self) -> None:

for migration, _ in plan:
# Run checkers on the migration
warnings = run_checks(migration)
warnings = run_checks(migration, state)

if self.apply_migrations:
queries, locks = self._apply_migration(migration, state)
Expand Down
24 changes: 20 additions & 4 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
RunSQL,
)
from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ProjectState
from django.db.models import (
AutoField,
CheckConstraint,
Index,
IntegerField,
Expand All @@ -36,11 +38,25 @@
)


def check_migration(*_operations: Operation) -> set[Warning]:
class Migration(migrations.Migration):
operations = list(_operations)
def check_migration(*test_operations: Operation) -> set[Warning]:
class InitialMigration(migrations.Migration):
initial = True
operations = [
migrations.CreateModel(
name="Foo",
fields=[("id", AutoField())],
),
]

return set(run_checks(migration=Migration(name="0001_foo", app_label="foo")))
class TestMigration(migrations.Migration):
operations = list(test_operations)

state = ProjectState()
state = InitialMigration(name="0001_foo", app_label="foo").mutate_state(state)

migration = TestMigration(name="0002_foo", app_label="foo")

return set(run_checks(migration, state))


def test_add_index() -> None:
Expand Down

0 comments on commit f8f42f8

Please sign in to comment.