Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/developer_portal/extensions/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ When GTF is considered stable, it will replace legacy Celery tasks for built-in
### Define a Task

```python
from superset_core.api.types import task, get_context
from superset_core.api.tasks import task, get_context

@task
def process_data(dataset_id: int) -> None:
Expand Down Expand Up @@ -245,7 +245,7 @@ Always implement an abort handler for long-running tasks. This allows users to c
Set a timeout to automatically abort tasks that run too long:

```python
from superset_core.api.types import task, get_context, TaskOptions
from superset_core.api.tasks import task, get_context, TaskOptions

# Set default timeout in decorator
@task(timeout=300) # 5 minutes
Expand Down Expand Up @@ -299,7 +299,7 @@ Timeouts require an abort handler to be effective. Without one, the timeout trig
Use `task_key` to prevent duplicate task execution:

```python
from superset_core.api.types import TaskOptions
from superset_core.api.tasks import TaskOptions

# Without key - creates new task each time (random UUID)
task1 = my_task.schedule(x=1)
Expand Down Expand Up @@ -331,7 +331,7 @@ print(task2.status) # "success" (terminal status)
## Task Scopes

```python
from superset_core.api.types import task, TaskScope
from superset_core.api.tasks import task, TaskScope

@task # Private by default
def private_task(): ...
Expand Down
2 changes: 1 addition & 1 deletion superset-core/src/superset_core/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def task(
is discarded; only side effects and context updates matter.

Example:
from superset_core.api.types import task, get_context, TaskScope
from superset_core.api.tasks import task, get_context, TaskScope

# Private task (default scope)
@task
Expand Down
10 changes: 7 additions & 3 deletions superset/daos/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from superset.extensions import db
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
from superset.tasks.constants import ABORTABLE_STATES
from superset.tasks.constants import ABORTABLE_STATES, TERMINAL_STATES
from superset.tasks.filters import TaskFilter
from superset.tasks.utils import get_active_dedup_key, json
from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key, json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,7 +243,7 @@ def abort_task(cls, task_uuid: UUID, skip_base_filter: bool = False) -> Task | N
)

# Transition to ABORTING (not ABORTED yet)
task.status = TaskStatus.ABORTING.value
task.set_status(TaskStatus.ABORTING)
db.session.merge(task)
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)

Expand Down Expand Up @@ -444,6 +444,10 @@ def conditional_status_update(
if set_ended_at:
update_values["ended_at"] = datetime.now(timezone.utc)

# Update dedup_key if transitioning to terminal state
if new_status_val in TERMINAL_STATES:
update_values["dedup_key"] = get_finished_dedup_key(task_uuid)

# Atomic compare-and-swap: only update if status matches expected
rows_updated = (
db.session.query(Task)
Expand Down
15 changes: 3 additions & 12 deletions superset/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from superset.models.helpers import AuditMixinNullable
from superset.models.task_subscribers import TaskSubscriber
from superset.tasks.constants import TERMINAL_STATES
from superset.tasks.utils import (
error_update,
get_finished_dedup_key,
Expand Down Expand Up @@ -218,12 +219,7 @@ def set_status(self, status: TaskStatus | str) -> None:
# (will be set to True if/when an abort handler is registered)
if self.properties_dict.get("is_abortable") is None:
self.update_properties({"is_abortable": False})
elif status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]:
elif status in TERMINAL_STATES:
if not self.ended_at:
self.ended_at = now
# Update dedup_key to UUID to free up the slot for new tasks
Expand All @@ -244,12 +240,7 @@ def is_running(self) -> bool:
@property
def is_finished(self) -> bool:
"""Check if task has finished (success, failure, aborted, or timed out)."""
return self.status in [
TaskStatus.SUCCESS.value,
TaskStatus.FAILURE.value,
TaskStatus.ABORTED.value,
TaskStatus.TIMED_OUT.value,
]
return self.status in TERMINAL_STATES

@property
def is_successful(self) -> bool:
Expand Down
11 changes: 4 additions & 7 deletions superset/tasks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ class TaskManager:
_completion_channel_prefix: str = "gtf:complete:"
_initialized: bool = False

# Backward compatibility alias - prefer importing from superset.tasks.constants
TERMINAL_STATES = TERMINAL_STATES

@classmethod
def init_app(cls, app: Flask) -> None:
"""
Expand Down Expand Up @@ -271,7 +268,7 @@ def get_task() -> "Task | None":
if not task:
raise ValueError(f"Task {task_uuid} not found")

if task.status in cls.TERMINAL_STATES:
if task.status in TERMINAL_STATES:
return task

logger.debug(
Expand Down Expand Up @@ -342,13 +339,13 @@ def _wait_via_pubsub(
message.get("data"),
)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
if task and task.status in TERMINAL_STATES:
return task

# Also check database periodically in case we missed the message
# (e.g., task completed before we subscribed)
task = get_task()
if task and task.status in cls.TERMINAL_STATES:
if task and task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via db check): status=%s",
task_uuid,
Expand Down Expand Up @@ -384,7 +381,7 @@ def _wait_via_polling(
if not task:
raise ValueError(f"Task {task_uuid} not found")

if task.status in cls.TERMINAL_STATES:
if task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via polling): status=%s",
task_uuid,
Expand Down
4 changes: 4 additions & 0 deletions superset/tasks/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
# Field descriptions
uuid_description = "The unique identifier (UUID) of the task"
task_key_description = "The task identifier used for deduplication"
dedup_key_description = (
"The hashed deduplication key used internally for task deduplication"
)
Comment on lines +28 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inaccurate API field description

The dedup_key description claims it's always 'hashed', but for finished tasks, it's set to the task's UUID string. This could mislead API consumers about the field's format.

Code suggestion
Check the AI-generated fix before applying
Suggested change
dedup_key_description = (
"The hashed deduplication key used internally for task deduplication"
)
dedup_key_description = (
"The deduplication key used internally for task deduplication"
)

Code Review Run #8b1382


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

task_type_description = (
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
)
Expand Down Expand Up @@ -74,6 +77,7 @@ class TaskResponseSchema(Schema):
id = fields.Int(metadata={"description": "Internal task ID"})
uuid = fields.UUID(metadata={"description": uuid_description})
task_key = fields.String(metadata={"description": task_key_description})
dedup_key = fields.String(metadata={"description": dedup_key_description})
task_type = fields.String(metadata={"description": task_type_description})
task_name = fields.String(
metadata={"description": task_name_description}, allow_none=True
Expand Down
15 changes: 0 additions & 15 deletions tests/integration_tests/tasks/test_sync_join_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,6 @@ def test_submit_task_distinguishes_new_vs_existing(
db.session.commit()


def test_terminal_states_recognized_correctly(app_context) -> None:
"""
Test that TaskManager.TERMINAL_STATES contains the expected values.
"""
assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES

# Non-terminal states should not be in the set
assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES


def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
"""
Test that wait_for_completion raises TimeoutError on timeout.
Expand Down
83 changes: 83 additions & 0 deletions tests/unit_tests/daos/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,86 @@ def test_get_status_not_found(session_with_task: Session) -> None:
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))

assert result is None


def test_conditional_status_update_non_terminal_state_keeps_dedup_key(
session_with_task: Session,
) -> None:
"""Test that conditional_status_update preserves dedup_key for
non-terminal transitions"""
from superset.daos.tasks import TaskDAO

# Create task in PENDING state
task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key="non-terminal-test-task",
status=TaskStatus.PENDING,
)

# Store original active dedup_key
original_dedup_key = task.dedup_key

# Transition to non-terminal state (IN_PROGRESS)
result = TaskDAO.conditional_status_update(
task_uuid=TASK_UUID,
new_status=TaskStatus.IN_PROGRESS,
expected_status=TaskStatus.PENDING,
set_started_at=True,
)

# Should succeed
assert result is True

# Refresh task and verify dedup_key was NOT changed
session_with_task.refresh(task)
assert task.status == TaskStatus.IN_PROGRESS.value
assert task.dedup_key == original_dedup_key # Should remain the same
assert task.started_at is not None


@pytest.mark.parametrize(
"terminal_state",
[
TaskStatus.SUCCESS,
TaskStatus.FAILURE,
TaskStatus.ABORTED,
TaskStatus.TIMED_OUT,
],
)
def test_conditional_status_update_terminal_state_updates_dedup_key(
session_with_task: Session, terminal_state: TaskStatus
) -> None:
"""Test that terminal states (SUCCESS, FAILURE, ABORTED, TIMED_OUT)
update dedup_key"""
from superset.daos.tasks import TaskDAO

task = create_task(
session_with_task,
task_uuid=TASK_UUID,
task_key=f"terminal-test-{terminal_state.value}",
status=TaskStatus.IN_PROGRESS,
)

original_dedup_key = task.dedup_key
expected_finished_key = get_finished_dedup_key(TASK_UUID)

# Transition to terminal state
result = TaskDAO.conditional_status_update(
task_uuid=TASK_UUID,
new_status=terminal_state,
expected_status=TaskStatus.IN_PROGRESS,
set_ended_at=True,
)

assert result is True, f"Failed to update to {terminal_state.value}"

# Verify dedup_key was updated
session_with_task.refresh(task)
assert task.status == terminal_state.value
assert task.dedup_key == expected_finished_key, (
f"dedup_key not updated for {terminal_state.value}"
)
assert task.dedup_key != original_dedup_key, (
f"dedup_key should have changed for {terminal_state.value}"
)
5 changes: 0 additions & 5 deletions tests/unit_tests/tasks/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,3 @@ def test_wait_for_completion_pubsub_error_raises(
timeout=5.0,
poll_interval=0.1,
)

def test_terminal_states_constant(self):
"""Test TERMINAL_STATES contains expected values"""
expected = {"success", "failure", "aborted", "timed_out"}
assert TaskManager.TERMINAL_STATES == expected
Loading