Skip to content

Commit

Permalink
Fix detection of non-atomic migrations
Browse files Browse the repository at this point in the history
Operations that cannot be run in a migration can also be defined in a
SeparateDatabaseAndState operation, or be written as raw SQL. This makes
an attempt at detecting those cases as well.
  • Loading branch information
ljodal committed Jun 29, 2023
1 parent f8f42f8 commit 3ca054b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 8 deletions.
66 changes: 58 additions & 8 deletions migration_checker/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
Helper to execute migrations and record results
"""

from typing import Any, Callable, Union, cast
from typing import Any, Callable, Sequence, Union, cast

import django
import sqlparse
from django.contrib.postgres.operations import NotInTransactionMixin
from django.db import connections, transaction
from django.db.migrations import Migration
from django.db.migrations import Migration, RunSQL, SeparateDatabaseAndState
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.operations.base import Operation
from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.state import ProjectState

Expand Down Expand Up @@ -43,7 +45,7 @@ def __init__(
*,
database: str,
apply_migrations: bool,
outputs: list[Union[ConsoleOutput, GithubCommentOutput]]
outputs: list[Union[ConsoleOutput, GithubCommentOutput]],
) -> None:
self.database = database
self.apply_migrations = apply_migrations
Expand Down Expand Up @@ -118,11 +120,7 @@ def _apply_migration(
# Some operations, like AddIndexConcurrently, cannot be run in a
# transaction, so for those special cases we skip recording locks
# because we ahave no way of doing that.
must_be_non_atomic = any(
isinstance(operation, NotInTransactionMixin)
for operation in migration.operations
)
if must_be_non_atomic:
if self._must_be_non_atomic(migration.operations):
return self._apply_non_atomic_migration(migration, state), None

# Apply the migration in the database and record queries and locks
Expand All @@ -137,6 +135,58 @@ def _apply_migration(

return query_logger.queries, locks

def _must_be_non_atomic_query(self, query: str) -> bool:
"""
Try to detect if a raw query must be non-atomic.
"""

patterns = [
[
(sqlparse.tokens.DDL, "CREATE"),
(sqlparse.tokens.Keyword, "INDEX"),
(sqlparse.tokens.Keyword, "CONCURRENTLY"),
],
[
(sqlparse.tokens.DDL, "DROP"),
(sqlparse.tokens.Keyword, "INDEX"),
(sqlparse.tokens.Keyword, "CONCURRENTLY"),
],
]

for statement in sqlparse.parse(query):
for pattern in patterns:
if all(
any(token.match(ttype, value) for token in statement.tokens)
for ttype, value in pattern
):
return True
return False

def _must_be_non_atomic(self, operations: Sequence[Operation]) -> bool:
"""
Check if any of the operations must be run outside of a transaction.
This is the case for some operations, like AddIndexConcurrently. This
will recursivey check SeparateDatabaseAndState migrations.
"""

for operation in operations:
if isinstance(operation, NotInTransactionMixin):
return True
if isinstance(operation, SeparateDatabaseAndState):
return self._must_be_non_atomic(operation.database_operations)
if isinstance(operation, RunSQL):
if isinstance(operation.sql, str):
return self._must_be_non_atomic_query(operation.sql)
else:
return any(
self._must_be_non_atomic_query(statement)
if isinstance(statement, str)
else self._must_be_non_atomic_query(statement[0])
for statement in operation.sql
)

return False

def _apply_non_atomic_migration(
self, migration: Migration, state: ProjectState
) -> list[str]:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from unittest.mock import Mock

import pytest
from django.contrib.postgres.operations import (
AddIndexConcurrently,
RemoveIndexConcurrently,
)
from django.db.migrations import AddIndex, RunSQL, SeparateDatabaseAndState
from django.db.migrations.operations.base import Operation

from migration_checker.executor import Executor
from migration_checker.output import ConsoleOutput

Expand All @@ -7,3 +17,50 @@ def test_executor(setup_db: None) -> None:
database="default", apply_migrations=True, outputs=[ConsoleOutput()]
)
executor.run()


@pytest.mark.parametrize(
"operation,must_be_non_atomic",
[
(AddIndex("test", index=Mock()), False),
(AddIndexConcurrently("test", index=Mock()), True),
(RemoveIndexConcurrently("foo", "test"), True),
(RunSQL("CREATE INDEX foobar"), False),
(RunSQL("CREATE INDEX foobar CONCURRENTLY", RunSQL.noop), True),
(RunSQL("DROP INDEX foobar CONCURRENTLY", RunSQL.noop), True),
(RunSQL([("CREATE INDEX foobar", None)]), False),
(RunSQL([("CREATE INDEX foobar CONCURRENTLY", None)]), True),
(
SeparateDatabaseAndState(
database_operations=[RunSQL("CREATE INDEX foobar")]
),
False,
),
(
SeparateDatabaseAndState(
database_operations=[RunSQL("CREATE INDEX foobar CONCURRENTLY")]
),
True,
),
(
SeparateDatabaseAndState(
database_operations=[AddIndex("test", index=Mock())]
),
False,
),
(
SeparateDatabaseAndState(
database_operations=[AddIndexConcurrently("test", index=Mock())]
),
True,
),
],
)
def test_run_sql_must_be_non_atomic(
operation: Operation, must_be_non_atomic: bool
) -> None:
executor = Executor(
database="default", apply_migrations=True, outputs=[ConsoleOutput()]
)

assert executor._must_be_non_atomic([operation]) is must_be_non_atomic

0 comments on commit 3ca054b

Please sign in to comment.