diff --git a/docs/developer_portal/extensions/tasks.md b/docs/developer_portal/extensions/tasks.md index 25b6569b4ad1..c833cd93680c 100644 --- a/docs/developer_portal/extensions/tasks.md +++ b/docs/developer_portal/extensions/tasks.md @@ -50,7 +50,7 @@ When GTF is considered stable, it will replace legacy Celery tasks for built-in ### Define a Task ```python -from superset_core.api.types import task, get_context +from superset_core.api.tasks import task, get_context @task def process_data(dataset_id: int) -> None: @@ -245,7 +245,7 @@ Always implement an abort handler for long-running tasks. This allows users to c Set a timeout to automatically abort tasks that run too long: ```python -from superset_core.api.types import task, get_context, TaskOptions +from superset_core.api.tasks import task, get_context, TaskOptions # Set default timeout in decorator @task(timeout=300) # 5 minutes @@ -299,7 +299,7 @@ Timeouts require an abort handler to be effective. Without one, the timeout trig Use `task_key` to prevent duplicate task execution: ```python -from superset_core.api.types import TaskOptions +from superset_core.api.tasks import TaskOptions # Without key - creates new task each time (random UUID) task1 = my_task.schedule(x=1) @@ -331,7 +331,7 @@ print(task2.status) # "success" (terminal status) ## Task Scopes ```python -from superset_core.api.types import task, TaskScope +from superset_core.api.tasks import task, TaskScope @task # Private by default def private_task(): ... diff --git a/superset-core/src/superset_core/api/tasks.py b/superset-core/src/superset_core/api/tasks.py index 1adcd9ab3275..cc00689d622a 100644 --- a/superset-core/src/superset_core/api/tasks.py +++ b/superset-core/src/superset_core/api/tasks.py @@ -259,7 +259,7 @@ def task( is discarded; only side effects and context updates matter. Example: - from superset_core.api.types import task, get_context, TaskScope + from superset_core.api.tasks import task, get_context, TaskScope # Private task (default scope) @task diff --git a/superset/daos/tasks.py b/superset/daos/tasks.py index 8253cf6d579f..c8b92cd948a9 100644 --- a/superset/daos/tasks.py +++ b/superset/daos/tasks.py @@ -28,9 +28,9 @@ from superset.extensions import db from superset.models.task_subscribers import TaskSubscriber from superset.models.tasks import Task -from superset.tasks.constants import ABORTABLE_STATES +from superset.tasks.constants import ABORTABLE_STATES, TERMINAL_STATES from superset.tasks.filters import TaskFilter -from superset.tasks.utils import get_active_dedup_key, json +from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key, json logger = logging.getLogger(__name__) @@ -243,7 +243,7 @@ def abort_task(cls, task_uuid: UUID, skip_base_filter: bool = False) -> Task | N ) # Transition to ABORTING (not ABORTED yet) - task.status = TaskStatus.ABORTING.value + task.set_status(TaskStatus.ABORTING) db.session.merge(task) logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope) @@ -444,6 +444,10 @@ def conditional_status_update( if set_ended_at: update_values["ended_at"] = datetime.now(timezone.utc) + # Update dedup_key if transitioning to terminal state + if new_status_val in TERMINAL_STATES: + update_values["dedup_key"] = get_finished_dedup_key(task_uuid) + # Atomic compare-and-swap: only update if status matches expected rows_updated = ( db.session.query(Task) diff --git a/superset/models/tasks.py b/superset/models/tasks.py index 6c6995e9563e..e7c3992f2bb7 100644 --- a/superset/models/tasks.py +++ b/superset/models/tasks.py @@ -37,6 +37,7 @@ from superset.models.helpers import AuditMixinNullable from superset.models.task_subscribers import TaskSubscriber +from superset.tasks.constants import TERMINAL_STATES from superset.tasks.utils import ( error_update, get_finished_dedup_key, @@ -218,12 +219,7 @@ def set_status(self, status: TaskStatus | str) -> None: # (will be set to True if/when an abort handler is registered) if self.properties_dict.get("is_abortable") is None: self.update_properties({"is_abortable": False}) - elif status in [ - TaskStatus.SUCCESS.value, - TaskStatus.FAILURE.value, - TaskStatus.ABORTED.value, - TaskStatus.TIMED_OUT.value, - ]: + elif status in TERMINAL_STATES: if not self.ended_at: self.ended_at = now # Update dedup_key to UUID to free up the slot for new tasks @@ -244,12 +240,7 @@ def is_running(self) -> bool: @property def is_finished(self) -> bool: """Check if task has finished (success, failure, aborted, or timed out).""" - return self.status in [ - TaskStatus.SUCCESS.value, - TaskStatus.FAILURE.value, - TaskStatus.ABORTED.value, - TaskStatus.TIMED_OUT.value, - ] + return self.status in TERMINAL_STATES @property def is_successful(self) -> bool: diff --git a/superset/tasks/manager.py b/superset/tasks/manager.py index f4595c51167c..21b28c7d42bb 100644 --- a/superset/tasks/manager.py +++ b/superset/tasks/manager.py @@ -112,9 +112,6 @@ class TaskManager: _completion_channel_prefix: str = "gtf:complete:" _initialized: bool = False - # Backward compatibility alias - prefer importing from superset.tasks.constants - TERMINAL_STATES = TERMINAL_STATES - @classmethod def init_app(cls, app: Flask) -> None: """ @@ -271,7 +268,7 @@ def get_task() -> "Task | None": if not task: raise ValueError(f"Task {task_uuid} not found") - if task.status in cls.TERMINAL_STATES: + if task.status in TERMINAL_STATES: return task logger.debug( @@ -342,13 +339,13 @@ def _wait_via_pubsub( message.get("data"), ) task = get_task() - if task and task.status in cls.TERMINAL_STATES: + if task and task.status in TERMINAL_STATES: return task # Also check database periodically in case we missed the message # (e.g., task completed before we subscribed) task = get_task() - if task and task.status in cls.TERMINAL_STATES: + if task and task.status in TERMINAL_STATES: logger.debug( "Task %s completed (detected via db check): status=%s", task_uuid, @@ -384,7 +381,7 @@ def _wait_via_polling( if not task: raise ValueError(f"Task {task_uuid} not found") - if task.status in cls.TERMINAL_STATES: + if task.status in TERMINAL_STATES: logger.debug( "Task %s completed (detected via polling): status=%s", task_uuid, diff --git a/superset/tasks/schemas.py b/superset/tasks/schemas.py index 9fe0b31ec7b4..bf93c5d0c472 100644 --- a/superset/tasks/schemas.py +++ b/superset/tasks/schemas.py @@ -25,6 +25,9 @@ # Field descriptions uuid_description = "The unique identifier (UUID) of the task" task_key_description = "The task identifier used for deduplication" +dedup_key_description = ( + "The hashed deduplication key used internally for task deduplication" +) task_type_description = ( "The type of task (e.g., 'sql_execution', 'thumbnail_generation')" ) @@ -74,6 +77,7 @@ class TaskResponseSchema(Schema): id = fields.Int(metadata={"description": "Internal task ID"}) uuid = fields.UUID(metadata={"description": uuid_description}) task_key = fields.String(metadata={"description": task_key_description}) + dedup_key = fields.String(metadata={"description": dedup_key_description}) task_type = fields.String(metadata={"description": task_type_description}) task_name = fields.String( metadata={"description": task_name_description}, allow_none=True diff --git a/tests/integration_tests/tasks/test_sync_join_wait.py b/tests/integration_tests/tasks/test_sync_join_wait.py index 9379efca1c75..9a611cd6b29d 100644 --- a/tests/integration_tests/tasks/test_sync_join_wait.py +++ b/tests/integration_tests/tasks/test_sync_join_wait.py @@ -68,21 +68,6 @@ def test_submit_task_distinguishes_new_vs_existing( db.session.commit() -def test_terminal_states_recognized_correctly(app_context) -> None: - """ - Test that TaskManager.TERMINAL_STATES contains the expected values. - """ - assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES - assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES - assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES - assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES - - # Non-terminal states should not be in the set - assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES - assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES - assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES - - def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None: """ Test that wait_for_completion raises TimeoutError on timeout. diff --git a/tests/unit_tests/daos/test_tasks.py b/tests/unit_tests/daos/test_tasks.py index f8f3bdc073a9..8d767a17373c 100644 --- a/tests/unit_tests/daos/test_tasks.py +++ b/tests/unit_tests/daos/test_tasks.py @@ -418,3 +418,86 @@ def test_get_status_not_found(session_with_task: Session) -> None: result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000")) assert result is None + + +def test_conditional_status_update_non_terminal_state_keeps_dedup_key( + session_with_task: Session, +) -> None: + """Test that conditional_status_update preserves dedup_key for + non-terminal transitions""" + from superset.daos.tasks import TaskDAO + + # Create task in PENDING state + task = create_task( + session_with_task, + task_uuid=TASK_UUID, + task_key="non-terminal-test-task", + status=TaskStatus.PENDING, + ) + + # Store original active dedup_key + original_dedup_key = task.dedup_key + + # Transition to non-terminal state (IN_PROGRESS) + result = TaskDAO.conditional_status_update( + task_uuid=TASK_UUID, + new_status=TaskStatus.IN_PROGRESS, + expected_status=TaskStatus.PENDING, + set_started_at=True, + ) + + # Should succeed + assert result is True + + # Refresh task and verify dedup_key was NOT changed + session_with_task.refresh(task) + assert task.status == TaskStatus.IN_PROGRESS.value + assert task.dedup_key == original_dedup_key # Should remain the same + assert task.started_at is not None + + +@pytest.mark.parametrize( + "terminal_state", + [ + TaskStatus.SUCCESS, + TaskStatus.FAILURE, + TaskStatus.ABORTED, + TaskStatus.TIMED_OUT, + ], +) +def test_conditional_status_update_terminal_state_updates_dedup_key( + session_with_task: Session, terminal_state: TaskStatus +) -> None: + """Test that terminal states (SUCCESS, FAILURE, ABORTED, TIMED_OUT) + update dedup_key""" + from superset.daos.tasks import TaskDAO + + task = create_task( + session_with_task, + task_uuid=TASK_UUID, + task_key=f"terminal-test-{terminal_state.value}", + status=TaskStatus.IN_PROGRESS, + ) + + original_dedup_key = task.dedup_key + expected_finished_key = get_finished_dedup_key(TASK_UUID) + + # Transition to terminal state + result = TaskDAO.conditional_status_update( + task_uuid=TASK_UUID, + new_status=terminal_state, + expected_status=TaskStatus.IN_PROGRESS, + set_ended_at=True, + ) + + assert result is True, f"Failed to update to {terminal_state.value}" + + # Verify dedup_key was updated + session_with_task.refresh(task) + assert task.status == terminal_state.value + assert task.dedup_key == expected_finished_key, ( + f"dedup_key not updated for {terminal_state.value}" + ) + assert task.dedup_key != original_dedup_key, ( + f"dedup_key should have changed for {terminal_state.value}" + ) diff --git a/tests/unit_tests/tasks/test_manager.py b/tests/unit_tests/tasks/test_manager.py index 13997a7f113e..4fc77c2e080d 100644 --- a/tests/unit_tests/tasks/test_manager.py +++ b/tests/unit_tests/tasks/test_manager.py @@ -455,8 +455,3 @@ def test_wait_for_completion_pubsub_error_raises( timeout=5.0, poll_interval=0.1, ) - - def test_terminal_states_constant(self): - """Test TERMINAL_STATES contains expected values""" - expected = {"success", "failure", "aborted", "timed_out"} - assert TaskManager.TERMINAL_STATES == expected