diff --git a/UPDATING.md b/UPDATING.md
index 0f22ef36f037..7329e6c6acac 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -24,6 +24,28 @@ assists people when migrating to a new version.
## Next
+### Signal Cache Backend
+
+A new `SIGNAL_CACHE_CONFIG` configuration provides a unified Redis-based backend for real-time coordination features in Superset. This backend enables:
+
+- **Pub/sub messaging** for real-time event notifications between workers
+- **Atomic distributed locking** using Redis SET NX EX (more performant than database-backed locks)
+- **Event-based coordination** for background task management
+
+The signal cache is used by the Global Task Framework (GTF) for abort notifications and task completion signaling, and will eventually replace `GLOBAL_ASYNC_QUERIES_CACHE_BACKEND` as the standard signaling backend. Configuring this is recommended for Redis enabled production deployments.
+
+Example configuration in `superset_config.py`:
+```python
+SIGNAL_CACHE_CONFIG = {
+ "CACHE_TYPE": "RedisCache",
+ "CACHE_KEY_PREFIX": "signal_",
+ "CACHE_REDIS_URL": "redis://localhost:6379/1",
+ "CACHE_DEFAULT_TIMEOUT": 300,
+}
+```
+
+See `superset/config.py` for complete configuration options.
+
### WebSocket config for GAQ with Docker
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:
diff --git a/docs/developer_portal/extensions/overview.md b/docs/developer_portal/extensions/overview.md
index db626b0a1e16..175fa9701caa 100644
--- a/docs/developer_portal/extensions/overview.md
+++ b/docs/developer_portal/extensions/overview.md
@@ -51,4 +51,5 @@ Extensions can provide:
- **[Deployment](./deployment)** - Packaging and deploying extensions
- **[MCP Integration](./mcp)** - Adding AI agent capabilities using extensions
- **[Security](./security)** - Security considerations and best practices
+- **[Tasks](./tasks)** - Framework for creating and managing long running tasks
- **[Community Extensions](./registry)** - Browse extensions shared by the community
diff --git a/docs/developer_portal/extensions/registry.md b/docs/developer_portal/extensions/registry.md
index 20ebfaf15963..36c05f39803f 100644
--- a/docs/developer_portal/extensions/registry.md
+++ b/docs/developer_portal/extensions/registry.md
@@ -1,6 +1,6 @@
---
title: Community Extensions
-sidebar_position: 10
+sidebar_position: 11
---
+
+# Global Task Framework
+
+The Global Task Framework (GTF) provides a unified way to manage background tasks. It handles task execution, progress tracking, cancellation, and deduplication for both synchronous and asynchronous execution. The framework uses distributed locking internally to ensure race-free operations—you don't need to worry about concurrent task creation or cancellation conflicts.
+
+## Enabling GTF
+
+GTF is disabled by default and must be enabled via the `GLOBAL_TASK_FRAMEWORK` feature flag in your `superset_config.py`:
+
+```python
+FEATURE_FLAGS = {
+ "GLOBAL_TASK_FRAMEWORK": True,
+}
+```
+
+When GTF is disabled:
+- The Task List UI menu item is hidden
+- The `/api/v1/task/*` endpoints return 404
+- Calling or scheduling a `@task`-decorated function raises `GlobalTaskFrameworkDisabledError`
+
+:::note Future Migration
+When GTF is considered stable, it will replace legacy Celery tasks for built-in features like thumbnails and alerts & reports. Enabling this flag prepares your deployment for that migration.
+:::
+
+## Quick Start
+
+### Define a Task
+
+```python
+from superset_core.api.types import task, get_context
+
+@task
+def process_data(dataset_id: int) -> None:
+ ctx = get_context()
+
+ @ctx.on_cleanup
+ def cleanup():
+ logger.info("Processing complete")
+
+ data = fetch_dataset(dataset_id)
+ process_and_cache(data)
+```
+
+### Execute a Task
+
+```python
+# Async execution - schedules on Celery worker
+task = process_data.schedule(dataset_id=123)
+print(task.status) # "pending"
+
+# Sync execution - runs inline in current process
+task = process_data(dataset_id=123)
+# ... blocks until complete
+print(task.status) # "success"
+```
+
+### Async vs Sync Execution
+
+| Method | When to Use |
+|--------|-------------|
+| `.schedule()` | Long-running operations, background processing, when you need to return immediately |
+| Direct call | Short operations, when deduplication matters, when you need the result before responding |
+
+Both execution modes provide the same task features: deduplication, progress tracking, cancellation, and visibility in the Task List UI. The difference is whether execution happens in a Celery worker (async) or inline (sync).
+
+## Task Lifecycle
+
+```
+PENDING ──→ IN_PROGRESS ────→ SUCCESS
+ │ │
+ │ ├──────────→ FAILURE
+ │ ↓ ↑
+ │ ABORTING ────────────┘
+ │ │
+ │ ├──────────→ TIMED_OUT (timeout)
+ │ │
+ └─────────────┴──────────→ ABORTED (user cancel)
+```
+
+| Status | Description |
+|--------|-------------|
+| `PENDING` | Queued, awaiting execution |
+| `IN_PROGRESS` | Executing |
+| `ABORTING` | Abort/timeout triggered, abort handlers running |
+| `SUCCESS` | Completed successfully |
+| `FAILURE` | Failed with error or abort/cleanup handler exception |
+| `ABORTED` | Cancelled by user/admin |
+| `TIMED_OUT` | Exceeded configured timeout |
+
+## Context API
+
+Access task context via `get_context()` from within any `@task` function. The context provides methods for updating task metadata and registering handlers.
+
+### Updating Task Metadata
+
+Use `update_task()` to report progress and store custom payload data:
+
+```python
+@task
+def my_task(items: list[int]) -> None:
+ ctx = get_context()
+
+ for i, item in enumerate(items):
+ result = process(item)
+ ctx.update_task(
+ progress=(i + 1, len(items)),
+ payload={"last_result": result}
+ )
+```
+
+:::tip
+Call `update_task()` once per iteration for best performance. Frequent DB writes are throttled to limit metastore load, so batching progress and payload updates together in a single call ensures both are persisted at the same time.
+:::
+
+#### Progress Formats
+
+The `progress` parameter accepts three formats:
+
+| Format | Example | Display |
+|--------|---------|---------|
+| `tuple[int, int]` | `progress=(3, 100)` | 3 of 100 (3%) with ETA |
+| `float` (0.0-1.0) | `progress=0.5` | 50% with ETA |
+| `int` | `progress=42` | 42 processed |
+
+:::tip
+Use the tuple format `(current, total)` whenever possible. It provides the richest information to users: showing both the count and percentage, while still computing ETA automatically.
+:::
+
+#### Payload
+
+The `payload` parameter stores custom metadata that can help users understand what the task is doing. Each call to `update_task()` replaces the previous payload completely.
+
+In the Task List UI, when a payload is defined, an info icon appears in the **Details** column. Users can hover over it to see the JSON content.
+
+### Handlers
+
+Register handlers to run cleanup logic or respond to abort requests:
+
+| Handler | When it runs | Use case |
+|---------|--------------|----------|
+| `on_cleanup` | Always (success, failure, abort) | Release resources, close connections |
+| `on_abort` | When task is aborted | Set stop flag, cancel external operations |
+
+```python
+@task
+def my_task() -> None:
+ ctx = get_context()
+
+ @ctx.on_cleanup
+ def cleanup():
+ logger.info("Task ended, cleaning up")
+
+ @ctx.on_abort
+ def handle_abort():
+ logger.info("Abort requested")
+
+ # ... task logic
+```
+
+Multiple handlers of the same type execute in LIFO order (last registered runs first). Abort handlers run first when abort is detected, then cleanup handlers run when the task ends.
+
+#### Best-Effort Execution
+
+**All registered handlers will always be attempted, even if one fails.** This ensures that a failure in one handler doesn't prevent other handlers from running their cleanup logic.
+
+For example, if you have three cleanup handlers and the second one throws an exception:
+1. Handler 3 runs ✓
+2. Handler 2 throws an exception ✗ (logged, but execution continues)
+3. Handler 1 runs ✓
+
+If any handler fails, the task is marked as `FAILURE` with combined error details showing all handler failures.
+
+:::tip
+Write handlers to be independent and self-contained. Don't assume previous handlers succeeded, and don't rely on shared state between handlers.
+:::
+
+## Making Tasks Abortable
+
+When users click **Cancel** in the Task List, the system decides whether to **abort** (stop) the task or **unsubscribe** (remove the user from a shared task). Abort occurs when:
+- It's a private or system task
+- It's a shared task and the user is the last subscriber
+- An admin checks **Force abort** to stop the task for all subscribers
+
+Pending tasks can always be aborted: they simply won't start. In-progress tasks require an abort handler to be abortable:
+
+```python
+@task
+def abortable_task(items: list[str]) -> None:
+ ctx = get_context()
+ should_stop = False
+
+ @ctx.on_abort
+ def handle_abort():
+ nonlocal should_stop
+ should_stop = True
+ logger.info("Abort signal received")
+
+ @ctx.on_cleanup
+ def cleanup():
+ logger.info("Task ended, cleaning up")
+
+ for item in items:
+ if should_stop:
+ return # Exit gracefully
+ process(item)
+```
+
+**Key points:**
+- Registering `on_abort` marks the task as abortable and starts the abort listener
+- The abort handler fires automatically when abort is triggered
+- Use a flag pattern to gracefully stop processing at safe points
+- Without an abort handler, in-progress tasks cannot be aborted: the Cancel button in the Task List UI will be disabled
+
+The framework automatically skips execution if a task was aborted while pending: no manual check needed at task start.
+
+:::tip
+Always implement an abort handler for long-running tasks. This allows users to cancel unneeded tasks and free up worker capacity for other operations.
+:::
+
+## Timeouts
+
+Set a timeout to automatically abort tasks that run too long:
+
+```python
+from superset_core.api.types import task, get_context, TaskOptions
+
+# Set default timeout in decorator
+@task(timeout=300) # 5 minutes
+def process_data(dataset_id: int) -> None:
+ ctx = get_context()
+ should_stop = False
+
+ @ctx.on_abort
+ def handle_abort():
+ nonlocal should_stop
+ should_stop = True
+
+ for chunk in fetch_large_dataset(dataset_id):
+ if should_stop:
+ return
+ process(chunk)
+
+# Override timeout at call time
+task = process_data.schedule(
+ dataset_id=123,
+ options=TaskOptions(timeout=600) # Override to 10 minutes
+)
+```
+
+### How Timeouts Work
+
+The timeout timer starts when the task begins executing (status changes to `IN_PROGRESS`). When the timeout expires:
+
+1. **With an abort handler registered:** The task transitions to `ABORTING`, abort handlers run, then cleanup handlers run. The final status depends on handler execution:
+ - If handlers complete successfully → `TIMED_OUT` status
+ - If handlers throw an exception → `FAILURE` status
+
+2. **Without an abort handler:** The framework cannot forcibly terminate the task. A warning is logged, and the task continues running. The Task List UI shows a warning indicator (⚠️) in the Details column to alert users that the timeout cannot be enforced.
+
+### Timeout Precedence
+
+| Source | Priority | Example |
+|--------|----------|---------|
+| `TaskOptions.timeout` | Highest | `options=TaskOptions(timeout=600)` |
+| `@task(timeout=...)` | Default | `@task(timeout=300)` |
+| Not set | No timeout | Task runs indefinitely |
+
+Call-time options always override decorator defaults, allowing tasks to have sensible defaults while permitting callers to extend or shorten the timeout for specific use cases.
+
+:::warning
+Timeouts require an abort handler to be effective. Without one, the timeout triggers only a warning and the task continues running. Always implement an abort handler when using timeouts.
+:::
+
+## Deduplication
+
+Use `task_key` to prevent duplicate task execution:
+
+```python
+from superset_core.api.types import TaskOptions
+
+# Without key - creates new task each time (random UUID)
+task1 = my_task.schedule(x=1)
+task2 = my_task.schedule(x=1) # Different task
+
+# With key - joins existing task if active
+task1 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123"))
+task2 = my_task.schedule(x=1, options=TaskOptions(task_key="report_123")) # Returns same task
+```
+
+When a task with matching key already exists, the user is added as a subscriber and the existing task is returned. This behavior is consistent across all scopes—private tasks naturally have only one subscriber since their deduplication key includes the user ID.
+
+Deduplication only applies to active tasks (pending/in-progress). Once a task completes, a new task with the same key can be created.
+
+### Sync Join-and-Wait
+
+When a sync call joins an existing task, it blocks until the task completes:
+
+```python
+# Schedule async task
+task = my_task.schedule(options=TaskOptions(task_key="report_123"))
+
+# Later sync call with same key blocks until completion of the active task
+task2 = my_task(options=TaskOptions(task_key="report_123"))
+assert task.uuid == task2.uuid # True
+print(task2.status) # "success" (terminal status)
+```
+
+## Task Scopes
+
+```python
+from superset_core.api.types import task, TaskScope
+
+@task # Private by default
+def private_task(): ...
+
+@task(scope=TaskScope.SHARED) # Multiple users can subscribe
+def shared_task(): ...
+
+@task(scope=TaskScope.SYSTEM) # Admin-only visibility
+def system_task(): ...
+```
+
+| Scope | Visibility | Cancel Behavior |
+|-------|------------|-----------------|
+| `PRIVATE` | Creator only | Cancels immediately |
+| `SHARED` | All subscribers | Last subscriber cancels; others unsubscribe |
+| `SYSTEM` | Admins only | Admin cancels |
+
+## Task Cleanup
+
+Completed tasks accumulate in the database over time. Configure a scheduled prune job to automatically remove old tasks:
+
+```python
+# In your superset_config.py, add to your Celery beat schedule:
+CELERY_CONFIG.beat_schedule["prune_tasks"] = {
+ "task": "prune_tasks",
+ "schedule": crontab(minute=0, hour=0), # Run daily at midnight
+ "kwargs": {
+ "retention_period_days": 90, # Keep tasks for 90 days
+ "max_rows_per_run": 10000, # Limit deletions per run
+ },
+}
+```
+
+The prune job only removes tasks in terminal states (`SUCCESS`, `FAILURE`, `ABORTED`, `TIMED_OUT`). Active tasks (`PENDING`, `IN_PROGRESS`, `ABORTING`) are never pruned.
+
+See `superset/config.py` for a complete example configuration.
+
+:::tip Signal Cache for Faster Notifications
+By default, abort detection and sync join-and-wait use database polling. Configure `SIGNAL_CACHE_CONFIG` to enable Redis pub/sub for real-time notifications. See [Signal Cache Backend](/docs/configuration/cache#signal-cache-backend) for configuration details.
+:::
+
+## API Reference
+
+### @task Decorator
+
+```python
+@task(
+ name: str | None = None,
+ scope: TaskScope = TaskScope.PRIVATE,
+ timeout: int | None = None
+)
+```
+
+- `name`: Task identifier (defaults to function name)
+- `scope`: `PRIVATE`, `SHARED`, or `SYSTEM`
+- `timeout`: Default timeout in seconds (can be overridden via `TaskOptions`)
+
+### TaskContext Methods
+
+| Method | Description |
+|--------|-------------|
+| `update_task(progress, payload)` | Update progress and/or custom payload |
+| `on_cleanup(handler)` | Register cleanup handler |
+| `on_abort(handler)` | Register abort handler (makes task abortable) |
+
+### TaskOptions
+
+```python
+TaskOptions(
+ task_key: str | None = None,
+ task_name: str | None = None,
+ timeout: int | None = None
+)
+```
+
+- `task_key`: Deduplication key (also used as display name if `task_name` is not set)
+- `task_name`: Human-readable display name for the Task List UI
+- `timeout`: Timeout in seconds (overrides decorator default)
+
+:::tip
+Provide a descriptive `task_name` for better readability in the Task List UI. While `task_key` is used for deduplication and may be technical (e.g., `chart_export_123`), `task_name` can be user-friendly (e.g., `"Export Sales Chart 123"`).
+:::
+
+## Error Handling
+
+Let exceptions propagate: the framework captures them automatically and sets task status to `FAILURE`:
+
+```python
+@task
+def risky_task() -> None:
+ # No try/catch needed - framework handles it
+ result = operation_that_might_fail()
+```
+
+On failure, the framework records:
+- `error_message`: Exception message
+- `exception_type`: Exception class name
+- `stack_trace`: Full traceback (visible when `SHOW_STACKTRACE=True`)
+
+In the Task List UI, failed tasks show error details when hovering over the status. When stack traces are enabled, a separate bug icon appears in the **Details** column for viewing the full traceback.
+
+Cleanup handlers still run after an exception, so resources can be properly released as necessary.
+
+:::tip
+Use descriptive exception messages. In environments where stack traces are hidden (`SHOW_STACKTRACE=False`), users see only the error message and exception type when hovering over failed tasks. Clear messages help users troubleshoot issues without administrator assistance.
+:::
diff --git a/docs/developer_portal/sidebars.js b/docs/developer_portal/sidebars.js
index 7c376be945e7..7926d80cf615 100644
--- a/docs/developer_portal/sidebars.js
+++ b/docs/developer_portal/sidebars.js
@@ -53,6 +53,7 @@ module.exports = {
'extensions/deployment',
'extensions/mcp',
'extensions/security',
+ 'extensions/tasks',
'extensions/registry',
],
},
diff --git a/docs/docs/configuration/cache.mdx b/docs/docs/configuration/cache.mdx
index c89cdc85eb78..73a73332ba63 100644
--- a/docs/docs/configuration/cache.mdx
+++ b/docs/docs/configuration/cache.mdx
@@ -7,6 +7,12 @@ version: 1
# Caching
+:::note
+When a cache backend is configured, Superset expects it to remain available. Operations will
+fail if the configured backend becomes unavailable rather than silently degrading. This
+fail-fast behavior ensures operators are immediately aware of infrastructure issues.
+:::
+
Superset uses [Flask-Caching](https://flask-caching.readthedocs.io/) for caching purposes.
Flask-Caching supports various caching backends, including Redis (recommended), Memcached,
SimpleCache (in-memory), or the local filesystem.
@@ -153,6 +159,84 @@ Then on configuration:
WEBDRIVER_AUTH_FUNC = auth_driver
```
+## Signal Cache Backend
+
+Superset supports an optional signal cache (`SIGNAL_CACHE_CONFIG`) for
+high-performance distributed operations. This configuration enables:
+
+- **Distributed locking**: Moves lock operations from the metadata database to Redis, improving
+ performance and reducing metastore load
+- **Real-time event notifications**: Enables instant pub/sub messaging for task abort signals and
+ completion notifications instead of polling-based approaches
+
+:::note
+This requires Redis or Valkey specifically—it uses Redis-specific features (pub/sub, `SET NX EX`)
+that are not available in general Flask-Caching backends.
+:::
+
+### Configuration
+
+The signal cache uses Flask-Caching style configuration for consistency with other cache
+backends. Configure `SIGNAL_CACHE_CONFIG` in `superset_config.py`:
+
+```python
+SIGNAL_CACHE_CONFIG = {
+ "CACHE_TYPE": "RedisCache",
+ "CACHE_REDIS_HOST": "localhost",
+ "CACHE_REDIS_PORT": 6379,
+ "CACHE_REDIS_DB": 0,
+ "CACHE_REDIS_PASSWORD": "", # Optional
+}
+```
+
+For Redis Sentinel deployments:
+
+```python
+SIGNAL_CACHE_CONFIG = {
+ "CACHE_TYPE": "RedisSentinelCache",
+ "CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
+ "CACHE_REDIS_SENTINEL_MASTER": "mymaster",
+ "CACHE_REDIS_SENTINEL_PASSWORD": None, # Sentinel password (if different)
+ "CACHE_REDIS_PASSWORD": "", # Redis password
+ "CACHE_REDIS_DB": 0,
+}
+```
+
+For SSL/TLS connections:
+
+```python
+SIGNAL_CACHE_CONFIG = {
+ "CACHE_TYPE": "RedisCache",
+ "CACHE_REDIS_HOST": "redis.example.com",
+ "CACHE_REDIS_PORT": 6380,
+ "CACHE_REDIS_SSL": True,
+ "CACHE_REDIS_SSL_CERTFILE": "/path/to/client.crt",
+ "CACHE_REDIS_SSL_KEYFILE": "/path/to/client.key",
+ "CACHE_REDIS_SSL_CA_CERTS": "/path/to/ca.crt",
+}
+```
+
+### Distributed Lock TTL
+
+You can configure the default lock TTL (time-to-live) in seconds. Locks automatically expire after
+this duration to prevent deadlocks from crashed processes:
+
+```python
+DISTRIBUTED_LOCK_DEFAULT_TTL = 30 # Default: 30 seconds
+```
+
+Individual lock acquisitions can override this value when needed.
+
+### Database-Only Mode
+
+When `SIGNAL_CACHE_CONFIG` is not configured, Superset uses database-backed operations:
+
+- **Locking**: Uses the KeyValue table with periodic cleanup of expired entries
+- **Event notifications**: Uses database polling instead of pub/sub
+
+While database-backed operations work reliably, the Redis backend is recommended for production
+deployments where low latency and reduced database load are important.
+
:::resources
- [Blog: The Data Engineer's Guide to Lightning-Fast Superset Dashboards](https://preset.io/blog/the-data-engineers-guide-to-lightning-fast-apache-superset-dashboards/)
- [Blog: Accelerating Dashboards with Materialized Views](https://preset.io/blog/accelerating-apache-superset-dashboards-with-materialized-views/)
diff --git a/docs/sidebarTutorials.js b/docs/sidebarTutorials.js
index b527517a9540..b786478c0b1f 100644
--- a/docs/sidebarTutorials.js
+++ b/docs/sidebarTutorials.js
@@ -97,6 +97,7 @@ const sidebars = {
'extensions/deployment',
'extensions/mcp',
'extensions/security',
+ 'extensions/tasks',
'extensions/registry',
],
},
diff --git a/superset-core/src/superset_core/api/daos.py b/superset-core/src/superset_core/api/daos.py
index 3dc4cf0a7de7..f686f0c659dc 100644
--- a/superset-core/src/superset_core/api/daos.py
+++ b/superset-core/src/superset_core/api/daos.py
@@ -46,6 +46,7 @@
Query,
SavedQuery,
Tag,
+ Task,
User,
)
@@ -248,6 +249,48 @@ class KeyValueDAO(BaseDAO[KeyValue]):
id_column_name = "id"
+class TaskDAO(BaseDAO[Task]):
+ """
+ Abstract Task DAO interface.
+
+ Host implementations will replace this class during initialization
+ with a concrete implementation providing actual functionality.
+ """
+
+ # Class variables that will be set by host implementation
+ model_cls = None
+ base_filter = None
+ id_column_name = "id"
+ uuid_column_name = "uuid"
+
+ @classmethod
+ @abstractmethod
+ def find_by_task_key(
+ cls,
+ task_type: str,
+ task_key: str,
+ scope: str = "private",
+ user_id: int | None = None,
+ ) -> Task | None:
+ """
+ Find active task by type, key, scope, and user.
+
+ Uses dedup_key internally for efficient querying with a unique index.
+ Only returns tasks that are active (pending or in progress).
+
+ Uniqueness logic by scope:
+ - private: scope + task_type + task_key + user_id
+ - shared/system: scope + task_type + task_key (user-agnostic)
+
+ :param task_type: Task type to filter by
+ :param task_key: Task identifier for deduplication
+ :param scope: Task scope (private/shared/system)
+ :param user_id: User ID (required for private tasks)
+ :returns: Task instance or None if not found or not active
+ """
+ ...
+
+
__all__ = [
"BaseDAO",
"DatasetDAO",
@@ -259,4 +302,5 @@ class KeyValueDAO(BaseDAO[KeyValue]):
"SavedQueryDAO",
"TagDAO",
"KeyValueDAO",
+ "TaskDAO",
]
diff --git a/superset-core/src/superset_core/api/models.py b/superset-core/src/superset_core/api/models.py
index 346e8392f165..91e10255d040 100644
--- a/superset-core/src/superset_core/api/models.py
+++ b/superset-core/src/superset_core/api/models.py
@@ -40,6 +40,7 @@
from sqlalchemy.orm import scoped_session
if TYPE_CHECKING:
+ from superset_core.api.tasks import TaskProperties
from superset_core.api.types import (
AsyncQueryHandle,
QueryOptions,
@@ -361,6 +362,132 @@ class KeyValue(CoreModel):
changed_by_fk: int | None
+class Task(CoreModel):
+ """
+ Abstract Task model interface.
+
+ Host implementations will replace this class during initialization
+ with concrete implementation providing actual functionality.
+
+ This model represents async tasks in the Global Task Framework (GTF).
+
+ Non-filterable fields (progress, error info, execution config) are stored
+ in a `properties` JSON blob for schema flexibility.
+ """
+
+ __abstract__ = True
+
+ # Type hints for expected column attributes
+ id: int
+ uuid: UUID
+ task_key: str # For deduplication
+ task_type: str # e.g., 'sql_execution'
+ task_name: str | None # Human readable name
+ scope: str # private/shared/system
+ status: str
+ dedup_key: str # Computed deduplication key
+
+ # Timestamps (from AuditMixinNullable)
+ created_on: datetime | None
+ changed_on: datetime | None
+ started_at: datetime | None
+ ended_at: datetime | None
+
+ # User context
+ created_by_fk: int | None
+ user_id: int | None
+
+ # Task output data
+ payload: str # JSON serialized task output data
+
+ def get_payload(self) -> dict[str, Any]:
+ """
+ Get payload as parsed JSON.
+
+ Payload contains task-specific output data set by task code.
+
+ Host implementations will replace this method during initialization
+ with concrete implementation providing actual functionality.
+
+ :returns: Dictionary containing payload data
+ """
+ raise NotImplementedError("Method will be replaced during initialization")
+
+ def set_payload(self, data: dict[str, Any]) -> None:
+ """
+ Update payload with new data (merges with existing).
+
+ Host implementations will replace this method during initialization
+ with concrete implementation providing actual functionality.
+
+ :param data: Dictionary of data to merge into payload
+ """
+ raise NotImplementedError("Method will be replaced during initialization")
+
+ @property
+ def properties(self) -> Any:
+ """
+ Get typed properties (runtime state and execution config).
+
+ Properties contain:
+ - is_abortable: bool | None - has abort handler registered
+ - progress_percent: float | None - progress 0.0-1.0
+ - progress_current: int | None - current iteration count
+ - progress_total: int | None - total iterations
+ - error_message: str | None - human-readable error message
+ - exception_type: str | None - exception class name
+ - stack_trace: str | None - full formatted traceback
+ - timeout: int | None - timeout in seconds
+
+ Host implementations will replace this property during initialization.
+
+ :returns: TaskProperties dataclass instance
+ """
+ raise NotImplementedError("Property will be replaced during initialization")
+
+ def update_properties(self, updates: "TaskProperties") -> None:
+ """
+ Update specific properties fields (merge semantics).
+
+ Only updates fields present in the updates dict.
+
+ Host implementations will replace this method during initialization.
+
+ :param updates: TaskProperties dict with fields to update
+
+ Example:
+ task.update_properties({"is_abortable": True})
+ """
+ raise NotImplementedError("Method will be replaced during initialization")
+
+
+class TaskSubscriber(CoreModel):
+ """
+ Abstract TaskSubscriber model interface.
+
+ Host implementations will replace this class during initialization
+ with concrete implementation providing actual functionality.
+
+ This model tracks task subscriptions for multi-user shared tasks. When a user
+ schedules a shared task with the same parameters as an existing task,
+ they are subscribed to that task instead of creating a duplicate.
+ """
+
+ __abstract__ = True
+
+ # Type hints for expected attributes (no actual field definitions)
+ id: int
+ task_id: int
+ user_id: int
+ subscribed_at: datetime
+
+ # Audit fields from AuditMixinNullable
+ created_on: datetime | None
+ changed_on: datetime | None
+ created_by_fk: int | None
+ changed_by_fk: int | None
+
+
def get_session() -> scoped_session:
"""
Retrieve the SQLAlchemy session to directly interface with the
@@ -384,6 +511,8 @@ def get_session() -> scoped_session:
"SavedQuery",
"Tag",
"KeyValue",
+ "Task",
+ "TaskSubscriber",
"CoreModel",
"get_session",
]
diff --git a/superset-core/src/superset_core/api/tasks.py b/superset-core/src/superset_core/api/tasks.py
new file mode 100644
index 000000000000..1adcd9ab3275
--- /dev/null
+++ b/superset-core/src/superset_core/api/tasks.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Callable, Generic, Literal, ParamSpec, TypedDict, TypeVar
+
+from superset_core.api.models import Task
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+class TaskStatus(str, Enum):
+ """
+ Status of task execution.
+ """
+
+ PENDING = "pending"
+ IN_PROGRESS = "in_progress"
+ SUCCESS = "success"
+ FAILURE = "failure"
+ ABORTING = "aborting" # Abort/timeout requested, handlers running
+ ABORTED = "aborted" # User/admin cancelled
+ TIMED_OUT = "timed_out" # Timeout expired
+
+
+class TaskScope(str, Enum):
+ """
+ Scope of task visibility and access control.
+ """
+
+ PRIVATE = "private" # User-specific tasks (default)
+ SHARED = "shared" # Multi-user collaborative tasks
+ SYSTEM = "system" # Admin-only background tasks
+
+
+class TaskProperties(TypedDict, total=False):
+ """
+ TypedDict for task runtime state and execution config.
+
+ Stored as JSON in the database, accessed as a dict throughout the codebase.
+ All fields are optional (total=False) - only set keys are present in the dict.
+
+ Usage:
+ # Reading - always use .get() since keys may not be present
+ if task.properties.get("is_abortable"):
+ ...
+
+ # Writing/updating - only include keys you want to set
+ task.update_properties({"is_abortable": True, "progress_percent": 0.5})
+
+ Notes:
+ - Sparse dict: only keys that are explicitly set are present
+ - Unknown keys from JSON are preserved (forward compatibility)
+ - Always use .get() for reads since keys may be absent
+ """
+
+ # Execution config - set at task creation
+ execution_mode: Literal["async", "sync"]
+ timeout: int
+
+ # Runtime state - set by framework during execution
+ is_abortable: bool
+ progress_percent: float
+ progress_current: int
+ progress_total: int
+
+ # Error info - set when task fails
+ error_message: str
+ exception_type: str
+ stack_trace: str
+
+
+@dataclass(frozen=True)
+class TaskOptions:
+ """
+ Execution metadata for tasks.
+
+ NOTE: This is intentionally minimal for the initial implementation.
+ Additional options (queue, priority, run_at, delay_s,
+ max_retries, retry_backoff_s, tags, etc.) can be added later when needed.
+
+ Future enhancements will include:
+ - Validation (e.g., run_at vs delay_s mutual exclusion)
+ - Queue routing and priority management
+ - Retry policies and backoff strategies
+
+ Example:
+ from superset_core.api.tasks import TaskOptions, TaskScope
+
+ # Private task (default)
+ task = my_task.schedule(arg1)
+
+ # Custom task with deduplication
+ task = my_task.schedule(
+ arg1,
+ options=TaskOptions(
+ task_key="custom_key",
+ task_name="Custom Task Name"
+ )
+ )
+
+ # Task with custom name
+ task = admin_task.schedule(
+ options=TaskOptions(task_name="Admin Operation")
+ )
+
+ # Task with timeout (overrides decorator default)
+ task = long_task.schedule(
+ options=TaskOptions(timeout=600) # 10 minute timeout
+ )
+ """
+
+ task_key: str | None = None
+ task_name: str | None = None
+ timeout: int | None = None # Timeout in seconds
+
+
+class TaskContext(ABC):
+ """
+ Abstract task context for write-only task state updates.
+
+ Tasks use this context to update their state (progress, payload) and
+ check for cancellation. Tasks should not need to read their own state -
+ they are the source of state, not consumers of it.
+
+ Host implementations will replace this abstract class during initialization
+ with a concrete implementation providing actual functionality.
+ """
+
+ @abstractmethod
+ def update_task(
+ self,
+ progress: float | int | tuple[int, int] | None = None,
+ payload: dict[str, Any] | None = None,
+ ) -> None:
+ """
+ Update task progress and/or payload atomically.
+
+ All parameters are optional. Payload is merged with existing data,
+ not replaced. All updates occur in a single database transaction.
+
+ Progress can be specified in three ways:
+ - float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
+ - int: Count only (total unknown), e.g., 42 means "42 items processed"
+ - tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
+ The percentage is automatically computed from count/total.
+
+ :param progress: Progress value, or None to leave unchanged
+ :param payload: Payload data to merge (dict), or None to leave unchanged
+
+ Examples:
+ # Percentage only - displays as "In progress: 50 %"
+ ctx.update_task(progress=0.5)
+
+ # Count only (total unknown) - displays as "In progress: 42"
+ ctx.update_task(progress=42)
+
+ # Count and total - displays as "In progress: 3 of 100 (3 %)"
+ ctx.update_task(progress=(3, 100))
+
+ # Update payload only
+ ctx.update_task(payload={"step": "processing"})
+
+ # Update both atomically
+ ctx.update_task(
+ progress=(80, 100),
+ payload={"processed": 80, "total": 100}
+ )
+ """
+ ...
+
+ @abstractmethod
+ def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
+ """
+ Register a cleanup handler that runs when the task ends.
+
+ Cleanup handlers are called when the task completes (success),
+ fails with an error, or is cancelled. Multiple handlers can be
+ registered and will execute in LIFO order (last registered runs first).
+
+ Can be used as a decorator:
+ @ctx.on_cleanup
+ def cleanup():
+ logger.info("Task ended")
+
+ Or called directly:
+ ctx.on_cleanup(lambda: logger.info("Task ended"))
+
+ :param handler: Cleanup function to register
+ :returns: The handler (for decorator compatibility)
+ """
+ ...
+
+ @abstractmethod
+ def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
+ """
+ Register handler that runs when task is aborted.
+
+ When the first handler is registered, background polling starts
+ automatically. The handler will be called when an abort is detected.
+
+ The handler executes in a background thread and the task code
+ continues running unless the handler takes action to stop it.
+
+ :param handler: Callback function to execute when abort is detected
+ :returns: The handler (for decorator compatibility)
+
+ Example:
+ @ctx.on_abort
+ def handle_abort():
+ logger.info("Task was aborted!")
+ cleanup_partial_work()
+ """
+ ...
+
+
+def task(
+ name: str | None = None,
+ scope: TaskScope = TaskScope.PRIVATE,
+ timeout: int | None = None,
+) -> Callable[[Callable[P, R]], "TaskWrapper[P]"]:
+ """
+ Decorator to register a task.
+
+ Host implementations will replace this function during initialization
+ with a concrete implementation providing actual functionality.
+
+ :param name: Optional unique task name (e.g., "superset.generate_thumbnail").
+ If not provided, uses the function name as the task name.
+ :param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
+ Defaults to TaskScope.PRIVATE.
+ :param timeout: Optional timeout in seconds. When the timeout is reached,
+ abort handlers are triggered if registered. Can be overridden
+ at call time via TaskOptions(timeout=...).
+ :returns: TaskWrapper with .schedule() method
+
+ Note:
+ Both direct calls and .schedule() return Task, regardless of the
+ original function's return type. The decorated function's return value
+ is discarded; only side effects and context updates matter.
+
+ Example:
+ from superset_core.api.types import task, get_context, TaskScope
+
+ # Private task (default scope)
+ @task
+ def generate_thumbnail(chart_id: int) -> None:
+ ctx = get_context()
+ # ... task implementation
+
+ # Named task with shared scope
+ @task(name="generate_report", scope=TaskScope.SHARED)
+ def generate_chart_thumbnail(chart_id: int) -> None:
+ ctx = get_context()
+
+ # Update progress and payload atomically
+ ctx.update_task(
+ progress=0.5,
+ payload={"chart_id": chart_id, "status": "processing"}
+ )
+ # ... task implementation
+
+ ctx.update_task(progress=1.0)
+
+ # System task (admin-only)
+ @task(scope=TaskScope.SYSTEM)
+ def cleanup_old_data() -> None:
+ ctx = get_context()
+ # ... cleanup implementation
+
+ # Task with timeout
+ @task(timeout=300) # 5-minute timeout
+ def long_running_task() -> None:
+ ctx = get_context()
+
+ @ctx.on_abort
+ def handle_abort():
+ # Called when timeout or manual abort
+ pass
+
+ # Schedule async execution
+ task = generate_chart_thumbnail.schedule(chart_id=123) # Returns Task
+
+ # Direct call for sync execution (blocks until task is complete)
+ task = generate_chart_thumbnail(chart_id=123) # Also returns Task
+ """
+ raise NotImplementedError("Function will be replaced during initialization")
+
+
+class TaskWrapper(Generic[P]):
+ """
+ Type stub for task wrapper returned by @task decorator.
+
+ Both __call__ and .schedule() return Task.
+ """
+
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Task:
+ """Execute the task synchronously."""
+ raise NotImplementedError("Will be replaced during initialization")
+
+ def schedule(self, *args: P.args, **kwargs: P.kwargs) -> Task:
+ """Schedule the task for async execution."""
+ raise NotImplementedError("Will be replaced during initialization")
+
+
+def get_context() -> TaskContext:
+ """
+ Get the current task context from ambient context.
+
+ Host implementations will replace this function during initialization
+ with a concrete implementation providing actual functionality.
+
+ This function provides ambient access to the task context without
+ requiring it to be passed as a parameter. It can only be called
+ from within an async task execution.
+
+ :returns: The current TaskContext
+ :raises RuntimeError: If called outside a task execution context
+
+ Example:
+ @task("thumbnail_generation")
+ def generate_chart_thumbnail(chart_id: int):
+ ctx = get_context() # Access ambient context
+
+ # Update task state - no need to fetch task object
+ ctx.update_task(
+ progress=0.5,
+ payload={"chart_id": chart_id}
+ )
+ """
+ raise NotImplementedError("Function will be replaced during initialization")
+
+
+__all__ = [
+ "TaskStatus",
+ "TaskScope",
+ "TaskProperties",
+ "TaskContext",
+ "TaskOptions",
+ "task",
+ "get_context",
+]
diff --git a/superset-frontend/package-lock.json b/superset-frontend/package-lock.json
index d555db9b026e..2ba264c8572c 100644
--- a/superset-frontend/package-lock.json
+++ b/superset-frontend/package-lock.json
@@ -109,6 +109,7 @@
"mustache": "^4.2.0",
"nanoid": "^5.1.6",
"ol": "^7.5.2",
+ "pretty-ms": "^9.3.0",
"query-string": "9.3.1",
"re-resizable": "^6.11.2",
"react": "^17.0.2",
@@ -43687,6 +43688,21 @@
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
"license": "MIT"
},
+ "node_modules/pretty-ms": {
+ "version": "9.3.0",
+ "resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
+ "integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
+ "license": "MIT",
+ "dependencies": {
+ "parse-ms": "^4.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
"node_modules/prismjs": {
"version": "1.30.0",
"resolved": "https://registry.npmjs.org/prismjs/-/prismjs-1.30.0.tgz",
@@ -56437,20 +56453,6 @@
"url": "https://github.com/sponsors/wooorm"
}
},
- "packages/superset-ui-core/node_modules/pretty-ms": {
- "version": "9.3.0",
- "resolved": "https://registry.npmjs.org/pretty-ms/-/pretty-ms-9.3.0.tgz",
- "integrity": "sha512-gjVS5hOP+M3wMm5nmNOucbIrqudzs9v/57bWRHQWLYklXqoXKrVfYW2W9+glfGsqtPgpiz5WwyEEB+ksXIx3gQ==",
- "dependencies": {
- "parse-ms": "^4.0.0"
- },
- "engines": {
- "node": ">=18"
- },
- "funding": {
- "url": "https://github.com/sponsors/sindresorhus"
- }
- },
"packages/superset-ui-core/node_modules/property-information": {
"version": "7.1.0",
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz",
diff --git a/superset-frontend/package.json b/superset-frontend/package.json
index c500b481d797..4422bd79feda 100644
--- a/superset-frontend/package.json
+++ b/superset-frontend/package.json
@@ -187,6 +187,7 @@
"markdown-to-jsx": "^9.7.3",
"match-sorter": "^6.3.4",
"memoize-one": "^5.2.1",
+ "pretty-ms": "^9.3.0",
"mousetrap": "^1.6.5",
"mustache": "^4.2.0",
"nanoid": "^5.1.6",
diff --git a/superset-frontend/packages/superset-ui-core/src/utils/featureFlags.ts b/superset-frontend/packages/superset-ui-core/src/utils/featureFlags.ts
index 46f832cc0946..57bda77b7dbb 100644
--- a/superset-frontend/packages/superset-ui-core/src/utils/featureFlags.ts
+++ b/superset-frontend/packages/superset-ui-core/src/utils/featureFlags.ts
@@ -54,6 +54,7 @@ export enum FeatureFlag {
EstimateQueryCost = 'ESTIMATE_QUERY_COST',
FilterBarClosedByDefault = 'FILTERBAR_CLOSED_BY_DEFAULT',
GlobalAsyncQueries = 'GLOBAL_ASYNC_QUERIES',
+ GlobalTaskFramework = 'GLOBAL_TASK_FRAMEWORK',
ListviewsDefaultCardView = 'LISTVIEWS_DEFAULT_CARD_VIEW',
Matrixify = 'MATRIXIFY',
ScheduledQueries = 'SCHEDULED_QUERIES',
diff --git a/superset-frontend/src/components/AuditInfo/index.tsx b/superset-frontend/src/components/AuditInfo/index.tsx
index 503dbf17cb6b..1df21c848cc8 100644
--- a/superset-frontend/src/components/AuditInfo/index.tsx
+++ b/superset-frontend/src/components/AuditInfo/index.tsx
@@ -19,9 +19,9 @@
import getOwnerName from 'src/utils/getOwnerName';
import { t } from '@apache-superset/core';
import { Tooltip } from '@superset-ui/core/components';
-import type { ModifiedInfoProps } from './types';
+import type { AuditInfoProps } from './types';
-export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
+export const ModifiedInfo = ({ user, date }: AuditInfoProps) => {
const dateSpan = (
{date}
@@ -40,4 +40,23 @@ export const ModifiedInfo = ({ user, date }: ModifiedInfoProps) => {
return dateSpan;
};
-export type { ModifiedInfoProps };
+export const CreatedInfo = ({ user, date }: AuditInfoProps) => {
+ const dateSpan = (
+
+ {date}
+
+ );
+
+ if (user) {
+ const userName = getOwnerName(user);
+ const title = t('Created by: %s', userName);
+ return (
+
+ {dateSpan}
+
+ );
+ }
+ return dateSpan;
+};
+
+export type { AuditInfoProps };
diff --git a/superset-frontend/src/components/AuditInfo/types.ts b/superset-frontend/src/components/AuditInfo/types.ts
index 06166b783758..f097198c1441 100644
--- a/superset-frontend/src/components/AuditInfo/types.ts
+++ b/superset-frontend/src/components/AuditInfo/types.ts
@@ -18,7 +18,7 @@
*/
import type Owner from 'src/types/Owner';
-export type ModifiedInfoProps = {
+export type AuditInfoProps = {
user?: Owner;
date: string;
};
diff --git a/superset-frontend/src/components/index.ts b/superset-frontend/src/components/index.ts
index 558f080261f4..5d936e225a28 100644
--- a/superset-frontend/src/components/index.ts
+++ b/superset-frontend/src/components/index.ts
@@ -41,7 +41,7 @@ export * from './GenericLink';
export { GridTable, type TableProps } from './GridTable';
export * from './Tag';
export * from './TagsList';
-export { ModifiedInfo, type ModifiedInfoProps } from './AuditInfo';
+export { CreatedInfo, ModifiedInfo, type AuditInfoProps } from './AuditInfo';
export {
DynamicPluginProvider,
PluginContext,
diff --git a/superset-frontend/src/features/tasks/TaskPayloadPopover.tsx b/superset-frontend/src/features/tasks/TaskPayloadPopover.tsx
new file mode 100644
index 000000000000..eaa27b900ad2
--- /dev/null
+++ b/superset-frontend/src/features/tasks/TaskPayloadPopover.tsx
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import { useState } from 'react';
+import { styled } from '@apache-superset/core/ui';
+import { Popover } from '@superset-ui/core/components';
+import { Icons } from '@superset-ui/core/components/Icons';
+
+const PayloadContainer = styled.div`
+ max-width: 400px;
+ max-height: 300px;
+ overflow: auto;
+ padding: ${({ theme }) => theme.sizeUnit * 2}px;
+`;
+
+const PayloadPre = styled.pre`
+ margin: 0;
+ font-size: ${({ theme }) => theme.fontSizeSM}px;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+`;
+
+const InfoIconWrapper = styled.span`
+ cursor: pointer;
+ color: ${({ theme }) => theme.colorIcon};
+
+ &:hover {
+ color: ${({ theme }) => theme.colorPrimary};
+ }
+`;
+
+interface TaskPayloadPopoverProps {
+ payload: Record;
+}
+
+export default function TaskPayloadPopover({
+ payload,
+}: TaskPayloadPopoverProps) {
+ const [visible, setVisible] = useState(false);
+
+ const content = (
+
+ {JSON.stringify(payload, null, 2)}
+
+ );
+
+ return (
+
+
+
+
+
+ );
+}
diff --git a/superset-frontend/src/features/tasks/TaskStackTracePopover.tsx b/superset-frontend/src/features/tasks/TaskStackTracePopover.tsx
new file mode 100644
index 000000000000..caacf6b633b9
--- /dev/null
+++ b/superset-frontend/src/features/tasks/TaskStackTracePopover.tsx
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import { useState, useCallback } from 'react';
+import { t } from '@apache-superset/core';
+import { styled } from '@apache-superset/core/ui';
+import { Popover, Tooltip } from '@superset-ui/core/components';
+import { Icons } from '@superset-ui/core/components/Icons';
+import { useToasts } from 'src/components/MessageToasts/withToasts';
+import copyTextToClipboard from 'src/utils/copy';
+
+const StackTraceContainer = styled.div`
+ max-width: 600px;
+ max-height: 400px;
+ display: flex;
+ flex-direction: column;
+`;
+
+const Header = styled.div`
+ display: flex;
+ justify-content: flex-end;
+ padding: ${({ theme }) => theme.sizeUnit}px
+ ${({ theme }) => theme.sizeUnit * 2}px;
+ border-bottom: 1px solid ${({ theme }) => theme.colorBorder};
+`;
+
+const CopyButton = styled.button`
+ background: none;
+ border: none;
+ cursor: pointer;
+ padding: ${({ theme }) => theme.sizeUnit / 2}px;
+ color: ${({ theme }) => theme.colorTextSecondary};
+ display: flex;
+ align-items: center;
+ gap: ${({ theme }) => theme.sizeUnit / 2}px;
+ font-size: ${({ theme }) => theme.fontSizeSM}px;
+
+ &:hover {
+ color: ${({ theme }) => theme.colorText};
+ }
+`;
+
+const StackTraceContent = styled.div`
+ overflow: auto;
+ padding: ${({ theme }) => theme.sizeUnit * 2}px;
+ flex: 1;
+`;
+
+const StackTrace = styled.pre`
+ margin: 0;
+ font-size: ${({ theme }) => theme.fontSizeSM}px;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+ font-family: ${({ theme }) => theme.fontFamilyCode};
+`;
+
+const ErrorIconWrapper = styled.span`
+ cursor: pointer;
+ color: ${({ theme }) => theme.colorError};
+
+ &:hover {
+ opacity: 0.8;
+ }
+`;
+
+interface TaskStackTracePopoverProps {
+ stackTrace: string;
+}
+
+export default function TaskStackTracePopover({
+ stackTrace,
+}: TaskStackTracePopoverProps) {
+ const [visible, setVisible] = useState(false);
+ const [copied, setCopied] = useState(false);
+ const { addDangerToast } = useToasts();
+
+ const handleCopy = useCallback(() => {
+ copyTextToClipboard(() => Promise.resolve(stackTrace))
+ .then(() => {
+ setCopied(true);
+ setTimeout(() => setCopied(false), 2000);
+ })
+ .catch(() => {
+ addDangerToast(t('Failed to copy stack trace to clipboard'));
+ });
+ }, [stackTrace, addDangerToast]);
+
+ const content = (
+
+
+
+
+ {copied ? (
+
+ ) : (
+
+ )}
+ {t('Copy')}
+
+
+
+
+ {stackTrace}
+
+
+ );
+
+ return (
+
+
+
+
+
+ );
+}
diff --git a/superset-frontend/src/features/tasks/TaskStatusIcon.tsx b/superset-frontend/src/features/tasks/TaskStatusIcon.tsx
new file mode 100644
index 000000000000..182afdcbf446
--- /dev/null
+++ b/superset-frontend/src/features/tasks/TaskStatusIcon.tsx
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import React from 'react';
+import { useTheme, SupersetTheme, t } from '@apache-superset/core/ui';
+import { Icons } from '@superset-ui/core/components/Icons';
+import { Tooltip } from '@superset-ui/core/components';
+import { TaskStatus } from './types';
+import { formatProgressTooltip } from './timeUtils';
+
+function getStatusColor(status: TaskStatus, theme: SupersetTheme): string {
+ switch (status) {
+ case TaskStatus.Pending:
+ return theme.colorPrimaryText;
+ case TaskStatus.InProgress:
+ return theme.colorPrimaryText;
+ case TaskStatus.Success:
+ return theme.colorSuccessText;
+ case TaskStatus.Failure:
+ return theme.colorErrorText;
+ case TaskStatus.TimedOut:
+ return theme.colorErrorText;
+ case TaskStatus.Aborting:
+ return theme.colorWarningText;
+ case TaskStatus.Aborted:
+ return theme.colorWarningText;
+ default:
+ return theme.colorText;
+ }
+}
+
+const statusIcons = {
+ [TaskStatus.Pending]: Icons.ClockCircleOutlined,
+ [TaskStatus.InProgress]: Icons.LoadingOutlined,
+ [TaskStatus.Success]: Icons.CheckCircleOutlined,
+ [TaskStatus.Failure]: Icons.CloseCircleOutlined,
+ [TaskStatus.TimedOut]: Icons.ClockCircleOutlined, // Clock to indicate timeout
+ [TaskStatus.Aborting]: Icons.LoadingOutlined, // Spinning to show in-progress abort
+ [TaskStatus.Aborted]: Icons.StopOutlined,
+};
+
+const statusLabels = {
+ [TaskStatus.Pending]: t('Pending'),
+ [TaskStatus.InProgress]: t('In Progress'),
+ [TaskStatus.Success]: t('Success'),
+ [TaskStatus.Failure]: t('Failed'),
+ [TaskStatus.TimedOut]: t('Timed Out'),
+ [TaskStatus.Aborting]: t('Aborting'),
+ [TaskStatus.Aborted]: t('Aborted'),
+};
+
+interface TaskStatusIconProps {
+ status: TaskStatus;
+ progressPercent?: number | null;
+ progressCurrent?: number | null;
+ progressTotal?: number | null;
+ durationSeconds?: number | null;
+ errorMessage?: string | null;
+ exceptionType?: string | null;
+}
+
+export default function TaskStatusIcon({
+ status,
+ progressPercent,
+ progressCurrent,
+ progressTotal,
+ durationSeconds,
+ errorMessage,
+ exceptionType,
+}: TaskStatusIconProps) {
+ const theme = useTheme();
+ const IconComponent = statusIcons[status];
+ const label = statusLabels[status];
+
+ // Build tooltip content based on status
+ let tooltipContent: React.ReactNode;
+ if (status === TaskStatus.InProgress || status === TaskStatus.Aborting) {
+ // Progress tooltip for active tasks (multiline)
+ const lines = formatProgressTooltip(
+ label,
+ progressCurrent,
+ progressTotal,
+ progressPercent,
+ durationSeconds,
+ );
+ tooltipContent = (
+ <>
+ {lines.map((line, index) => (
+
+ {index > 0 &&
}
+ {line}
+
+ ))}
+ >
+ );
+ } else if (
+ (status === TaskStatus.Failure || status === TaskStatus.TimedOut) &&
+ (exceptionType || errorMessage)
+ ) {
+ // Error tooltip for failed/timed out tasks: "Label (ExceptionType): message"
+ if (exceptionType && errorMessage) {
+ tooltipContent = `${label} (${exceptionType}): ${errorMessage}`;
+ } else if (exceptionType) {
+ tooltipContent = `${label} (${exceptionType})`;
+ } else if (errorMessage) {
+ tooltipContent = `${label}: ${errorMessage}`;
+ } else {
+ tooltipContent = label;
+ }
+ } else {
+ tooltipContent = label;
+ }
+
+ // Spin for in-progress and aborting states
+ const shouldSpin =
+ status === TaskStatus.InProgress || status === TaskStatus.Aborting;
+
+ return (
+
+
+
+
+
+ );
+}
diff --git a/superset-frontend/src/features/tasks/timeUtils.test.ts b/superset-frontend/src/features/tasks/timeUtils.test.ts
new file mode 100644
index 000000000000..3a286a4140b7
--- /dev/null
+++ b/superset-frontend/src/features/tasks/timeUtils.test.ts
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import {
+ formatDuration,
+ calculateEta,
+ formatProgressTooltip,
+} from './timeUtils';
+
+test('formatDuration returns null for invalid inputs', () => {
+ expect(formatDuration(null)).toBeNull();
+ expect(formatDuration(undefined)).toBeNull();
+ expect(formatDuration(0)).toBeNull();
+ expect(formatDuration(-5)).toBeNull();
+});
+
+test('formatDuration formats seconds correctly', () => {
+ expect(formatDuration(37.5)).toBe('37s');
+ expect(formatDuration(1)).toBe('1s');
+ expect(formatDuration(30)).toBe('30s');
+});
+
+test('formatDuration formats minutes correctly', () => {
+ expect(formatDuration(60)).toBe('1m');
+ expect(formatDuration(90)).toBe('1m 30s');
+ expect(formatDuration(150)).toBe('2m 30s');
+});
+
+test('formatDuration formats hours correctly', () => {
+ expect(formatDuration(3600)).toBe('1h');
+ expect(formatDuration(3660)).toBe('1h 1m');
+ expect(formatDuration(7200)).toBe('2h');
+});
+
+test('calculateEta returns null for invalid inputs', () => {
+ expect(calculateEta(null, 60)).toBeNull();
+ expect(calculateEta(undefined, 60)).toBeNull();
+ expect(calculateEta(0.5, null)).toBeNull();
+ expect(calculateEta(0.5, undefined)).toBeNull();
+});
+
+test('calculateEta returns null for edge case progress values', () => {
+ // No progress yet
+ expect(calculateEta(0, 60)).toBeNull();
+ // Already complete
+ expect(calculateEta(1, 60)).toBeNull();
+ // Negative progress (invalid)
+ expect(calculateEta(-0.1, 60)).toBeNull();
+ // Over 100% (invalid)
+ expect(calculateEta(1.1, 60)).toBeNull();
+});
+
+test('calculateEta calculates correct remaining time', () => {
+ // 50% done in 60s -> ETA = 60s remaining
+ expect(calculateEta(0.5, 60)).toBe('1m');
+
+ // 30% done in 60s -> remaining = (60/0.3) * 0.7 = 140s
+ expect(calculateEta(0.3, 60)).toBe('2m 20s');
+
+ // 10% done in 10s -> remaining = (10/0.1) * 0.9 = 90s
+ expect(calculateEta(0.1, 10)).toBe('1m 30s');
+
+ // 90% done in 90s -> remaining = (90/0.9) * 0.1 = 10s
+ expect(calculateEta(0.9, 90)).toBe('10s');
+});
+
+test('calculateEta returns null for ETAs over 24 hours', () => {
+ // 0.1% done in 100s -> remaining = (100/0.001) * 0.999 = ~99900s > 86400s
+ expect(calculateEta(0.001, 100)).toBeNull();
+});
+
+test('formatProgressTooltip returns label only when no progress data', () => {
+ expect(formatProgressTooltip('In Progress')).toEqual(['In Progress']);
+ expect(formatProgressTooltip('In Progress', null, null, null, null)).toEqual([
+ 'In Progress',
+ ]);
+});
+
+test('formatProgressTooltip formats count and total correctly', () => {
+ expect(formatProgressTooltip('In Progress', 9, 60)).toEqual([
+ 'In Progress: 9 of 60',
+ ]);
+});
+
+test('formatProgressTooltip formats count only correctly', () => {
+ expect(formatProgressTooltip('In Progress', 42)).toEqual([
+ 'In Progress: 42 processed',
+ ]);
+ expect(formatProgressTooltip('In Progress', 42, null)).toEqual([
+ 'In Progress: 42 processed',
+ ]);
+});
+
+test('formatProgressTooltip formats percentage correctly', () => {
+ expect(formatProgressTooltip('In Progress', null, null, 0.5)).toEqual([
+ 'In Progress: 50%',
+ ]);
+ expect(formatProgressTooltip('In Progress', null, null, 0.333)).toEqual([
+ 'In Progress: 33%',
+ ]);
+});
+
+test('formatProgressTooltip combines count, total, and percentage', () => {
+ expect(formatProgressTooltip('In Progress', 9, 60, 0.15)).toEqual([
+ 'In Progress: 9 of 60 (15%)',
+ ]);
+});
+
+test('formatProgressTooltip includes ETA when duration is provided', () => {
+ // 50% done in 60s -> ETA = 60s = ~1m
+ expect(formatProgressTooltip('In Progress', 30, 60, 0.5, 60)).toEqual([
+ 'In Progress: 30 of 60 (50%)',
+ 'ETA: 1m',
+ ]);
+});
+
+test('formatProgressTooltip works with percentage and ETA only', () => {
+ // 25% done in 30s -> ETA = (30/0.25) * 0.75 = 90s = 1m 30s
+ expect(formatProgressTooltip('In Progress', null, null, 0.25, 30)).toEqual([
+ 'In Progress: 25%',
+ 'ETA: 1m 30s',
+ ]);
+});
+
+test('formatProgressTooltip works with different labels', () => {
+ expect(formatProgressTooltip('Aborting', 5, 10, 0.5)).toEqual([
+ 'Aborting: 5 of 10 (50%)',
+ ]);
+});
diff --git a/superset-frontend/src/features/tasks/timeUtils.ts b/superset-frontend/src/features/tasks/timeUtils.ts
new file mode 100644
index 000000000000..4a52f6f48141
--- /dev/null
+++ b/superset-frontend/src/features/tasks/timeUtils.ts
@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import prettyMs from 'pretty-ms';
+
+/**
+ * Maximum ETA to display (24 hours in seconds).
+ * ETAs beyond this are not shown as they're unreliable.
+ */
+const MAX_ETA_SECONDS = 86400;
+
+/**
+ * Format a duration in seconds to a human-readable string.
+ *
+ * @param seconds - Duration in seconds
+ * @returns Formatted string like "1m 30s" or "2h 15m", or null if invalid
+ */
+export function formatDuration(
+ seconds: number | null | undefined,
+): string | null {
+ if (seconds === null || seconds === undefined || seconds <= 0) {
+ return null;
+ }
+
+ return prettyMs(seconds * 1000, {
+ unitCount: 2,
+ secondsDecimalDigits: 0,
+ keepDecimalsOnWholeSeconds: false,
+ });
+}
+
+/**
+ * Calculate and format estimated time to completion based on progress and elapsed time.
+ *
+ * Uses the formula: ETA = (elapsed / progress) * (1 - progress)
+ * For example, if 30% done in 60s, remaining = (60/0.3) * 0.7 = 140s
+ *
+ * @param progressPercent - Progress as a fraction (0.0 to 1.0)
+ * @param durationSeconds - Time elapsed so far in seconds
+ * @returns Formatted ETA string or null if cannot be calculated
+ */
+export function calculateEta(
+ progressPercent: number | null | undefined,
+ durationSeconds: number | null | undefined,
+): string | null {
+ // Need both progress and duration to calculate ETA
+ if (
+ progressPercent === null ||
+ progressPercent === undefined ||
+ durationSeconds === null ||
+ durationSeconds === undefined
+ ) {
+ return null;
+ }
+
+ // Can't calculate ETA if no progress yet or already complete
+ if (progressPercent <= 0 || progressPercent >= 1) {
+ return null;
+ }
+
+ // ETA = (elapsed / progress) * (1 - progress)
+ const estimatedTotalTime = durationSeconds / progressPercent;
+ const remainingSeconds = estimatedTotalTime * (1 - progressPercent);
+
+ // Only show ETA if it's reasonable (less than 24 hours)
+ if (remainingSeconds <= 0 || remainingSeconds > MAX_ETA_SECONDS) {
+ return null;
+ }
+
+ // Use unitCount: 2 to show up to 2 units (e.g., "1m 30s" instead of just "1m")
+ // Use secondsDecimalDigits: 0 to show whole seconds (e.g., "52s" instead of "52.4s")
+ return prettyMs(remainingSeconds * 1000, {
+ unitCount: 2,
+ secondsDecimalDigits: 0,
+ });
+}
+
+/**
+ * Build a progress display for task status tooltips.
+ *
+ * Returns an array of lines for proper multiline tooltip rendering:
+ * - ["In Progress: 9 of 60 (15%)", "ETA: 51s"]
+ * - ["In Progress: 42 processed"]
+ * - ["In Progress: 50%"]
+ * - ["In Progress: 50%", "ETA: 2m"]
+ *
+ * @param label - Status label (e.g., "In Progress", "Aborting")
+ * @param progressCurrent - Current count of items processed
+ * @param progressTotal - Total count of items to process
+ * @param progressPercent - Progress as a fraction (0.0 to 1.0)
+ * @param durationSeconds - Time elapsed so far in seconds (used for ETA calculation)
+ * @returns Array of lines for tooltip display
+ */
+export function formatProgressTooltip(
+ label: string,
+ progressCurrent?: number | null,
+ progressTotal?: number | null,
+ progressPercent?: number | null,
+ durationSeconds?: number | null,
+): string[] {
+ const lines: string[] = [];
+ let progressPart = '';
+
+ // Build progress part
+ if (progressCurrent !== null && progressCurrent !== undefined) {
+ if (progressTotal !== null && progressTotal !== undefined) {
+ // Count and total with percentage: "3 of 278 (15%)"
+ progressPart = `${progressCurrent} of ${progressTotal}`;
+ if (progressPercent !== null && progressPercent !== undefined) {
+ progressPart += ` (${Math.round(progressPercent * 100)}%)`;
+ }
+ } else {
+ // Count only: "3 processed"
+ progressPart = `${progressCurrent} processed`;
+ }
+ } else if (progressPercent !== null && progressPercent !== undefined) {
+ // Percentage only: "50%"
+ progressPart = `${Math.round(progressPercent * 100)}%`;
+ }
+
+ // Add the main progress line
+ if (progressPart) {
+ lines.push(`${label}: ${progressPart}`);
+ } else {
+ lines.push(label);
+ }
+
+ // Add ETA on a separate line if available
+ const eta = calculateEta(progressPercent, durationSeconds);
+ if (eta) {
+ lines.push(`ETA: ${eta}`);
+ }
+
+ return lines;
+}
diff --git a/superset-frontend/src/features/tasks/types.ts b/superset-frontend/src/features/tasks/types.ts
new file mode 100644
index 000000000000..27d129cb8563
--- /dev/null
+++ b/superset-frontend/src/features/tasks/types.ts
@@ -0,0 +1,115 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+export interface TaskSubscriber {
+ user_id: number;
+ first_name: string;
+ last_name: string;
+ subscribed_at: string;
+}
+
+export enum TaskScope {
+ Private = 'private',
+ Shared = 'shared',
+ System = 'system',
+}
+
+/**
+ * Task properties - runtime state and execution config stored in JSON blob.
+ */
+export interface TaskProperties {
+ // Execution config - set at task creation
+ execution_mode: 'async' | 'sync' | null;
+ timeout: number | null;
+
+ // Runtime state - set by framework during execution
+ is_abortable: boolean | null;
+ progress_percent: number | null;
+ progress_current: number | null;
+ progress_total: number | null;
+
+ // Error info - set when task fails
+ error_message: string | null;
+ exception_type: string | null;
+ stack_trace: string | null;
+}
+
+export interface Task {
+ id: number;
+ uuid: string;
+ task_key: string;
+ task_type: string;
+ task_name: string | null;
+ status:
+ | 'pending'
+ | 'in_progress'
+ | 'success'
+ | 'failure'
+ | 'aborting'
+ | 'aborted'
+ | 'timed_out';
+ scope: TaskScope;
+ created_on: string;
+ created_on_delta_humanized?: string;
+ changed_on: string;
+ started_at: string | null;
+ ended_at: string | null;
+ created_by: {
+ id: number;
+ first_name: string;
+ last_name: string;
+ } | null;
+ changed_by?: {
+ first_name: string;
+ last_name: string;
+ } | null;
+ user_id: number | null;
+ payload: Record;
+ properties: TaskProperties;
+ duration_seconds: number | null;
+ subscriber_count: number;
+ subscribers: TaskSubscriber[];
+}
+
+// Derived status helpers (frontend computes these from status and properties)
+export function isTaskFinished(task: Task): boolean {
+ return ['success', 'failure', 'aborted', 'timed_out'].includes(task.status);
+}
+
+export function isTaskAborting(task: Task): boolean {
+ return task.status === 'aborting';
+}
+
+export function canAbortTask(task: Task): boolean {
+ if (task.status === 'pending') return true;
+ if (task.status === 'in_progress' && task.properties.is_abortable === true)
+ return true;
+ if (task.status === 'aborting') return true; // Idempotent
+ return false;
+}
+
+export enum TaskStatus {
+ Pending = 'pending',
+ InProgress = 'in_progress',
+ Success = 'success',
+ Failure = 'failure',
+ Aborting = 'aborting',
+ Aborted = 'aborted',
+ TimedOut = 'timed_out',
+}
diff --git a/superset-frontend/src/pages/TaskList/TaskList.test.tsx b/superset-frontend/src/pages/TaskList/TaskList.test.tsx
new file mode 100644
index 000000000000..f7b85c42e749
--- /dev/null
+++ b/superset-frontend/src/pages/TaskList/TaskList.test.tsx
@@ -0,0 +1,328 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import { MemoryRouter } from 'react-router-dom';
+import fetchMock from 'fetch-mock';
+import {
+ render,
+ screen,
+ waitFor,
+ fireEvent,
+} from 'spec/helpers/testing-library';
+import { QueryParamProvider } from 'use-query-params';
+import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
+import { TaskStatus, TaskScope } from 'src/features/tasks/types';
+import TaskList from 'src/pages/TaskList';
+
+// Set up window.featureFlags before importing TaskList
+window.featureFlags = { GLOBAL_TASK_FRAMEWORK: true };
+
+// Mock getBootstrapData before importing components that use it
+jest.mock('src/utils/getBootstrapData', () => ({
+ __esModule: true,
+ default: () => ({
+ user: {
+ userId: 1,
+ firstName: 'admin',
+ lastName: 'user',
+ roles: { Admin: [] },
+ },
+ common: {
+ feature_flags: { GLOBAL_TASK_FRAMEWORK: true },
+ conf: {},
+ },
+ }),
+}));
+
+const tasksInfoEndpoint = 'glob:*/api/v1/task/_info*';
+const tasksCreatedByEndpoint = 'glob:*/api/v1/task/related/created_by*';
+const tasksEndpoint = 'glob:*/api/v1/task/?*';
+const taskCancelEndpoint = 'glob:*/api/v1/task/*/cancel';
+
+const mockTasks = [
+ {
+ id: 1,
+ uuid: 'task-uuid-1',
+ task_key: 'test_task_1',
+ task_type: 'data_export',
+ task_name: 'Export Data Task',
+ status: TaskStatus.Success,
+ scope: TaskScope.Private,
+ created_on: '2024-01-15T10:00:00Z',
+ changed_on: '2024-01-15T10:05:00Z',
+ created_on_delta_humanized: '5 minutes ago',
+ started_at: '2024-01-15T10:00:01Z',
+ ended_at: '2024-01-15T10:05:00Z',
+ created_by: { id: 1, first_name: 'admin', last_name: 'user' },
+ user_id: 1,
+ payload: {},
+ duration_seconds: 299,
+ subscriber_count: 0,
+ subscribers: [],
+ properties: {
+ is_abortable: null,
+ progress_percent: 1.0,
+ progress_current: null,
+ progress_total: null,
+ error_message: null,
+ exception_type: null,
+ stack_trace: null,
+ timeout: null,
+ },
+ },
+ {
+ id: 2,
+ uuid: 'task-uuid-2',
+ task_key: 'test_task_2',
+ task_type: 'report_generation',
+ task_name: null,
+ status: TaskStatus.InProgress,
+ scope: TaskScope.Private,
+ created_on: '2024-01-15T11:00:00Z',
+ changed_on: '2024-01-15T11:00:00Z',
+ created_on_delta_humanized: '1 minute ago',
+ started_at: '2024-01-15T11:00:01Z',
+ ended_at: null,
+ created_by: { id: 1, first_name: 'admin', last_name: 'user' },
+ user_id: 1,
+ payload: { report_id: 42 },
+ duration_seconds: null,
+ subscriber_count: 0,
+ subscribers: [],
+ properties: {
+ is_abortable: true,
+ progress_percent: 0.5,
+ progress_current: null,
+ progress_total: null,
+ error_message: null,
+ exception_type: null,
+ stack_trace: null,
+ timeout: null,
+ },
+ },
+ {
+ id: 3,
+ uuid: 'task-uuid-3',
+ task_key: 'shared_task_1',
+ task_type: 'bulk_operation',
+ task_name: 'Shared Bulk Task',
+ status: TaskStatus.Pending,
+ scope: TaskScope.Shared,
+ created_on: '2024-01-15T12:00:00Z',
+ changed_on: '2024-01-15T12:00:00Z',
+ created_on_delta_humanized: 'just now',
+ started_at: null,
+ ended_at: null,
+ created_by: { id: 2, first_name: 'other', last_name: 'user' },
+ user_id: 2,
+ payload: {},
+ duration_seconds: null,
+ subscriber_count: 2,
+ subscribers: [
+ {
+ user_id: 1,
+ first_name: 'admin',
+ last_name: 'user',
+ subscribed_at: '2024-01-15T12:00:00Z',
+ },
+ {
+ user_id: 2,
+ first_name: 'other',
+ last_name: 'user',
+ subscribed_at: '2024-01-15T12:00:01Z',
+ },
+ ],
+ properties: {
+ is_abortable: null,
+ progress_percent: null,
+ progress_current: null,
+ progress_total: null,
+ error_message: null,
+ exception_type: null,
+ stack_trace: null,
+ timeout: null,
+ },
+ },
+];
+
+const mockUser = {
+ userId: 1,
+ firstName: 'admin',
+ lastName: 'user',
+};
+
+fetchMock.get(
+ tasksInfoEndpoint,
+ { permissions: ['can_read', 'can_write'] },
+ { name: tasksInfoEndpoint },
+);
+fetchMock.get(
+ tasksCreatedByEndpoint,
+ { result: [] },
+ { name: tasksCreatedByEndpoint },
+);
+fetchMock.get(
+ tasksEndpoint,
+ { result: mockTasks, count: 3 },
+ { name: tasksEndpoint },
+);
+fetchMock.post(
+ taskCancelEndpoint,
+ { action: 'aborted', message: 'Task cancelled' },
+ { name: taskCancelEndpoint },
+);
+
+const renderTaskList = (props = {}, userProp = mockUser) =>
+ render(
+
+
+
+
+ ,
+ { useRedux: true },
+ );
+
+beforeEach(() => {
+ fetchMock.clearHistory();
+});
+
+test('renders TaskList with title, ListView, and fetches data from endpoints', async () => {
+ renderTaskList();
+
+ // Wait for data to load and verify title
+ expect(await screen.findByText('Tasks')).toBeInTheDocument();
+ expect(screen.getByTestId('task-list-view')).toBeInTheDocument();
+
+ // Verify API calls were made
+ expect(fetchMock.callHistory.calls(/task\/_info/).length).toBe(1);
+ expect(fetchMock.callHistory.calls(/task\/\?q/).length).toBe(1);
+});
+
+test('displays task data including types, scope labels, and duration', async () => {
+ renderTaskList();
+
+ // Wait for data to load
+ await screen.findByText('Export Data Task');
+
+ // Task types
+ expect(screen.getByText('data_export')).toBeInTheDocument();
+ expect(screen.getByText('report_generation')).toBeInTheDocument();
+
+ // Scope labels
+ expect(screen.getAllByText('Private').length).toBeGreaterThan(0);
+ expect(screen.getByText('Shared')).toBeInTheDocument();
+
+ // Duration (299s = 4m 59s via prettyMs)
+ expect(screen.getByText('4m 59s')).toBeInTheDocument();
+});
+
+test('shows cancel button and modal for cancellable tasks', async () => {
+ renderTaskList();
+
+ // Wait for data to load
+ await screen.findByText('test_task_2');
+
+ // Cancel buttons exist for in-progress and shared tasks
+ const stopIcons = screen.getAllByRole('img', { name: 'stop' });
+ expect(stopIcons.length).toBeGreaterThan(0);
+
+ // Click a cancel button to show confirmation modal
+ const cancelButton = stopIcons.find(
+ icon => icon.closest('[role="button"]') !== null,
+ );
+ expect(cancelButton).toBeDefined();
+ fireEvent.click(cancelButton!);
+
+ expect(await screen.findByText('Cancel Task')).toBeInTheDocument();
+});
+
+test('does not show cancel button for completed shared tasks', async () => {
+ const completedSharedTask = {
+ id: 4,
+ uuid: 'task-uuid-4',
+ task_key: 'completed_shared_task',
+ task_type: 'bulk_operation',
+ task_name: 'Completed Shared Task',
+ status: TaskStatus.Success,
+ scope: TaskScope.Shared,
+ created_on: '2024-01-15T12:00:00Z',
+ changed_on: '2024-01-15T12:05:00Z',
+ created_on_delta_humanized: '5 minutes ago',
+ started_at: '2024-01-15T12:00:01Z',
+ ended_at: '2024-01-15T12:05:00Z',
+ created_by: { id: 2, first_name: 'other', last_name: 'user' },
+ user_id: 2,
+ payload: {},
+ duration_seconds: 299,
+ subscriber_count: 1,
+ subscribers: [
+ {
+ user_id: 1,
+ first_name: 'admin',
+ last_name: 'user',
+ subscribed_at: '2024-01-15T12:00:00Z',
+ },
+ ],
+ properties: {
+ is_abortable: null,
+ progress_percent: 1.0,
+ progress_current: null,
+ progress_total: null,
+ error_message: null,
+ exception_type: null,
+ stack_trace: null,
+ timeout: null,
+ },
+ };
+
+ fetchMock.modifyRoute(tasksEndpoint, {
+ response: { result: [completedSharedTask], count: 1 },
+ });
+
+ renderTaskList();
+ await screen.findByText('Completed Shared Task');
+
+ // No action buttons with stop icons for completed tasks
+ const stopIcons = screen.queryAllByRole('img', { name: 'stop' });
+ const actionButtons = stopIcons.filter(
+ icon => icon.closest('[role="button"]') !== null,
+ );
+ expect(actionButtons).toHaveLength(0);
+
+ // Restore mock
+ fetchMock.modifyRoute(tasksEndpoint, {
+ response: { result: mockTasks, count: 3 },
+ });
+});
+
+test('displays empty state when no tasks', async () => {
+ fetchMock.modifyRoute(tasksEndpoint, {
+ response: { result: [], count: 0 },
+ });
+
+ renderTaskList();
+
+ await waitFor(() => {
+ expect(screen.getByText('No tasks yet')).toBeInTheDocument();
+ });
+
+ // Restore mock
+ fetchMock.modifyRoute(tasksEndpoint, {
+ response: { result: mockTasks, count: 3 },
+ });
+});
diff --git a/superset-frontend/src/pages/TaskList/index.tsx b/superset-frontend/src/pages/TaskList/index.tsx
new file mode 100644
index 000000000000..31590001eaef
--- /dev/null
+++ b/superset-frontend/src/pages/TaskList/index.tsx
@@ -0,0 +1,658 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import {
+ FeatureFlag,
+ isFeatureEnabled,
+ SupersetClient,
+} from '@superset-ui/core';
+import { t, useTheme } from '@apache-superset/core';
+import { useMemo, useCallback, useState } from 'react';
+import { Tooltip, Label, Modal, Checkbox } from '@superset-ui/core/components';
+import {
+ CreatedInfo,
+ ListView,
+ ListViewFilterOperator as FilterOperator,
+ type ListViewFilters,
+ FacePile,
+} from 'src/components';
+import { Icons } from '@superset-ui/core/components/Icons';
+import withToasts from 'src/components/MessageToasts/withToasts';
+import SubMenu from 'src/features/home/SubMenu';
+import { useListViewResource } from 'src/views/CRUD/hooks';
+import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils';
+import TaskStatusIcon from 'src/features/tasks/TaskStatusIcon';
+import TaskPayloadPopover from 'src/features/tasks/TaskPayloadPopover';
+import TaskStackTracePopover from 'src/features/tasks/TaskStackTracePopover';
+import { formatDuration } from 'src/features/tasks/timeUtils';
+import {
+ Task,
+ TaskStatus,
+ TaskScope,
+ canAbortTask,
+ isTaskAborting,
+ TaskSubscriber,
+} from 'src/features/tasks/types';
+import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
+import getBootstrapData from 'src/utils/getBootstrapData';
+
+const PAGE_SIZE = 25;
+
+/**
+ * Typed cell props for react-table columns.
+ * Replaces `: any` for better type safety in Cell render functions.
+ */
+interface TaskCellProps {
+ row: {
+ original: Task;
+ };
+}
+
+interface TaskListProps {
+ addDangerToast: (msg: string) => void;
+ addSuccessToast: (msg: string) => void;
+ user: {
+ userId: string | number;
+ firstName: string;
+ lastName: string;
+ };
+}
+
+function TaskList({ addDangerToast, addSuccessToast, user }: TaskListProps) {
+ const theme = useTheme();
+
+ // Check if GTF feature flag is enabled
+ if (!isFeatureEnabled(FeatureFlag.GlobalTaskFramework)) {
+ return (
+ <>
+
+
+
{t('Feature Not Enabled')}
+
+ {t(
+ 'The Global Task Framework is not enabled. Please contact your administrator to enable the GLOBAL_TASK_FRAMEWORK feature flag.',
+ )}
+
+
+ >
+ );
+ }
+
+ const {
+ state: { loading, resourceCount: tasksCount, resourceCollection: tasks },
+ fetchData,
+ refreshData,
+ } = useListViewResource('task', t('task'), addDangerToast);
+
+ // Get full user with roles to check admin status
+ const bootstrapData = getBootstrapData();
+ const fullUser = bootstrapData?.user;
+ const isAdmin = useMemo(() => isUserAdmin(fullUser), [fullUser]);
+
+ // State for cancel confirmation modal
+ const [cancelModalTask, setCancelModalTask] = useState(null);
+ const [forceCancel, setForceCancel] = useState(false);
+
+ // Determine dialog message based on task context
+ const getCancelDialogMessage = useCallback((task: Task) => {
+ const isSharedTask = task.scope === TaskScope.Shared;
+ const subscriberCount = task.subscriber_count || 0;
+ const otherSubscribers = subscriberCount - 1;
+
+ // If it's going to abort (private, system, or last subscriber)
+ if (!isSharedTask || subscriberCount <= 1) {
+ return t('This will cancel the task.');
+ }
+
+ // Shared task with multiple subscribers
+ return t(
+ "You'll be removed from this task. It will continue running for %s other subscriber(s).",
+ otherSubscribers,
+ );
+ }, []);
+
+ // Get force abort message for admin checkbox
+ const getForceAbortMessage = useCallback((task: Task) => {
+ const subscriberCount = task.subscriber_count || 0;
+ return t(
+ 'This will abort (stop) the task for all %s subscriber(s).',
+ subscriberCount,
+ );
+ }, []);
+
+ // Check if current user is subscribed to a task
+ const isUserSubscribed = useCallback(
+ (task: Task) =>
+ task.subscribers?.some(
+ (sub: TaskSubscriber) => sub.user_id === user.userId,
+ ) ?? false,
+ [user.userId],
+ );
+
+ // Check if force cancel option should be shown (for admins on shared tasks)
+ const showForceCancelOption = useCallback(
+ (task: Task) => {
+ const isSharedTask = task.scope === TaskScope.Shared;
+ const subscriberCount = task.subscriber_count || 0;
+ const userSubscribed = isUserSubscribed(task);
+ // Show for admins on shared tasks when:
+ // - Not subscribed (can only abort, so show checkbox pre-checked disabled), OR
+ // - Multiple subscribers (can choose between unsubscribe and force abort)
+ // Don't show when admin is the sole subscriber - cancel will abort anyway
+ return (
+ isAdmin && isSharedTask && (subscriberCount > 1 || !userSubscribed)
+ );
+ },
+ [isAdmin, isUserSubscribed],
+ );
+
+ // Check if force cancel checkbox should be disabled (admin not subscribed)
+ const isForceCancelDisabled = useCallback(
+ (task: Task) => isAdmin && !isUserSubscribed(task),
+ [isAdmin, isUserSubscribed],
+ );
+
+ const handleTaskCancel = useCallback(
+ (task: Task, force: boolean = false) => {
+ SupersetClient.post({
+ endpoint: `/api/v1/task/${task.uuid}/cancel`,
+ jsonPayload: force ? { force: true } : {},
+ }).then(
+ ({ json }) => {
+ refreshData();
+ const { action } = json as { action: string };
+ if (action === 'aborted') {
+ addSuccessToast(
+ t('Task cancelled: %s', task.task_name || task.task_key),
+ );
+ } else {
+ addSuccessToast(
+ t(
+ 'You have been removed from task: %s',
+ task.task_name || task.task_key,
+ ),
+ );
+ }
+ },
+ createErrorHandler(errMsg =>
+ addDangerToast(
+ t('There was an issue cancelling the task: %s', errMsg),
+ ),
+ ),
+ );
+ },
+ [addDangerToast, addSuccessToast, refreshData],
+ );
+
+ // Handle opening the cancel modal - set initial forceCancel state
+ const openCancelModal = useCallback(
+ (task: Task) => {
+ // Pre-check force cancel if admin is not subscribed
+ const shouldPreCheck = isAdmin && !isUserSubscribed(task);
+ setForceCancel(shouldPreCheck);
+ setCancelModalTask(task);
+ },
+ [isAdmin, isUserSubscribed],
+ );
+
+ // Handle modal confirmation
+ const handleCancelConfirm = useCallback(() => {
+ if (cancelModalTask) {
+ handleTaskCancel(cancelModalTask, forceCancel);
+ setCancelModalTask(null);
+ setForceCancel(false);
+ }
+ }, [cancelModalTask, forceCancel, handleTaskCancel]);
+
+ // Handle modal close
+ const handleCancelModalClose = useCallback(() => {
+ setCancelModalTask(null);
+ setForceCancel(false);
+ }, []);
+
+ const columns = useMemo(
+ () => [
+ {
+ Cell: ({
+ row: {
+ original: { task_name, task_key, uuid },
+ },
+ }: TaskCellProps) => {
+ // Display preference: task_name > task_key
+ const displayText = task_name || task_key;
+ const truncated =
+ displayText.length > 30
+ ? `${displayText.slice(0, 30)}...`
+ : displayText;
+
+ // Build tooltip with all identifiers
+ const tooltipLines = [];
+ if (task_name) tooltipLines.push(`Name: ${task_name}`);
+ tooltipLines.push(`Key: ${task_key}`);
+ tooltipLines.push(`UUID: ${uuid}`);
+ const tooltipText = tooltipLines.join('\n');
+
+ return (
+ {tooltipText}
+ }
+ placement="top"
+ >
+ {truncated}
+
+ );
+ },
+ accessor: 'task_name',
+ Header: t('Task'),
+ size: 'xl',
+ id: 'task',
+ },
+ {
+ Cell: ({
+ row: {
+ original: { status, properties, duration_seconds },
+ },
+ }: TaskCellProps) => (
+
+ ),
+ accessor: 'status',
+ Header: t('Status'),
+ size: 'xs',
+ id: 'status',
+ },
+ {
+ accessor: 'task_type',
+ Header: t('Type'),
+ size: 'md',
+ id: 'task_type',
+ },
+ {
+ Cell: ({
+ row: {
+ original: { scope },
+ },
+ }: TaskCellProps) => {
+ const scopeConfig: Record<
+ TaskScope,
+ { label: string; type: 'default' | 'info' | 'warning' }
+ > = {
+ [TaskScope.Private]: { label: t('Private'), type: 'default' },
+ [TaskScope.Shared]: { label: t('Shared'), type: 'info' },
+ [TaskScope.System]: { label: t('System'), type: 'warning' },
+ };
+
+ const config = scopeConfig[scope as TaskScope] || {
+ label: scope,
+ type: 'default' as const,
+ };
+
+ return ;
+ },
+ accessor: 'scope',
+ Header: t('Scope'),
+ size: 'sm',
+ id: 'scope',
+ },
+ {
+ Cell: ({
+ row: {
+ original: { subscriber_count, subscribers },
+ },
+ }: TaskCellProps) => {
+ if (!subscribers || subscriber_count === 0) {
+ return '-';
+ }
+
+ // Convert subscribers to FacePile format
+ const users = subscribers.map((sub: TaskSubscriber) => ({
+ id: sub.user_id,
+ first_name: sub.first_name,
+ last_name: sub.last_name,
+ }));
+
+ return ;
+ },
+ accessor: 'subscriber_count',
+ Header: t('Subscribers'),
+ size: 'md',
+ id: 'subscribers',
+ disableSortBy: true,
+ },
+ {
+ Cell: ({
+ row: {
+ original: {
+ created_on_delta_humanized: createdOn,
+ created_by: createdBy,
+ },
+ },
+ }: TaskCellProps) => (
+
+ ),
+ Header: t('Created'),
+ accessor: 'created_on',
+ size: 'xl',
+ id: 'created_on',
+ },
+ {
+ // Hidden column for filtering by created_by
+ accessor: 'created_by',
+ id: 'created_by',
+ hidden: true,
+ },
+ {
+ Cell: ({
+ row: {
+ original: { duration_seconds },
+ },
+ }: TaskCellProps) => formatDuration(duration_seconds) ?? '-',
+ accessor: 'duration_seconds',
+ Header: t('Duration'),
+ size: 'sm',
+ id: 'duration_seconds',
+ disableSortBy: true,
+ },
+ {
+ Cell: ({
+ row: {
+ original: { payload, properties, status },
+ },
+ }: TaskCellProps) => {
+ const hasPayload = payload && Object.keys(payload).length > 0;
+ const hasStackTrace = !!properties?.stack_trace;
+
+ // Show warning if timeout is set but no abort handler during execution
+ // Only show for IN_PROGRESS (abort handler registers at runtime, not during PENDING)
+ const hasTimeoutWithoutHandler =
+ status === TaskStatus.InProgress &&
+ properties?.timeout &&
+ !properties?.is_abortable;
+
+ if (!hasPayload && !hasStackTrace && !hasTimeoutWithoutHandler) {
+ return null;
+ }
+
+ return (
+
+ {hasTimeoutWithoutHandler && (
+
+
+
+
+
+ )}
+ {hasPayload && }
+ {hasStackTrace && properties.stack_trace && (
+
+ )}
+
+ );
+ },
+ accessor: 'payload',
+ Header: t('Details'),
+ size: 'xs',
+ id: 'payload',
+ disableSortBy: true,
+ },
+ {
+ Cell: ({ row: { original } }: TaskCellProps) => {
+ // Unified Cancel button logic:
+ // - Show Cancel for any active task that the user can cancel
+ // - The backend handles the smart behavior (unsubscribe vs abort)
+ const isRunning = original.status === TaskStatus.InProgress;
+ // Task is not cancellable if running without abort handler
+ // Use !== true to catch false, undefined, and null
+ const isRunningButNotCancellable =
+ isRunning && !original.properties?.is_abortable;
+
+ const isSharedTask = original.scope === TaskScope.Shared;
+ const userIsSubscribed = original.subscribers?.some(
+ (sub: any) => sub.user_id === user.userId,
+ );
+
+ // Check if task is in a non-active state (completed or aborting)
+ const isNonActiveStatus = [
+ TaskStatus.Success,
+ TaskStatus.Failure,
+ TaskStatus.Aborted,
+ TaskStatus.Aborting,
+ TaskStatus.TimedOut,
+ ].includes(original.status as TaskStatus);
+
+ // Show disabled button for running tasks without abort handler
+ // (only for non-shared tasks or when user is the only subscriber)
+ const showDisabledCancel =
+ isRunningButNotCancellable &&
+ !isNonActiveStatus &&
+ (!isSharedTask || (original.subscriber_count || 0) <= 1);
+
+ // Show Cancel button when:
+ // 1. Task can be aborted (pending, or in-progress with handler), OR
+ // 2. User is subscribed to a shared task (can always unsubscribe)
+ // But NOT when disabled cancel is shown (mutually exclusive)
+ const canCancelTask =
+ !showDisabledCancel &&
+ ((canAbortTask(original) && !isTaskAborting(original)) ||
+ (isSharedTask && userIsSubscribed && !isNonActiveStatus));
+
+ if (!canCancelTask && !showDisabledCancel) {
+ return null;
+ }
+
+ return (
+
+ {showDisabledCancel && (
+
+
+
+
+
+ )}
+ {canCancelTask && (
+
+ openCancelModal(original)}
+ >
+
+
+
+ )}
+
+ );
+ },
+ Header: t('Actions'),
+ id: 'actions',
+ size: 'sm',
+ disableSortBy: true,
+ },
+ ],
+ [user.userId, theme, openCancelModal],
+ );
+
+ const filters: ListViewFilters = useMemo(
+ () => [
+ {
+ Header: t('Status'),
+ key: 'status',
+ id: 'status',
+ input: 'select',
+ operator: FilterOperator.Equals,
+ unfilteredLabel: t('Any'),
+ selects: [
+ { label: t('Pending'), value: TaskStatus.Pending },
+ { label: t('In Progress'), value: TaskStatus.InProgress },
+ { label: t('Success'), value: TaskStatus.Success },
+ { label: t('Failed'), value: TaskStatus.Failure },
+ { label: t('Timed Out'), value: TaskStatus.TimedOut },
+ { label: t('Aborting'), value: TaskStatus.Aborting },
+ { label: t('Aborted'), value: TaskStatus.Aborted },
+ ],
+ },
+ {
+ Header: t('Type'),
+ key: 'task_type',
+ id: 'task_type',
+ input: 'search',
+ operator: FilterOperator.Contains,
+ },
+ {
+ Header: t('Scope'),
+ key: 'scope',
+ id: 'scope',
+ input: 'select',
+ operator: FilterOperator.Equals,
+ unfilteredLabel: t('Any'),
+ selects: [
+ { label: t('Private'), value: TaskScope.Private },
+ { label: t('Shared'), value: TaskScope.Shared },
+ { label: t('System'), value: TaskScope.System },
+ ],
+ },
+ {
+ Header: t('Created by'),
+ key: 'created_by',
+ id: 'created_by',
+ input: 'select',
+ operator: FilterOperator.RelationOneMany,
+ unfilteredLabel: t('All'),
+ fetchSelects: createFetchRelated(
+ 'task',
+ 'created_by',
+ createErrorHandler(errMsg =>
+ addDangerToast(
+ t(
+ 'An error occurred while fetching created by values: %s',
+ errMsg,
+ ),
+ ),
+ ),
+ ),
+ },
+ ],
+ [addDangerToast],
+ );
+
+ const initialSort = [{ id: 'created_on', desc: true }];
+
+ const emptyState = {
+ title: t('No tasks yet'),
+ image: 'filter-results.svg',
+ description: t(
+ 'Tasks will appear here as background operations are executed.',
+ ),
+ };
+
+ return (
+ <>
+
+
+ className="task-list-view"
+ columns={columns}
+ count={tasksCount}
+ data={tasks}
+ emptyState={emptyState}
+ fetchData={fetchData}
+ filters={filters}
+ initialSort={initialSort}
+ loading={loading}
+ pageSize={PAGE_SIZE}
+ refreshData={refreshData}
+ addDangerToast={addDangerToast}
+ addSuccessToast={addSuccessToast}
+ />
+
+ {/* Cancel Confirmation Modal */}
+
+ {cancelModalTask && (
+ <>
+
+ {forceCancel
+ ? getForceAbortMessage(cancelModalTask)
+ : getCancelDialogMessage(cancelModalTask)}
+
+ {showForceCancelOption(cancelModalTask) && (
+ setForceCancel(e.target.checked)}
+ disabled={isForceCancelDisabled(cancelModalTask)}
+ >
+ {t('Force abort (stops task for all subscribers)')}
+
+ )}
+ >
+ )}
+
+ >
+ );
+}
+
+export default withToasts(TaskList);
diff --git a/superset-frontend/src/views/routes.tsx b/superset-frontend/src/views/routes.tsx
index a961c9fb7fcb..674fbc5846c9 100644
--- a/superset-frontend/src/views/routes.tsx
+++ b/superset-frontend/src/views/routes.tsx
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
import {
lazy,
@@ -138,6 +139,10 @@ const RowLevelSecurityList = lazy(
),
);
+const TaskList = lazy(
+ () => import(/* webpackChunkName: "TaskList" */ 'src/pages/TaskList'),
+);
+
const RolesList = lazy(
() => import(/* webpackChunkName: "RolesList" */ 'src/pages/RolesList'),
);
@@ -297,6 +302,10 @@ export const routes: Routes = [
path: '/rowlevelsecurity/list',
Component: RowLevelSecurityList,
},
+ {
+ path: '/tasks/list/',
+ Component: TaskList,
+ },
{
path: '/sqllab/',
Component: SqlLab,
diff --git a/superset/async_events/cache_backend.py b/superset/async_events/cache_backend.py
index 9158e2d119ad..65130a3dcf9c 100644
--- a/superset/async_events/cache_backend.py
+++ b/superset/async_events/cache_backend.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from __future__ import annotations
+
+from typing import Any
import redis
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
@@ -28,15 +30,15 @@ def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
- password: Optional[str] = None,
+ password: str | None = None,
db: int = 0,
default_timeout: int = 300,
- key_prefix: Optional[str] = None,
+ key_prefix: str | None = None,
ssl: bool = False,
- ssl_certfile: Optional[str] = None,
- ssl_keyfile: Optional[str] = None,
+ ssl_certfile: str | None = None,
+ ssl_keyfile: str | None = None,
ssl_cert_reqs: str = "required",
- ssl_ca_certs: Optional[str] = None,
+ ssl_ca_certs: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -61,12 +63,61 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs,
)
+ def set(
+ self,
+ name: str,
+ value: Any,
+ ex: int | None = None,
+ px: int | None = None,
+ nx: bool = False,
+ xx: bool = False,
+ ) -> bool | None:
+ """
+ Set the value at key ``name``.
+
+ :param name: Key name
+ :param value: Value to set
+ :param ex: Expire time in seconds
+ :param px: Expire time in milliseconds
+ :param nx: If True, set only if key does not exist
+ :param xx: If True, set only if key already exists
+ :returns: True if set successfully, None if nx/xx condition not met
+ """
+ return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
+
+ def delete(self, *names: str) -> int:
+ """
+ Delete one or more keys.
+
+ :param names: Key names to delete
+ :returns: Number of keys deleted
+ """
+ return self._cache.delete(*names)
+
+ def publish(self, channel: str, message: str) -> int:
+ """
+ Publish a message to a Redis pub/sub channel.
+
+ :param channel: The channel name to publish to
+ :param message: The message to publish
+ :returns: Number of subscribers that received the message
+ """
+ return self._cache.publish(channel, message)
+
+ def pubsub(self) -> redis.client.PubSub:
+ """
+ Create a pub/sub subscription object.
+
+ :returns: PubSub object for subscribing to channels
+ """
+ return self._cache.pubsub()
+
def xadd(
self,
stream_name: str,
- event_data: Dict[str, Any],
+ event_data: dict[str, Any],
event_id: str = "*",
- maxlen: Optional[int] = None,
+ maxlen: int | None = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
@@ -75,13 +126,13 @@ def xrange(
stream_name: str,
start: str = "-",
end: str = "+",
- count: Optional[int] = None,
- ) -> List[Any]:
+ count: int | None = None,
+ ) -> list[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
- def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
+ def from_config(cls, config: dict[str, Any]) -> RedisCacheBackend:
kwargs = {
"host": config.get("CACHE_REDIS_HOST", "localhost"),
"port": config.get("CACHE_REDIS_PORT", 6379),
@@ -108,18 +159,18 @@ class RedisSentinelCacheBackend(RedisSentinelCache):
def __init__( # pylint: disable=too-many-arguments
self,
- sentinels: List[Tuple[str, int]],
+ sentinels: list[tuple[str, int]],
master: str,
- password: Optional[str] = None,
- sentinel_password: Optional[str] = None,
+ password: str | None = None,
+ sentinel_password: str | None = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: str = "",
ssl: bool = False,
- ssl_certfile: Optional[str] = None,
- ssl_keyfile: Optional[str] = None,
+ ssl_certfile: str | None = None,
+ ssl_keyfile: str | None = None,
ssl_cert_reqs: str = "required",
- ssl_ca_certs: Optional[str] = None,
+ ssl_ca_certs: str | None = None,
**kwargs: Any,
) -> None:
# Sentinel dont directly support SSL
@@ -177,12 +228,61 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs,
)
+ def set(
+ self,
+ name: str,
+ value: Any,
+ ex: int | None = None,
+ px: int | None = None,
+ nx: bool = False,
+ xx: bool = False,
+ ) -> bool | None:
+ """
+ Set the value at key ``name``.
+
+ :param name: Key name
+ :param value: Value to set
+ :param ex: Expire time in seconds
+ :param px: Expire time in milliseconds
+ :param nx: If True, set only if key does not exist
+ :param xx: If True, set only if key already exists
+ :returns: True if set successfully, None if nx/xx condition not met
+ """
+ return self._cache.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
+
+ def delete(self, *names: str) -> int:
+ """
+ Delete one or more keys.
+
+ :param names: Key names to delete
+ :returns: Number of keys deleted
+ """
+ return self._cache.delete(*names)
+
+ def publish(self, channel: str, message: str) -> int:
+ """
+ Publish a message to a Redis pub/sub channel.
+
+ :param channel: The channel name to publish to
+ :param message: The message to publish
+ :returns: Number of subscribers that received the message
+ """
+ return self._cache.publish(channel, message)
+
+ def pubsub(self) -> redis.client.PubSub:
+ """
+ Create a pub/sub subscription object.
+
+ :returns: PubSub object for subscribing to channels
+ """
+ return self._cache.pubsub()
+
def xadd(
self,
stream_name: str,
- event_data: Dict[str, Any],
+ event_data: dict[str, Any],
event_id: str = "*",
- maxlen: Optional[int] = None,
+ maxlen: int | None = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
@@ -191,13 +291,13 @@ def xrange(
stream_name: str,
start: str = "-",
end: str = "+",
- count: Optional[int] = None,
- ) -> List[Any]:
+ count: int | None = None,
+ ) -> list[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
- def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
+ def from_config(cls, config: dict[str, Any]) -> RedisSentinelCacheBackend:
kwargs = {
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),
diff --git a/superset/commands/distributed_lock/__init__.py b/superset/commands/distributed_lock/__init__.py
index e69de29bb2d1..c145aa79c9c8 100644
--- a/superset/commands/distributed_lock/__init__.py
+++ b/superset/commands/distributed_lock/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from superset.commands.distributed_lock.acquire import AcquireDistributedLock
+from superset.commands.distributed_lock.release import ReleaseDistributedLock
+
+__all__ = [
+ "AcquireDistributedLock",
+ "ReleaseDistributedLock",
+]
diff --git a/superset/commands/distributed_lock/acquire.py b/superset/commands/distributed_lock/acquire.py
new file mode 100644
index 000000000000..e06439a49b48
--- /dev/null
+++ b/superset/commands/distributed_lock/acquire.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+from datetime import datetime, timedelta, timezone
+from functools import partial
+from typing import Any
+
+import redis
+from sqlalchemy.exc import SQLAlchemyError
+
+from superset.commands.distributed_lock.base import (
+ BaseDistributedLockCommand,
+ get_default_lock_ttl,
+ get_redis_client,
+)
+from superset.daos.key_value import KeyValueDAO
+from superset.exceptions import AcquireDistributedLockFailedException
+from superset.key_value.exceptions import (
+ KeyValueCodecEncodeException,
+ KeyValueUpsertFailedError,
+)
+from superset.key_value.types import KeyValueResource
+from superset.utils.decorators import on_error, transaction
+
+logger = logging.getLogger(__name__)
+
+
+class AcquireDistributedLock(BaseDistributedLockCommand):
+ """
+ Acquire a distributed lock with automatic backend selection.
+
+ Uses Redis SET NX EX when SIGNAL_CACHE_CONFIG is configured,
+ otherwise falls back to KeyValue table.
+
+ Raises AcquireDistributedLockFailedException if:
+ - Lock is already held by another process
+ - Redis connection fails
+ """
+
+ ttl_seconds: int
+
+ def __init__(
+ self,
+ namespace: str,
+ params: dict[str, Any] | None = None,
+ ttl_seconds: int | None = None,
+ ) -> None:
+ super().__init__(namespace, params)
+ self.ttl_seconds = ttl_seconds or get_default_lock_ttl()
+
+ def run(self) -> None:
+ if (redis_client := get_redis_client()) is not None:
+ self._acquire_redis(redis_client)
+ else:
+ self._acquire_kv()
+
+ def _acquire_redis(self, redis_client: Any) -> None:
+ """Acquire lock using Redis SET NX EX (atomic)."""
+ try:
+ # SET NX EX: Set if not exists, with expiration
+ # Returns True if lock acquired, None if already exists
+ acquired = redis_client.set(
+ self.redis_lock_key,
+ "1",
+ nx=True,
+ ex=self.ttl_seconds,
+ )
+
+ if not acquired:
+ logger.debug("Redis lock on %s already taken", self.redis_lock_key)
+ raise AcquireDistributedLockFailedException("Lock already taken")
+
+ logger.debug(
+ "Acquired Redis lock: %s (TTL=%ds)",
+ self.redis_lock_key,
+ self.ttl_seconds,
+ )
+
+ except redis.RedisError as ex:
+ logger.error("Redis lock error for %s: %s", self.redis_lock_key, ex)
+ raise AcquireDistributedLockFailedException(
+ f"Redis lock failed: {ex}"
+ ) from ex
+
+ @transaction(
+ on_error=partial(
+ on_error,
+ catches=(
+ KeyValueCodecEncodeException,
+ KeyValueUpsertFailedError,
+ SQLAlchemyError,
+ ),
+ reraise=AcquireDistributedLockFailedException,
+ ),
+ )
+ def _acquire_kv(self) -> None:
+ """Acquire lock using KeyValue table (database)."""
+ # Delete expired entries first to prevent stale locks from blocking
+ KeyValueDAO.delete_expired_entries(self.resource)
+
+ # Create entry - unique constraint will raise if lock already exists
+ KeyValueDAO.create_entry(
+ resource=KeyValueResource.LOCK,
+ value={"value": True},
+ codec=self.codec,
+ key=self.key,
+ expires_on=datetime.now(timezone.utc) + timedelta(seconds=self.ttl_seconds),
+ )
+
+ logger.debug(
+ "Acquired KV lock: namespace=%s key=%s (TTL=%ds)",
+ self.namespace,
+ self.key,
+ self.ttl_seconds,
+ )
diff --git a/superset/commands/distributed_lock/base.py b/superset/commands/distributed_lock/base.py
index 03fb9f2a6b4b..3317887d8c8a 100644
--- a/superset/commands/distributed_lock/base.py
+++ b/superset/commands/distributed_lock/base.py
@@ -15,27 +15,58 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
import uuid
-from typing import Any, Union
+from typing import Any, TYPE_CHECKING
-from flask import current_app as app
+from flask import current_app
from superset.commands.base import BaseCommand
from superset.distributed_lock.utils import get_key
+from superset.extensions import cache_manager
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
+if TYPE_CHECKING:
+ import redis
+
logger = logging.getLogger(__name__)
-stats_logger = app.config["STATS_LOGGER"]
+
+
+def get_default_lock_ttl() -> int:
+ """Get the default lock TTL from config."""
+ return int(current_app.config.get("DISTRIBUTED_LOCK_DEFAULT_TTL", 30))
+
+
+def get_redis_client() -> "redis.Redis[Any] | None":
+ """
+ Get Redis client from signal cache if available.
+
+ Returns None if SIGNAL_CACHE_CONFIG is not configured,
+ allowing fallback to database-backed locking.
+ """
+ backend = cache_manager.signal_cache
+ return backend._cache if backend else None
class BaseDistributedLockCommand(BaseCommand):
+ """Base command for distributed lock operations."""
+
key: uuid.UUID
+ namespace: str
codec = JsonKeyValueCodec()
resource = KeyValueResource.LOCK
- def __init__(self, namespace: str, params: Union[dict[str, Any], None] = None):
- self.key = get_key(namespace, **(params or {}))
+ def __init__(self, namespace: str, params: dict[str, Any] | None = None) -> None:
+ self.namespace = namespace
+ self.params = params or {}
+ self.key = get_key(namespace, **self.params)
+
+ @property
+ def redis_lock_key(self) -> str:
+ """Redis key for this lock."""
+ return f"lock:{self.namespace}:{self.key}"
def validate(self) -> None:
pass
diff --git a/superset/commands/distributed_lock/create.py b/superset/commands/distributed_lock/create.py
deleted file mode 100644
index 2ac443df57fd..000000000000
--- a/superset/commands/distributed_lock/create.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-from datetime import datetime, timedelta
-from functools import partial
-
-from flask import current_app as app
-from sqlalchemy.exc import SQLAlchemyError
-
-from superset.commands.distributed_lock.base import BaseDistributedLockCommand
-from superset.daos.key_value import KeyValueDAO
-from superset.exceptions import CreateKeyValueDistributedLockFailedException
-from superset.key_value.exceptions import (
- KeyValueCodecEncodeException,
- KeyValueUpsertFailedError,
-)
-from superset.key_value.types import KeyValueResource
-from superset.utils.decorators import on_error, transaction
-
-logger = logging.getLogger(__name__)
-stats_logger = app.config["STATS_LOGGER"]
-
-
-class CreateDistributedLock(BaseDistributedLockCommand):
- lock_expiration = timedelta(seconds=30)
-
- def validate(self) -> None:
- pass
-
- @transaction(
- on_error=partial(
- on_error,
- catches=(
- KeyValueCodecEncodeException,
- KeyValueUpsertFailedError,
- SQLAlchemyError,
- ),
- reraise=CreateKeyValueDistributedLockFailedException,
- ),
- )
- def run(self) -> None:
- KeyValueDAO.delete_expired_entries(self.resource)
- KeyValueDAO.create_entry(
- resource=KeyValueResource.LOCK,
- value={"value": True},
- codec=self.codec,
- key=self.key,
- expires_on=datetime.now() + self.lock_expiration,
- )
diff --git a/superset/commands/distributed_lock/delete.py b/superset/commands/distributed_lock/delete.py
deleted file mode 100644
index 2f4b64901005..000000000000
--- a/superset/commands/distributed_lock/delete.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-from functools import partial
-
-from flask import current_app as app
-from sqlalchemy.exc import SQLAlchemyError
-
-from superset.commands.distributed_lock.base import BaseDistributedLockCommand
-from superset.daos.key_value import KeyValueDAO
-from superset.exceptions import DeleteKeyValueDistributedLockFailedException
-from superset.key_value.exceptions import KeyValueDeleteFailedError
-from superset.utils.decorators import on_error, transaction
-
-logger = logging.getLogger(__name__)
-stats_logger = app.config["STATS_LOGGER"]
-
-
-class DeleteDistributedLock(BaseDistributedLockCommand):
- def validate(self) -> None:
- pass
-
- @transaction(
- on_error=partial(
- on_error,
- catches=(
- KeyValueDeleteFailedError,
- SQLAlchemyError,
- ),
- reraise=DeleteKeyValueDistributedLockFailedException,
- ),
- )
- def run(self) -> None:
- KeyValueDAO.delete_entry(self.resource, self.key)
diff --git a/superset/commands/distributed_lock/release.py b/superset/commands/distributed_lock/release.py
new file mode 100644
index 000000000000..6d98f82aa4b0
--- /dev/null
+++ b/superset/commands/distributed_lock/release.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+from functools import partial
+from typing import Any
+
+import redis
+from sqlalchemy.exc import SQLAlchemyError
+
+from superset.commands.distributed_lock.base import (
+ BaseDistributedLockCommand,
+ get_redis_client,
+)
+from superset.daos.key_value import KeyValueDAO
+from superset.exceptions import ReleaseDistributedLockFailedException
+from superset.key_value.exceptions import KeyValueDeleteFailedError
+from superset.utils.decorators import on_error, transaction
+
+logger = logging.getLogger(__name__)
+
+
+class ReleaseDistributedLock(BaseDistributedLockCommand):
+ """
+ Release a distributed lock with automatic backend selection.
+
+ Uses Redis DELETE when SIGNAL_CACHE_CONFIG is configured,
+ otherwise deletes from KeyValue table.
+ """
+
+ def run(self) -> None:
+ if (redis_client := get_redis_client()) is not None:
+ self._release_redis(redis_client)
+ else:
+ self._release_kv()
+
+ def _release_redis(self, redis_client: Any) -> None:
+ """Release lock using Redis DELETE."""
+ try:
+ redis_client.delete(self.redis_lock_key)
+ logger.debug("Released Redis lock: %s", self.redis_lock_key)
+ except redis.RedisError as ex:
+ # Log warning but don't raise - TTL will handle cleanup
+ logger.warning(
+ "Failed to release Redis lock %s: %s (TTL will handle cleanup)",
+ self.redis_lock_key,
+ ex,
+ )
+
+ @transaction(
+ on_error=partial(
+ on_error,
+ catches=(
+ KeyValueDeleteFailedError,
+ SQLAlchemyError,
+ ),
+ reraise=ReleaseDistributedLockFailedException,
+ ),
+ )
+ def _release_kv(self) -> None:
+ """Release lock using KeyValue table (database)."""
+ KeyValueDAO.delete_entry(self.resource, self.key)
+ logger.debug(
+ "Released KV lock: namespace=%s key=%s",
+ self.namespace,
+ self.key,
+ )
diff --git a/superset/commands/tasks/__init__.py b/superset/commands/tasks/__init__.py
new file mode 100644
index 000000000000..c474e55340d4
--- /dev/null
+++ b/superset/commands/tasks/__init__.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from superset.commands.tasks.cancel import CancelTaskCommand
+from superset.commands.tasks.prune import TaskPruneCommand
+from superset.commands.tasks.submit import SubmitTaskCommand
+from superset.commands.tasks.update import UpdateTaskCommand
+
+__all__ = [
+ "CancelTaskCommand",
+ "SubmitTaskCommand",
+ "TaskPruneCommand",
+ "UpdateTaskCommand",
+]
diff --git a/superset/commands/tasks/cancel.py b/superset/commands/tasks/cancel.py
new file mode 100644
index 000000000000..befe6074242b
--- /dev/null
+++ b/superset/commands/tasks/cancel.py
@@ -0,0 +1,314 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unified cancel task command for GTF."""
+
+import logging
+from functools import partial
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+from flask import current_app
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset.commands.base import BaseCommand
+from superset.commands.tasks.exceptions import (
+ TaskAbortFailedError,
+ TaskNotAbortableError,
+ TaskNotFoundError,
+ TaskPermissionDeniedError,
+)
+from superset.extensions import security_manager
+from superset.stats_logger import BaseStatsLogger
+from superset.tasks.locks import task_lock
+from superset.tasks.utils import get_active_dedup_key
+from superset.utils.core import get_user_id
+from superset.utils.decorators import on_error, transaction
+
+if TYPE_CHECKING:
+ from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+
+class CancelTaskCommand(BaseCommand):
+ """
+ Unified command to cancel a task.
+
+ Behavior:
+ - For private tasks or single-subscriber tasks: aborts the task
+ - For shared tasks with multiple subscribers (non-admin): unsubscribes user
+ - For shared tasks with force=True (admin only): aborts for all subscribers
+
+ The term "cancel" is user-facing; internally this may abort or unsubscribe.
+
+ This command acquires a distributed lock before starting a transaction to
+ prevent race conditions with concurrent submit/cancel operations.
+
+ Permission checks are deferred to inside the lock to minimize SELECTs:
+ we only fetch the task once, then validate permissions on the fetched data.
+ """
+
+ def __init__(self, task_uuid: UUID, force: bool = False):
+ """
+ Initialize the cancel command.
+
+ :param task_uuid: UUID of the task to cancel
+ :param force: If True, force abort even with multiple subscribers (admin only)
+ """
+ self._task_uuid = task_uuid
+ self._force = force
+ self._action_taken: str = (
+ "cancelled" # Will be set to 'aborted' or 'unsubscribed'
+ )
+ self._should_publish_abort: bool = False
+
+ def run(self) -> "Task":
+ """
+ Execute the cancel command with distributed locking.
+
+ The lock is acquired BEFORE starting the transaction to avoid holding
+ a DB connection during lock acquisition. Uses dedup_key as lock key
+ to ensure Submit and Cancel operations use the same lock.
+
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ # Lightweight fetch to compute dedup_key for locking
+ # This is needed to use the same lock key as SubmitTaskCommand
+ task = TaskDAO.find_one_or_none(
+ skip_base_filter=security_manager.is_admin(), uuid=self._task_uuid
+ )
+
+ if not task:
+ raise TaskNotFoundError()
+
+ # Compute dedup_key using the same logic as SubmitTaskCommand
+ dedup_key = get_active_dedup_key(
+ scope=task.scope,
+ task_type=task.task_type,
+ task_key=task.task_key,
+ user_id=task.user_id,
+ )
+
+ # Acquire lock BEFORE transaction starts
+ # Using dedup_key ensures Submit and Cancel use the same lock
+ with task_lock(dedup_key):
+ result = self._execute_with_transaction()
+
+ # Publish abort notification AFTER transaction commits
+ # This prevents race conditions where listeners check DB before commit
+ if self._should_publish_abort:
+ from superset.tasks.manager import TaskManager
+
+ TaskManager.publish_abort(self._task_uuid)
+
+ return result
+
+ @transaction(on_error=partial(on_error, reraise=TaskAbortFailedError))
+ def _execute_with_transaction(self) -> "Task":
+ """
+ Execute the cancel operation inside a transaction.
+
+ Combines fetch + validation + execution in a single transaction,
+ reducing the number of SELECTs from 3 to 1 (plus DAO operations).
+
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ # Check admin status (no DB access)
+ is_admin = security_manager.is_admin()
+
+ # Force flag requires admin
+ if self._force and not is_admin:
+ raise TaskPermissionDeniedError(
+ "Only administrators can force cancel a task"
+ )
+
+ # Single SELECT: fetch task and validate permissions on it
+ task = TaskDAO.find_one_or_none(skip_base_filter=is_admin, uuid=self._task_uuid)
+
+ if not task:
+ raise TaskNotFoundError()
+
+ # Validate permissions on the fetched task
+ self._validate_permissions(task, is_admin)
+
+ # Execute cancel and return updated task
+ return self._do_cancel(task, is_admin)
+
+ def _validate_permissions(self, task: "Task", is_admin: bool) -> None:
+ """
+ Validate permissions on an already-fetched task.
+
+ Permission rules by scope:
+ - private: Only creator or admin (already filtered by base_filter)
+ - shared: Subscribers or admin
+ - system: Only admin
+
+ :param task: The task to validate permissions for
+ :param is_admin: Whether current user is admin
+ :raises TaskAbortFailedError: If task is not in cancellable state
+ :raises TaskPermissionDeniedError: If user lacks permission
+ """
+ # Check if task is in a cancellable state
+ if task.status not in [
+ TaskStatus.PENDING.value,
+ TaskStatus.IN_PROGRESS.value,
+ TaskStatus.ABORTING.value, # Already aborting is OK (idempotent)
+ ]:
+ raise TaskAbortFailedError()
+
+ # Admin can cancel anything
+ if is_admin:
+ return
+
+ # Non-admin permission checks by scope
+ user_id = get_user_id()
+
+ if task.scope == TaskScope.SYSTEM.value:
+ # System tasks are admin-only
+ raise TaskPermissionDeniedError(
+ "Only administrators can cancel system tasks"
+ )
+
+ if task.is_shared:
+ # Shared tasks: must be a subscriber
+ if not user_id or not task.has_subscriber(user_id):
+ raise TaskPermissionDeniedError(
+ "You must be subscribed to cancel this shared task"
+ )
+
+ # Private tasks: already filtered by base_filter (only creator can see)
+ # If we got here, user has permission
+
+ def _do_cancel(self, task: "Task", is_admin: bool) -> "Task":
+ """
+ Execute the cancel operation (abort or unsubscribe).
+
+ :param task: The task to cancel
+ :param is_admin: Whether current user is admin
+ :returns: The updated task model
+ """
+ user_id = get_user_id()
+
+ # Determine action based on task scope and force flag
+ should_abort = (
+ # Admin with force flag always aborts
+ (is_admin and self._force)
+ # Private tasks always abort (only one user)
+ or task.is_private
+ # System tasks always abort (admin only anyway)
+ or task.is_system
+ # Single or last subscriber - abort
+ or task.subscriber_count <= 1
+ )
+
+ if should_abort:
+ return self._do_abort(task, is_admin)
+ else:
+ return self._do_unsubscribe(task, user_id)
+
+ def _do_abort(self, task: "Task", is_admin: bool) -> "Task":
+ """
+ Execute abort operation.
+
+ :param task: The task to abort
+ :param is_admin: Whether current user is admin
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ try:
+ result: Task | None = TaskDAO.abort_task(
+ task.uuid, skip_base_filter=is_admin
+ )
+ except TaskNotAbortableError:
+ raise
+
+ if result is None:
+ # abort_task returned None - task wasn't aborted
+ # This can happen if task is already finished
+ raise TaskAbortFailedError()
+
+ self._action_taken = "aborted"
+
+ # Track if we need to publish abort after commit
+ if TaskStatus(result.status) == TaskStatus.ABORTING:
+ self._should_publish_abort = True
+
+ # Emit stats metric
+ stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
+ stats_logger.incr("gtf.task.abort")
+
+ logger.info(
+ "Task aborted: %s (scope: %s, force: %s)",
+ task.uuid,
+ task.scope,
+ self._force,
+ )
+
+ return result
+
+ def _do_unsubscribe(self, task: "Task", user_id: int | None) -> "Task":
+ """
+ Execute unsubscribe operation.
+
+ :param task: The task to unsubscribe from
+ :param user_id: ID of user to unsubscribe
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ self._action_taken = "unsubscribed"
+
+ if not user_id or not task.has_subscriber(user_id):
+ # User not subscribed - they shouldn't be able to cancel
+ raise TaskPermissionDeniedError(
+ "You are not subscribed to this shared task"
+ )
+
+ result = TaskDAO.remove_subscriber(task.id, user_id)
+ if result is None:
+ raise TaskPermissionDeniedError(
+ "You are not subscribed to this shared task"
+ )
+
+ # Emit stats metric
+ stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
+ stats_logger.incr("gtf.task.unsubscribe")
+
+ logger.info(
+ "User %s unsubscribed from shared task: %s",
+ user_id,
+ task.uuid,
+ )
+
+ return result
+
+ def validate(self) -> None:
+ pass
+
+ @property
+ def action_taken(self) -> str:
+ """
+ Get the action that was taken.
+
+ :returns: 'aborted' or 'unsubscribed'
+ """
+ return self._action_taken
diff --git a/superset/commands/tasks/exceptions.py b/superset/commands/tasks/exceptions.py
new file mode 100644
index 000000000000..a54030dd5c43
--- /dev/null
+++ b/superset/commands/tasks/exceptions.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from flask_babel import lazy_gettext as _
+
+from superset.commands.exceptions import (
+ CommandException,
+ CommandInvalidError,
+ CreateFailedError,
+ ForbiddenError,
+ UpdateFailedError,
+)
+
+
+class TaskNotFoundError(CommandException):
+ """Task not found."""
+
+ status = 404
+ message = "Task not found."
+
+
+class TaskInvalidError(CommandInvalidError):
+ """Task parameters are invalid."""
+
+ message = _("Task parameters are invalid.")
+
+
+class TaskCreateFailedError(CreateFailedError):
+ """Task creation failed."""
+
+ message = _("Task could not be created.")
+
+
+class TaskUpdateFailedError(UpdateFailedError):
+ """Task update failed."""
+
+ message = _("Task could not be updated.")
+
+
+class TaskAbortFailedError(CommandException):
+ """Task abortion failed."""
+
+ status = 422
+ message = _("Task could not be aborted.")
+
+
+class TaskNotAbortableError(CommandException):
+ """
+ Task cannot be aborted.
+
+ Raised when attempting to abort an in-progress task that has not
+ registered an abort handler (is_abortable is not True).
+ """
+
+ status = 400
+ message = _(
+ "Task is not abortable. The task is in progress but has not "
+ "registered an abort handler."
+ )
+
+
+class TaskForbiddenError(ForbiddenError):
+ """Task operation forbidden."""
+
+ message = _("Changing this task is forbidden")
+
+
+class TaskPermissionDeniedError(ForbiddenError):
+ """Task operation not permitted for current user."""
+
+ def __init__(self, message: str | None = None):
+ super().__init__()
+ if message:
+ self.message = message
+ else:
+ self.message = _("You do not have permission to perform this operation")
+
+
+class GlobalTaskFrameworkDisabledError(CommandException):
+ """
+ Raised when a GTF task is called or scheduled but GTF is disabled.
+
+ This exception is raised at call/schedule time (not decoration time) to allow
+ modules with @task-decorated functions to be imported safely when GTF is disabled.
+ The check is deferred until someone actually tries to execute a task.
+ """
+
+ message = _(
+ "The Global Task Framework is not enabled. "
+ "Set GLOBAL_TASK_FRAMEWORK=True in your feature flags to use @task. "
+ "See https://superset.apache.org/docs/configuration/async-queries-celery "
+ "for configuration details."
+ )
diff --git a/superset/commands/tasks/internal_update.py b/superset/commands/tasks/internal_update.py
new file mode 100644
index 000000000000..fc177a0cee61
--- /dev/null
+++ b/superset/commands/tasks/internal_update.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Internal task update commands for GTF internal use only.
+
+These commands perform zero-read updates using targeted SQL UPDATE statements.
+They're designed for use by TaskContext and executor code where the framework
+owns the authoritative state and doesn't need to read before writing.
+
+Unlike UpdateTaskCommand, these commands:
+- Do NOT fetch the task entity before updating
+- Do NOT check permissions (internal use only)
+- Use targeted SQL UPDATE for efficiency
+"""
+
+from __future__ import annotations
+
+import logging
+from functools import partial
+from typing import Any
+from uuid import UUID
+
+from superset_core.api.tasks import TaskProperties, TaskStatus
+
+from superset.commands.base import BaseCommand
+from superset.commands.tasks.exceptions import TaskUpdateFailedError
+from superset.daos.tasks import TaskDAO
+from superset.utils.decorators import on_error, transaction
+
+logger = logging.getLogger(__name__)
+
+
+class InternalUpdateTaskCommand(BaseCommand):
+ """
+ Zero-read task update command for properties/payload.
+
+ This command directly writes properties and/or payload to the database
+ without reading the current values first. The caller (TaskContext)
+ maintains the authoritative cached state and passes complete merged
+ values to write.
+
+ This is an optimization for task execution where:
+ 1. The executor owns the properties/payload state
+ 2. No permission checks are needed (internal framework code)
+ 3. Status column should not be touched (use InternalStatusTransitionCommand)
+
+ WARNING: This command should ONLY be used by TaskContext and similar
+ internal framework code. External callers should use UpdateTaskCommand.
+ """
+
+ def __init__(
+ self,
+ task_uuid: UUID,
+ properties: TaskProperties | None = None,
+ payload: dict[str, Any] | None = None,
+ ):
+ """
+ Initialize internal update command.
+
+ :param task_uuid: UUID of the task to update
+ :param properties: Complete properties dict to write (replaces existing)
+ :param payload: Complete payload dict to write (replaces existing)
+ """
+ self._task_uuid = task_uuid
+ self._properties = properties
+ self._payload = payload
+
+ def validate(self) -> None:
+ """No validation needed for internal command."""
+ pass
+
+ @transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
+ def run(self) -> bool:
+ """
+ Execute zero-read update.
+
+ :returns: True if task was updated, False if not found or nothing to update
+ """
+ if self._properties is None and self._payload is None:
+ return False
+
+ updated = TaskDAO.set_properties_and_payload(
+ task_uuid=self._task_uuid,
+ properties=self._properties,
+ payload=self._payload,
+ )
+
+ if updated:
+ logger.debug(
+ "Internal update for task %s: properties=%s, payload=%s",
+ self._task_uuid,
+ self._properties is not None,
+ self._payload is not None,
+ )
+
+ return updated
+
+
+class InternalStatusTransitionCommand(BaseCommand):
+ """
+ Atomic conditional status transition command for executor use.
+
+ This command provides race-safe status transitions by using atomic
+ compare-and-swap semantics. The status is only updated if the current
+ status matches the expected value(s).
+
+ Use cases:
+ - PENDING → IN_PROGRESS: Task pickup (executor starting)
+ - IN_PROGRESS → SUCCESS: Normal completion (only if not ABORTING)
+ - IN_PROGRESS → FAILURE: Task exception (only if not ABORTING)
+ - ABORTING → ABORTED: Abort handlers completed successfully
+ - ABORTING → TIMED_OUT: Timeout handlers completed successfully
+ - ABORTING → FAILURE: Abort/cleanup handlers failed
+
+ The atomic nature prevents race conditions where:
+ - Executor tries to set SUCCESS but task was concurrently aborted
+ - Multiple executors try to pick up the same task
+
+ WARNING: This command should ONLY be used by executor code (decorators.py,
+ scheduler.py). External callers should use UpdateTaskCommand.
+ """
+
+ def __init__(
+ self,
+ task_uuid: UUID,
+ new_status: TaskStatus | str,
+ expected_status: TaskStatus | str | list[TaskStatus | str],
+ properties: TaskProperties | None = None,
+ set_started_at: bool = False,
+ set_ended_at: bool = False,
+ ):
+ """
+ Initialize status transition command.
+
+ :param task_uuid: UUID of the task to update
+ :param new_status: Target status to set
+ :param expected_status: Current status(es) required for update to succeed.
+ Can be a single status or list of acceptable current statuses.
+ :param properties: Optional properties to update atomically with status
+ (e.g., error_message on FAILURE)
+ :param set_started_at: If True, also set started_at to current timestamp.
+ Should be True for PENDING → IN_PROGRESS transitions.
+ :param set_ended_at: If True, also set ended_at to current timestamp.
+ Should be True for terminal status transitions.
+ """
+ self._task_uuid = task_uuid
+ self._new_status = new_status
+ self._expected_status = expected_status
+ self._properties = properties
+ self._set_started_at = set_started_at
+ self._set_ended_at = set_ended_at
+
+ def validate(self) -> None:
+ """No validation needed for internal command."""
+ pass
+
+ @transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
+ def run(self) -> bool:
+ """
+ Execute atomic conditional status update.
+
+ :returns: True if status was updated (expected matched), False otherwise
+ """
+ return TaskDAO.conditional_status_update(
+ task_uuid=self._task_uuid,
+ new_status=self._new_status,
+ expected_status=self._expected_status,
+ properties=self._properties,
+ set_started_at=self._set_started_at,
+ set_ended_at=self._set_ended_at,
+ )
diff --git a/superset/commands/tasks/prune.py b/superset/commands/tasks/prune.py
new file mode 100644
index 000000000000..cf59e182cbe4
--- /dev/null
+++ b/superset/commands/tasks/prune.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+import time
+from datetime import datetime, timedelta
+
+import sqlalchemy as sa
+from superset_core.api.tasks import TaskStatus
+
+from superset import db
+from superset.commands.base import BaseCommand
+
+logger = logging.getLogger(__name__)
+
+
+# pylint: disable=consider-using-transaction
+class TaskPruneCommand(BaseCommand):
+ """
+ Command to prune the tasks table by deleting rows older than the specified
+ retention period.
+
+ This command deletes records from the `Task` table that are in terminal states
+ (success, failure, aborted, or timed_out) and have not been changed within the
+ specified number of days. It helps in maintaining the database by removing
+ outdated entries and freeing up space.
+
+ Attributes:
+ retention_period_days (int): The number of days for which records should be retained.
+ Records older than this period will be deleted.
+ max_rows_per_run (int | None): The maximum number of rows to delete in a single run.
+ If provided and greater than zero, rows are selected
+ deterministically from the oldest first (by timestamp then id)
+ up to this limit in this execution.
+ """ # noqa: E501
+
+ def __init__(self, retention_period_days: int, max_rows_per_run: int | None = None):
+ """
+ :param retention_period_days: Number of days to keep in the tasks table
+ :param max_rows_per_run: The maximum number of rows to delete in a single run.
+ If provided and greater than zero, rows are selected deterministically from the
+ oldest first (by timestamp then id) up to this limit in this execution.
+ """ # noqa: E501
+ self.retention_period_days = retention_period_days
+ self.max_rows_per_run = max_rows_per_run
+
+ def run(self) -> None:
+ """
+ Executes the prune command
+ """
+ batch_size = 999 # SQLite has a IN clause limit of 999
+ total_deleted = 0
+ start_time = time.time()
+
+ # Select all IDs that need to be deleted
+ # Only delete completed tasks (success, failure, or aborted)
+ from superset.models.tasks import Task
+
+ select_stmt = sa.select(Task.id).where(
+ Task.ended_at < datetime.now() - timedelta(days=self.retention_period_days),
+ Task.status.in_(
+ [
+ TaskStatus.SUCCESS.value,
+ TaskStatus.FAILURE.value,
+ TaskStatus.ABORTED.value,
+ TaskStatus.TIMED_OUT.value,
+ ]
+ ),
+ )
+
+ # Optionally limited by max_rows_per_run
+ # order by oldest first for deterministic deletion
+ if self.max_rows_per_run is not None and self.max_rows_per_run > 0:
+ select_stmt = select_stmt.order_by(
+ Task.ended_at.asc(), Task.id.asc()
+ ).limit(self.max_rows_per_run)
+
+ ids_to_delete = db.session.execute(select_stmt).scalars().all()
+
+ total_rows = len(ids_to_delete)
+
+ logger.info("Total rows to be deleted: %s", f"{total_rows:,}")
+
+ next_logging_threshold = 1
+
+ # Iterate over the IDs in batches
+ for i in range(0, total_rows, batch_size):
+ batch_ids = ids_to_delete[i : i + batch_size]
+
+ # Delete the selected batch using IN clause
+ result = db.session.execute(sa.delete(Task).where(Task.id.in_(batch_ids)))
+
+ # Update the total number of deleted records
+ total_deleted += result.rowcount
+
+ # Explicitly commit the transaction given that if an error occurs, we want to ensure that the # noqa: E501
+ # records that have been deleted so far are committed
+ db.session.commit()
+
+ # Log the number of deleted records every 1% increase in progress
+ percentage_complete = (total_deleted / total_rows) * 100
+ if percentage_complete >= next_logging_threshold:
+ logger.info(
+ "Deleted %s rows from the tasks table older than %s days (%d%% complete)", # noqa: E501
+ f"{total_deleted:,}",
+ self.retention_period_days,
+ percentage_complete,
+ )
+ next_logging_threshold += 1
+
+ elapsed_time = time.time() - start_time
+ minutes, seconds = divmod(elapsed_time, 60)
+ formatted_time = f"{int(minutes):02}:{int(seconds):02}"
+ logger.info(
+ "Pruning complete: %s rows deleted in %s",
+ f"{total_deleted:,}",
+ formatted_time,
+ )
+
+ def validate(self) -> None:
+ pass
diff --git a/superset/commands/tasks/submit.py b/superset/commands/tasks/submit.py
new file mode 100644
index 000000000000..247f209b7d07
--- /dev/null
+++ b/superset/commands/tasks/submit.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Submit task command for GTF."""
+
+import logging
+import uuid
+from functools import partial
+from typing import Any, TYPE_CHECKING
+
+from flask import current_app
+from marshmallow import ValidationError
+from superset_core.api.tasks import TaskScope
+
+from superset.commands.base import BaseCommand
+from superset.commands.tasks.exceptions import (
+ TaskCreateFailedError,
+ TaskInvalidError,
+)
+from superset.daos.exceptions import DAOCreateFailedError
+from superset.stats_logger import BaseStatsLogger
+from superset.tasks.locks import task_lock
+from superset.tasks.utils import get_active_dedup_key
+from superset.utils.decorators import on_error, transaction
+
+if TYPE_CHECKING:
+ from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+
+class SubmitTaskCommand(BaseCommand):
+ """
+ Command to submit a task (create new or join existing).
+
+ This command owns locking and create-vs-join business logic.
+ It acquires a distributed lock and then decides whether to:
+ - Create a new task (if no existing task with same dedup_key)
+ - Join an existing task by adding the user as subscriber
+ """
+
+ def __init__(self, data: dict[str, Any]):
+ self._properties = data.copy()
+
+ @transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
+ def run(self) -> "Task":
+ """
+ Execute the command with distributed locking.
+
+ Acquires lock based on dedup_key, then checks for existing task
+ and either creates new or joins existing (adding subscriber).
+
+ :returns: Task model (either newly created or existing)
+ """
+ task, _ = self.run_with_info()
+ return task
+
+ @transaction(on_error=partial(on_error, reraise=TaskCreateFailedError))
+ def run_with_info(self) -> tuple["Task", bool]:
+ """
+ Execute the command and return (task, is_new) tuple.
+
+ This variant allows callers to distinguish between creating a new task
+ and joining an existing one. Useful for sync execution where the caller
+ needs to wait for an existing task to complete rather than executing again.
+
+ :returns: Tuple of (Task, is_new) where is_new is True if task was created
+ """
+ from superset.daos.tasks import TaskDAO
+
+ self.validate()
+
+ # Extract and normalize parameters
+ task_type = self._properties["task_type"]
+ task_key = self._properties.get("task_key") or str(uuid.uuid4())
+ scope = self._properties.get("scope", TaskScope.PRIVATE.value)
+ user_id = self._properties.get("user_id")
+
+ # Build dedup_key for lock
+ dedup_key = get_active_dedup_key(
+ scope=scope,
+ task_type=task_type,
+ task_key=task_key,
+ user_id=user_id,
+ )
+
+ # Acquire lock to prevent race conditions during create/join
+ with task_lock(dedup_key):
+ # Check for existing task (safe under lock)
+ existing = TaskDAO.find_by_task_key(task_type, task_key, scope, user_id)
+
+ # Get stats logger
+ stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
+
+ if existing:
+ # Join existing task - add subscriber if not already subscribed
+ if user_id and not existing.has_subscriber(user_id):
+ TaskDAO.add_subscriber(existing.id, user_id)
+ stats_logger.incr("gtf.task.subscribe")
+ logger.info(
+ "User %s joined existing task: %s",
+ user_id,
+ task_key,
+ )
+ else:
+ # Same user submitted the same task - deduplication hit
+ stats_logger.incr("gtf.task.dedupe")
+ logger.debug(
+ "Deduplication hit for task: %s (user_id=%s)",
+ task_key,
+ user_id,
+ )
+ return existing, False # is_new=False: joined existing task
+
+ # Create new task (DAO is now a pure data operation)
+ try:
+ task = TaskDAO.create_task(
+ task_type=task_type,
+ task_key=task_key,
+ scope=scope,
+ task_name=self._properties.get("task_name"),
+ user_id=user_id,
+ payload=self._properties.get("payload", {}),
+ properties=self._properties.get("properties", {}),
+ )
+ stats_logger.incr("gtf.task.create")
+ return task, True # is_new=True: created new task
+ except DAOCreateFailedError as ex:
+ raise TaskCreateFailedError() from ex
+
+ def validate(self) -> None:
+ """Validate command parameters."""
+ exceptions: list[ValidationError] = []
+
+ # Require task_type
+ if not self._properties.get("task_type"):
+ exceptions.append(
+ ValidationError("task_type is required", field_name="task_type")
+ )
+
+ scope = self._properties.get("scope", TaskScope.PRIVATE.value)
+ scope_value = scope.value if isinstance(scope, TaskScope) else scope
+ valid_scopes = [s.value for s in TaskScope]
+ if scope_value not in valid_scopes:
+ exceptions.append(
+ ValidationError(
+ f"scope must be one of {valid_scopes}",
+ field_name="scope",
+ )
+ )
+ # Store normalized value for use in run()
+ self._properties["scope"] = scope_value
+
+ if exceptions:
+ raise TaskInvalidError(exceptions=exceptions)
diff --git a/superset/commands/tasks/update.py b/superset/commands/tasks/update.py
new file mode 100644
index 000000000000..29ba0509ea6d
--- /dev/null
+++ b/superset/commands/tasks/update.py
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from datetime import datetime
+from functools import partial
+from typing import Any, TYPE_CHECKING
+from uuid import UUID
+
+from superset_core.api.tasks import TaskProperties
+
+from superset import security_manager
+from superset.commands.base import BaseCommand
+from superset.commands.tasks.exceptions import (
+ TaskForbiddenError,
+ TaskNotFoundError,
+ TaskUpdateFailedError,
+)
+from superset.exceptions import SupersetSecurityException
+from superset.tasks.locks import task_lock
+from superset.tasks.utils import get_active_dedup_key
+from superset.utils.decorators import on_error, transaction
+
+if TYPE_CHECKING:
+ from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+
+class UpdateTaskCommand(BaseCommand):
+ """
+ Command to update a task.
+
+ Uses explicit typed parameters to avoid confusion between
+ payload (task output) and properties (runtime state/config).
+
+ This command acquires a distributed lock to prevent race conditions with
+ concurrent submit/cancel operations on the same logical task.
+ """
+
+ def __init__(
+ self,
+ task_uuid: UUID,
+ *,
+ status: str | None = None,
+ started_at: datetime | None = None,
+ ended_at: datetime | None = None,
+ payload: dict[str, Any] | None = None,
+ properties: TaskProperties | None = None,
+ skip_security_check: bool = False,
+ ):
+ """
+ Initialize UpdateTaskCommand.
+
+ :param task_uuid: UUID of the task to update
+ :param status: New status value (column field)
+ :param started_at: Started timestamp (column field)
+ :param ended_at: Ended timestamp (column field)
+ :param payload: Task output data to merge (stored in payload column)
+ :param properties: Runtime state/config updates as dict. Keys must be
+ valid TaskProperties field names (is_abortable, progress_percent, etc.)
+ :param skip_security_check: If True, skip ownership validation.
+ Use this for internal task updates (e.g., task executor updating
+ its own task's progress). Default is False for API-driven updates.
+ """
+ self._task_uuid = task_uuid
+ self._status = status
+ self._started_at = started_at
+ self._ended_at = ended_at
+ self._payload = payload
+ self._properties = properties
+ self._model: Task | None = None
+ self._skip_security_check = skip_security_check
+
+ @transaction(on_error=partial(on_error, reraise=TaskUpdateFailedError))
+ def run(self) -> Task:
+ """
+ Execute the update command with distributed locking.
+
+ Acquires lock based on dedup_key to prevent race conditions with
+ concurrent submit/cancel operations on the same logical task.
+
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ self.validate()
+
+ # Fetch task to compute dedup_key for locking
+ task = TaskDAO.find_one_or_none(
+ skip_base_filter=self._skip_security_check,
+ uuid=self._task_uuid,
+ )
+ if not task:
+ raise TaskNotFoundError()
+
+ self._model = task
+
+ # Build lock key from task properties (same structure as dedup_key)
+ dedup_key = get_active_dedup_key(
+ scope=self._model.scope,
+ task_type=self._model.task_type,
+ task_key=self._model.task_key,
+ user_id=self._model.user_id,
+ )
+
+ # Acquire lock to prevent race with submit/cancel operations
+ with task_lock(dedup_key):
+ return self._execute_update()
+
+ def _execute_update(self) -> "Task":
+ """
+ Execute the update operation under lock.
+
+ :returns: The updated task model
+ """
+ from superset.daos.tasks import TaskDAO
+
+ # Re-fetch model under lock to get fresh state
+ fresh_model = TaskDAO.find_one_or_none(
+ skip_base_filter=self._skip_security_check,
+ uuid=self._task_uuid,
+ )
+ if not fresh_model:
+ raise TaskNotFoundError()
+ self._model = fresh_model
+
+ # Verify ownership (user can only update their own tasks)
+ # Skip this check for internal updates (e.g., task executor updating progress)
+ if not self._skip_security_check:
+ try:
+ security_manager.raise_for_ownership(self._model)
+ except SupersetSecurityException as ex:
+ raise TaskForbiddenError() from ex
+
+ # Update status via set_status() for proper timestamp handling
+ if self._status is not None:
+ self._model.set_status(self._status)
+ if self._started_at is not None:
+ self._model.started_at = self._started_at
+ if self._ended_at is not None:
+ self._model.ended_at = self._ended_at
+
+ # Update payload (merges with existing)
+ if self._payload is not None:
+ self._model.set_payload(self._payload)
+
+ # Update properties (dict passed through to model)
+ if self._properties:
+ self._model.update_properties(self._properties)
+
+ return TaskDAO.update(self._model)
+
+ def validate(self) -> None:
+ pass
diff --git a/superset/config.py b/superset/config.py
index 970c965d0514..423f6fb3bde9 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -662,6 +662,9 @@ class D3TimeFormat(TypedDict, total=False):
# sts:AssumeRole permissions to prevent unauthorized access.
# @lifecycle: testing
"AWS_DATABASE_IAM_AUTH": False,
+ # Global Task Framework - unified task management with progress tracking,
+ # cancellation, and deduplication.
+ "GLOBAL_TASK_FRAMEWORK": False,
# Use analogous colors in charts
# @lifecycle: testing
"USE_ANALOGOUS_COLORS": False,
@@ -1393,6 +1396,12 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# "schedule": crontab(minute="*", hour="*"),
# "kwargs": {"retention_period_days": 180, "max_rows_per_run": 10000},
# },
+ # Uncomment to enable pruning of the tasks table
+ # "prune_tasks": {
+ # "task": "prune_tasks",
+ # "schedule": crontab(minute=0, hour=0),
+ # "kwargs": {"retention_period_days": 90, "max_rows_per_run": 10000},
+ # },
# Uncomment to enable Slack channel cache warm-up
# "slack.cache_channels": {
# "task": "slack.cache_channels",
@@ -2456,6 +2465,62 @@ class ExtraDynamicQueryFilters(TypedDict, total=False):
LOCAL_EXTENSIONS: list[str] = []
EXTENSIONS_PATH: str | None = None
+# Default polling interval for tasks (seconds)
+TASK_ABORT_POLLING_DEFAULT_INTERVAL = 10
+
+# Minimum interval in seconds between database writes for task progress updates.
+# Set to 0 to disable throttling (write every update to DB).
+TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL = 2 # seconds
+
+# ---------------------------------------------------
+# Signal Cache Configuration
+# ---------------------------------------------------
+# Shared Redis/Valkey configuration for signaling features that require
+# Redis-specific primitives (pub/sub messaging, distributed locks).
+#
+# Uses Flask-Caching style configuration for consistency with other cache backends.
+# Set CACHE_TYPE to 'RedisCache' for standard Redis or 'RedisSentinelCache' for
+# Sentinel.
+#
+# These features cannot use generic cache backends because they rely on:
+# - Pub/Sub: Real-time message broadcasting between workers
+# - SET NX EX: Atomic lock acquisition with automatic expiration
+#
+# When configured, enables:
+# - Real-time abort/completion notifications for GTF tasks (vs database polling)
+# - Redis-based distributed locking (vs KeyValueDAO-backed DistributedLock)
+#
+# Future: This cache will also be used by Global Async Queries, consolidating
+# GLOBAL_ASYNC_QUERIES_CACHE_BACKEND into this unified configuration.
+#
+# Example with standard Redis:
+# SIGNAL_CACHE_CONFIG: CacheConfig = {
+# "CACHE_TYPE": "RedisCache",
+# "CACHE_REDIS_HOST": "localhost",
+# "CACHE_REDIS_PORT": 6379,
+# "CACHE_REDIS_DB": 0,
+# "CACHE_REDIS_PASSWORD": "",
+# }
+#
+# Example with Redis Sentinel:
+# SIGNAL_CACHE_CONFIG: CacheConfig = {
+# "CACHE_TYPE": "RedisSentinelCache",
+# "CACHE_REDIS_SENTINELS": [("sentinel1", 26379), ("sentinel2", 26379)],
+# "CACHE_REDIS_SENTINEL_MASTER": "mymaster",
+# "CACHE_REDIS_SENTINEL_PASSWORD": None,
+# "CACHE_REDIS_DB": 0,
+# "CACHE_REDIS_PASSWORD": "",
+# }
+SIGNAL_CACHE_CONFIG: CacheConfig | None = None
+
+# Default lock TTL (time-to-live) in seconds for distributed locks.
+# Can be overridden per-call via the `ttl_seconds` parameter.
+# After TTL expires, the lock is automatically released to prevent deadlocks.
+DISTRIBUTED_LOCK_DEFAULT_TTL = 30
+
+# Channel prefix for task abort pub/sub messages
+TASKS_ABORT_CHANNEL_PREFIX = "gtf:abort:"
+
# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
# -------------------------------------------------------------------
diff --git a/superset/core/api/core_api_injection.py b/superset/core/api/core_api_injection.py
index 28e3c6be3192..be4ea69db4c5 100644
--- a/superset/core/api/core_api_injection.py
+++ b/superset/core/api/core_api_injection.py
@@ -52,6 +52,7 @@ def inject_dao_implementations() -> None:
SavedQueryDAO as HostSavedQueryDAO,
)
from superset.daos.tag import TagDAO as HostTagDAO
+ from superset.daos.tasks import TaskDAO as HostTaskDAO
from superset.daos.user import UserDAO as HostUserDAO
# Replace abstract classes with concrete implementations
@@ -64,18 +65,7 @@ def inject_dao_implementations() -> None:
core_dao_module.SavedQueryDAO = HostSavedQueryDAO # type: ignore[assignment,misc]
core_dao_module.TagDAO = HostTagDAO # type: ignore[assignment,misc]
core_dao_module.KeyValueDAO = HostKeyValueDAO # type: ignore[assignment,misc]
-
- core_dao_module.__all__ = [
- "DatasetDAO",
- "DatabaseDAO",
- "ChartDAO",
- "DashboardDAO",
- "UserDAO",
- "QueryDAO",
- "SavedQueryDAO",
- "TagDAO",
- "KeyValueDAO",
- ]
+ core_dao_module.TaskDAO = HostTaskDAO # type: ignore[assignment,misc]
def inject_model_implementations() -> None:
@@ -94,6 +84,7 @@ def inject_model_implementations() -> None:
from superset.models.dashboard import Dashboard as HostDashboard
from superset.models.slice import Slice as HostChart
from superset.models.sql_lab import Query as HostQuery, SavedQuery as HostSavedQuery
+ from superset.models.tasks import Task as HostTask
from superset.tags.models import Tag as HostTag
# In-place replacement - extensions will import concrete implementations
@@ -106,6 +97,7 @@ def inject_model_implementations() -> None:
core_models_module.SavedQuery = HostSavedQuery # type: ignore[misc]
core_models_module.Tag = HostTag # type: ignore[misc]
core_models_module.KeyValue = HostKeyValue # type: ignore[misc]
+ core_models_module.Task = HostTask # type: ignore[misc]
def inject_query_implementations() -> None:
@@ -124,7 +116,23 @@ def get_sqlglot_dialect(database: "Database") -> Any:
)
core_query_module.get_sqlglot_dialect = get_sqlglot_dialect
- core_query_module.__all__ = ["get_sqlglot_dialect"]
+
+
+def inject_task_implementations() -> None:
+ """
+ Replace abstract task functions in superset_core.api.tasks with concrete
+ implementations from Superset.
+ """
+ import superset_core.api.tasks as core_tasks_module
+
+ from superset.tasks.ambient_context import get_context
+ from superset.tasks.context import TaskContext
+ from superset.tasks.decorators import task
+
+ # Replace abstract classes and functions with concrete implementations
+ core_tasks_module.TaskContext = TaskContext # type: ignore[assignment,misc]
+ core_tasks_module.task = task # type: ignore[assignment]
+ core_tasks_module.get_context = get_context
def inject_rest_api_implementations() -> None:
@@ -147,7 +155,6 @@ def add_extension_api(api: "type[RestApi]") -> None:
core_rest_api_module.add_api = add_api
core_rest_api_module.add_extension_api = add_extension_api
- core_rest_api_module.__all__ = ["RestApi", "add_api", "add_extension_api"]
def inject_model_session_implementation() -> None:
@@ -163,7 +170,6 @@ def get_session() -> scoped_session:
return db.session
core_models_module.get_session = get_session
- # Update __all__ to include get_session (already done in the module)
def initialize_core_api_dependencies() -> None:
@@ -177,4 +183,5 @@ def initialize_core_api_dependencies() -> None:
inject_model_implementations()
inject_model_session_implementation()
inject_query_implementations()
+ inject_task_implementations()
inject_rest_api_implementations()
diff --git a/superset/daos/chart.py b/superset/daos/chart.py
index 56488f1c1937..44d0057eea83 100644
--- a/superset/daos/chart.py
+++ b/superset/daos/chart.py
@@ -18,7 +18,7 @@
import logging
from datetime import datetime
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -30,9 +30,6 @@
from superset.models.slice import id_or_uuid_filter, Slice
from superset.utils.core import get_user_id
-if TYPE_CHECKING:
- pass
-
logger = logging.getLogger(__name__)
# Custom filterable fields for charts
diff --git a/superset/daos/tasks.py b/superset/daos/tasks.py
new file mode 100644
index 000000000000..8253cf6d579f
--- /dev/null
+++ b/superset/daos/tasks.py
@@ -0,0 +1,470 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task DAO for Global Task Framework (GTF)"""
+
+import logging
+from datetime import datetime, timezone
+from typing import Any
+from uuid import UUID
+
+from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
+
+from superset.daos.base import BaseDAO
+from superset.daos.exceptions import DAODeleteFailedError
+from superset.extensions import db
+from superset.models.task_subscribers import TaskSubscriber
+from superset.models.tasks import Task
+from superset.tasks.constants import ABORTABLE_STATES
+from superset.tasks.filters import TaskFilter
+from superset.tasks.utils import get_active_dedup_key, json
+
+logger = logging.getLogger(__name__)
+
+
+class TaskDAO(BaseDAO[Task]):
+ """
+ Concrete TaskDAO for the Global Task Framework (GTF).
+
+ Provides database access operations for async tasks including
+ creation, status management, filtering, and subscription management
+ for shared tasks.
+ """
+
+ base_filter = TaskFilter
+
+ @classmethod
+ def get_status(cls, task_uuid: UUID) -> str | None:
+ """
+ Get only the status of a task by UUID.
+
+ This is a lightweight query that only fetches the status column,
+ optimized for polling endpoints where full entity loading is unnecessary.
+ Applies the base filter (TaskFilter) to enforce permission checks.
+
+ :param task_uuid: UUID of the task
+ :returns: Task status string, or None if task not found or not accessible
+ """
+ # Start with query on Task model so base filter can be applied
+ query = db.session.query(Task)
+ query = cls._apply_base_filter(query)
+ query = query.filter(Task.uuid == task_uuid)
+
+ # Select only the status column for efficiency
+ result = query.with_entities(Task.status).one_or_none()
+ return result[0] if result else None
+
+ @classmethod
+ def find_by_task_key(
+ cls,
+ task_type: str,
+ task_key: str,
+ scope: TaskScope | str = TaskScope.PRIVATE,
+ user_id: int | None = None,
+ ) -> Task | None:
+ """
+ Find active task by type, key, scope, and user.
+
+ Uses dedup_key internally for efficient querying with a unique index.
+ Only returns tasks that are active (pending or in progress).
+
+ Uniqueness logic by scope:
+ - private: scope + task_type + task_key + user_id
+ - shared/system: scope + task_type + task_key (user-agnostic)
+
+ :param task_type: Task type to filter by
+ :param task_key: Task identifier for deduplication
+ :param scope: Task scope (private/shared/system)
+ :param user_id: User ID (required for private tasks)
+ :returns: Task instance or None if not found or not active
+ """
+ dedup_key = get_active_dedup_key(
+ scope=scope,
+ task_type=task_type,
+ task_key=task_key,
+ user_id=user_id,
+ )
+
+ # Simple single-column query with unique index
+ return db.session.query(Task).filter(Task.dedup_key == dedup_key).one_or_none()
+
+ @classmethod
+ def create_task(
+ cls,
+ task_type: str,
+ task_key: str,
+ scope: TaskScope | str = TaskScope.PRIVATE,
+ user_id: int | None = None,
+ payload: dict[str, Any] | None = None,
+ properties: TaskProperties | None = None,
+ **kwargs: Any,
+ ) -> Task:
+ """
+ Create a new task record in the database.
+
+ This is a pure data operation - assumes caller holds lock and has
+ already checked for existing tasks. Business logic (create vs join)
+ is handled by SubmitTaskCommand.
+
+ :param task_type: Type of task to create
+ :param task_key: Task identifier (required)
+ :param scope: Task scope (private/shared/system), defaults to private
+ :param user_id: User ID creating the task
+ :param payload: Optional user-defined context data (dict)
+ :param properties: Optional framework-managed runtime state (e.g., timeout)
+ :param kwargs: Additional task attributes (e.g., task_name)
+ :returns: Created Task instance
+ """
+ # Handle both TaskScope enum and string values
+ scope_value = scope.value if isinstance(scope, TaskScope) else scope
+ scope_enum = scope if isinstance(scope, TaskScope) else TaskScope(scope)
+
+ # Validate user_id is required for private tasks
+ if scope_enum == TaskScope.PRIVATE and user_id is None:
+ raise ValueError("user_id is required for private tasks")
+
+ # Build dedup_key for active task
+ dedup_key = get_active_dedup_key(
+ scope=scope,
+ task_type=task_type,
+ task_key=task_key,
+ user_id=user_id,
+ )
+
+ # Note: properties is handled separately via update_properties()
+ task_data = {
+ "task_type": task_type,
+ "task_key": task_key,
+ "scope": scope_value,
+ "status": TaskStatus.PENDING.value,
+ "dedup_key": dedup_key,
+ **kwargs,
+ }
+
+ # Handle payload - serialize to JSON if dict provided
+ if payload:
+ task_data["payload"] = json.dumps(payload)
+
+ if user_id is not None:
+ task_data["user_id"] = user_id
+
+ task = cls.create(attributes=task_data)
+
+ # Set properties after creation via update_properties (handles caching)
+ if properties:
+ task.update_properties(properties)
+
+ # Flush to get the task ID (auto-incremented primary key)
+ db.session.flush()
+
+ # Auto-subscribe creator for all tasks
+ # This enables consistent subscriber display across all task types
+ if user_id:
+ cls.add_subscriber(task.id, user_id)
+ logger.info(
+ "Creator %s auto-subscribed to task: %s (scope: %s)",
+ user_id,
+ task_key,
+ scope_value,
+ )
+
+ logger.info(
+ "Created new async task: %s (type: %s, scope: %s)",
+ task_key,
+ task_type,
+ scope_value,
+ )
+ return task
+
+ @classmethod
+ def abort_task(cls, task_uuid: UUID, skip_base_filter: bool = False) -> Task | None:
+ """
+ Abort a task by UUID.
+
+ This is a pure data operation. Business logic (subscriber count checks,
+ permission validation) is handled by CancelTaskCommand which holds the lock.
+
+ Abort behavior by status:
+ - PENDING: Goes directly to ABORTED (always abortable)
+ - IN_PROGRESS with is_abortable=True: Goes to ABORTING
+ - IN_PROGRESS with is_abortable=False/None: Raises TaskNotAbortableError
+ - ABORTING: Returns task (idempotent)
+ - Finished statuses: Returns None
+
+ Note: Caller is responsible for calling TaskManager.publish_abort() AFTER
+ the transaction commits if task.status == ABORTING. This prevents race
+ conditions where listeners check the DB before the status is visible.
+
+ :param task_uuid: UUID of task to abort
+ :param skip_base_filter: If True, skip base filter (for admin abortions)
+ :returns: Task if aborted/aborting, None if not found or already finished
+ :raises TaskNotAbortableError: If in-progress task has no abort handler
+ """
+ from superset.commands.tasks.exceptions import TaskNotAbortableError
+
+ task = cls.find_one_or_none(skip_base_filter=skip_base_filter, uuid=task_uuid)
+ if not task:
+ return None
+
+ # Already aborting - idempotent success
+ if task.status == TaskStatus.ABORTING.value:
+ logger.info("Task %s is already aborting", task_uuid)
+ return task
+
+ # Already finished - cannot abort
+ if task.status not in ABORTABLE_STATES:
+ return None
+
+ # PENDING: Go directly to ABORTED
+ if task.status == TaskStatus.PENDING.value:
+ task.set_status(TaskStatus.ABORTED)
+ logger.info("Aborted pending task: %s (scope: %s)", task_uuid, task.scope)
+ return task
+
+ # IN_PROGRESS: Check if abortable
+ if task.status == TaskStatus.IN_PROGRESS.value:
+ if task.properties_dict.get("is_abortable") is not True:
+ raise TaskNotAbortableError(
+ f"Task {task_uuid} is in progress but has not registered "
+ "an abort handler (is_abortable is not true)"
+ )
+
+ # Transition to ABORTING (not ABORTED yet)
+ task.status = TaskStatus.ABORTING.value
+ db.session.merge(task)
+ logger.info("Set task %s to ABORTING (scope: %s)", task_uuid, task.scope)
+
+ # NOTE: publish_abort is NOT called here - caller handles it after commit
+ # This prevents race conditions where listeners check DB before commit
+
+ return task
+
+ return None
+
+ # Subscription management methods
+
+ @classmethod
+ def add_subscriber(cls, task_id: int, user_id: int) -> bool:
+ """
+ Add a user as a subscriber to a task.
+
+ :param task_id: ID of the task
+ :param user_id: ID of the user to subscribe
+ :returns: True if subscriber was added, False if already exists
+ """
+ # Check first to avoid IntegrityError which invalidates the session
+ # in nested transaction contexts (IntegrityError can't be recovered from)
+ existing = (
+ db.session.query(TaskSubscriber)
+ .filter_by(task_id=task_id, user_id=user_id)
+ .first()
+ )
+ if existing:
+ logger.debug(
+ "Subscriber %s already subscribed to task %s", user_id, task_id
+ )
+ return False
+
+ subscription = TaskSubscriber(
+ task_id=task_id,
+ user_id=user_id,
+ subscribed_at=datetime.now(timezone.utc),
+ )
+ db.session.add(subscription)
+ db.session.flush()
+ logger.info("Added subscriber %s to task %s", user_id, task_id)
+ return True
+
+ @classmethod
+ def remove_subscriber(cls, task_id: int, user_id: int) -> Task | None:
+ """
+ Remove a user's subscription from a task and return the updated task.
+
+ This is a pure data operation. Business logic (whether to abort after
+ last subscriber leaves) is handled by CancelTaskCommand which holds
+ the lock and decides whether to call abort_task() separately.
+
+ :param task_id: ID of the task
+ :param user_id: ID of the user to unsubscribe
+ :returns: Updated Task if subscriber was removed, None if not subscribed
+ :raises DAODeleteFailedError: If subscription removal fails
+ """
+ subscription = (
+ db.session.query(TaskSubscriber)
+ .filter(
+ TaskSubscriber.task_id == task_id,
+ TaskSubscriber.user_id == user_id,
+ )
+ .one_or_none()
+ )
+
+ if not subscription:
+ return None
+
+ try:
+ db.session.delete(subscription)
+ db.session.flush()
+ logger.info("Removed subscriber %s from task %s", user_id, task_id)
+
+ # Return the updated task
+ task = cls.find_by_id(task_id, skip_base_filter=True)
+ if task:
+ db.session.refresh(task) # Ensure subscribers list is fresh
+ return task
+
+ except DAODeleteFailedError:
+ raise
+ except Exception as ex:
+ raise DAODeleteFailedError(
+ f"Failed to remove subscription for task {task_id}, user {user_id}"
+ ) from ex
+
+ @classmethod
+ def set_properties_and_payload(
+ cls,
+ task_uuid: UUID,
+ properties: TaskProperties | None = None,
+ payload: dict[str, Any] | None = None,
+ ) -> bool:
+ """
+ Perform a zero-read SQL UPDATE on properties and/or payload columns.
+
+ This method directly writes the provided values without reading first.
+ The caller (TaskContext) is responsible for maintaining the authoritative
+ cached state and passing complete values to write.
+
+ This method is designed for internal task updates (progress, is_abortable)
+ where the executor owns the state and doesn't need to read before writing.
+
+ IMPORTANT: This method only touches properties and payload columns.
+ It does NOT touch the status column, so it's safe to use concurrently
+ with operations that modify status (like abort).
+
+ :param task_uuid: UUID of the task to update
+ :param properties: Complete properties dict to write (replaces existing)
+ :param payload: Complete payload dict to write (replaces existing)
+ :returns: True if task was updated, False if not found or nothing to update
+ """
+ if properties is None and payload is None:
+ return False
+
+ # Build update values dict - no reads, just write what caller provides
+ update_values: dict[str, Any] = {}
+
+ if properties is not None:
+ # Write complete properties (caller manages merging in their cache)
+ update_values["properties"] = json.dumps(properties)
+
+ if payload is not None:
+ # Write complete payload (payload column name matches attribute name)
+ update_values["payload"] = json.dumps(payload)
+
+ if not update_values:
+ return False
+
+ # Execute targeted UPDATE - zero read, just write
+ rows_updated = (
+ db.session.query(Task)
+ .filter(Task.uuid == task_uuid)
+ .update(update_values, synchronize_session=False)
+ )
+
+ return rows_updated > 0
+
+ @classmethod
+ def conditional_status_update(
+ cls,
+ task_uuid: UUID,
+ new_status: TaskStatus | str,
+ expected_status: TaskStatus | str | list[TaskStatus | str],
+ properties: TaskProperties | None = None,
+ set_started_at: bool = False,
+ set_ended_at: bool = False,
+ ) -> bool:
+ """
+ Atomically update task status only if current status matches expected.
+
+ This provides atomic compare-and-swap semantics for status transitions,
+ preventing race conditions between executor status updates and concurrent
+ abort operations. Uses a single UPDATE with WHERE clause for atomicity.
+
+ Use cases:
+ - Executor transitioning IN_PROGRESS → SUCCESS (only if not ABORTING)
+ - Executor transitioning ABORTING → ABORTED/TIMED_OUT (cleanup complete)
+ - Initial PENDING → IN_PROGRESS (task pickup)
+
+ :param task_uuid: UUID of the task to update
+ :param new_status: Target status to set
+ :param expected_status: Current status(es) required for update to succeed.
+ Can be a single status or list of statuses.
+ :param properties: Optional properties to update atomically with status
+ :param set_started_at: If True, also set started_at to current timestamp
+ :param set_ended_at: If True, also set ended_at to current timestamp
+ :returns: True if status was updated (expected matched), False otherwise
+ """
+ # Normalize status values
+ new_status_val = (
+ new_status.value if isinstance(new_status, TaskStatus) else new_status
+ )
+
+ # Build list of expected status values
+ if isinstance(expected_status, list):
+ expected_vals = [
+ s.value if isinstance(s, TaskStatus) else s for s in expected_status
+ ]
+ else:
+ expected_vals = [
+ expected_status.value
+ if isinstance(expected_status, TaskStatus)
+ else expected_status
+ ]
+
+ # Build update values
+ update_values: dict[str, Any] = {"status": new_status_val}
+
+ if properties is not None:
+ update_values["properties"] = json.dumps(properties)
+
+ if set_started_at:
+ update_values["started_at"] = datetime.now(timezone.utc)
+
+ if set_ended_at:
+ update_values["ended_at"] = datetime.now(timezone.utc)
+
+ # Atomic compare-and-swap: only update if status matches expected
+ rows_updated = (
+ db.session.query(Task)
+ .filter(Task.uuid == task_uuid, Task.status.in_(expected_vals))
+ .update(update_values, synchronize_session=False)
+ )
+
+ if rows_updated > 0:
+ logger.debug(
+ "Conditional status update succeeded: %s -> %s (expected: %s)",
+ task_uuid,
+ new_status_val,
+ expected_vals,
+ )
+ else:
+ logger.debug(
+ "Conditional status update skipped: %s -> %s "
+ "(current status not in expected: %s)",
+ task_uuid,
+ new_status_val,
+ expected_vals,
+ )
+
+ return rows_updated > 0
diff --git a/superset/distributed_lock/__init__.py b/superset/distributed_lock/__init__.py
index 4374e85aa672..69ec3d825384 100644
--- a/superset/distributed_lock/__init__.py
+++ b/superset/distributed_lock/__init__.py
@@ -17,61 +17,46 @@
from __future__ import annotations
-import logging
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
-from datetime import timedelta
from typing import Any
from superset.distributed_lock.utils import get_key
-from superset.exceptions import CreateKeyValueDistributedLockFailedException
-from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
-
-logger = logging.getLogger(__name__)
-
-CODEC = JsonKeyValueCodec()
-LOCK_EXPIRATION = timedelta(seconds=30)
-RESOURCE = KeyValueResource.LOCK
@contextmanager
-def KeyValueDistributedLock( # pylint: disable=invalid-name # noqa: N802
+def DistributedLock( # noqa: N802
namespace: str,
+ ttl_seconds: int | None = None,
**kwargs: Any,
) -> Iterator[uuid.UUID]:
"""
- KV global lock for refreshing tokens.
-
- This context manager acquires a distributed lock for a given namespace, with
- optional parameters (eg, namespace="cache", user_id=1). It yields a UUID for the
- lock that can be used within the context, and corresponds to the key in the KV
- store.
-
- :param namespace: The namespace for which the lock is to be acquired.
- :param kwargs: Additional keyword arguments.
- :yields: A unique identifier (UUID) for the acquired lock (the KV key).
- :raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
+ Distributed lock for coordinating operations across workers.
+
+ Automatically uses Redis-based locking when SIGNAL_CACHE_CONFIG is
+ configured, falling back to database-backed locking otherwise.
+
+ Redis locking uses SET NX EX for atomic acquisition with automatic expiration.
+ Database locking uses the KeyValue table with manual expiration cleanup.
+
+ :param namespace: Lock namespace for grouping related locks
+ :param ttl_seconds: Lock TTL in seconds. Defaults to 30 seconds.
+ After expiration, the lock is automatically released
+ to prevent deadlocks from crashed processes.
+ :param kwargs: Additional key parameters to differentiate locks
+ :yields: UUID identifying this lock acquisition
+ :raises AcquireDistributedLockFailedException: If lock is already held
+ or Redis connection fails
"""
-
# pylint: disable=import-outside-toplevel
- from superset.commands.distributed_lock.create import CreateDistributedLock
- from superset.commands.distributed_lock.delete import DeleteDistributedLock
- from superset.commands.distributed_lock.get import GetDistributedLock
+ from superset.commands.distributed_lock.acquire import AcquireDistributedLock
+ from superset.commands.distributed_lock.release import ReleaseDistributedLock
key = get_key(namespace, **kwargs)
- value = GetDistributedLock(namespace=namespace, params=kwargs).run()
- if value:
- logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
- raise CreateKeyValueDistributedLockFailedException("Lock already taken")
- logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
+ AcquireDistributedLock(namespace, kwargs, ttl_seconds).run()
try:
- CreateDistributedLock(namespace=namespace, params=kwargs).run()
- except CreateKeyValueDistributedLockFailedException as ex:
- logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
- raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex
-
- yield key
- DeleteDistributedLock(namespace=namespace, params=kwargs).run()
- logger.debug("Removed lock on namespace %s for key %s", namespace, key)
+ yield key
+ finally:
+ ReleaseDistributedLock(namespace, kwargs).run()
diff --git a/superset/exceptions.py b/superset/exceptions.py
index fabbbe133477..3a81a249c477 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -414,15 +414,15 @@ def __init__(self, tables: set[str]):
)
-class CreateKeyValueDistributedLockFailedException(Exception): # noqa: N818
+class AcquireDistributedLockFailedException(Exception): # noqa: N818
"""
Exception to signalize failure to acquire lock.
"""
-class DeleteKeyValueDistributedLockFailedException(Exception): # noqa: N818
+class ReleaseDistributedLockFailedException(Exception): # noqa: N818
"""
- Exception to signalize failure to delete lock.
+ Exception to signalize failure to release lock.
"""
diff --git a/superset/extensions/local_extensions_watcher.py b/superset/extensions/local_extensions_watcher.py
index 4a5c9a0501ba..6233f91fe5cf 100644
--- a/superset/extensions/local_extensions_watcher.py
+++ b/superset/extensions/local_extensions_watcher.py
@@ -23,13 +23,10 @@
import threading
import time
from pathlib import Path
-from typing import Any, TYPE_CHECKING
+from typing import Any
from flask import Flask
-if TYPE_CHECKING:
- pass
-
logger = logging.getLogger(__name__)
# Guard to prevent multiple initializations
diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py
index 3a34d315bf52..c3dacc7d6ea2 100644
--- a/superset/initialization/__init__.py
+++ b/superset/initialization/__init__.py
@@ -218,6 +218,7 @@ def init_views(self) -> None:
)
from superset.views.sqllab import SqllabView
from superset.views.tags import TagModelView, TagView
+ from superset.views.tasks import TaskModelView
from superset.views.themes import ThemeModelView
from superset.views.user_info import UserInfoView
from superset.views.user_registrations import UserRegistrationsView
@@ -275,6 +276,11 @@ def init_views(self) -> None:
appbuilder.add_api(ExtensionsRestApi)
+ if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
+ from superset.tasks.api import TaskRestApi
+
+ appbuilder.add_api(TaskRestApi)
+
#
# Setup regular views
#
@@ -408,6 +414,18 @@ def init_views(self) -> None:
),
)
+ appbuilder.add_view(
+ TaskModelView,
+ "Tasks",
+ label=_("Tasks"),
+ icon="fa-clock-o",
+ category="Manage",
+ category_label=_("Manage"),
+ menu_cond=lambda: feature_flag_manager.is_feature_enabled(
+ "GLOBAL_TASK_FRAMEWORK"
+ ),
+ )
+
#
# Setup views with no menu
#
@@ -588,6 +606,7 @@ def init_app_in_ctx(self) -> None:
self.configure_async_queries()
self.configure_ssh_manager()
self.configure_stats_manager()
+ self.configure_task_manager()
# Hook that provides administrators a handle on the Flask APP
# after initialization
@@ -928,6 +947,13 @@ def configure_async_queries(self) -> None:
if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
async_query_manager_factory.init_app(self.superset_app)
+ def configure_task_manager(self) -> None:
+ """Initialize the TaskManager for GTF realtime notifications."""
+ if feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
+ from superset.tasks.manager import TaskManager
+
+ TaskManager.init_app(self.superset_app)
+
def register_blueprints(self) -> None:
# Register custom blueprints from config
for bp in self.config["BLUEPRINTS"]:
diff --git a/superset/migrations/versions/2025_12_18_0220_create_tasks_table.py b/superset/migrations/versions/2025_12_18_0220_create_tasks_table.py
new file mode 100644
index 000000000000..ade179975e34
--- /dev/null
+++ b/superset/migrations/versions/2025_12_18_0220_create_tasks_table.py
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Create tasks and task_subscriber tables for Global Task Framework (GTF)
+
+Revision ID: 4b2a8c9d3e1f
+Revises: 9787190b3d89
+Create Date: 2025-12-18 02:20:00.000000
+
+"""
+
+from sqlalchemy import (
+ Column,
+ DateTime,
+ Integer,
+ String,
+ Text,
+ UniqueConstraint,
+)
+from sqlalchemy_utils import UUIDType
+
+from superset.migrations.shared.utils import (
+ create_fks_for_table,
+ create_index,
+ create_table,
+ drop_fks_for_table,
+ drop_index,
+ drop_table,
+)
+
+# revision identifiers, used by Alembic.
+revision = "4b2a8c9d3e1f"
+down_revision = "9787190b3d89"
+
+TASKS_TABLE = "tasks"
+TASK_SUBSCRIBERS_TABLE = "task_subscribers"
+
+
+def upgrade():
+ """
+ Create tasks and task_subscribers tables for the Global Task Framework (GTF).
+
+ This migration creates:
+ 1. tasks table - unified tracking for all long running tasks
+ 2. task_subscribers table - multi-user task subscriptions for shared tasks
+
+ The scope feature allows tasks to be:
+ - private: user-specific (default)
+ - shared: multi-user collaborative tasks
+ - system: admin-only background tasks
+ """
+ # Create tasks table
+ create_table(
+ TASKS_TABLE,
+ Column("id", Integer, primary_key=True),
+ Column("uuid", UUIDType(binary=True), nullable=False, unique=True),
+ Column("task_key", String(256), nullable=False),
+ Column("task_type", String(100), nullable=False),
+ Column("task_name", String(256), nullable=True),
+ Column("scope", String(20), nullable=False, server_default="private"),
+ Column("status", String(50), nullable=False),
+ Column("dedup_key", String(64), nullable=False),
+ # AuditMixinNullable columns
+ Column("created_on", DateTime, nullable=True),
+ Column("changed_on", DateTime, nullable=True),
+ Column("created_by_fk", Integer, nullable=True),
+ Column("changed_by_fk", Integer, nullable=True),
+ # Task-specific columns
+ Column("started_at", DateTime, nullable=True),
+ Column("ended_at", DateTime, nullable=True),
+ Column("user_id", Integer, nullable=True),
+ Column("payload", Text, nullable=True),
+ Column("properties", Text, nullable=True),
+ )
+
+ # Create indexes for optimal query performance
+ create_index(TASKS_TABLE, "idx_tasks_dedup_key", ["dedup_key"], unique=True)
+ create_index(TASKS_TABLE, "idx_tasks_status", ["status"])
+ create_index(TASKS_TABLE, "idx_tasks_scope", ["scope"])
+ create_index(TASKS_TABLE, "idx_tasks_ended_at", ["ended_at"])
+ create_index(TASKS_TABLE, "idx_tasks_created_by", ["created_by_fk"])
+ create_index(TASKS_TABLE, "idx_tasks_created_on", ["created_on"])
+ create_index(TASKS_TABLE, "idx_tasks_task_key", ["task_key"])
+ create_index(TASKS_TABLE, "idx_tasks_task_type", ["task_type"])
+ create_index(TASKS_TABLE, "idx_tasks_uuid", ["uuid"], unique=True)
+
+ # Create foreign key constraints for tasks
+ create_fks_for_table(
+ foreign_key_name="fk_tasks_created_by_fk_ab_user",
+ table_name=TASKS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["created_by_fk"],
+ remote_cols=["id"],
+ ondelete="SET NULL",
+ )
+
+ create_fks_for_table(
+ foreign_key_name="fk_tasks_changed_by_fk_ab_user",
+ table_name=TASKS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["changed_by_fk"],
+ remote_cols=["id"],
+ ondelete="SET NULL",
+ )
+
+ create_fks_for_table(
+ foreign_key_name="fk_tasks_user_id_ab_user",
+ table_name=TASKS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["user_id"],
+ remote_cols=["id"],
+ ondelete="SET NULL",
+ )
+
+ # Create task_subscribers table for multi-user task subscriptions
+ create_table(
+ TASK_SUBSCRIBERS_TABLE,
+ Column("id", Integer, primary_key=True),
+ Column("task_id", Integer, nullable=False),
+ Column("user_id", Integer, nullable=False),
+ Column("subscribed_at", DateTime, nullable=False),
+ # AuditMixinNullable columns
+ Column("created_on", DateTime, nullable=True),
+ Column("created_by_fk", Integer, nullable=True),
+ Column("changed_on", DateTime, nullable=True),
+ Column("changed_by_fk", Integer, nullable=True),
+ # Unique constraint defined as part of table creation (SQLite compatible)
+ UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
+ )
+
+ # Create indexes for task_subscribers table
+ create_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id", ["user_id"])
+
+ # Create foreign key constraints for task_subscribers
+ create_fks_for_table(
+ foreign_key_name="fk_task_subscribers_task_id_tasks",
+ table_name=TASK_SUBSCRIBERS_TABLE,
+ referenced_table=TASKS_TABLE,
+ local_cols=["task_id"],
+ remote_cols=["id"],
+ ondelete="CASCADE",
+ )
+
+ create_fks_for_table(
+ foreign_key_name="fk_task_subscribers_user_id_ab_user",
+ table_name=TASK_SUBSCRIBERS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["user_id"],
+ remote_cols=["id"],
+ ondelete="CASCADE",
+ )
+
+ create_fks_for_table(
+ foreign_key_name="fk_task_subscribers_created_by_fk_ab_user",
+ table_name=TASK_SUBSCRIBERS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["created_by_fk"],
+ remote_cols=["id"],
+ ondelete="SET NULL",
+ )
+
+ create_fks_for_table(
+ foreign_key_name="fk_task_subscribers_changed_by_fk_ab_user",
+ table_name=TASK_SUBSCRIBERS_TABLE,
+ referenced_table="ab_user",
+ local_cols=["changed_by_fk"],
+ remote_cols=["id"],
+ ondelete="SET NULL",
+ )
+
+
+def downgrade():
+ """
+ Drop tasks and task_subscribers tables and all related indexes and foreign keys.
+ """
+ drop_fks_for_table(
+ TASK_SUBSCRIBERS_TABLE,
+ [
+ "fk_task_subscribers_task_id_tasks",
+ "fk_task_subscribers_user_id_ab_user",
+ "fk_task_subscribers_created_by_fk_ab_user",
+ "fk_task_subscribers_changed_by_fk_ab_user",
+ ],
+ )
+
+ drop_index(TASK_SUBSCRIBERS_TABLE, "idx_task_subscribers_user_id")
+ drop_table(TASK_SUBSCRIBERS_TABLE)
+
+ drop_fks_for_table(
+ TASKS_TABLE,
+ [
+ "fk_tasks_created_by_fk_ab_user",
+ "fk_tasks_changed_by_fk_ab_user",
+ "fk_tasks_user_id_ab_user",
+ ],
+ )
+
+ drop_index(TASKS_TABLE, "idx_tasks_dedup_key")
+ drop_index(TASKS_TABLE, "idx_tasks_status")
+ drop_index(TASKS_TABLE, "idx_tasks_scope")
+ drop_index(TASKS_TABLE, "idx_tasks_ended_at")
+ drop_index(TASKS_TABLE, "idx_tasks_created_by")
+ drop_index(TASKS_TABLE, "idx_tasks_created_on")
+ drop_index(TASKS_TABLE, "idx_tasks_task_key")
+ drop_index(TASKS_TABLE, "idx_tasks_task_type")
+ drop_index(TASKS_TABLE, "idx_tasks_uuid")
+
+ drop_table(TASKS_TABLE)
diff --git a/superset/models/task_subscribers.py b/superset/models/task_subscribers.py
new file mode 100644
index 000000000000..509f5fd17f6f
--- /dev/null
+++ b/superset/models/task_subscribers.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""TaskSubscriber model for tracking multi-user task subscriptions"""
+
+from datetime import datetime, timezone
+
+from flask_appbuilder import Model
+from sqlalchemy import Column, DateTime, ForeignKey, Integer, UniqueConstraint
+from sqlalchemy.orm import relationship
+from superset_core.api.models import TaskSubscriber as CoreTaskSubscriber
+
+from superset.models.helpers import AuditMixinNullable
+
+
+class TaskSubscriber(CoreTaskSubscriber, AuditMixinNullable, Model):
+ """
+ Model for tracking task subscriptions in shared tasks.
+
+ This model enables multi-user collaboration on async tasks. When a user
+ schedules a shared task with the same parameters as an existing task,
+ they are automatically subscribed to that task instead of creating a
+ duplicate.
+
+ Subscribers can unsubscribe from shared tasks. When the last subscriber
+ unsubscribes, the task is automatically aborted.
+ """
+
+ __tablename__ = "task_subscribers"
+
+ id = Column(Integer, primary_key=True)
+ task_id = Column(
+ Integer, ForeignKey("tasks.id", ondelete="CASCADE"), nullable=False
+ )
+ user_id = Column(
+ Integer, ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False
+ )
+ subscribed_at = Column(DateTime, nullable=False, default=datetime.now(timezone.utc))
+
+ # Relationships
+ task = relationship("Task", back_populates="subscribers")
+ user = relationship("User", foreign_keys=[user_id], lazy="joined")
+
+ __table_args__ = (
+ UniqueConstraint("task_id", "user_id", name="uq_task_subscribers_task_user"),
+ )
+
+ def __repr__(self) -> str:
+ return f""
diff --git a/superset/models/tasks.py b/superset/models/tasks.py
new file mode 100644
index 000000000000..6c6995e9563e
--- /dev/null
+++ b/superset/models/tasks.py
@@ -0,0 +1,367 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task model for Global Task Framework (GTF)"""
+
+from __future__ import annotations
+
+import uuid as uuid_module
+from datetime import datetime, timezone
+from typing import Any, cast
+
+from flask_appbuilder import Model
+from sqlalchemy import (
+ Column,
+ DateTime,
+ Integer,
+ String,
+ Text,
+)
+from sqlalchemy.orm import relationship
+from sqlalchemy_utils import UUIDType
+from superset_core.api.models import Task as CoreTask
+from superset_core.api.tasks import TaskProperties, TaskStatus
+
+from superset.models.helpers import AuditMixinNullable
+from superset.models.task_subscribers import TaskSubscriber
+from superset.tasks.utils import (
+ error_update,
+ get_finished_dedup_key,
+ parse_properties,
+ serialize_properties,
+)
+from superset.utils import json
+
+
+class Task(CoreTask, AuditMixinNullable, Model):
+ """
+ Concrete Task model for the Global Task Framework (GTF).
+
+ This model represents async tasks in Superset, providing unified tracking
+ for all background operations including SQL queries, thumbnail generation,
+ reports, and other async operations.
+
+ Non-filterable fields (progress, error info, execution config) are stored
+ in a `properties` JSON blob for schema flexibility.
+ """
+
+ __tablename__ = "tasks"
+
+ # Primary key and identifiers
+ id = Column(Integer, primary_key=True)
+ uuid = Column(
+ UUIDType(binary=True), nullable=False, unique=True, default=uuid_module.uuid4
+ )
+
+ # Task metadata (filterable)
+ task_key = Column(String(256), nullable=False, index=True) # For deduplication
+ task_type = Column(String(100), nullable=False, index=True) # e.g., 'sql_execution'
+ task_name = Column(String(256), nullable=True) # Human readable name
+ scope = Column(
+ String(20), nullable=False, index=True, default="private"
+ ) # private/shared/system
+ status = Column(
+ String(50), nullable=False, index=True, default=TaskStatus.PENDING.value
+ )
+ dedup_key = Column(
+ String(64), nullable=False, unique=True, index=True
+ ) # Hashed deduplication key (SHA-256 = 64 chars, UUID = 36 chars)
+
+ # Timestamps
+ started_at = Column(DateTime, nullable=True)
+ ended_at = Column(DateTime, nullable=True)
+
+ # User context for execution
+ user_id = Column(Integer, nullable=True)
+
+ # Task-specific output data (set by task code via ctx.update_task(payload=...))
+ payload = Column(Text, nullable=True, default="{}")
+
+ # Properties JSON blob - contains runtime state and execution config:
+ # - is_abortable: bool - has abort handler registered
+ # - progress_percent: float - progress 0.0-1.0
+ # - progress_current: int - current iteration count
+ # - progress_total: int - total iterations
+ # - error_message: str - human-readable error message
+ # - exception_type: str - exception class name
+ # - stack_trace: str - full formatted traceback
+ # - timeout: int - timeout in seconds
+ properties = Column(Text, nullable=True, default="{}")
+
+ # Relationships
+ # Use lazy="selectin" to avoid N+1 queries when listing tasks with subscribers
+ subscribers = relationship(
+ TaskSubscriber,
+ back_populates="task",
+ cascade="all, delete-orphan",
+ lazy="selectin",
+ )
+
+ def __repr__(self) -> str:
+ return f""
+
+ # -------------------------------------------------------------------------
+ # Properties accessor
+ # -------------------------------------------------------------------------
+
+ @property
+ def properties_dict(self) -> TaskProperties:
+ """
+ Get typed properties.
+
+ Properties contain runtime state and execution config that doesn't
+ need database filtering. Always use .get() for reads since keys may
+ be absent.
+
+ :returns: TaskProperties dict (sparse - only contains keys that were set)
+ """
+ return parse_properties(self.properties)
+
+ def update_properties(self, updates: TaskProperties) -> None:
+ """
+ Update specific properties fields (merge semantics).
+
+ Only updates fields present in the updates dict.
+
+ :param updates: TaskProperties dict with fields to update
+
+ Example:
+ task.update_properties({"is_abortable": True})
+ task.update_properties(progress_update((50, 100)))
+ """
+ current = cast(TaskProperties, dict(self.properties_dict))
+ current.update(updates) # Merge updates
+ self.properties = serialize_properties(current)
+
+ # -------------------------------------------------------------------------
+ # Payload accessor (for task-specific output data)
+ # -------------------------------------------------------------------------
+
+ @property
+ def payload_dict(self) -> dict[str, Any]:
+ """
+ Get payload as parsed JSON.
+
+ Payload contains task-specific output data set by task code via
+ ctx.update_task(payload=...).
+
+ :returns: Dictionary containing payload data
+ """
+ try:
+ return json.loads(self.payload or "{}")
+ except (json.JSONDecodeError, TypeError):
+ return {}
+
+ def set_payload(self, data: dict[str, Any]) -> None:
+ """
+ Update payload with new data.
+
+ The payload is merged with existing data, not replaced.
+
+ :param data: Dictionary of data to merge into payload
+ """
+ current = self.payload_dict
+ current.update(data)
+ self.payload = json.dumps(current)
+
+ # -------------------------------------------------------------------------
+ # Error handling
+ # -------------------------------------------------------------------------
+
+ def set_error_from_exception(self, exception: BaseException) -> None:
+ """
+ Set error fields from an exception.
+
+ Captures the error message, exception type, and full stack trace.
+ Called automatically by the executor when a task raises an exception.
+
+ :param exception: The exception that caused the failure
+ """
+ self.update_properties(error_update(exception))
+
+ # -------------------------------------------------------------------------
+ # Status management
+ # -------------------------------------------------------------------------
+
+ def set_status(self, status: TaskStatus | str) -> None:
+ """
+ Update task status and dedup_key.
+
+ When a task finishes (success, failure, or abort), the dedup_key is
+ changed to the task's UUID. This frees up the slot so new tasks with
+ the same parameters can be created.
+
+ :param status: New task status
+ """
+ if isinstance(status, TaskStatus):
+ status = status.value
+ self.status = status
+
+ # Update timestamps and is_abortable based on status
+ now = datetime.now(timezone.utc)
+ if status == TaskStatus.IN_PROGRESS.value and not self.started_at:
+ self.started_at = now
+ # Set is_abortable to False when task starts executing
+ # (will be set to True if/when an abort handler is registered)
+ if self.properties_dict.get("is_abortable") is None:
+ self.update_properties({"is_abortable": False})
+ elif status in [
+ TaskStatus.SUCCESS.value,
+ TaskStatus.FAILURE.value,
+ TaskStatus.ABORTED.value,
+ TaskStatus.TIMED_OUT.value,
+ ]:
+ if not self.ended_at:
+ self.ended_at = now
+ # Update dedup_key to UUID to free up the slot for new tasks
+ self.dedup_key = get_finished_dedup_key(self.uuid)
+ # Note: ABORTING status doesn't set ended_at yet - that happens when
+ # the task transitions to ABORTED after handlers complete
+
+ @property
+ def is_pending(self) -> bool:
+ """Check if task is pending."""
+ return self.status == TaskStatus.PENDING.value
+
+ @property
+ def is_running(self) -> bool:
+ """Check if task is currently running."""
+ return self.status == TaskStatus.IN_PROGRESS.value
+
+ @property
+ def is_finished(self) -> bool:
+ """Check if task has finished (success, failure, aborted, or timed out)."""
+ return self.status in [
+ TaskStatus.SUCCESS.value,
+ TaskStatus.FAILURE.value,
+ TaskStatus.ABORTED.value,
+ TaskStatus.TIMED_OUT.value,
+ ]
+
+ @property
+ def is_successful(self) -> bool:
+ """Check if task completed successfully."""
+ return self.status == TaskStatus.SUCCESS.value
+
+ @property
+ def duration_seconds(self) -> float | None:
+ """
+ Get task duration in seconds.
+
+ - Finished tasks: Time from started_at to ended_at (None if never started)
+ - Running/aborting tasks: Time from started_at to now
+ - Pending tasks: Time from created_on to now (queue time)
+
+ Note: started_at/ended_at are stored in UTC, but created_on from
+ AuditMixinNullable is stored as naive local time. We handle both cases.
+ """
+ if self.is_finished:
+ # Task has completed - use fixed timestamps, never increment
+ if self.started_at and self.ended_at:
+ # Finished task - both timestamps use the same timezone (UTC)
+ # Just compute the difference directly
+ return (self.ended_at - self.started_at).total_seconds()
+ # Never started (e.g., aborted while pending) - no duration
+ return None
+ elif self.started_at:
+ # Running or aborting - started_at is UTC (set by set_status)
+ # Use UTC now for comparison
+ now = datetime.now(timezone.utc)
+ started = (
+ self.started_at.replace(tzinfo=timezone.utc)
+ if self.started_at.tzinfo is None
+ else self.started_at
+ )
+ return (now - started).total_seconds()
+ elif self.created_on:
+ # Pending - created_on is naive LOCAL time (from AuditMixinNullable)
+ # Use naive local time for comparison
+ now = datetime.now() # Local time, no timezone
+ created = (
+ self.created_on.replace(tzinfo=None)
+ if self.created_on.tzinfo is not None
+ else self.created_on
+ )
+ return (now - created).total_seconds()
+ return None
+
+ # Scope-related properties
+ @property
+ def is_private(self) -> bool:
+ """Check if task is private (user-specific)."""
+ return self.scope == "private"
+
+ @property
+ def is_shared(self) -> bool:
+ """Check if task is shared (multi-user)."""
+ return self.scope == "shared"
+
+ @property
+ def is_system(self) -> bool:
+ """Check if task is system (admin-only)."""
+ return self.scope == "system"
+
+ # Subscriber-related methods
+ @property
+ def subscriber_count(self) -> int:
+ """Get number of subscribers to this task."""
+ return len(self.subscribers)
+
+ def has_subscriber(self, user_id: int) -> bool:
+ """
+ Check if a user is subscribed to this task.
+
+ :param user_id: User ID to check
+ :returns: True if user is subscribed
+ """
+ return any(sub.user_id == user_id for sub in self.subscribers)
+
+ def get_subscriber_ids(self) -> list[int]:
+ """
+ Get list of all subscriber user IDs.
+
+ :returns: List of user IDs subscribed to this task
+ """
+ return [sub.user_id for sub in self.subscribers]
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Convert task to dictionary representation.
+
+ Minimal API payload - frontend derives status booleans and abort logic
+ from status and properties.is_abortable.
+
+ :returns: Dictionary representation of the task
+ """
+ return {
+ "id": self.id,
+ "uuid": str(self.uuid),
+ "task_key": self.task_key,
+ "task_type": self.task_type,
+ "task_name": self.task_name,
+ "scope": self.scope,
+ "status": self.status,
+ "created_on": self.created_on.isoformat() if self.created_on else None,
+ "changed_on": self.changed_on.isoformat() if self.changed_on else None,
+ "started_at": self.started_at.isoformat() if self.started_at else None,
+ "ended_at": self.ended_at.isoformat() if self.ended_at else None,
+ "created_by_fk": self.created_by_fk,
+ "user_id": self.user_id,
+ "payload": self.payload_dict,
+ "properties": self.properties_dict,
+ "subscriber_count": self.subscriber_count,
+ "subscriber_ids": self.get_subscriber_ids(),
+ }
diff --git a/superset/sql/execution/celery_task.py b/superset/sql/execution/celery_task.py
index c5dd1e9f1c9f..4e68d3d9eb46 100644
--- a/superset/sql/execution/celery_task.py
+++ b/superset/sql/execution/celery_task.py
@@ -26,7 +26,7 @@
import dataclasses
import logging
import uuid
-from typing import Any, TYPE_CHECKING
+from typing import Any
import msgpack
from celery.exceptions import SoftTimeLimitExceeded
@@ -56,9 +56,6 @@
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
-if TYPE_CHECKING:
- pass
-
logger = logging.getLogger(__name__)
BYTES_IN_MB = 1024 * 1024
diff --git a/superset/tasks/ambient_context.py b/superset/tasks/ambient_context.py
new file mode 100644
index 000000000000..90618149bdef
--- /dev/null
+++ b/superset/tasks/ambient_context.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Ambient context management for the Global Task Framework (GTF)"""
+
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import Iterator
+
+from superset.tasks.context import TaskContext
+
+# Global context variable for ambient context pattern
+# This is thread-safe and async-safe via Python's contextvars
+_current_context: ContextVar[TaskContext | None] = ContextVar(
+ "task_context", default=None
+)
+
+
+def get_context() -> TaskContext:
+ """
+ Get the current task context from contextvars.
+
+ This function provides ambient access to the task context without
+ requiring it to be passed as a parameter. It can only be called
+ from within a task execution.
+
+ :returns: The current TaskContext
+ :raises RuntimeError: If called outside a task execution context
+
+ Example:
+ >>> @task()
+ >>> def my_task(chart_id: int) -> None:
+ >>> ctx = get_context() # Access ambient context
+ >>>
+ >>> # Update progress and payload atomically
+ >>> ctx.update_task(
+ >>> progress=0.5,
+ >>> payload={"chart_id": chart_id}
+ >>> )
+ """
+ ctx = _current_context.get()
+ if ctx is None:
+ raise RuntimeError(
+ "get_context() called outside task execution context. "
+ "This function can only be called from within a @task "
+ "decorated function."
+ )
+ return ctx
+
+
+@contextmanager
+def use_context(ctx: TaskContext) -> Iterator[None]:
+ """
+ Context manager to set ambient context for task execution.
+
+ This is used internally by the framework to establish the ambient context
+ before executing a task function. The context is automatically cleaned up
+ after execution, even if the task raises an exception.
+
+ :param ctx: TaskContext to set as the current context
+ :yields: None
+
+ Example (internal framework use):
+ >>> ctx = TaskContext(task_uuid=task.uuid)
+ >>> with use_context(ctx):
+ >>> # Task function can now call get_context()
+ >>> task_function(*args, **kwargs)
+ >>> # Context automatically reset after execution
+ """
+ token = _current_context.set(ctx)
+ try:
+ yield
+ finally:
+ _current_context.reset(token)
diff --git a/superset/tasks/api.py b/superset/tasks/api.py
new file mode 100644
index 000000000000..1870d963b670
--- /dev/null
+++ b/superset/tasks/api.py
@@ -0,0 +1,471 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task REST API"""
+
+import logging
+from uuid import UUID
+
+from flask import Response
+from flask_appbuilder.api import expose, protect, safe
+from flask_appbuilder.models.sqla.interface import SQLAInterface
+
+from superset.commands.tasks.cancel import CancelTaskCommand
+from superset.commands.tasks.exceptions import (
+ TaskAbortFailedError,
+ TaskForbiddenError,
+ TaskNotAbortableError,
+ TaskNotFoundError,
+ TaskPermissionDeniedError,
+)
+from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
+from superset.extensions import event_logger
+from superset.models.tasks import Task
+from superset.tasks.filters import TaskFilter
+from superset.tasks.schemas import (
+ openapi_spec_methods_override,
+ TaskCancelRequestSchema,
+ TaskCancelResponseSchema,
+ TaskResponseSchema,
+ TaskStatusResponseSchema,
+)
+from superset.views.base_api import (
+ BaseSupersetModelRestApi,
+ RelatedFieldFilter,
+ statsd_metrics,
+)
+from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners
+
+logger = logging.getLogger(__name__)
+
+
+class TaskRestApi(BaseSupersetModelRestApi):
+ """REST API for task management"""
+
+ datamodel = SQLAInterface(Task)
+ resource_name = "task"
+ allow_browser_login = True
+
+ class_permission_name = "Task"
+
+ # Map cancel and status to write/read permissions
+ method_permission_name = {
+ **MODEL_API_RW_METHOD_PERMISSION_MAP,
+ "cancel": "write",
+ "status": "read",
+ }
+
+ # Only allow read operations - no create/update/delete through REST API
+ # Tasks are created via SubmitTaskCommand, cancelled via /cancel endpoint
+ include_route_methods = {
+ RouteMethod.GET,
+ RouteMethod.GET_LIST,
+ RouteMethod.INFO,
+ "cancel",
+ "status",
+ "related_subscribers",
+ "related",
+ }
+
+ list_columns = [
+ "id",
+ "uuid",
+ "task_type",
+ "task_key",
+ "task_name",
+ "scope",
+ "status",
+ "created_on",
+ "created_on_delta_humanized",
+ "changed_on",
+ "changed_by.first_name",
+ "changed_by.last_name",
+ "started_at",
+ "ended_at",
+ "created_by.id",
+ "created_by.first_name",
+ "created_by.last_name",
+ "user_id",
+ "payload",
+ "properties",
+ "duration_seconds",
+ "subscriber_count",
+ "subscribers",
+ ]
+
+ list_select_columns = list_columns + ["created_by_fk", "changed_by_fk"]
+
+ show_columns = list_columns
+
+ order_columns = [
+ "task_type",
+ "scope",
+ "status",
+ "created_on",
+ "changed_on",
+ "started_at",
+ "ended_at",
+ ]
+
+ search_columns = [
+ "task_type",
+ "task_key",
+ "task_name",
+ "scope",
+ "status",
+ "created_by",
+ "created_on",
+ ]
+
+ base_order = ("created_on", "desc")
+ base_filters = [["id", TaskFilter, lambda: []]]
+
+ # Related field configuration for filter dropdowns
+ allowed_rel_fields = {"created_by"}
+ related_field_filters = {
+ "created_by": RelatedFieldFilter("first_name", FilterRelatedOwners),
+ }
+ base_related_field_filters = {
+ "created_by": [["id", BaseFilterRelatedUsers, lambda: []]],
+ }
+
+ show_model_schema = TaskResponseSchema()
+ list_model_schema = TaskResponseSchema()
+ cancel_request_schema = TaskCancelRequestSchema()
+
+ openapi_spec_tag = "Tasks"
+ openapi_spec_component_schemas = (
+ TaskResponseSchema,
+ TaskCancelRequestSchema,
+ TaskCancelResponseSchema,
+ TaskStatusResponseSchema,
+ )
+ openapi_spec_methods = openapi_spec_methods_override
+
+ @expose("/", methods=("GET",))
+ @protect()
+ @safe
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
+ log_to_statsd=False,
+ )
+ def get(self, task_uuid: str) -> Response:
+ """Get a task.
+ ---
+ get:
+ summary: Get a task
+ parameters:
+ - in: path
+ schema:
+ type: string
+ format: uuid
+ name: task_uuid
+ description: The UUID of the task
+ responses:
+ 200:
+ description: Task detail
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ result:
+ $ref: '#/components/schemas/TaskResponseSchema'
+ 401:
+ $ref: '#/components/responses/401'
+ 403:
+ $ref: '#/components/responses/403'
+ 404:
+ $ref: '#/components/responses/404'
+ """
+ from superset.daos.tasks import TaskDAO
+
+ try:
+ uuid = UUID(task_uuid)
+ task = TaskDAO.find_one_or_none(uuid=uuid)
+
+ if not task:
+ return self.response_404()
+
+ result = self.show_model_schema.dump(task)
+ return self.response(200, result=result)
+ except (ValueError, TypeError):
+ return self.response_404()
+
+ @expose("//status", methods=("GET",))
+ @protect()
+ @safe
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.status",
+ log_to_statsd=False,
+ )
+ def status(self, task_uuid: str) -> Response:
+ """Get only the status of a task (lightweight for polling).
+ ---
+ get:
+ summary: Get task status
+ parameters:
+ - in: path
+ schema:
+ type: string
+ format: uuid
+ name: task_uuid
+ description: The UUID of the task
+ responses:
+ 200:
+ description: Task status
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ status:
+ type: string
+ description: Current status of the task
+ 401:
+ $ref: '#/components/responses/401'
+ 403:
+ $ref: '#/components/responses/403'
+ 404:
+ $ref: '#/components/responses/404'
+ """
+ from superset.daos.tasks import TaskDAO
+
+ try:
+ uuid = UUID(task_uuid)
+ status = TaskDAO.get_status(uuid)
+
+ if status is None:
+ return self.response_404()
+
+ return self.response(200, status=status)
+ except (ValueError, TypeError):
+ return self.response_404()
+
+ @expose("//cancel", methods=("POST",))
+ @protect()
+ @safe
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.cancel",
+ log_to_statsd=False,
+ )
+ def cancel(self, task_uuid: str) -> Response:
+ """Cancel a task.
+ ---
+ post:
+ summary: Cancel a task
+ description: >
+ Cancel a task. The behavior depends on task scope and subscriber
+ count:
+
+ - **Private tasks**: Aborts the task
+ - **Shared tasks (single subscriber)**: Aborts the task
+ - **Shared tasks (multiple subscribers)**: Removes current user's
+ subscription; the task continues for other subscribers
+ - **Shared tasks with force=true (admin only)**: Aborts task for
+ all subscribers
+
+ The `action` field in the response indicates what happened:
+ - `aborted`: Task was terminated
+ - `unsubscribed`: User was removed from task (task continues)
+ parameters:
+ - in: path
+ schema:
+ type: string
+ format: uuid
+ name: task_uuid
+ description: The UUID of the task to cancel
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskCancelRequestSchema'
+ responses:
+ 200:
+ description: Task cancelled successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskCancelResponseSchema'
+ 401:
+ $ref: '#/components/responses/401'
+ 403:
+ $ref: '#/components/responses/403'
+ 404:
+ $ref: '#/components/responses/404'
+ 422:
+ $ref: '#/components/responses/422'
+ """
+ return self._execute_cancel(task_uuid)
+
+ def _execute_cancel(self, task_uuid_str: str) -> Response:
+ """Execute the cancel operation with error handling."""
+ try:
+ task_uuid = UUID(task_uuid_str)
+
+ command, updated_task = self._run_cancel_command(task_uuid)
+ return self._build_cancel_response(command, updated_task)
+
+ except TaskNotFoundError:
+ return self.response_404()
+ except (TaskForbiddenError, TaskPermissionDeniedError) as ex:
+ if isinstance(ex, TaskPermissionDeniedError):
+ logger.warning(
+ "Permission denied cancelling task %s: %s",
+ task_uuid_str,
+ str(ex),
+ )
+ return self.response_403()
+ except TaskNotAbortableError as ex:
+ logger.warning("Task %s is not cancellable: %s", task_uuid_str, str(ex))
+ return self.response_422(message=str(ex))
+ except TaskAbortFailedError as ex:
+ logger.error(
+ "Error cancelling task %s: %s", task_uuid_str, str(ex), exc_info=True
+ )
+ return self.response_422(message=str(ex))
+ except (ValueError, TypeError):
+ return self.response_404()
+
+ def _run_cancel_command(self, task_uuid: UUID) -> tuple[CancelTaskCommand, "Task"]:
+ """Parse request and run the cancel command."""
+ from flask import request
+
+ force = False
+ # Use get_json with silent=True to handle missing Content-Type gracefully
+ json_data = request.get_json(silent=True)
+ if json_data:
+ parsed = self.cancel_request_schema.load(json_data)
+ force = parsed.get("force", False)
+
+ command = CancelTaskCommand(task_uuid, force=force)
+ updated_task = command.run()
+ return command, updated_task
+
+ def _build_cancel_response(
+ self, command: CancelTaskCommand, updated_task: "Task"
+ ) -> Response:
+ """Build the response for a successful cancel operation."""
+ action = command.action_taken
+ message = (
+ "Task cancelled"
+ if action == "aborted"
+ else "You have been removed from this task"
+ )
+ result = {
+ "message": message,
+ "action": action,
+ "task": self.show_model_schema.dump(updated_task),
+ }
+ return self.response(200, **result)
+
+ @expose("/related/subscribers", methods=("GET",))
+ @protect()
+ @safe
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+ ".related_subscribers",
+ log_to_statsd=False,
+ )
+ def related_subscribers(self) -> Response:
+ """Get users who are subscribers to tasks.
+ ---
+ get:
+ summary: Get related subscribers
+ description: >
+ Returns a list of users who are subscribed to tasks, for use in filter
+ dropdowns. Results can be filtered by a search query parameter.
+ parameters:
+ - in: query
+ schema:
+ type: string
+ name: q
+ description: Search query to filter subscribers by name
+ responses:
+ 200:
+ description: List of subscribers
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ count:
+ type: integer
+ description: Total number of matching subscribers
+ result:
+ type: array
+ items:
+ type: object
+ properties:
+ value:
+ type: integer
+ description: User ID
+ text:
+ type: string
+ description: User display name
+ 401:
+ $ref: '#/components/responses/401'
+ """
+ from flask import request
+
+ from superset import db, security_manager
+ from superset.models.task_subscribers import TaskSubscriber
+
+ # Get search query
+
+ # Get user model
+ user_model = security_manager.user_model
+
+ # Query distinct users who are task subscribers
+ query = (
+ db.session.query(user_model.id, user_model.first_name, user_model.last_name)
+ .join(TaskSubscriber, user_model.id == TaskSubscriber.user_id)
+ .distinct()
+ )
+
+ # Apply search filter if provided
+ if search_query := request.args.get("q", ""):
+ like_value = f"%{search_query}%"
+ query = query.filter(
+ (user_model.first_name + " " + user_model.last_name).ilike(like_value)
+ | user_model.username.ilike(like_value)
+ )
+
+ # Order by name
+ query = query.order_by(user_model.first_name, user_model.last_name)
+
+ # Limit results
+ query = query.limit(100)
+
+ # Execute and format results
+ results = query.all()
+
+ return self.response(
+ 200,
+ count=len(results),
+ result=[
+ {
+ "value": user_id,
+ "text": f"{first_name or ''} {last_name or ''}".strip()
+ or str(user_id),
+ }
+ for user_id, first_name, last_name in results
+ ],
+ )
diff --git a/superset/tasks/constants.py b/superset/tasks/constants.py
new file mode 100644
index 000000000000..231859e23cd7
--- /dev/null
+++ b/superset/tasks/constants.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constants for the Global Task Framework (GTF)."""
+
+from superset_core.api.tasks import TaskStatus
+
+# Terminal states: Task execution has ended and dedup_key slot is freed
+TERMINAL_STATES: frozenset[str] = frozenset(
+ {
+ TaskStatus.SUCCESS.value,
+ TaskStatus.FAILURE.value,
+ TaskStatus.ABORTED.value,
+ TaskStatus.TIMED_OUT.value,
+ }
+)
+
+# Active states: Task is still in progress and dedup_key is reserved
+ACTIVE_STATES: frozenset[str] = frozenset(
+ {
+ TaskStatus.PENDING.value,
+ TaskStatus.IN_PROGRESS.value,
+ TaskStatus.ABORTING.value,
+ }
+)
+
+# Abortable states: Task can be aborted (for pending or abortable in-progress)
+ABORTABLE_STATES: frozenset[str] = frozenset(
+ {
+ TaskStatus.PENDING.value,
+ TaskStatus.IN_PROGRESS.value,
+ }
+)
+
+# Abort-related states: Task is being or has been aborted
+ABORT_STATES: frozenset[str] = frozenset(
+ {
+ TaskStatus.ABORTING.value,
+ TaskStatus.ABORTED.value,
+ }
+)
diff --git a/superset/tasks/context.py b/superset/tasks/context.py
new file mode 100644
index 000000000000..1bce61462b23
--- /dev/null
+++ b/superset/tasks/context.py
@@ -0,0 +1,673 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Concrete TaskContext implementation for GTF"""
+
+import logging
+import threading
+import time
+import traceback
+from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
+
+from flask import current_app
+from superset_core.api.tasks import (
+ TaskContext as CoreTaskContext,
+ TaskProperties,
+ TaskStatus,
+)
+
+from superset.stats_logger import BaseStatsLogger
+from superset.tasks.constants import ABORT_STATES
+from superset.tasks.utils import progress_update
+
+if TYPE_CHECKING:
+ from superset.models.tasks import Task
+ from superset.tasks.manager import AbortListener
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+class TaskContext(CoreTaskContext):
+ """
+ Concrete implementation of TaskContext for the Global Async Task Framework.
+
+ Provides write-only access to task state. Tasks use this context to update
+ their progress and payload, and check for cancellation. Tasks should not
+ need to read their own state - they are the source of state, not consumers.
+ """
+
+ # Type alias for handler failures: (handler_type, exception, stack_trace)
+ HandlerFailure = tuple[str, Exception, str]
+
+ def __init__(self, task: "Task") -> None:
+ """
+ Initialize TaskContext with a pre-fetched task entity.
+
+ The task entity must be pre-fetched by the caller (executor) to ensure
+ caching works correctly and to enforce the pattern of single initial fetch.
+
+ :param task: Pre-fetched Task entity (required)
+ """
+ self._task_uuid = task.uuid
+ self._cleanup_handlers: list[Callable[[], None]] = []
+ self._abort_handlers: list[Callable[[], None]] = []
+ self._abort_listener: "AbortListener | None" = None
+ self._abort_detected = False
+ self._abort_handlers_completed = False # Track if all abort handlers finished
+ self._execution_completed = False # Set by executor after task work completes
+
+ # Collected handler failures for unified reporting
+ self._handler_failures: list[TaskContext.HandlerFailure] = []
+
+ # Timeout timer state
+ self._timeout_timer: threading.Timer | None = None
+ self._timeout_triggered = False
+
+ # Throttling state for update_task()
+ # These manage the minimum interval between DB writes
+ self._last_db_write_time: float | None = None
+ self._has_pending_updates: bool = False
+ self._deferred_flush_timer: threading.Timer | None = None
+ self._throttle_lock = threading.Lock()
+
+ # Cached task entity - avoids repeated DB fetches.
+ # Updated only by _refresh_task() when checking external state changes.
+ self._task: "Task" = task
+
+ # In-memory state caches - authoritative during execution
+ # These are initialized from the task entity and updated locally
+ # before being written to DB via targeted SQL updates.
+ # We copy the dicts to avoid mutating the Task's cached instances.
+ self._properties_cache: TaskProperties = cast(
+ TaskProperties, {**task.properties_dict}
+ )
+ self._payload_cache: dict[str, Any] = {**task.payload_dict}
+
+ # Store Flask app reference for background thread database access
+ # Use _get_current_object() to get actual app, not proxy
+ try:
+ self._app = current_app._get_current_object()
+ # Cache stats logger to avoid repeated config lookups
+ self._stats_logger: BaseStatsLogger = current_app.config.get(
+ "STATS_LOGGER", BaseStatsLogger()
+ )
+ except RuntimeError:
+ # Handle case where app context isn't available (e.g., tests)
+ self._app = None
+ self._stats_logger = BaseStatsLogger()
+
+ def _refresh_task(self) -> "Task":
+ """
+ Force refresh the task entity from the database.
+
+ Use this method when you need to check for external state changes,
+ such as whether the task has been aborted by a concurrent operation.
+
+ This method:
+ - Fetches fresh task entity from database
+ - Updates the cached _task reference
+ - Updates properties/payload caches from fresh data
+
+ :returns: Fresh task entity from database
+ :raises ValueError: If task is not found
+ """
+ from superset.daos.tasks import TaskDAO
+
+ fresh_task = TaskDAO.find_one_or_none(uuid=self._task_uuid)
+ if not fresh_task:
+ raise ValueError(f"Task {self._task_uuid} not found")
+
+ self._task = fresh_task
+
+ # Update caches from fresh data (copy to avoid mutating Task's cache)
+ self._properties_cache = cast(TaskProperties, {**fresh_task.properties_dict})
+ self._payload_cache = {**fresh_task.payload_dict}
+
+ return self._task
+
+ def update_task(
+ self,
+ progress: float | int | tuple[int, int] | None = None,
+ payload: dict[str, object] | None = None,
+ ) -> None:
+ """
+ Update task progress and/or payload atomically.
+
+ All parameters are optional. Payload is merged with existing cached data.
+ In-memory caches are always updated immediately, but DB writes are
+ throttled according to TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL to prevent
+ excessive database load from eager tasks.
+
+ Progress can be specified in three ways:
+ - float (0.0-1.0): Percentage only, e.g., 0.5 means 50%
+ - int: Count only (total unknown), e.g., 42 means "42 items processed"
+ - tuple[int, int]: Count and total, e.g., (3, 100) means "3 of 100"
+ The percentage is automatically computed from count/total.
+
+ :param progress: Progress value, or None to leave unchanged
+ :param payload: Payload data to merge (dict), or None to leave unchanged
+ """
+ has_updates = False
+
+ # Handle progress updates - always update in-memory cache
+ if progress is not None:
+ progress_props = progress_update(progress)
+ if progress_props:
+ # Merge progress into cached properties
+ self._properties_cache.update(progress_props)
+ has_updates = True
+ else:
+ # Invalid progress format - progress_update returns empty dict
+ logger.warning(
+ "Invalid progress value for task %s: %s "
+ "(expected float, int, or tuple[int, int])",
+ self._task_uuid,
+ progress,
+ )
+
+ # Handle payload updates - always update in-memory cache
+ if payload is not None:
+ # Merge payload into cached payload
+ self._payload_cache.update(payload)
+ has_updates = True
+
+ if not has_updates:
+ return
+
+ # Get throttle interval from config
+ throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
+
+ # If throttling is disabled (0), write immediately
+ if throttle_interval <= 0:
+ self._write_to_db()
+ return
+
+ # Apply throttling with deferred flush
+ with self._throttle_lock:
+ now = time.time()
+
+ if self._last_db_write_time is None:
+ # First update - write immediately
+ self._write_to_db()
+ self._last_db_write_time = now
+ elif now - self._last_db_write_time >= throttle_interval:
+ # Throttle window has passed - write immediately
+ self._cancel_deferred_flush_timer()
+ self._write_to_db()
+ self._last_db_write_time = now
+ self._has_pending_updates = False
+ else:
+ # Within throttle window - defer the write
+ self._has_pending_updates = True
+ self._stats_logger.incr("gtf.task.update_deferred")
+
+ # Start deferred flush timer if not already running
+ if self._deferred_flush_timer is None:
+ remaining_time = throttle_interval - (
+ now - self._last_db_write_time
+ )
+ self._deferred_flush_timer = threading.Timer(
+ remaining_time, self._deferred_flush
+ )
+ self._deferred_flush_timer.daemon = True
+ self._deferred_flush_timer.start()
+
+ def _write_to_db(self) -> None:
+ """
+ Write current cached state to database.
+
+ This method performs the actual DB write using InternalUpdateTaskCommand.
+ It writes whatever is in the caches at the time of the call.
+ """
+ from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
+
+ self._stats_logger.incr("gtf.task.update_write")
+
+ InternalUpdateTaskCommand(
+ task_uuid=self._task_uuid,
+ properties=self._properties_cache,
+ payload=self._payload_cache,
+ ).run()
+
+ def _deferred_flush(self) -> None:
+ """
+ Timer callback that flushes pending updates at end of throttle window.
+
+ This ensures the UI never shows stale progress for longer than the
+ throttle interval.
+ """
+ with self._throttle_lock:
+ self._deferred_flush_timer = None
+
+ if self._has_pending_updates:
+ # Need app context for DB operations in timer thread
+ if self._app:
+ with self._app.app_context():
+ self._write_to_db()
+ else:
+ self._write_to_db()
+
+ self._last_db_write_time = time.time()
+ self._has_pending_updates = False
+
+ def _cancel_deferred_flush_timer(self) -> None:
+ """Cancel the deferred flush timer if running."""
+ if self._deferred_flush_timer is not None:
+ self._deferred_flush_timer.cancel()
+ self._deferred_flush_timer = None
+
+ def on_cleanup(self, handler: Callable[[], None]) -> Callable[[], None]:
+ """
+ Register a cleanup handler that runs when the task ends.
+
+ Cleanup handlers are called when the task completes (success),
+ fails with an error, or is aborted. Multiple handlers can be
+ registered and will execute in LIFO order (last registered runs first).
+
+ Can be used as a decorator:
+ @ctx.on_cleanup
+ def cleanup():
+ logger.info("Task ended")
+
+ Or called directly:
+ ctx.on_cleanup(lambda: logger.info("Task ended"))
+
+ :param handler: Cleanup function to register
+ :returns: The handler (for decorator compatibility)
+ """
+ self._cleanup_handlers.append(handler)
+ return handler
+
+ def on_abort(self, handler: Callable[[], None]) -> Callable[[], None]:
+ """
+ Register abort handler with automatic background listening.
+
+ When the first handler is registered:
+ 1. Sets is_abortable=true in the database (marks task as abortable)
+ 2. Background abort listener starts automatically (pub/sub or polling)
+
+ The handler will be called automatically when an abort is detected.
+
+ :param handler: Callback function to execute when abort is detected
+ :returns: The handler (for decorator compatibility)
+
+ Example:
+ @ctx.on_abort
+ def handle_abort():
+ logger.info("Task was aborted!")
+ cleanup_partial_work()
+
+ Note:
+ The handler executes in a background thread when abort is detected.
+ The task code continues running unless the handler does something
+ to stop it (e.g., raises an exception, modifies shared state, etc.)
+ """
+ is_first_handler = len(self._abort_handlers) == 0
+ self._abort_handlers.append(handler)
+
+ if is_first_handler:
+ # Mark task as abortable in database
+ self._set_abortable()
+
+ # Auto-start abort listener when first handler is registered
+ interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
+ self._start_abort_listener(interval)
+
+ return handler
+
+ def _set_abortable(self) -> None:
+ """Mark the task as abortable (abort handler has been registered)."""
+ from superset.commands.tasks.internal_update import InternalUpdateTaskCommand
+
+ # Update local cache and write to DB
+ self._properties_cache["is_abortable"] = True
+ InternalUpdateTaskCommand(
+ task_uuid=self._task_uuid,
+ properties=self._properties_cache,
+ ).run()
+
+ def _start_abort_listener(self, interval: float) -> None:
+ """
+ Start background abort listener via TaskManager.
+
+ Uses Redis pub/sub if available, otherwise falls back to database polling.
+ The implementation is encapsulated in TaskManager.
+ """
+ if self._abort_listener is not None:
+ return # Already listening
+
+ from superset.tasks.manager import TaskManager
+
+ self._abort_listener = TaskManager.listen_for_abort(
+ task_uuid=self._task_uuid,
+ callback=self._on_abort_detected,
+ poll_interval=interval,
+ app=self._app,
+ )
+
+ def _on_abort_detected(self) -> None:
+ """
+ Callback invoked by TaskManager when abort is detected.
+
+ Triggers all registered abort handlers.
+ """
+ if self._abort_detected:
+ return # Already handled
+
+ # Check if task execution has already completed (late abort race).
+ # Executor sets _execution_completed after task work finishes.
+ if self._execution_completed:
+ logger.info(
+ "Abort detected for task %s but execution already completed",
+ self._task_uuid,
+ )
+ return
+
+ self._abort_detected = True
+ logger.info("Abort detected for task %s", self._task_uuid)
+ self._trigger_abort_handlers()
+
+ def mark_execution_completed(self) -> None:
+ """
+ Mark that the task's main execution has completed.
+
+ Called by the executor after the task function returns (successfully
+ or with an exception). This prevents late abort callbacks from running
+ handlers when the task work has already finished. Cleanup handlers
+ still run after this is set.
+ """
+ self._execution_completed = True
+
+ def start_abort_polling(self, interval: float | None = None) -> None:
+ """
+ Start background abort listener.
+
+ This method is kept for backwards compatibility. It now delegates
+ to _start_abort_listener which uses TaskManager.
+
+ :param interval: Polling interval in seconds (uses config default if None)
+ """
+ if interval is None:
+ interval = current_app.config["TASK_ABORT_POLLING_DEFAULT_INTERVAL"]
+ self._start_abort_listener(interval)
+
+ def _trigger_abort_handlers(self) -> None:
+ """
+ Execute all registered abort handlers (called by polling thread or cleanup).
+
+ All handlers are attempted even if some fail (best-effort cleanup).
+ Failures are collected in self._handler_failures for unified reporting.
+
+ Note: This method never writes to DB directly. All failures are collected
+ and written by _run_cleanup() in the executor's finally block, ensuring
+ abort and cleanup handler failures are combined into a single record.
+ """
+ for handler in reversed(self._abort_handlers):
+ try:
+ handler()
+ except Exception as ex:
+ stack_trace = traceback.format_exc()
+ logger.error(
+ "Abort handler failed for task %s: %s",
+ self._task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+ self._handler_failures.append(("abort", ex, stack_trace))
+
+ # Check if all abort handlers completed successfully
+ abort_failures = [f for f in self._handler_failures if f[0] == "abort"]
+ if not abort_failures:
+ self._abort_handlers_completed = True
+
+ def _write_handler_failures_to_db(self) -> None:
+ """
+ Write collected handler failures to the database.
+
+ Combines all failures (abort + cleanup) into a single error record.
+ If the task already has an error (e.g., task function threw exception),
+ handler failures are APPENDED to preserve the original error context.
+ """
+ from superset.commands.tasks.update import UpdateTaskCommand
+
+ if not self._handler_failures:
+ return
+
+ # Build error message from all handler failures
+ error_messages = [str(ex) for _, ex, _ in self._handler_failures]
+ handler_types = {htype for htype, _, _ in self._handler_failures}
+
+ if len(self._handler_failures) == 1:
+ htype, ex, handler_stack_trace = self._handler_failures[0]
+ handler_error_msg = (
+ f"{htype.capitalize()} handler failed: {error_messages[0]}"
+ )
+ handler_exception_type = type(ex).__name__
+ else:
+ # Multiple failures
+ handler_error_msg = f"Handler(s) failed: {'; '.join(error_messages)}"
+ if handler_types == {"abort"}:
+ handler_exception_type = "MultipleAbortHandlerFailures"
+ elif handler_types == {"cleanup"}:
+ handler_exception_type = "MultipleCleanupHandlerFailures"
+ else:
+ handler_exception_type = "MultipleHandlerFailures"
+
+ # Combine stack traces with clear separators
+ handler_stack_trace = "\n--- Next handler failure ---\n".join(
+ f"[{htype}:{type(ex).__name__}]\n{trace}"
+ for htype, ex, trace in self._handler_failures
+ )
+
+ if self._app:
+ with self._app.app_context():
+ # Check if task already has an error (preserve original context)
+ task = self._task
+ original_error = task.properties_dict.get("error_message")
+ original_type = task.properties_dict.get("exception_type")
+ original_trace = task.properties_dict.get("stack_trace")
+
+ if original_error:
+ # Append handler failures to original error
+ error_msg = f"{original_error} | {handler_error_msg}"
+ exception_type = (
+ f"{original_type}+{handler_exception_type}"
+ if original_type
+ else handler_exception_type
+ )
+ stack_trace = (
+ f"{original_trace}\n\n"
+ f"=== Handler failures during cleanup ===\n\n"
+ f"{handler_stack_trace}"
+ if original_trace
+ else handler_stack_trace
+ )
+ else:
+ # No original error, just use handler failures
+ error_msg = handler_error_msg
+ exception_type = handler_exception_type
+ stack_trace = handler_stack_trace
+
+ # Update task with combined error info
+ UpdateTaskCommand(
+ self._task_uuid,
+ status=TaskStatus.FAILURE.value,
+ properties={
+ "error_message": error_msg,
+ "exception_type": exception_type,
+ "stack_trace": stack_trace,
+ },
+ skip_security_check=True,
+ ).run()
+
+ # Clear failures after writing
+ self._handler_failures = []
+
+ def stop_abort_polling(self) -> None:
+ """Stop the background abort listener."""
+ if self._abort_listener is not None:
+ self._abort_listener.stop()
+ self._abort_listener = None
+
+ def start_timeout_timer(self, timeout_seconds: int) -> None:
+ """
+ Start a timeout timer that triggers abort when elapsed.
+
+ Called by execute_task when task transitions to IN_PROGRESS.
+ Timer only triggers abort handlers if task is abortable.
+
+ :param timeout_seconds: Timeout duration in seconds
+ """
+ if self._timeout_timer is not None:
+ return # Already started
+
+ def on_timeout() -> None:
+ if self._abort_detected:
+ return # Already aborting
+
+ self._timeout_triggered = True
+
+ # Check if task has abort handler (requires app context)
+ if not self._app:
+ logger.error(
+ "Timeout fired for task %s but no app context available",
+ self._task_uuid,
+ )
+ return
+
+ with self._app.app_context():
+ from superset.commands.tasks.update import UpdateTaskCommand
+
+ task = self._task
+ if task.properties_dict.get("is_abortable", False):
+ logger.info(
+ "Timeout reached for task %s after %d seconds - "
+ "transitioning to ABORTING and triggering abort handlers",
+ self._task_uuid,
+ timeout_seconds,
+ )
+ # Set status to ABORTING (same as user abort)
+ # The executor will determine TIMED_OUT vs FAILURE based on
+ # whether handlers complete successfully
+ UpdateTaskCommand(
+ self._task_uuid,
+ status=TaskStatus.ABORTING.value,
+ properties={"error_message": "Task timed out"},
+ skip_security_check=True,
+ ).run()
+
+ # Trigger abort handlers for cleanup
+ self._on_abort_detected()
+ else:
+ # No abort handler - just log warning
+ logger.warning(
+ "Timeout reached for task %s after %d seconds, but no "
+ "abort handler is registered. Task will continue running.",
+ self._task_uuid,
+ timeout_seconds,
+ )
+
+ self._timeout_timer = threading.Timer(timeout_seconds, on_timeout)
+ # Timer is daemon so it won't prevent process exit. If the worker dies,
+ # the task is already in an inconsistent state (stuck IN_PROGRESS) that
+ # requires external recovery (orphan detection). A non-daemon timer with
+ # long timeouts (hours) would block graceful worker shutdown.
+ self._timeout_timer.daemon = True
+ self._timeout_timer.start()
+ logger.debug(
+ "Started timeout timer for task %s: %d seconds",
+ self._task_uuid,
+ timeout_seconds,
+ )
+
+ def stop_timeout_timer(self) -> None:
+ """Cancel the timeout timer if running."""
+ if self._timeout_timer is not None:
+ self._timeout_timer.cancel()
+ self._timeout_timer = None
+
+ @property
+ def timeout_triggered(self) -> bool:
+ """Check if the timeout was triggered."""
+ return self._timeout_triggered
+
+ @property
+ def abort_handlers_completed(self) -> bool:
+ """Check if all abort handlers have completed successfully."""
+ return self._abort_handlers_completed
+
+ def _run_cleanup(self) -> None:
+ """
+ Run cleanup handlers (called by executor in finally block).
+
+ This runs:
+ 1. Flushes any pending throttled updates to ensure final state is persisted
+ 2. Abort handlers if task was aborting/aborted (but not yet detected)
+ 3. All cleanup handlers (always)
+
+ All handler failures (abort + cleanup) are collected and written to DB
+ as a unified error record at the end.
+ """
+ # Flush any pending throttled updates before cleanup
+ with self._throttle_lock:
+ self._cancel_deferred_flush_timer()
+ if self._has_pending_updates:
+ self._write_to_db()
+ self._has_pending_updates = False
+
+ # Stop abort listener and timeout timer
+ self.stop_abort_polling()
+ self.stop_timeout_timer()
+
+ # If aborting/aborted but handlers haven't run yet, run them now
+ # (This catches the case where task ended before listener detected abort)
+ if self._app:
+ with self._app.app_context():
+ task = self._task
+ if task.status in ABORT_STATES and not self._abort_detected:
+ self._trigger_abort_handlers()
+ else:
+ # Fallback without app context
+ try:
+ task = self._task
+ if task.status in ABORT_STATES and not self._abort_detected:
+ self._trigger_abort_handlers()
+ except Exception as ex:
+ logger.warning(
+ "Could not check abort status during cleanup for task %s: %s",
+ self._task_uuid,
+ str(ex),
+ )
+
+ # Always run cleanup handlers, collecting failures
+ for handler in reversed(self._cleanup_handlers):
+ try:
+ handler()
+ except Exception as ex:
+ stack_trace = traceback.format_exc()
+ logger.error(
+ "Cleanup handler failed for task %s: %s",
+ self._task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+ self._handler_failures.append(("cleanup", ex, stack_trace))
+
+ # Write all collected failures (abort + cleanup) to DB as unified record
+ if self._handler_failures:
+ self._write_handler_failures_to_db()
diff --git a/superset/tasks/decorators.py b/superset/tasks/decorators.py
new file mode 100644
index 000000000000..1c69280f80fa
--- /dev/null
+++ b/superset/tasks/decorators.py
@@ -0,0 +1,609 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Decorators for the Global Task Framework (GTF)"""
+
+from __future__ import annotations
+
+import inspect
+import logging
+from typing import Any, Callable, cast, Generic, ParamSpec, TYPE_CHECKING, TypeVar
+
+from superset_core.api.tasks import TaskOptions, TaskScope, TaskStatus
+
+from superset import is_feature_enabled
+from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
+from superset.tasks.ambient_context import use_context
+from superset.tasks.constants import TERMINAL_STATES
+from superset.tasks.context import TaskContext
+from superset.tasks.manager import TaskManager
+from superset.tasks.registry import TaskRegistry
+from superset.tasks.utils import generate_random_task_key
+from superset.utils.core import get_user_id
+
+if TYPE_CHECKING:
+ from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def task(
+ func: Callable[P, R] | None = None,
+ *,
+ name: str | None = None,
+ scope: TaskScope = TaskScope.PRIVATE,
+ timeout: int | None = None,
+) -> Callable[[Callable[P, R]], "TaskWrapper[P]"] | "TaskWrapper[P]":
+ """
+ Decorator to register a task with default scope.
+
+ Can be used with or without parentheses:
+ @task
+ def my_func(): ...
+
+ @task()
+ def my_func(): ...
+
+ @task(name="custom_name", scope=TaskScope.SHARED)
+ def my_func(): ...
+
+ @task(timeout=300) # 5-minute timeout
+ def long_running_func(): ...
+
+ Args:
+ func: The function to decorate (when used without parentheses).
+ name: Optional unique task name (e.g., "superset.generate_thumbnail").
+ If not provided, uses the function name as the task name.
+ scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM).
+ Defaults to TaskScope.PRIVATE.
+ timeout: Optional timeout in seconds. When the timeout is reached,
+ abort handlers are triggered if registered. Can be overridden
+ at call time via TaskOptions(timeout=...).
+
+ Usage:
+ # Private task (default scope) - no parentheses
+ @task
+ def my_async_func(chart_id: int) -> None:
+ ctx = get_context()
+ ...
+
+ # Named task with shared scope
+ @task(name="generate_report", scope=TaskScope.SHARED)
+ def generate_expensive_report(report_id: int) -> None:
+ ctx = get_context()
+ ...
+
+ # System task (admin-only)
+ @task(scope=TaskScope.SYSTEM)
+ def cleanup_task() -> None:
+ ctx = get_context()
+ ...
+
+ # Task with timeout
+ @task(timeout=300)
+ def long_task() -> None:
+ ctx = get_context()
+
+ @ctx.on_abort
+ def handle_abort():
+ # Called when timeout is reached or user cancels
+ ...
+
+ Note:
+ Both direct calls and .schedule() return Task, regardless of the
+ original function's return type. The decorated function's return value
+ is discarded; only side effects and context updates matter.
+ """
+
+ def decorator(f: Callable[P, R]) -> "TaskWrapper[P]":
+ # Use function name if no name provided
+ task_name = name if name is not None else f.__name__
+
+ # Create default options with no scope (scope is now in decorator)
+ default_options = TaskOptions()
+
+ # Validate function signature - must not have ctx or options params
+ sig = inspect.signature(f)
+ forbidden = {"ctx", "options"}
+ if any(param in forbidden for param in sig.parameters):
+ raise TypeError(
+ f"Task function {f.__name__} must not define 'ctx' or "
+ "'options' parameters. "
+ f"Use get_context() instead for ambient context access."
+ )
+
+ # Register task
+ TaskRegistry.register(task_name, f)
+
+ # Create wrapper with schedule() method, default options, scope, and timeout
+ wrapper = TaskWrapper(task_name, f, default_options, scope, timeout)
+
+ # Preserve signature for introspection
+ wrapper.__signature__ = sig # type: ignore[attr-defined]
+
+ return wrapper
+
+ if func is None:
+ # Called with parentheses: @task() or @task(name="foo", scope=TaskScope.SHARED)
+ return decorator
+ else:
+ # Called without parentheses: @task
+ return decorator(func)
+
+
+class TaskWrapper(Generic[P]):
+ """
+ Wrapper for task functions that provides .schedule() method.
+
+ Both direct calls and .schedule() return Task. The original function's
+ return value is discarded.
+
+ Direct calls execute synchronously, .schedule() runs async via Celery.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ func: Callable[P, R],
+ default_options: TaskOptions,
+ scope: TaskScope = TaskScope.PRIVATE,
+ default_timeout: int | None = None,
+ ) -> None:
+ self.name = name
+ self.func = func
+ self.default_options = default_options
+ self.scope = scope
+ self.default_timeout = default_timeout
+ self.__name__ = func.__name__
+ self.__doc__ = func.__doc__
+ self.__module__ = func.__module__
+
+ # Patch schedule.__signature__ to mirror function + options parameter
+ # This enables proper IDE support and introspection
+ sig = inspect.signature(func)
+ params = list(sig.parameters.values())
+ # Add keyword-only options parameter
+ params.append(
+ inspect.Parameter(
+ "options",
+ inspect.Parameter.KEYWORD_ONLY,
+ default=None,
+ annotation=TaskOptions | None,
+ )
+ )
+ self.schedule.__func__.__signature__ = sig.replace( # type: ignore[attr-defined]
+ parameters=params, return_annotation="Task"
+ )
+
+ def _merge_options(self, override_options: TaskOptions | None) -> TaskOptions:
+ """
+ Merge decorator defaults with call-time overrides.
+
+ Call-time options take precedence over decorator defaults.
+ For timeout, an explicit None in TaskOptions disables the decorator timeout.
+
+ Args:
+ override_options: Options provided at call time, or None
+
+ Returns:
+ Merged TaskOptions with overrides applied
+ """
+ if override_options is None:
+ return TaskOptions(
+ task_key=self.default_options.task_key,
+ task_name=self.default_options.task_name,
+ timeout=self.default_timeout, # Use decorator default
+ )
+
+ # Merge: use override if provided, otherwise use default
+ # For timeout: if override_options.timeout is explicitly set (even to None),
+ # use it; otherwise fall back to decorator default
+ return TaskOptions(
+ task_key=override_options.task_key or self.default_options.task_key,
+ task_name=override_options.task_name or self.default_options.task_name,
+ timeout=override_options.timeout
+ if override_options.timeout is not None
+ else self.default_timeout,
+ )
+
+ def _validate_task(self, options: TaskOptions) -> None:
+ """
+ Validate task configuration before execution.
+
+ Args:
+ options: Merged task options to validate
+
+ Raises:
+ ValueError: If validation fails
+ """
+ # Shared tasks must have an explicit task_key for deduplication
+ if self.scope == TaskScope.SHARED and options.task_key is None:
+ raise ValueError(
+ f"Shared task '{self.name}' requires an explicit task_key in "
+ "TaskOptions for deduplication. Without a task_key, each "
+ "invocation creates a separate task with a random UUID, "
+ "defeating the purpose of shared tasks."
+ )
+
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
+ """
+ Call the function synchronously.
+
+ This is invoked when you call the decorated function directly:
+ task = generate_thumbnail(chart_id) # Blocks until completion
+
+ Flow:
+ 1. Submit task (create new or join existing via deduplication)
+ 2. If joining existing task: wait for it to complete (blocking)
+ 3. If new task: execute inline and return completed task
+
+ Sync execution always blocks until completion - even when joining an
+ existing task that's running in another process/worker.
+
+ Returns the Task entity in terminal state (SUCCESS, FAILURE, etc.).
+
+ Raises:
+ GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
+ ValueError: If task validation fails
+ TimeoutError: If timeout expires while waiting for existing task
+ """
+ from superset.commands.tasks.submit import SubmitTaskCommand
+
+ if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
+ raise GlobalTaskFrameworkDisabledError()
+
+ # Extract and merge options (decorator defaults + call-time overrides)
+ override_options = cast(TaskOptions | None, kwargs.pop("options", None))
+ options = self._merge_options(override_options)
+
+ # Validate task configuration
+ self._validate_task(options)
+
+ # Extract task_name and task_key from merged options, scope from decorator
+ task_name = (
+ options.task_name or f"{self.name}:{generate_random_task_key()[:50]}"
+ )
+ task_key = options.task_key or generate_random_task_key()
+ scope = self.scope # Use scope from decorator
+
+ # Build properties with execution_mode and timeout
+ properties: dict[str, str | int] = {"execution_mode": "sync"}
+ if options.timeout:
+ properties["timeout"] = options.timeout
+
+ # Submit task - may create new or join existing
+ task, is_new = SubmitTaskCommand(
+ {
+ "task_type": self.name,
+ "task_key": task_key,
+ "task_name": task_name,
+ "scope": scope.value,
+ "properties": properties,
+ "user_id": get_user_id(),
+ }
+ ).run_with_info()
+
+ # If joining existing task, wait for it to complete
+ if not is_new:
+ return self._wait_for_existing_task(task, options.timeout)
+
+ # New task - execute inline
+ return self._execute_inline(task, options, args, kwargs)
+
+ def _wait_for_existing_task(self, task: "Task", timeout: int | None) -> "Task":
+ """
+ Wait for an existing task to complete.
+
+ Called when sync execution joins a pre-existing task via deduplication.
+ Blocks until the task reaches a terminal state.
+
+ :param task: The existing task to wait for
+ :param timeout: Maximum time to wait in seconds (None = no limit)
+ :returns: Task in terminal state
+ :raises TimeoutError: If timeout expires before task completes
+ """
+ from flask import current_app
+
+ from superset.daos.tasks import TaskDAO
+
+ # Check if already in terminal state
+ if task.status in TERMINAL_STATES:
+ logger.info(
+ "Joined already-completed task %s (uuid=%s, status=%s)",
+ self.name,
+ task.uuid,
+ task.status,
+ )
+ return task
+
+ # Wait for the existing task to complete
+ logger.info(
+ "Joined active task %s (uuid=%s, status=%s), waiting for completion",
+ self.name,
+ task.uuid,
+ task.status,
+ )
+
+ try:
+ app = current_app._get_current_object()
+ except RuntimeError:
+ app = None
+
+ try:
+ task = TaskManager.wait_for_completion(
+ task_uuid=task.uuid,
+ timeout=float(timeout) if timeout else None,
+ poll_interval=1.0,
+ app=app,
+ )
+ logger.info(
+ "Task %s (uuid=%s) completed with status=%s",
+ self.name,
+ task.uuid,
+ task.status,
+ )
+ return task
+
+ except TimeoutError:
+ logger.warning(
+ "Timeout waiting for task %s (uuid=%s)",
+ self.name,
+ task.uuid,
+ )
+ # Return task in current state (caller can check status)
+ refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
+ return refreshed if refreshed else task
+
+ def _execute_inline(
+ self,
+ task: "Task",
+ options: TaskOptions,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ) -> "Task":
+ """
+ Execute task function inline (synchronously).
+
+ Called when this is a new task (not joining existing).
+ Uses atomic conditional status transitions for race-safe execution.
+
+ :param task: The newly created task
+ :param options: Merged task options
+ :param args: Positional arguments for the task function
+ :param kwargs: Keyword arguments for the task function
+ :returns: Task in terminal state
+ """
+ from superset.commands.tasks.internal_update import (
+ InternalStatusTransitionCommand,
+ )
+ from superset.daos.tasks import TaskDAO
+ from superset.tasks.constants import ABORT_STATES
+
+ # PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
+ # (Matches async flow in scheduler.py)
+ if task.status in ABORT_STATES:
+ logger.info(
+ "Task %s (uuid=%s) was aborted before execution started",
+ self.name,
+ task.uuid,
+ )
+ # Ensure status is ABORTED (not just ABORTING)
+ InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.ABORTED,
+ expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
+ set_ended_at=True,
+ ).run()
+ # Refresh to get updated task
+ refreshed = TaskDAO.find_one_or_none(uuid=task.uuid)
+ return refreshed if refreshed else task
+
+ # Atomic transition: PENDING → IN_PROGRESS (set started_at for duration
+ # tracking)
+ task_uuid = task.uuid # Cache UUID before any potential state changes
+ if not InternalStatusTransitionCommand(
+ task_uuid=task_uuid,
+ new_status=TaskStatus.IN_PROGRESS,
+ expected_status=TaskStatus.PENDING,
+ set_started_at=True,
+ ).run():
+ # Status wasn't PENDING - task may have been aborted concurrently
+ logger.warning(
+ "Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
+ "(may have been aborted concurrently)",
+ self.name,
+ task_uuid,
+ )
+ refreshed = TaskDAO.find_one_or_none(uuid=task_uuid)
+ return refreshed if refreshed else task
+
+ # Update cached status (no DB read needed - we just wrote IN_PROGRESS)
+ task.status = TaskStatus.IN_PROGRESS.value
+
+ # Build context with the updated task entity
+ ctx = TaskContext(task)
+
+ # Start timeout timer if configured
+ if options.timeout:
+ ctx.start_timeout_timer(options.timeout)
+ logger.debug(
+ "Started timeout timer for task %s: %d seconds",
+ task.uuid,
+ options.timeout,
+ )
+
+ # Track final task state for completion notification
+ final_task: Task | None = None
+
+ try:
+ # Execute with ambient context
+ with use_context(ctx):
+ self.func(*args, **kwargs)
+
+ # Determine terminal status based on abort detection
+ # Use atomic conditional updates to prevent overwriting concurrent abort
+ if ctx._abort_detected or ctx.timeout_triggered:
+ # Abort was detected - transition ABORTING → terminal
+ if ctx.timeout_triggered:
+ InternalStatusTransitionCommand(
+ task_uuid=task_uuid,
+ new_status=TaskStatus.TIMED_OUT,
+ expected_status=TaskStatus.ABORTING,
+ set_ended_at=True,
+ ).run()
+ logger.info(
+ "Task %s (uuid=%s) timed out and completed cleanup",
+ self.name,
+ task_uuid,
+ )
+ else:
+ InternalStatusTransitionCommand(
+ task_uuid=task_uuid,
+ new_status=TaskStatus.ABORTED,
+ expected_status=TaskStatus.ABORTING,
+ set_ended_at=True,
+ ).run()
+ logger.info(
+ "Task %s (uuid=%s) was aborted by user",
+ self.name,
+ task_uuid,
+ )
+ else:
+ # Normal completion - atomic IN_PROGRESS → SUCCESS
+ # This will fail (return False) if task was concurrently aborted
+ if InternalStatusTransitionCommand(
+ task_uuid=task_uuid,
+ new_status=TaskStatus.SUCCESS,
+ expected_status=TaskStatus.IN_PROGRESS,
+ set_ended_at=True,
+ ).run():
+ logger.debug(
+ "Synchronous execution of task %s (uuid=%s) "
+ "completed successfully",
+ self.name,
+ task_uuid,
+ )
+ else:
+ # Transition failed - task was likely aborted concurrently
+ logger.info(
+ "Task %s (uuid=%s) IN_PROGRESS → SUCCESS failed "
+ "(may have been aborted concurrently)",
+ self.name,
+ task_uuid,
+ )
+
+ # Refresh once at end to return current state
+ final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
+ return final_task if final_task else task
+
+ except Exception as ex:
+ # Atomic transition to FAILURE (only if still IN_PROGRESS)
+ InternalStatusTransitionCommand(
+ task_uuid=task_uuid,
+ new_status=TaskStatus.FAILURE,
+ expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
+ properties={"error_message": str(ex)},
+ set_ended_at=True,
+ ).run()
+
+ logger.error(
+ "Synchronous execution of task %s (uuid=%s) failed: %s",
+ self.name,
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+
+ # Refresh once at end to return current state
+ final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
+ return final_task if final_task else task
+
+ finally:
+ # Always clean up timer and handlers
+ ctx._run_cleanup()
+
+ # Publish completion notification for any waiters
+ # Use final_task if set by try/except, otherwise refresh (fallback)
+ if final_task is None:
+ final_task = TaskDAO.find_one_or_none(uuid=task_uuid)
+ if final_task and final_task.status in TERMINAL_STATES:
+ TaskManager.publish_completion(task_uuid, final_task.status)
+
+ def schedule(self, *args: P.args, **kwargs: P.kwargs) -> "Task":
+ """
+ Schedule this task for asynchronous execution.
+
+ The signature mirrors the original task function, with an additional
+ keyword-only 'options' parameter for execution metadata.
+
+ Args:
+ *args, **kwargs: Business arguments for the task function
+ options: Execution options
+
+ Returns:
+ Task model representing the scheduled task (PENDING status)
+
+ Raises:
+ GlobalTaskFrameworkDisabledError: If GTF feature flag is not enabled
+ ValueError: If task is SHARED scope but no task_key is provided
+
+ Usage:
+ # Auto-generated task_key (random UUID, no deduplication):
+ task = generate_thumbnail.schedule(chart_id)
+
+ # Custom task_key for task deduplication:
+ task = generate_thumbnail.schedule(
+ chart_id,
+ options=TaskOptions(task_key=f"thumb_{chart_id}")
+ )
+
+ # SHARED tasks require task_key:
+ task = shared_task.schedule(
+ data_id,
+ options=TaskOptions(task_key=f"shared_{data_id}")
+ )
+
+ Note: Unlike direct calls (__call__), this schedules async execution.
+ The function returns immediately with the Task model in PENDING status.
+ """
+ if not is_feature_enabled("GLOBAL_TASK_FRAMEWORK"):
+ raise GlobalTaskFrameworkDisabledError()
+
+ # Extract and merge options (decorator defaults + call-time overrides)
+ override_options = cast(TaskOptions | None, kwargs.pop("options", None))
+ options = self._merge_options(override_options)
+
+ # Validate task configuration
+ self._validate_task(options)
+
+ # Extract task_name and task_key from merged options, scope from decorator
+ task_name = options.task_name
+ task_key = options.task_key
+ scope = self.scope # Use scope from decorator
+
+ # Create task entry in metastore and schedule execution
+ return TaskManager.submit_task(
+ task_type=self.name,
+ task_name=task_name,
+ task_key=task_key,
+ scope=scope,
+ timeout=options.timeout,
+ args=args,
+ kwargs=kwargs,
+ )
diff --git a/superset/tasks/filters.py b/superset/tasks/filters.py
new file mode 100644
index 000000000000..862f19987247
--- /dev/null
+++ b/superset/tasks/filters.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Filters for Task model"""
+
+from typing import Any
+
+from sqlalchemy.orm.query import Query
+
+from superset.utils.core import get_user_id
+from superset.views.base import BaseFilter
+
+
+class TaskFilter(BaseFilter): # pylint: disable=too-few-public-methods
+ """
+ Filter for Task that shows tasks based on scope and user permissions.
+
+ Filtering rules:
+ - Admins: See all tasks (private, shared, system)
+ - Non-admins:
+ - Private tasks: Only their own tasks
+ - Shared tasks: Tasks they're subscribed to
+ - System tasks: None (admin-only)
+ """
+
+ def apply(self, query: Query, value: Any) -> Query:
+ """Apply the filter to the query."""
+ from flask import g, has_request_context
+ from sqlalchemy import or_
+
+ from superset import db, security_manager
+ from superset.models.task_subscribers import TaskSubscriber
+ from superset.models.tasks import Task
+
+ # If no request context or no user, return unfiltered query
+ # (this handles background tasks and system operations)
+ if not has_request_context() or not hasattr(g, "user"):
+ return query
+
+ # If user is admin, return unfiltered query
+ if security_manager.is_admin():
+ return query
+
+ # For non-admins, filter by scope and permissions
+ user_id = get_user_id()
+
+ # Use subquery for shared tasks to avoid join ambiguity
+ shared_task_ids_query = (
+ db.session.query(Task.id)
+ .join(TaskSubscriber, Task.id == TaskSubscriber.task_id)
+ .filter(
+ Task.scope == "shared",
+ TaskSubscriber.user_id == user_id,
+ )
+ )
+
+ # Build filter conditions:
+ # 1. Private tasks created by current user
+ # 2. Shared tasks where user is subscribed (via subquery)
+ # 3. System tasks are excluded (admin-only)
+ return query.filter(
+ or_(
+ # Own private tasks
+ (Task.scope == "private") & (Task.created_by_fk == user_id),
+ # Shared tasks where user is subscribed
+ Task.id.in_(shared_task_ids_query),
+ )
+ )
+
+
+class TaskSubscriberFilter(BaseFilter): # pylint: disable=too-few-public-methods
+ """
+ Filter tasks by subscriber user ID.
+
+ This filter allows finding tasks where a specific user is subscribed.
+ Used by the frontend for the subscribers filter dropdown.
+ """
+
+ def apply(self, query: Query, value: Any) -> Query:
+ """Apply the filter to the query."""
+ from superset import db
+ from superset.models.task_subscribers import TaskSubscriber
+ from superset.models.tasks import Task
+
+ if not value:
+ return query
+
+ # Handle both single ID and list of IDs
+ if isinstance(value, (list, tuple)):
+ user_ids = [int(v) for v in value]
+ else:
+ user_ids = [int(value)]
+
+ # Find tasks where any of these users are subscribers
+ subscribed_task_ids = db.session.query(TaskSubscriber.task_id).filter(
+ TaskSubscriber.user_id.in_(user_ids)
+ )
+
+ return query.filter(Task.id.in_(subscribed_task_ids))
diff --git a/superset/tasks/locks.py b/superset/tasks/locks.py
new file mode 100644
index 000000000000..f6af3df13c7f
--- /dev/null
+++ b/superset/tasks/locks.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Distributed locking utilities for the Global Task Framework (GTF).
+
+This module provides distributed locks for task operations to prevent race
+conditions during concurrent task creation, subscription, and cancellation.
+
+The lock key uses the task's dedup_key, ensuring all operations on the same
+logical task serialize correctly.
+
+When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX for
+efficient single-command locking. Otherwise falls back to database-backed
+locking via DistributedLock.
+"""
+
+from __future__ import annotations
+
+import logging
+from contextlib import contextmanager
+from typing import Iterator
+
+from superset.distributed_lock import DistributedLock
+
+logger = logging.getLogger(__name__)
+
+
+# Task operations use a shorter TTL than the global default since
+# they complete quickly (just DB operations, no external calls)
+TASK_LOCK_TTL_SECONDS = 10
+
+
+@contextmanager
+def task_lock(dedup_key: str) -> Iterator[None]:
+ """
+ Acquire a distributed lock for task operations.
+
+ Uses the task's dedup_key as the lock key. All operations on the same
+ logical task (create, subscribe, cancel) use the same lock, ensuring
+ mutual exclusion. This prevents race conditions such as:
+ - Two concurrent creates with the same key
+ - Subscribe racing with cancel
+ - Multiple concurrent cancel requests
+
+ When SIGNAL_CACHE_CONFIG is configured, uses Redis SET NX EX
+ for efficient single-command locking. Otherwise falls back to
+ database-backed DistributedLock.
+
+ :param dedup_key: Task deduplication key (from get_active_dedup_key)
+ :yields: Nothing; used as context manager
+ :raises AcquireDistributedLockFailedException: If lock is already held
+
+ Example:
+ dedup_key = get_active_dedup_key(TaskScope.SHARED, "report", "monthly")
+ with task_lock(dedup_key):
+ # Create, subscribe, or cancel task here
+ ...
+ """
+ logger.debug("Acquiring task lock for key: %s", dedup_key)
+
+ with DistributedLock(
+ namespace="gtf:task",
+ key=dedup_key,
+ ttl_seconds=TASK_LOCK_TTL_SECONDS,
+ ):
+ yield
+
+ logger.debug("Released task lock for key: %s", dedup_key)
diff --git a/superset/tasks/manager.py b/superset/tasks/manager.py
new file mode 100644
index 000000000000..f4595c51167c
--- /dev/null
+++ b/superset/tasks/manager.py
@@ -0,0 +1,764 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task manager for the Global Task Framework (GTF)"""
+
+from __future__ import annotations
+
+import logging
+import threading
+import time
+from typing import Any, Callable, TYPE_CHECKING
+from uuid import UUID
+
+import redis
+from superset_core.api.tasks import TaskProperties, TaskScope
+
+from superset.async_events.cache_backend import (
+ RedisCacheBackend,
+ RedisSentinelCacheBackend,
+)
+from superset.extensions import cache_manager
+from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
+from superset.tasks.utils import generate_random_task_key
+
+if TYPE_CHECKING:
+ from flask import Flask
+
+ from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+
+class AbortListener:
+ """
+ Handle for a background abort listener.
+
+ Returned by TaskManager.listen_for_abort() to allow stopping the listener.
+ """
+
+ def __init__(
+ self,
+ task_uuid: UUID,
+ thread: threading.Thread,
+ stop_event: threading.Event,
+ pubsub: redis.client.PubSub | None = None,
+ ) -> None:
+ self._task_uuid = task_uuid
+ self._thread = thread
+ self._stop_event = stop_event
+ self._pubsub = pubsub
+
+ def stop(self) -> None:
+ """Stop the abort listener."""
+ self._stop_event.set()
+
+ # Close pub/sub subscription if active
+ if self._pubsub is not None:
+ try:
+ self._pubsub.unsubscribe()
+ self._pubsub.close()
+ except Exception as ex:
+ logger.debug("Error closing pub/sub during stop: %s", ex)
+
+ # Wait for thread to finish (with timeout to avoid blocking indefinitely)
+ if self._thread.is_alive():
+ self._thread.join(timeout=2.0)
+
+ # Check if thread is still running after timeout
+ if self._thread.is_alive():
+ # Thread is a daemon, so it will be killed when process exits.
+ # Log warning but continue - cleanup will still proceed.
+ logger.warning(
+ "Abort listener thread for task %s did not terminate within "
+ "2 seconds. Thread will be terminated when process exits.",
+ self._task_uuid,
+ )
+ else:
+ logger.debug("Stopped abort listener for task %s", self._task_uuid)
+ else:
+ logger.debug("Stopped abort listener for task %s", self._task_uuid)
+
+
+class TaskManager:
+ """
+ Handles task creation, scheduling, and abort notifications.
+
+ The TaskManager is responsible for:
+ 1. Creating task entries in the metastore (Task model)
+ 2. Scheduling task execution via Celery
+ 3. Handling deduplication (returning existing active task if duplicate)
+ 4. Managing real-time abort notifications (optional)
+
+ Redis pub/sub is opt-in via SIGNAL_CACHE_CONFIG configuration. When not
+ configured, tasks use database polling for abort detection.
+ """
+
+ # Class-level state (initialized once via init_app)
+ _channel_prefix: str = "gtf:abort:"
+ _completion_channel_prefix: str = "gtf:complete:"
+ _initialized: bool = False
+
+ # Backward compatibility alias - prefer importing from superset.tasks.constants
+ TERMINAL_STATES = TERMINAL_STATES
+
+ @classmethod
+ def init_app(cls, app: Flask) -> None:
+ """
+ Initialize the TaskManager with Flask app config.
+
+ Redis connection is managed by CacheManager - this just reads channel prefixes.
+
+ :param app: Flask application instance
+ """
+ if cls._initialized:
+ return
+
+ cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:")
+ cls._completion_channel_prefix = app.config.get(
+ "TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:"
+ )
+
+ cls._initialized = True
+
+ @classmethod
+ def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
+ """
+ Get the signal cache backend.
+
+ :returns: The signal cache backend, or None if not configured
+ """
+ return cache_manager.signal_cache
+
+ @classmethod
+ def is_pubsub_available(cls) -> bool:
+ """
+ Check if Redis pub/sub backend is configured and available.
+
+ :returns: True if Redis is available for pub/sub, False otherwise
+ """
+ return cls._get_cache() is not None
+
+ @classmethod
+ def get_abort_channel(cls, task_uuid: UUID) -> str:
+ """
+ Get the abort channel name for a task.
+
+ :param task_uuid: UUID of the task
+ :returns: Channel name for the task's abort notifications
+ """
+ return f"{cls._channel_prefix}{task_uuid}"
+
+ @classmethod
+ def publish_abort(cls, task_uuid: UUID) -> bool:
+ """
+ Publish an abort message to the task's channel.
+
+ :param task_uuid: UUID of the task to abort
+ :returns: True if message was published, False if Redis unavailable
+ """
+ cache = cls._get_cache()
+ if not cache:
+ return False
+
+ try:
+ channel = cls.get_abort_channel(task_uuid)
+ subscriber_count = cache.publish(channel, "abort")
+ logger.debug(
+ "Published abort to channel %s (%d subscribers)",
+ channel,
+ subscriber_count,
+ )
+ return True
+ except redis.RedisError as ex:
+ logger.error("Failed to publish abort for task %s: %s", task_uuid, ex)
+ return False
+
+ @classmethod
+ def get_completion_channel(cls, task_uuid: UUID) -> str:
+ """
+ Get the completion channel name for a task.
+
+ :param task_uuid: UUID of the task
+ :returns: Channel name for the task's completion notifications
+ """
+ return f"{cls._completion_channel_prefix}{task_uuid}"
+
+ @classmethod
+ def publish_completion(cls, task_uuid: UUID, status: str) -> bool:
+ """
+ Publish a completion message to the task's channel.
+
+ Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT).
+ This notifies any waiters (e.g., sync callers waiting for an existing task).
+
+ :param task_uuid: UUID of the completed task
+ :param status: Final status of the task
+ :returns: True if message was published, False if Redis unavailable
+ """
+ cache = cls._get_cache()
+ if not cache:
+ return False
+
+ try:
+ channel = cls.get_completion_channel(task_uuid)
+ subscriber_count = cache.publish(channel, status)
+ logger.debug(
+ "Published completion to channel %s (status=%s, %d subscribers)",
+ channel,
+ status,
+ subscriber_count,
+ )
+ return True
+ except redis.RedisError as ex:
+ logger.error("Failed to publish completion for task %s: %s", task_uuid, ex)
+ return False
+
+ @classmethod
+ def wait_for_completion(
+ cls,
+ task_uuid: UUID,
+ timeout: float | None = None,
+ poll_interval: float = 1.0,
+ app: Any = None,
+ ) -> "Task":
+ """
+ Block until task reaches terminal state.
+
+ Uses Redis pub/sub if configured for low-latency, low-CPU waiting.
+ Uses database polling if Redis is not configured.
+
+ :param task_uuid: UUID of the task to wait for
+ :param timeout: Maximum time to wait in seconds (None = no limit)
+ :param poll_interval: Interval for database polling (seconds)
+ :param app: Flask app for database access
+ :returns: Task in terminal state
+ :raises TimeoutError: If timeout expires before task completes
+ :raises ValueError: If task not found
+ """
+ from superset.daos.tasks import TaskDAO
+
+ start_time = time.monotonic()
+
+ def time_remaining() -> float | None:
+ if timeout is None:
+ return None
+ elapsed = time.monotonic() - start_time
+ remaining = timeout - elapsed
+ return remaining if remaining > 0 else 0
+
+ def get_task() -> "Task | None":
+ if app:
+ with app.app_context():
+ return TaskDAO.find_one_or_none(uuid=task_uuid)
+ return TaskDAO.find_one_or_none(uuid=task_uuid)
+
+ # Check current state first
+ task = get_task()
+ if not task:
+ raise ValueError(f"Task {task_uuid} not found")
+
+ if task.status in cls.TERMINAL_STATES:
+ return task
+
+ logger.debug(
+ "Waiting for task %s to complete (current status=%s, timeout=%s)",
+ task_uuid,
+ task.status,
+ timeout,
+ )
+
+ # Use Redis pub/sub if configured
+ if (cache := cls._get_cache()) is not None:
+ task = cls._wait_via_pubsub(
+ task_uuid,
+ cache.pubsub(),
+ timeout,
+ poll_interval,
+ get_task,
+ time_remaining,
+ )
+ if task:
+ return task
+ # Should not reach here - _wait_via_pubsub returns task or raises
+ raise RuntimeError(f"Unexpected state waiting for task {task_uuid}")
+
+ # Use database polling when Redis is not configured
+ return cls._wait_via_polling(task_uuid, poll_interval, get_task, time_remaining)
+
+ @classmethod
+ def _wait_via_pubsub(
+ cls,
+ task_uuid: UUID,
+ pubsub: redis.client.PubSub,
+ timeout: float | None,
+ poll_interval: float,
+ get_task: Callable[[], "Task | None"],
+ time_remaining: Callable[[], float | None],
+ ) -> "Task | None":
+ """
+ Wait for task completion using Redis pub/sub.
+
+ :returns: Task when completed
+ :raises TimeoutError: If timeout expires
+ :raises redis.RedisError: If Redis connection fails
+ """
+ channel = cls.get_completion_channel(task_uuid)
+ pubsub.subscribe(channel)
+
+ try:
+ while True:
+ remaining = time_remaining()
+ if remaining is not None and remaining <= 0:
+ raise TimeoutError(
+ f"Timeout waiting for task {task_uuid} to complete"
+ )
+
+ # Wait for message with short timeout for responsive checking
+ wait_time = min(1.0, remaining) if remaining else 1.0
+ message = pubsub.get_message(
+ ignore_subscribe_messages=True,
+ timeout=wait_time,
+ )
+
+ if message and message.get("type") == "message":
+ # Completion received - fetch fresh task state
+ logger.debug(
+ "Received completion message for task %s: %s",
+ task_uuid,
+ message.get("data"),
+ )
+ task = get_task()
+ if task and task.status in cls.TERMINAL_STATES:
+ return task
+
+ # Also check database periodically in case we missed the message
+ # (e.g., task completed before we subscribed)
+ task = get_task()
+ if task and task.status in cls.TERMINAL_STATES:
+ logger.debug(
+ "Task %s completed (detected via db check): status=%s",
+ task_uuid,
+ task.status,
+ )
+ return task
+
+ finally:
+ pubsub.unsubscribe()
+ pubsub.close()
+
+ @classmethod
+ def _wait_via_polling(
+ cls,
+ task_uuid: UUID,
+ poll_interval: float,
+ get_task: Callable[[], "Task | None"],
+ time_remaining: Callable[[], float | None],
+ ) -> "Task":
+ """
+ Wait for task completion using database polling.
+
+ :returns: Task when completed
+ :raises TimeoutError: If timeout expires
+ :raises ValueError: If task not found
+ """
+ while True:
+ remaining = time_remaining()
+ if remaining is not None and remaining <= 0:
+ raise TimeoutError(f"Timeout waiting for task {task_uuid} to complete")
+
+ task = get_task()
+ if not task:
+ raise ValueError(f"Task {task_uuid} not found")
+
+ if task.status in cls.TERMINAL_STATES:
+ logger.debug(
+ "Task %s completed (detected via polling): status=%s",
+ task_uuid,
+ task.status,
+ )
+ return task
+
+ # Sleep with timeout awareness
+ sleep_time = min(poll_interval, remaining) if remaining else poll_interval
+ time.sleep(sleep_time)
+
+ @classmethod
+ def listen_for_abort(
+ cls,
+ task_uuid: UUID,
+ callback: Callable[[], None],
+ poll_interval: float,
+ app: Any = None,
+ ) -> AbortListener:
+ """
+ Start listening for abort notifications for a task.
+
+ Uses Redis pub/sub if configured, otherwise uses database polling.
+ The callback is invoked when an abort is detected.
+
+ :param task_uuid: UUID of the task to monitor (native UUID)
+ :param callback: Function to call when abort is detected
+ :param poll_interval: Interval for database polling (when Redis not configured)
+ :param app: Flask app for database access in background thread
+ :returns: AbortListener handle to stop listening
+ """
+ stop_event = threading.Event()
+ pubsub: redis.client.PubSub | None = None
+ uuid_str = str(task_uuid)
+
+ # Use Redis pub/sub if configured
+ if (cache := cls._get_cache()) is not None:
+ pubsub = cache.pubsub()
+ channel = cls.get_abort_channel(task_uuid)
+ pubsub.subscribe(channel)
+ logger.debug("Subscribed to abort channel: %s", channel)
+
+ # Start pub/sub listener thread
+ thread = threading.Thread(
+ target=cls._listen_pubsub,
+ args=(task_uuid, pubsub, callback, stop_event, app),
+ daemon=True,
+ name=f"abort-listener-{uuid_str[:8]}",
+ )
+ logger.debug("Started pub/sub abort listener for task %s", task_uuid)
+ else:
+ # Use polling when Redis is not configured
+ pubsub = None
+ thread = threading.Thread(
+ target=cls._poll_for_abort,
+ args=(task_uuid, callback, stop_event, poll_interval, app),
+ daemon=True,
+ name=f"abort-poller-{uuid_str[:8]}",
+ )
+ logger.debug(
+ "Started database abort polling for task %s (interval=%ss)",
+ task_uuid,
+ poll_interval,
+ )
+
+ thread.start()
+ return AbortListener(task_uuid, thread, stop_event, pubsub)
+
+ @staticmethod
+ def _invoke_callback_with_context(
+ callback: Callable[[], None],
+ app: Any,
+ ) -> None:
+ """
+ Invoke callback with Flask app context if provided.
+
+ :param callback: Function to invoke
+ :param app: Flask app for context, or None
+ """
+ if app:
+ with app.app_context():
+ callback()
+ else:
+ callback()
+
+ @classmethod
+ def _check_abort_status(cls, task_uuid: UUID) -> bool:
+ """
+ Check if task has been aborted via database query.
+
+ :param task_uuid: UUID of the task to check (native UUID)
+ :returns: True if task is in ABORTING or ABORTED state
+ """
+ from superset.daos.tasks import TaskDAO
+
+ task = TaskDAO.find_one_or_none(uuid=task_uuid)
+ return task is not None and task.status in ABORT_STATES
+
+ @classmethod
+ def _run_abort_listener_loop(
+ cls,
+ task_uuid: UUID,
+ callback: Callable[[], None],
+ stop_event: threading.Event,
+ interval: float,
+ app: Any,
+ check_fn: Callable[[], bool],
+ source: str,
+ ) -> None:
+ """
+ Common abort listener loop used by both pub/sub and polling modes.
+
+ :param task_uuid: UUID of the task to monitor (native UUID)
+ :param callback: Function to call when abort is detected
+ :param stop_event: Event to signal loop termination
+ :param interval: Wait interval between checks
+ :param app: Flask app for context
+ :param check_fn: Function that returns True if abort was detected
+ :param source: Source identifier for logging ("pub/sub" or "polling")
+ """
+ while not stop_event.is_set():
+ try:
+ if check_fn():
+ logger.info(
+ "Abort detected via %s for task %s",
+ source,
+ task_uuid,
+ )
+ cls._invoke_callback_with_context(callback, app)
+ break
+
+ # Wait for interval or until stop is requested
+ stop_event.wait(timeout=interval)
+
+ except (ValueError, OSError) as ex:
+ # ValueError/OSError with "I/O operation on closed file" or
+ # "Bad file descriptor" typically means the connection was closed
+ # during shutdown. Check if stop was requested.
+ if stop_event.is_set():
+ logger.debug(
+ "Abort %s for task %s stopped cleanly (connection closed)",
+ source,
+ task_uuid,
+ )
+ else:
+ logger.error(
+ "Error in abort %s for task %s: %s",
+ source,
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+ break
+
+ except Exception as ex:
+ # Check if stop was requested - if so, this may be expected
+ if stop_event.is_set():
+ logger.debug(
+ "Abort %s for task %s stopped with exception: %s",
+ source,
+ task_uuid,
+ ex,
+ )
+ else:
+ logger.error(
+ "Error in abort %s for task %s: %s",
+ source,
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+ break
+
+ @classmethod
+ def _listen_pubsub(
+ cls,
+ task_uuid: UUID,
+ pubsub: redis.client.PubSub,
+ callback: Callable[[], None],
+ stop_event: threading.Event,
+ app: Any,
+ ) -> None:
+ """Listen for abort via Redis pub/sub."""
+ # Track if abort was received to avoid double-callback
+ abort_received = False
+
+ def check_pubsub() -> bool:
+ nonlocal abort_received
+ message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
+ if message is not None and message.get("type") == "message":
+ abort_received = True
+ return True
+ return False
+
+ try:
+ cls._run_abort_listener_loop(
+ task_uuid=task_uuid,
+ callback=callback,
+ stop_event=stop_event,
+ interval=0, # pub/sub has its own timeout in get_message
+ app=app,
+ check_fn=check_pubsub,
+ source="pub/sub",
+ )
+
+ except redis.RedisError as ex:
+ # Check if we were asked to stop - if so, this is expected
+ if stop_event.is_set():
+ logger.debug(
+ "Abort listener for task %s stopped (Redis error: %s)",
+ task_uuid,
+ ex,
+ )
+ else:
+ # Log error but don't fall back - let the failure be visible
+ logger.error(
+ "Redis signal backend failed for task %s abort listener: %s. "
+ "Task may not receive abort signal.",
+ task_uuid,
+ ex,
+ )
+
+ except (ValueError, OSError) as ex:
+ # ValueError: "I/O operation on closed file" - expected when stop() closes
+ # OSError: Similar connection-closed errors
+ if stop_event.is_set():
+ # Clean shutdown, expected behavior
+ logger.debug(
+ "Abort listener for task %s stopped cleanly",
+ task_uuid,
+ )
+ else:
+ # Unexpected error while running
+ logger.error(
+ "Error in abort listener for task %s: %s",
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+
+ except Exception as ex:
+ # Only log as error if we weren't asked to stop
+ if stop_event.is_set():
+ logger.debug(
+ "Abort listener for task %s stopped with exception: %s",
+ task_uuid,
+ ex,
+ )
+ else:
+ logger.error(
+ "Error in abort listener for task %s: %s",
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+
+ finally:
+ # Clean up pub/sub subscription
+ try:
+ pubsub.unsubscribe()
+ pubsub.close()
+ except Exception as ex:
+ logger.debug("Error closing pub/sub during cleanup: %s", ex)
+
+ @classmethod
+ def _poll_for_abort(
+ cls,
+ task_uuid: UUID,
+ callback: Callable[[], None],
+ stop_event: threading.Event,
+ interval: float,
+ app: Any,
+ ) -> None:
+ """Background polling loop - used when Redis pub/sub is not configured."""
+
+ def check_database() -> bool:
+ # Need app context for database access
+ if app:
+ with app.app_context():
+ return cls._check_abort_status(task_uuid)
+ else:
+ return cls._check_abort_status(task_uuid)
+
+ cls._run_abort_listener_loop(
+ task_uuid=task_uuid,
+ callback=callback,
+ stop_event=stop_event,
+ interval=interval,
+ app=app,
+ check_fn=check_database,
+ source="polling",
+ )
+
+ @staticmethod
+ def submit_task(
+ task_type: str,
+ task_key: str | None,
+ task_name: str | None,
+ scope: TaskScope,
+ timeout: int | None,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ) -> "Task":
+ """
+ Create task entry and schedule for async execution.
+
+ Flow:
+ 1. Generate task_key if not provided (random UUID)
+ 2. Submit to SubmitTaskCommand which handles locking and create-vs-join
+ 3. Schedule Celery task ONLY for new tasks (not deduplicated ones)
+ 4. Return Task model to caller
+
+ The SubmitTaskCommand uses a distributed lock to prevent race conditions,
+ returning either a new task or an existing active task with the same key.
+
+ :param task_type: Task type identifier (e.g., "superset.generate_thumbnail")
+ :param task_key: Optional deduplication key (None for random UUID)
+ :param task_name: Human readable task name
+ :param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM)
+ :param timeout: Optional timeout in seconds
+ :param args: Positional arguments for the task function
+ :param kwargs: Keyword arguments for the task function
+ :returns: Task model representing the scheduled task
+ """
+ from superset.commands.tasks.submit import SubmitTaskCommand
+
+ if task_key is None:
+ task_key = generate_random_task_key()
+
+ # Build properties with execution_mode and timeout
+ properties: TaskProperties = {"execution_mode": "async"}
+ if timeout:
+ properties["timeout"] = timeout
+
+ # Create or join task entry in metastore
+ # SubmitTaskCommand handles locking and create-vs-join logic:
+ # - Acquires distributed lock on dedup_key
+ # - If active task exists: adds subscriber and returns existing task
+ # (is_new=False)
+ # - If no active task: creates new task (is_new=True)
+ task, is_new = SubmitTaskCommand(
+ {
+ "task_key": task_key,
+ "task_type": task_type,
+ "task_name": task_name,
+ "scope": scope.value,
+ "properties": properties,
+ }
+ ).run_with_info()
+
+ # Only schedule Celery task for NEW tasks, not deduplicated ones
+ # Deduplicated tasks are already pending or running
+ if is_new:
+ # Import here to avoid circular dependency
+ from superset.tasks.scheduler import execute_task
+
+ # Schedule Celery task for async execution
+ execute_task.delay(
+ task_uuid=str(task.uuid),
+ task_type=task_type,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ logger.debug(
+ "Scheduled task %s (uuid=%s) for async execution",
+ task_type,
+ task.uuid,
+ )
+ else:
+ logger.debug(
+ "Joined existing task %s (uuid=%s) - no new Celery task scheduled",
+ task_type,
+ task.uuid,
+ )
+
+ return task
diff --git a/superset/tasks/registry.py b/superset/tasks/registry.py
new file mode 100644
index 000000000000..f712aa7f3404
--- /dev/null
+++ b/superset/tasks/registry.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task registry for the Global Task Framework (GTF)"""
+
+import logging
+from typing import Any, Callable
+
+logger = logging.getLogger(__name__)
+
+
+class TaskRegistry:
+ """
+ Registry for task functions.
+
+ Stores task functions by name, allowing the Celery executor to look up
+ and execute registered tasks. This enables the decorator pattern where
+ functions are registered at module import time.
+ """
+
+ _tasks: dict[str, Callable[..., Any]] = {}
+
+ @classmethod
+ def register(cls, task_name: str, func: Callable[..., Any]) -> None:
+ """
+ Register a task function by name.
+
+ :param task_name: Unique task identifier (e.g., "superset.generate_thumbnail")
+ :param func: The task function to register
+ :raises ValueError: If task name is already registered
+ """
+ if task_name in cls._tasks:
+ existing_func = cls._tasks[task_name]
+ if existing_func is not func:
+ raise ValueError(
+ f"Task '{task_name}' is already registered with a different "
+ "function. "
+ f"Existing: {existing_func.__module__}.{existing_func.__name__}, "
+ f"New: {func.__module__}.{func.__name__}"
+ )
+ # Same function being registered again (e.g., module reload) - allow it
+ logger.debug("Task '%s' re-registered with same function", task_name)
+ return
+
+ cls._tasks[task_name] = func
+ logger.info(
+ "Registered async task: %s -> %s.%s",
+ task_name,
+ func.__module__,
+ func.__name__,
+ )
+
+ @classmethod
+ def get_executor(cls, task_name: str) -> Callable[..., Any]:
+ """
+ Get the executor function for a task.
+
+ :param task_name: Task identifier to look up
+ :returns: The registered task function
+ :raises KeyError: If task name is not registered
+ """
+ if task_name not in cls._tasks:
+ raise KeyError(
+ f"Task '{task_name}' is not registered. "
+ f"Available tasks: {', '.join(sorted(cls._tasks.keys()))}"
+ )
+ return cls._tasks[task_name]
+
+ @classmethod
+ def is_registered(cls, task_name: str) -> bool:
+ """
+ Check if a task is registered.
+
+ :param task_name: Task identifier to check
+ :returns: True if task is registered
+ """
+ return task_name in cls._tasks
+
+ @classmethod
+ def list_tasks(cls) -> list[str]:
+ """
+ Get list of all registered task names.
+
+ :returns: Sorted list of task names
+ """
+ return sorted(cls._tasks.keys())
+
+ @classmethod
+ def clear(cls) -> None:
+ """
+ Clear all registered tasks.
+
+ WARNING: This is primarily for testing purposes. In production,
+ tasks should remain registered for the lifetime of the process.
+ """
+ cls._tasks.clear()
+ logger.warning("Task registry cleared")
diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py
index 9f439166c26c..1462850755af 100644
--- a/superset/tasks/scheduler.py
+++ b/superset/tasks/scheduler.py
@@ -19,11 +19,13 @@
import logging
from datetime import datetime, timezone
from typing import Any
+from uuid import UUID
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.signals import task_failure
from flask import current_app
+from superset_core.api.tasks import TaskStatus
from superset import is_feature_enabled
from superset.commands.exceptions import CommandException
@@ -32,10 +34,17 @@
from superset.commands.report.execute import AsyncExecuteReportScheduleCommand
from superset.commands.report.log_prune import AsyncPruneReportScheduleLogCommand
from superset.commands.sql_lab.query import QueryPruneCommand
+from superset.commands.tasks.prune import TaskPruneCommand
from superset.daos.report import ReportScheduleDAO
+from superset.daos.tasks import TaskDAO
from superset.extensions import celery_app
from superset.stats_logger import BaseStatsLogger
+from superset.tasks.ambient_context import use_context
+from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES
+from superset.tasks.context import TaskContext
from superset.tasks.cron_util import cron_schedule_window
+from superset.tasks.manager import TaskManager
+from superset.tasks.registry import TaskRegistry
from superset.utils.core import LoggerLevel
from superset.utils.log import get_logger_from_status
@@ -199,3 +208,251 @@ def prune_logs(
LogPruneCommand(retention_period_days, max_rows_per_run).run()
except CommandException as ex:
logger.exception("An error occurred while pruning logs: %s", ex)
+
+
+@celery_app.task(name="prune_tasks", bind=True)
+def prune_tasks(
+ self: Task,
+ retention_period_days: int | None = None,
+ max_rows_per_run: int | None = None,
+ **kwargs: Any,
+) -> None:
+ stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
+ stats_logger.incr("prune_tasks")
+
+ # TODO: Deprecated: Remove support for passing retention period via options in 6.0
+ if retention_period_days is None:
+ retention_period_days = prune_tasks.request.properties.get(
+ "retention_period_days"
+ )
+ logger.warning(
+ "Your `prune_tasks` beat schedule uses `options` to pass the "
+ "retention period, please use `kwargs` instead."
+ )
+
+ try:
+ TaskPruneCommand(retention_period_days, max_rows_per_run).run()
+ except CommandException as ex:
+ logger.exception("An error occurred while pruning async tasks: %s", ex)
+
+
+@celery_app.task(name="tasks.execute", bind=True)
+def execute_task( # noqa: C901
+ self: Any, # Celery task instance
+ task_uuid: str,
+ task_type: str,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+) -> dict[str, Any]:
+ """
+ Generic task executor for GTF tasks.
+
+ This executor:
+ 1. Checks if task was aborted before execution starts
+ 2. Fetches task from metastore
+ 3. Builds context (task + user) and sets ambient context via contextvars
+ 4. Executes the task function (which accesses context via get_context())
+ 5. Updates task status throughout lifecycle using atomic conditional updates
+ 6. Runs cleanup handlers on task end (success/failure/abortion)
+ 7. Resets context after execution
+
+ Uses atomic conditional status updates to prevent race conditions with
+ concurrent abort operations.
+
+ :param task_uuid: UUID of the task to execute
+ :param task_type: Type of the task (for registry lookup)
+ :param args: Positional arguments for the task function
+ :param kwargs: Keyword arguments for the task function
+ :returns: Dict with status and task_uuid
+ """
+ from superset.commands.tasks.internal_update import InternalStatusTransitionCommand
+
+ # Convert string UUID to native UUID (Celery deserializes as string)
+ native_uuid = UUID(task_uuid)
+
+ task = TaskDAO.find_one_or_none(uuid=native_uuid)
+ if not task:
+ logger.error("Task %s not found in metastore", task_uuid)
+ return {"status": "error", "message": "Task not found"}
+
+ # AUTOMATIC PRE-EXECUTION CHECK: Don't execute if already aborted/aborting
+ if task.status in ABORT_STATES:
+ logger.info(
+ "Task %s (uuid=%s) was aborted before execution started",
+ task_type,
+ task_uuid,
+ )
+ # Atomic transition to ABORTED (if not already)
+ InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.ABORTED,
+ expected_status=[TaskStatus.PENDING, TaskStatus.ABORTING],
+ set_ended_at=True,
+ ).run()
+ return {"status": TaskStatus.ABORTED.value, "task_uuid": task_uuid}
+
+ # Atomic transition: PENDING → IN_PROGRESS (set started_at for duration tracking)
+ if not InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.IN_PROGRESS,
+ expected_status=TaskStatus.PENDING,
+ set_started_at=True,
+ ).run():
+ # Status wasn't PENDING - task may have been aborted concurrently
+ logger.warning(
+ "Task %s (uuid=%s) failed PENDING → IN_PROGRESS transition "
+ "(may have been aborted concurrently)",
+ task_type,
+ task_uuid,
+ )
+ refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
+ return {
+ "status": refreshed.status if refreshed else "unknown",
+ "task_uuid": task_uuid,
+ }
+
+ # Update cached status (no DB read needed - we just wrote IN_PROGRESS)
+ task.status = TaskStatus.IN_PROGRESS.value
+
+ # Build context from task (includes user who created the task)
+ ctx = TaskContext(task)
+
+ # Start timeout timer if configured (timer starts from execution time)
+ if timeout := task.properties_dict.get("timeout"):
+ ctx.start_timeout_timer(timeout)
+ logger.debug(
+ "Started timeout timer for task %s: %d seconds",
+ task_uuid,
+ timeout,
+ )
+
+ try:
+ # Get registered executor function
+ executor_fn = TaskRegistry.get_executor(task_type)
+
+ logger.info(
+ "Executing task %s (uuid=%s) with function %s.%s",
+ task_type,
+ task_uuid,
+ executor_fn.__module__,
+ executor_fn.__name__,
+ )
+
+ # Execute with ambient context (no ctx parameter!)
+ with use_context(ctx):
+ executor_fn(*args, **kwargs)
+
+ # Mark execution as completed to prevent late abort handlers
+ ctx.mark_execution_completed()
+
+ # Determine terminal status based on abort detection
+ # Use atomic conditional updates to prevent overwriting concurrent abort
+ if ctx._abort_detected or ctx.timeout_triggered:
+ # Abort was detected - will be handled in finally block
+ pass
+ else:
+ # Normal completion - also allow ABORTING → SUCCESS for late abort
+ # (task finished before abort was detected)
+ if InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.SUCCESS,
+ expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
+ set_ended_at=True,
+ ).run():
+ # Emit stats metric for success
+ stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"]
+ stats_logger.incr("gtf.task.success")
+ logger.info(
+ "Task %s (uuid=%s) completed successfully", task_type, task_uuid
+ )
+ else:
+ # Transition failed - task was likely already in a terminal state
+ logger.info(
+ "Task %s (uuid=%s) completion transition failed "
+ "(task may already be in terminal state)",
+ task_type,
+ task_uuid,
+ )
+
+ except Exception as ex:
+ # Mark execution as completed to prevent late abort handlers
+ ctx.mark_execution_completed()
+
+ # Atomic transition to FAILURE (only if still IN_PROGRESS or ABORTING)
+ InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.FAILURE,
+ expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
+ properties={"error_message": str(ex)},
+ set_ended_at=True,
+ ).run()
+
+ logger.error(
+ "Task %s (uuid=%s) failed with error: %s",
+ task_type,
+ task_uuid,
+ str(ex),
+ exc_info=True,
+ )
+
+ # Emit stats metric for failure
+ stats_logger = current_app.config["STATS_LOGGER"]
+ stats_logger.incr("gtf.task.failure")
+
+ finally:
+ # ALWAYS run cleanup handlers (also stops timeout timer)
+ ctx._run_cleanup()
+
+ # Handle abort/timeout terminal transitions
+ # Use atomic updates to safely transition ABORTING → terminal state
+ if ctx._abort_detected or ctx.timeout_triggered:
+ if ctx.abort_handlers_completed:
+ # All handlers succeeded - determine terminal state based on cause
+ if ctx.timeout_triggered:
+ InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.TIMED_OUT,
+ expected_status=TaskStatus.ABORTING,
+ set_ended_at=True,
+ ).run()
+ logger.info(
+ "Task %s (uuid=%s) timed out and completed cleanup",
+ task_type,
+ task_uuid,
+ )
+ else:
+ InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.ABORTED,
+ expected_status=TaskStatus.ABORTING,
+ set_ended_at=True,
+ ).run()
+ logger.info(
+ "Task %s (uuid=%s) was aborted by user",
+ task_type,
+ task_uuid,
+ )
+ else:
+ # Handlers didn't complete successfully - mark as FAILURE
+ InternalStatusTransitionCommand(
+ task_uuid=native_uuid,
+ new_status=TaskStatus.FAILURE,
+ expected_status=TaskStatus.ABORTING,
+ properties={"error_message": "Abort handlers did not complete"},
+ set_ended_at=True,
+ ).run()
+ logger.warning(
+ "Task %s (uuid=%s) stuck in ABORTING - marking as FAILURE",
+ task_type,
+ task_uuid,
+ )
+
+ # Refresh to get final status for return value and completion notification
+ refreshed = TaskDAO.find_one_or_none(uuid=native_uuid)
+ final_status = refreshed.status if refreshed else "unknown"
+
+ # Publish completion notification for any waiters (e.g., sync callers)
+ if final_status in TERMINAL_STATES:
+ TaskManager.publish_completion(native_uuid, final_status)
+
+ return {"status": final_status, "task_uuid": task_uuid}
diff --git a/superset/tasks/schemas.py b/superset/tasks/schemas.py
new file mode 100644
index 000000000000..9fe0b31ec7b4
--- /dev/null
+++ b/superset/tasks/schemas.py
@@ -0,0 +1,200 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task API schemas"""
+
+from marshmallow import fields, Schema
+from marshmallow.fields import Method
+
+# RISON/JSON schemas for query parameters
+get_delete_ids_schema = {"type": "array", "items": {"type": "string"}}
+
+# Field descriptions
+uuid_description = "The unique identifier (UUID) of the task"
+task_key_description = "The task identifier used for deduplication"
+task_type_description = (
+ "The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
+)
+task_name_description = "Human-readable name for the task"
+status_description = "Current status of the task"
+created_on_description = "Timestamp when the task was created"
+changed_on_description = "Timestamp when the task was last updated"
+started_at_description = "Timestamp when the task started execution"
+ended_at_description = "Timestamp when the task completed or failed"
+created_by_description = "User who created the task"
+user_id_description = "ID of the user context for task execution"
+payload_description = "Task-specific data in JSON format"
+properties_description = (
+ "Runtime state and execution config. Contains: is_abortable, progress_percent, "
+ "progress_current, progress_total, error_message, exception_type, stack_trace, "
+ "timeout"
+)
+duration_seconds_description = (
+ "Duration in seconds - for finished tasks: execution time, "
+ "for running tasks: time since start, for pending: queue time"
+)
+scope_description = (
+ "Task scope: 'private' (user-specific), 'shared' (multi-user), "
+ "or 'system' (admin-only)"
+)
+subscriber_count_description = (
+ "Number of users subscribed to this task (for shared tasks)"
+)
+subscribers_description = "List of users subscribed to this task (for shared tasks)"
+
+
+class UserSchema(Schema):
+ """Schema for user information"""
+
+ id = fields.Int()
+ first_name = fields.String()
+ last_name = fields.String()
+
+
+class TaskResponseSchema(Schema):
+ """
+ Schema for task response.
+
+ Used for both list and detail endpoints.
+ """
+
+ id = fields.Int(metadata={"description": "Internal task ID"})
+ uuid = fields.UUID(metadata={"description": uuid_description})
+ task_key = fields.String(metadata={"description": task_key_description})
+ task_type = fields.String(metadata={"description": task_type_description})
+ task_name = fields.String(
+ metadata={"description": task_name_description}, allow_none=True
+ )
+ status = fields.String(metadata={"description": status_description})
+ created_on = fields.DateTime(metadata={"description": created_on_description})
+ created_on_delta_humanized = Method(
+ "get_created_on_delta_humanized",
+ metadata={"description": "Humanized time since creation"},
+ )
+ changed_on = fields.DateTime(metadata={"description": changed_on_description})
+ changed_by = fields.Nested(UserSchema, allow_none=True)
+ started_at = fields.DateTime(
+ metadata={"description": started_at_description}, allow_none=True
+ )
+ ended_at = fields.DateTime(
+ metadata={"description": ended_at_description}, allow_none=True
+ )
+ created_by = fields.Nested(UserSchema, allow_none=True)
+ user_id = fields.Int(metadata={"description": user_id_description}, allow_none=True)
+ payload = Method("get_payload_dict", metadata={"description": payload_description})
+ properties = Method(
+ "get_properties", metadata={"description": properties_description}
+ )
+ duration_seconds = Method(
+ "get_duration",
+ metadata={"description": duration_seconds_description},
+ )
+ scope = fields.String(metadata={"description": scope_description})
+ subscriber_count = Method(
+ "get_subscriber_count", metadata={"description": subscriber_count_description}
+ )
+ subscribers = Method(
+ "get_subscribers", metadata={"description": subscribers_description}
+ )
+
+ def get_payload_dict(self, obj: object) -> dict[str, object] | None:
+ """Get payload as dictionary"""
+ return obj.payload_dict # type: ignore[attr-defined]
+
+ def get_properties(self, obj: object) -> dict[str, object]:
+ """Get properties dict, filtering stack_trace if SHOW_STACKTRACE is disabled."""
+ from flask import current_app
+
+ properties = dict(obj.properties_dict) # type: ignore[attr-defined]
+
+ # Remove stack_trace unless SHOW_STACKTRACE is enabled
+ if not current_app.config.get("SHOW_STACKTRACE", False):
+ properties.pop("stack_trace", None)
+
+ return properties
+
+ def get_duration(self, obj: object) -> float | None:
+ """Get duration in seconds"""
+ return obj.duration_seconds # type: ignore[attr-defined]
+
+ def get_created_on_delta_humanized(self, obj: object) -> str:
+ """Get humanized time since creation"""
+ return obj.created_on_delta_humanized() # type: ignore[attr-defined]
+
+ def get_subscriber_count(self, obj: object) -> int:
+ """Get number of subscribers"""
+ return obj.subscriber_count # type: ignore[attr-defined]
+
+ def get_subscribers(self, obj: object) -> list[dict[str, object]]:
+ """Get list of subscribers with user info"""
+ subscribers = []
+ for sub in obj.subscribers: # type: ignore[attr-defined]
+ subscribers.append(
+ {
+ "user_id": sub.user_id,
+ "first_name": sub.user.first_name if sub.user else None,
+ "last_name": sub.user.last_name if sub.user else None,
+ "subscribed_at": sub.subscribed_at.isoformat()
+ if sub.subscribed_at
+ else None,
+ }
+ )
+ return subscribers
+
+
+class TaskStatusResponseSchema(Schema):
+ """Schema for task status response (lightweight for polling)"""
+
+ status = fields.String(metadata={"description": status_description})
+
+
+class TaskCancelRequestSchema(Schema):
+ """Schema for task cancellation request"""
+
+ force = fields.Boolean(
+ load_default=False,
+ metadata={
+ "description": "Force cancel the task for all subscribers (admin only). "
+ "Only applicable for shared tasks with multiple subscribers."
+ },
+ )
+
+
+class TaskCancelResponseSchema(Schema):
+ """Schema for task cancellation response"""
+
+ message = fields.String(metadata={"description": "Success or status message"})
+ action = fields.String(
+ metadata={
+ "description": "The action taken: 'aborted' (task terminated) or "
+ "'unsubscribed' (user removed from shared task)"
+ }
+ )
+ task = fields.Nested(TaskResponseSchema, allow_none=True)
+
+
+openapi_spec_methods_override = {
+ "get": {"get": {"summary": "Get a task detail"}},
+ "get_list": {
+ "get": {
+ "summary": "Get a list of tasks",
+ "description": "Gets a list of tasks for the current user. "
+ "Use Rison or JSON query parameters for filtering, sorting, "
+ "pagination and for selecting specific columns and metadata.",
+ }
+ },
+ "info": {"get": {"summary": "Get metadata information about this API resource"}},
+}
diff --git a/superset/tasks/types.py b/superset/tasks/types.py
index 8f2f76b40528..c6fd1e438d11 100644
--- a/superset/tasks/types.py
+++ b/superset/tasks/types.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from typing import NamedTuple
from superset.utils.backports import StrEnum
diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py
index 845ea2b5fc90..5316e7334cf1 100644
--- a/superset/tasks/utils.py
+++ b/superset/tasks/utils.py
@@ -18,16 +18,25 @@
from __future__ import annotations
import logging
+import traceback
from http.client import HTTPResponse
-from typing import Optional, TYPE_CHECKING
+from typing import cast, TYPE_CHECKING
from urllib import request
+from uuid import UUID, uuid4
from celery.utils.log import get_task_logger
from flask import g
+from superset_core.api.tasks import TaskProperties, TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
-from superset.tasks.types import ChosenExecutor, Executor, ExecutorType, FixedExecutor
+from superset.tasks.types import (
+ ChosenExecutor,
+ Executor,
+ ExecutorType,
+ FixedExecutor,
+)
from superset.utils import json
+from superset.utils.hashing import hash_from_str
from superset.utils.urls import get_url_path
if TYPE_CHECKING:
@@ -123,7 +132,7 @@ def fetch_csrf_token(
response: HTTPResponse
with request.urlopen(req, timeout=600) as response: # noqa: S310
body = response.read().decode("utf-8")
- session_cookie: Optional[str] = None
+ session_cookie: str | None = None
cookie_headers = response.headers.get_all("set-cookie")
if cookie_headers:
for cookie in cookie_headers:
@@ -142,3 +151,164 @@ def fetch_csrf_token(
logger.error("Error fetching CSRF token, status code: %s", response.status)
return {}
+
+
+def generate_random_task_key() -> str:
+ """
+ Generate a random task key.
+
+ This is the default behavior - each task submission gets a unique UUID
+ unless an explicit task_key is provided in TaskOptions.
+
+ :returns: A random UUID string
+ """
+ return str(uuid4())
+
+
+def get_active_dedup_key(
+ scope: TaskScope | str,
+ task_type: str,
+ task_key: str,
+ user_id: int | None = None,
+) -> str:
+ """
+ Build a deduplication key for active tasks.
+
+ The dedup_key enforces uniqueness at the database level via a unique index.
+ Active tasks use a composite key based on scope, which is then hashed using
+ the configured HASH_ALGORITHM to produce a fixed-length key.
+
+ The composite key format before hashing is:
+ - Private: private|task_type|task_key|user_id
+ - Shared: shared|task_type|task_key
+ - System: system|task_type|task_key
+
+ The final key is a hash digest (64 chars for sha256, 32 chars for md5).
+
+ :param scope: Task scope (PRIVATE/SHARED/SYSTEM) as TaskScope enum or string
+ :param task_type: Type of task (e.g., 'sql_execution')
+ :param task_key: Task identifier for deduplication
+ :param user_id: User ID (required for private tasks)
+ :returns: Hashed deduplication key string
+ :raises ValueError: If user_id is missing for private scope
+ """
+ # Convert string to TaskScope if needed
+ if isinstance(scope, str):
+ scope = TaskScope(scope)
+
+ # Build composite key
+ match scope:
+ case TaskScope.PRIVATE:
+ if user_id is None:
+ raise ValueError("user_id required for private tasks")
+ composite_key = f"{scope.value}|{task_type}|{task_key}|{user_id}"
+ case TaskScope.SHARED:
+ composite_key = f"{scope.value}|{task_type}|{task_key}"
+ case TaskScope.SYSTEM:
+ composite_key = f"{scope.value}|{task_type}|{task_key}"
+ case _:
+ raise ValueError(f"Invalid scope: {scope}")
+
+ # Hash the composite key to produce a fixed-length dedup_key
+ # Truncate to 64 chars max to fit the database column in case
+ # a hash algo is used that generates hashes that exceed 64 chars
+ return hash_from_str(composite_key)[:64]
+
+
+def get_finished_dedup_key(task_uuid: UUID) -> str:
+ """
+ Build a deduplication key for finished tasks.
+
+ When a task completes (success, failure, or abort), its dedup_key is
+ changed to its UUID. This frees up the slot so new tasks with the same
+ parameters can be created.
+
+ :param task_uuid: Task UUID (native UUID type)
+ :returns: The task UUID string as the dedup key
+
+ Example:
+ >>> from uuid import UUID
+ >>> get_finished_dedup_key(UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890"))
+ 'a1b2c3d4-e5f6-7890-abcd-ef1234567890'
+ """
+ return str(task_uuid)
+
+
+# -----------------------------------------------------------------------------
+# TaskProperties helper functions
+# -----------------------------------------------------------------------------
+
+
+def progress_update(progress: float | int | tuple[int, int]) -> TaskProperties:
+ """
+ Create a properties update dict for progress values.
+
+ :param progress: One of:
+ - float (0.0-1.0): Percentage only
+ - int: Count only (total unknown)
+ - tuple[int, int]: (current, total) with auto-computed percentage
+ :returns: TaskProperties dict with appropriate progress fields set
+
+ Example:
+ task.update_properties(progress_update((50, 100)))
+ """
+ if isinstance(progress, float):
+ return {"progress_percent": progress}
+ if isinstance(progress, int):
+ return {"progress_current": progress}
+ # tuple
+ current, total = progress
+ result: TaskProperties = {
+ "progress_current": current,
+ "progress_total": total,
+ }
+ if total > 0:
+ result["progress_percent"] = current / total
+ return result
+
+
+def error_update(exception: BaseException) -> TaskProperties:
+ """
+ Create a properties update dict from an exception.
+
+ :param exception: The exception that caused the failure
+ :returns: TaskProperties dict with error fields populated
+ """
+ return {
+ "error_message": str(exception),
+ "exception_type": type(exception).__name__,
+ "stack_trace": traceback.format_exc(),
+ }
+
+
+def parse_properties(json_str: str | None) -> TaskProperties:
+ """
+ Parse JSON string into TaskProperties dict.
+
+ Returns empty dict on parse errors. Unknown keys are preserved
+ for forward compatibility (allows adding new properties without
+ breaking existing code).
+
+ :param json_str: JSON string or None
+ :returns: TaskProperties dict (sparse - only contains keys that were set)
+ """
+ if not json_str:
+ return {}
+
+ try:
+ raw = json.loads(json_str)
+ if isinstance(raw, dict):
+ return cast(TaskProperties, raw)
+ return {}
+ except (json.JSONDecodeError, TypeError):
+ return {}
+
+
+def serialize_properties(props: TaskProperties) -> str:
+ """
+ Serialize TaskProperties to JSON string.
+
+ :param props: TaskProperties dict
+ :returns: JSON string
+ """
+ return json.dumps(props)
diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py
index 0804e0d4b5d3..48ff0e11cd0a 100644
--- a/superset/utils/cache_manager.py
+++ b/superset/utils/cache_manager.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import hashlib
import logging
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from flask import current_app, Flask
from flask_caching import Cache
@@ -24,6 +26,12 @@
from superset.utils.core import DatasourceType
+if TYPE_CHECKING:
+ from superset.async_events.cache_backend import (
+ RedisCacheBackend,
+ RedisSentinelCacheBackend,
+ )
+
logger = logging.getLogger(__name__)
CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache"
@@ -185,6 +193,7 @@ def __init__(self) -> None:
self._thumbnail_cache = SupersetCache()
self._filter_state_cache = SupersetCache()
self._explore_form_data_cache = ExploreFormDataCache()
+ self._signal_cache: RedisCacheBackend | RedisSentinelCacheBackend | None = None
@staticmethod
def _init_cache(
@@ -226,6 +235,30 @@ def init_app(self, app: Flask) -> None:
"EXPLORE_FORM_DATA_CACHE_CONFIG",
required=True,
)
+ self._init_signal_cache(app)
+
+ def _init_signal_cache(self, app: Flask) -> None:
+ """Initialize the signal cache for pub/sub and distributed locks."""
+ from superset.async_events.cache_backend import (
+ RedisCacheBackend,
+ RedisSentinelCacheBackend,
+ )
+
+ config = app.config.get("SIGNAL_CACHE_CONFIG")
+ if not config:
+ return
+
+ cache_type = config.get("CACHE_TYPE")
+ if cache_type == "RedisCache":
+ self._signal_cache = RedisCacheBackend.from_config(config)
+ elif cache_type == "RedisSentinelCache":
+ self._signal_cache = RedisSentinelCacheBackend.from_config(config)
+ else:
+ logger.warning(
+ "Unsupported CACHE_TYPE for SIGNAL_CACHE_CONFIG: %s. "
+ "Use 'RedisCache' or 'RedisSentinelCache'.",
+ cache_type,
+ )
@property
def data_cache(self) -> Cache:
@@ -246,3 +279,23 @@ def filter_state_cache(self) -> Cache:
@property
def explore_form_data_cache(self) -> Cache:
return self._explore_form_data_cache
+
+ @property
+ def signal_cache(
+ self,
+ ) -> RedisCacheBackend | RedisSentinelCacheBackend | None:
+ """
+ Return the signal cache backend.
+
+ Used for signaling features that require Redis-specific primitives:
+ - Pub/Sub messaging for real-time abort/completion notifications
+ - SET NX EX for atomic distributed lock acquisition
+
+ The backend provides:
+ - `._cache`: Raw Redis client
+ - `.key_prefix`: Configured key prefix (from CACHE_KEY_PREFIX)
+ - `.default_timeout`: Default timeout in seconds (from CACHE_DEFAULT_TIMEOUT)
+
+ Returns None if SIGNAL_CACHE_CONFIG is not configured.
+ """
+ return self._signal_cache
diff --git a/superset/utils/log.py b/superset/utils/log.py
index 51e91e1a37fb..6685010b4d21 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -24,7 +24,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta
-from typing import Any, Callable, cast, Literal, TYPE_CHECKING
+from typing import Any, Callable, cast, Literal
from flask import g, has_request_context, request
from flask_appbuilder.const import API_URI_RIS_KEY
@@ -34,9 +34,6 @@
from superset.utils import json
from superset.utils.core import get_user_id, LoggerLevel, to_int
-if TYPE_CHECKING:
- pass
-
logger = logging.getLogger(__name__)
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 9a24f0c09581..57cc0a25ce9f 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -31,8 +31,8 @@
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from superset import db
-from superset.distributed_lock import KeyValueDistributedLock
-from superset.exceptions import CreateKeyValueDistributedLockFailedException
+from superset.distributed_lock import DistributedLock
+from superset.exceptions import AcquireDistributedLockFailedException
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
@@ -77,7 +77,7 @@ def generate_code_challenge(code_verifier: str) -> str:
@backoff.on_exception(
backoff.expo,
- CreateKeyValueDistributedLockFailedException,
+ AcquireDistributedLockFailedException,
factor=10,
base=2,
max_tries=5,
@@ -128,8 +128,10 @@ def refresh_oauth2_token(
db_engine_spec: type[BaseEngineSpec],
token: DatabaseUserOAuth2Tokens,
) -> str | None:
- with KeyValueDistributedLock(
+ # Use longer TTL for OAuth2 token refresh (may involve network calls)
+ with DistributedLock(
namespace="refresh_oauth2_token",
+ ttl_seconds=30,
user_id=user_id,
database_id=database_id,
):
diff --git a/superset/commands/distributed_lock/get.py b/superset/views/tasks.py
similarity index 51%
rename from superset/commands/distributed_lock/get.py
rename to superset/views/tasks.py
index 30115f542260..57f5bb7194ac 100644
--- a/superset/commands/distributed_lock/get.py
+++ b/superset/views/tasks.py
@@ -14,32 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from flask_appbuilder import expose, has_access
-from __future__ import annotations
+from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP
+from superset.superset_typing import FlaskResponse
+from superset.views.base import BaseSupersetView
-import logging
-from typing import cast
-from flask import current_app as app
+class TaskModelView(BaseSupersetView):
+ route_base = "/tasks"
+ class_permission_name = "Task"
+ method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
-from superset.commands.distributed_lock.base import BaseDistributedLockCommand
-from superset.daos.key_value import KeyValueDAO
-from superset.distributed_lock.types import LockValue
-
-logger = logging.getLogger(__name__)
-stats_logger = app.config["STATS_LOGGER"]
-
-
-class GetDistributedLock(BaseDistributedLockCommand):
- def validate(self) -> None:
- pass
-
- def run(self) -> LockValue | None:
- entry = KeyValueDAO.get_entry(
- resource=self.resource,
- key=self.key,
- )
- if not entry or entry.is_expired():
- return None
-
- return cast(LockValue, self.codec.decode(entry.value))
+ @expose("/list/")
+ @has_access
+ def list(self) -> FlaskResponse:
+ return super().render_app_template()
diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py
index c1992af8f0b3..c16591ed375a 100644
--- a/tests/integration_tests/superset_test_config.py
+++ b/tests/integration_tests/superset_test_config.py
@@ -73,6 +73,7 @@
"AVOID_COLORS_COLLISION": True,
"DRILL_TO_DETAIL": True,
"DRILL_BY": True,
+ "GLOBAL_TASK_FRAMEWORK": True,
}
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
diff --git a/tests/integration_tests/tasks/api_tests.py b/tests/integration_tests/tasks/api_tests.py
new file mode 100644
index 000000000000..37e96a60bc30
--- /dev/null
+++ b/tests/integration_tests/tasks/api_tests.py
@@ -0,0 +1,538 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration tests for Task REST API"""
+
+from contextlib import contextmanager
+from typing import Generator
+
+import prison
+from superset_core.api.tasks import TaskStatus
+
+from superset import db
+from superset.models.tasks import Task
+from superset.utils import json
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.constants import (
+ ADMIN_USERNAME,
+ GAMMA_USERNAME,
+)
+
+
+class TestTaskApi(SupersetTestCase):
+ """Tests for Task REST API"""
+
+ TASK_API_BASE = "api/v1/task"
+
+ @contextmanager
+ def _create_tasks(self) -> Generator[list[Task], None, None]:
+ """
+ Context manager to create test tasks with guaranteed cleanup.
+
+ Uses TaskDAO to create tasks, testing the actual production code path.
+
+ Usage:
+ with self._create_tasks() as tasks:
+ # Use tasks in test
+ # Cleanup happens automatically even if test fails
+ """
+ from superset_core.api.tasks import TaskScope
+
+ from superset.daos.tasks import TaskDAO
+
+ admin = self.get_user("admin")
+ gamma = self.get_user("gamma")
+
+ tasks = []
+
+ try:
+ # Create tasks with different statuses using TaskDAO
+ for i in range(5):
+ task_key = f"test_task_{i}"
+
+ # Create task using DAO (this tests the dedup_key creation logic)
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=task_key,
+ task_name=f"Test Task {i}",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ payload={"test": "data"},
+ )
+
+ # Set created_by for test purposes (DAO uses Flask-AppBuilder context)
+ task.created_by = admin
+
+ # Alternate between pending and finished tasks
+ if i % 2 != 0:
+ # Simulate realistic task lifecycle: PENDING → IN_PROGRESS → SUCCESS
+ # This sets both started_at (on IN_PROGRESS) and ended_at (on
+ # SUCCESS) so duration_seconds returns a valid value
+ task.set_status(TaskStatus.IN_PROGRESS)
+ task.set_status(TaskStatus.SUCCESS)
+
+ db.session.commit()
+ tasks.append(task)
+
+ # Create pending task for gamma user (use PENDING so it can be aborted)
+ gamma_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="gamma_task",
+ task_name="Gamma Task",
+ scope=TaskScope.PRIVATE,
+ user_id=gamma.id,
+ payload={"user": "gamma"},
+ )
+ # Set created_by for test purposes
+ gamma_task.created_by = gamma
+ db.session.commit()
+ tasks.append(gamma_task)
+
+ yield tasks
+ finally:
+ # Cleanup happens here regardless of test success/failure
+ for task in tasks:
+ try:
+ db.session.delete(task)
+ except Exception: # noqa: S110
+ # Task may already be deleted or session may be in bad state
+ pass
+ try:
+ db.session.commit()
+ except Exception:
+ # Rollback if commit fails
+ db.session.rollback()
+
+ def test_info_task(self):
+ """
+ Task API: Test info endpoint
+ """
+ self.login(ADMIN_USERNAME)
+ uri = f"{self.TASK_API_BASE}/_info"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+ data = json.loads(rv.data.decode("utf-8"))
+ assert "permissions" in data
+
+ def test_get_task_by_uuid(self):
+ """
+ Task API: Test get task by UUID and verify dedup_key is hashed
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ # Get a pending task to verify active dedup_key format
+ task = (
+ db.session.query(Task)
+ .filter_by(
+ created_by_fk=admin.id,
+ status=TaskStatus.PENDING.value,
+ task_type="test_type",
+ )
+ .first()
+ )
+ assert task is not None
+
+ # Verify active task has hashed dedup_key (64 chars for SHA-256)
+ assert len(task.dedup_key) == 64
+ assert all(c in "0123456789abcdef" for c in task.dedup_key)
+ assert task.dedup_key != str(task.uuid)
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ # Compare strings since JSON response contains string UUID
+ assert data["result"]["uuid"] == str(task.uuid)
+ assert data["result"]["id"] == task.id
+
+ def test_get_task_not_found(self):
+ """
+ Task API: Test get task not found with non-existent UUID
+ """
+ self.login(ADMIN_USERNAME)
+ # Use a valid UUID that doesn't exist in the database
+ uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000"
+ rv = self.client.get(uri)
+ assert rv.status_code == 404
+
+ def test_get_task_invalid_uuid(self):
+ """
+ Task API: Test get task with invalid UUID
+ """
+ self.login(ADMIN_USERNAME)
+ uri = f"{self.TASK_API_BASE}/invalid-uuid"
+ rv = self.client.get(uri)
+ assert rv.status_code == 404
+
+ def test_get_task_list(self):
+ """
+ Task API: Test get task list
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ uri = f"{self.TASK_API_BASE}/"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert data["count"] >= 6 # At least the fixtures we created
+ assert "result" in data
+
+ def test_get_task_list_filtered_by_status(self):
+ """
+ Task API: Test get task list filtered by status
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ arguments = {
+ "filters": [
+ {"col": "status", "opr": "eq", "value": TaskStatus.PENDING.value}
+ ]
+ }
+ uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ for task in data["result"]:
+ assert task["status"] == TaskStatus.PENDING.value
+
+ def test_get_task_list_filtered_by_type(self):
+ """
+ Task API: Test get task list filtered by type
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ arguments = {
+ "filters": [{"col": "task_type", "opr": "eq", "value": "test_type"}]
+ }
+ uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert data["count"] >= 6
+ for task in data["result"]:
+ assert task["task_type"] == "test_type"
+
+ def test_get_task_list_ordered(self):
+ """
+ Task API: Test get task list with ordering
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ arguments = {
+ "order_column": "created_on",
+ "order_direction": "desc",
+ }
+ uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert len(data["result"]) > 0
+
+ def test_get_task_list_paginated(self):
+ """
+ Task API: Test get task list with pagination
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ arguments = {"page": 0, "page_size": 2}
+ uri = f"{self.TASK_API_BASE}/?q={prison.dumps(arguments)}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert len(data["result"]) <= 2
+ assert data["count"] >= 6
+
+ def test_cancel_task_by_uuid(self):
+ """
+ Task API: Test cancel task by UUID
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ task = (
+ db.session.query(Task)
+ .filter_by(created_by_fk=admin.id, status=TaskStatus.PENDING.value)
+ .first()
+ )
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
+ rv = self.client.post(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ # Compare strings since JSON response contains string UUID
+ assert data["task"]["uuid"] == str(task.uuid)
+ assert data["task"]["status"] == TaskStatus.ABORTED.value
+ assert data["action"] == "aborted"
+
+ def test_cancel_task_not_found(self):
+ """
+ Task API: Test cancel task not found with non-existent UUID
+ """
+ self.login(ADMIN_USERNAME)
+ uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/cancel"
+ rv = self.client.post(uri)
+ assert rv.status_code == 404
+
+ def test_cancel_task_not_owned(self):
+ """
+ Task API: Test cancel task not owned by user
+ """
+ with self._create_tasks():
+ self.login(GAMMA_USERNAME)
+ admin = self.get_user("admin")
+
+ # Try to cancel admin's task as gamma user
+ task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
+ rv = self.client.post(uri)
+ assert rv.status_code == 404
+
+ def test_cancel_task_admin_can_cancel_others(self):
+ """
+ Task API: Test admin can cancel other users' tasks
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ gamma = self.get_user("gamma")
+
+ # Admin cancels gamma's task
+ task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/cancel"
+ rv = self.client.post(uri)
+ assert rv.status_code == 200
+
+ def test_get_task_status_by_uuid(self):
+ """
+ Task API: Test get task status by UUID
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert "status" in data
+ assert data["status"] == task.status
+
+ def test_get_task_status_not_found(self):
+ """
+ Task API: Test get task status not found with non-existent UUID
+ """
+ self.login(ADMIN_USERNAME)
+ uri = f"{self.TASK_API_BASE}/00000000-0000-0000-0000-000000000000/status"
+ rv = self.client.get(uri)
+ assert rv.status_code == 404
+
+ def test_get_task_status_not_owned(self):
+ """
+ Task API: Test non-owner can't see task status
+ """
+ with self._create_tasks():
+ self.login(GAMMA_USERNAME)
+ admin = self.get_user("admin")
+
+ # Try to get status of admin's task as gamma user
+ task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
+ rv = self.client.get(uri)
+ # Should be forbidden due to base filter
+ assert rv.status_code == 404
+
+ def test_get_task_status_admin_can_see_others(self):
+ """
+ Task API: Test admin can see other users' task status
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ gamma = self.get_user("gamma")
+
+ # Admin gets gamma's task status
+ task = db.session.query(Task).filter_by(created_by_fk=gamma.id).first()
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}/status"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ assert data["status"] == task.status
+
+ def test_get_task_list_user_sees_own_tasks(self):
+ """
+ Task API: Test non-admin user only sees their own tasks
+ """
+ with self._create_tasks():
+ self.login(GAMMA_USERNAME)
+ gamma = self.get_user("gamma")
+
+ uri = f"{self.TASK_API_BASE}/"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ # Gamma should only see their own task
+ for task in data["result"]:
+ assert task["created_by"]["id"] == gamma.id
+
+ def test_get_task_list_admin_sees_all_tasks(self):
+ """
+ Task API: Test admin sees all tasks
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+
+ uri = f"{self.TASK_API_BASE}/"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ # Admin should see all tasks
+ assert data["count"] >= 6
+
+ def test_task_response_schema(self):
+ """
+ Task API: Test response schema includes all expected fields
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ task = db.session.query(Task).filter_by(created_by_fk=admin.id).first()
+ uri = f"{self.TASK_API_BASE}/{task.uuid}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ result = data["result"]
+
+ # Check all expected fields are present
+ expected_fields = [
+ "id",
+ "uuid",
+ "task_key",
+ "task_type",
+ "task_name",
+ "status",
+ "created_on",
+ "created_on_delta_humanized",
+ "changed_on",
+ "changed_by",
+ "started_at",
+ "ended_at",
+ "created_by",
+ "user_id",
+ "payload",
+ "properties",
+ "duration_seconds",
+ "scope",
+ "subscriber_count",
+ "subscribers",
+ ]
+
+ for field in expected_fields:
+ assert field in result, f"Field {field} missing from response"
+
+ # Verify properties is a dict with expected structure
+ properties = result["properties"]
+ assert isinstance(properties, dict)
+
+ def test_task_payload_serialization(self):
+ """
+ Task API: Test payload is properly serialized as dict
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ task = (
+ db.session.query(Task)
+ .filter_by(created_by_fk=admin.id, task_type="test_type")
+ .first()
+ )
+ uri = f"{self.TASK_API_BASE}/{task.uuid}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ payload = data["result"]["payload"]
+
+ # Payload should be a dict, not a string
+ assert isinstance(payload, dict)
+ assert "test" in payload
+ assert payload["test"] == "data"
+
+ def test_task_computed_properties(self):
+ """
+ Task API: Test computed properties in response
+
+ This test verifies that computed properties (status, duration_seconds)
+ are correctly returned in the API response. Internal DB columns like
+ dedup_key are tested in unit tests (test_find_by_task_key_finished_not_found).
+ """
+ with self._create_tasks():
+ self.login(ADMIN_USERNAME)
+ admin = self.get_user("admin")
+
+ # Get a successful task
+ task = (
+ db.session.query(Task)
+ .filter_by(created_by_fk=admin.id, status=TaskStatus.SUCCESS.value)
+ .first()
+ )
+ assert task is not None
+
+ uri = f"{self.TASK_API_BASE}/{task.uuid}"
+ rv = self.client.get(uri)
+ assert rv.status_code == 200
+
+ data = json.loads(rv.data.decode("utf-8"))
+ result = data["result"]
+
+ # Check status field (computed properties are now derived from status)
+ assert result["status"] == TaskStatus.SUCCESS.value
+
+ # Properties dict should exist and be a dict
+ assert "properties" in result
+ assert isinstance(result["properties"], dict)
+
+ # Verify duration_seconds is not null for completed tasks with timestamps
+ # (requires both started_at and ended_at to be set)
+ if result.get("started_at") and result.get("ended_at"):
+ assert result["duration_seconds"] is not None
+ assert result["duration_seconds"] >= 0.0
diff --git a/tests/integration_tests/tasks/commands/__init__.py b/tests/integration_tests/tasks/commands/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/integration_tests/tasks/commands/test_cancel.py b/tests/integration_tests/tasks/commands/test_cancel.py
new file mode 100644
index 000000000000..fa41e5b94296
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/test_cancel.py
@@ -0,0 +1,482 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import patch
+from uuid import UUID, uuid4
+
+import pytest
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset import db
+from superset.commands.tasks.cancel import CancelTaskCommand
+from superset.commands.tasks.exceptions import (
+ TaskAbortFailedError,
+ TaskNotAbortableError,
+ TaskNotFoundError,
+ TaskPermissionDeniedError,
+)
+from superset.daos.tasks import TaskDAO
+from superset.utils.core import override_user
+from tests.integration_tests.test_app import app
+
+
+def test_cancel_pending_task_aborts(app_context, get_user) -> None:
+ """Test canceling a pending task directly aborts it"""
+ admin = get_user("admin")
+
+ # Create a pending private task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="cancel_pending_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Cancel the pending task with admin user context
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ # Verify task is aborted (pending goes directly to ABORTED)
+ assert result.uuid == task.uuid
+ assert result.status == TaskStatus.ABORTED.value
+ assert command.action_taken == "aborted"
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.status == TaskStatus.ABORTED.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_in_progress_abortable_task_sets_aborting(app_context, get_user) -> None:
+ """Test canceling an in-progress task with abort handler sets ABORTING"""
+ admin = get_user("admin")
+
+ # Create an in-progress abortable task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="cancel_in_progress_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ properties={"is_abortable": True},
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Cancel the in-progress task - mock publish_abort to avoid Redis dependency
+ with (
+ override_user(admin),
+ patch("superset.tasks.manager.TaskManager.publish_abort"),
+ ):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ # In-progress tasks go to ABORTING (not ABORTED)
+ assert result.uuid == task.uuid
+ assert result.status == TaskStatus.ABORTING.value
+ assert command.action_taken == "aborted"
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.status == TaskStatus.ABORTING.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_in_progress_not_abortable_raises_error(app_context, get_user) -> None:
+ """Test canceling an in-progress task without abort handler raises error"""
+ admin = get_user("admin")
+ unique_key = f"cancel_not_abortable_test_{uuid4().hex[:8]}"
+
+ # Create an in-progress non-abortable task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=unique_key,
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ properties={"is_abortable": False},
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+
+ with pytest.raises(TaskNotAbortableError):
+ command.run()
+
+ # Verify task status unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_task_not_found(app_context, get_user) -> None:
+ """Test canceling non-existent task raises error"""
+ admin = get_user("admin")
+
+ with override_user(admin):
+ command = CancelTaskCommand(
+ task_uuid=UUID("00000000-0000-0000-0000-000000000000")
+ )
+
+ with pytest.raises(TaskNotFoundError):
+ command.run()
+
+
+def test_cancel_finished_task_raises_error(app_context, get_user) -> None:
+ """Test canceling an already finished task raises error"""
+
+ admin = get_user("admin")
+ unique_key = f"cancel_finished_test_{uuid4().hex[:8]}"
+
+ # Create a finished task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=unique_key,
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.SUCCESS)
+ db.session.commit()
+
+ try:
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+
+ with pytest.raises(TaskAbortFailedError):
+ command.run()
+
+ # Verify task status unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.SUCCESS.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_shared_task_with_multiple_subscribers_unsubscribes(
+ app_context, get_user
+) -> None:
+ """Test canceling a shared task with multiple subscribers unsubscribes user"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+
+ # Create a shared task with admin as creator
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="cancel_shared_test",
+ scope=TaskScope.SHARED,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ # Add gamma as subscriber
+ TaskDAO.add_subscriber(task.id, user_id=gamma.id)
+ db.session.commit()
+
+ try:
+ # Verify we have 2 subscribers
+ db.session.refresh(task)
+ assert task.subscriber_count == 2
+
+ # Cancel as gamma (non-admin subscriber)
+ with override_user(gamma):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ # Should unsubscribe, not abort
+ assert command.action_taken == "unsubscribed"
+ assert result.status == TaskStatus.PENDING.value # Status unchanged
+
+ # Verify gamma was unsubscribed
+ db.session.refresh(task)
+ assert task.subscriber_count == 1
+ assert not task.has_subscriber(gamma.id)
+ assert task.has_subscriber(admin.id)
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_shared_task_last_subscriber_aborts(app_context, get_user) -> None:
+ """Test canceling a shared task as last subscriber aborts it"""
+ admin = get_user("admin")
+
+ # Create a shared task with only admin as subscriber
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="cancel_last_subscriber_test",
+ scope=TaskScope.SHARED,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Verify only 1 subscriber
+ db.session.refresh(task)
+ assert task.subscriber_count == 1
+
+ # Cancel as the only subscriber
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ # Should abort since last subscriber
+ assert command.action_taken == "aborted"
+ assert result.status == TaskStatus.ABORTED.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_with_force_aborts_for_all_subscribers(app_context, get_user) -> None:
+ """Test force cancel aborts shared task even with multiple subscribers"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+
+ # Create a shared task with multiple subscribers
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="force_cancel_test",
+ scope=TaskScope.SHARED,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ # Add gamma as subscriber
+ TaskDAO.add_subscriber(task.id, user_id=gamma.id)
+ db.session.commit()
+
+ try:
+ # Verify 2 subscribers
+ db.session.refresh(task)
+ assert task.subscriber_count == 2
+
+ # Force cancel as admin
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid, force=True)
+ result = command.run()
+
+ # Should abort despite multiple subscribers
+ assert command.action_taken == "aborted"
+ assert result.status == TaskStatus.ABORTED.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_with_force_requires_admin(app_context, get_user) -> None:
+ """Test force cancel requires admin privileges"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+
+ # Create a shared task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="force_requires_admin_test",
+ scope=TaskScope.SHARED,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ # Add gamma as subscriber
+ TaskDAO.add_subscriber(task.id, user_id=gamma.id)
+ db.session.commit()
+
+ try:
+ # Try to force cancel as gamma (non-admin)
+ with override_user(gamma):
+ command = CancelTaskCommand(task_uuid=task.uuid, force=True)
+
+ with pytest.raises(TaskPermissionDeniedError):
+ command.run()
+
+ # Verify task unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.PENDING.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_private_task_permission_denied(app_context, get_user) -> None:
+ """Test non-owner cannot cancel private task"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+ unique_key = f"private_permission_test_{uuid4().hex[:8]}"
+
+ # Use test_request_context to ensure has_request_context() returns True
+ # so that TaskFilter properly applies permission filtering
+ with app.test_request_context():
+ # Create a private task owned by admin
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=unique_key,
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Try to cancel admin's private task as gamma (non-owner)
+ with override_user(gamma):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+
+ # Should fail because gamma can't see admin's private task (base filter)
+ with pytest.raises(TaskNotFoundError):
+ command.run()
+
+ # Verify task unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.PENDING.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_system_task_requires_admin(app_context, get_user) -> None:
+ """Test system tasks can only be canceled by admin"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+ unique_key = f"system_task_test_{uuid4().hex[:8]}"
+
+ # Use test_request_context to ensure has_request_context() returns True
+ # so that TaskFilter properly applies permission filtering
+ with app.test_request_context():
+ # Create a system task
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=unique_key,
+ scope=TaskScope.SYSTEM,
+ user_id=None,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Try to cancel as gamma (non-admin)
+ with override_user(gamma):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+
+ # System tasks are not visible to non-admins via base filter
+ with pytest.raises(TaskNotFoundError):
+ command.run()
+
+ # Verify task unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.PENDING.value
+
+ # But admin can cancel it
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ assert result.status == TaskStatus.ABORTED.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_already_aborting_is_idempotent(app_context, get_user) -> None:
+ """Test canceling an already aborting task is idempotent"""
+ admin = get_user("admin")
+
+ # Create a task already in ABORTING state
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="idempotent_cancel_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.ABORTING)
+ db.session.commit()
+
+ try:
+ # Cancel the already aborting task
+ with override_user(admin):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+ result = command.run()
+
+ # Should succeed without error
+ assert result.uuid == task.uuid
+ assert result.status == TaskStatus.ABORTING.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_cancel_shared_task_not_subscribed_raises_error(app_context, get_user) -> None:
+ """Test non-subscriber cannot cancel shared task"""
+ admin = get_user("admin")
+ gamma = get_user("gamma")
+
+ # Create a shared task with only admin as subscriber
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="not_subscribed_test",
+ scope=TaskScope.SHARED,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Try to cancel as gamma (not subscribed)
+ with override_user(gamma):
+ command = CancelTaskCommand(task_uuid=task.uuid)
+
+ with pytest.raises(TaskPermissionDeniedError):
+ command.run()
+
+ # Verify task unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.PENDING.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
diff --git a/tests/integration_tests/tasks/commands/test_internal_update.py b/tests/integration_tests/tasks/commands/test_internal_update.py
new file mode 100644
index 000000000000..b4c8ee56c935
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/test_internal_update.py
@@ -0,0 +1,419 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration tests for internal task state update commands."""
+
+from uuid import UUID
+
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset import db
+from superset.commands.tasks.internal_update import (
+ InternalStatusTransitionCommand,
+ InternalUpdateTaskCommand,
+)
+from superset.daos.tasks import TaskDAO
+
+
+def test_internal_update_properties(app_context, get_user, login_as) -> None:
+ """Test updating only properties without reading task first."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_props",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Perform zero-read update
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties={"is_abortable": True, "progress_percent": 0.5},
+ )
+ result = command.run()
+
+ assert result is True
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.properties_dict.get("is_abortable") is True
+ assert task.properties_dict.get("progress_percent") == 0.5
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_internal_update_payload(app_context, get_user, login_as) -> None:
+ """Test updating only payload without reading task first."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_payload",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Perform zero-read update
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ payload={"custom_key": "value", "count": 42},
+ )
+ result = command.run()
+
+ assert result is True
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.payload_dict == {"custom_key": "value", "count": 42}
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_internal_update_both_properties_and_payload(
+ app_context, get_user, login_as
+) -> None:
+ """Test updating both properties and payload in one call."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_both",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Perform zero-read update of both
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties={"progress_current": 50, "progress_total": 100},
+ payload={"last_item": "xyz"},
+ )
+ result = command.run()
+
+ assert result is True
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.properties_dict.get("progress_current") == 50
+ assert task.properties_dict.get("progress_total") == 100
+ assert task.payload_dict == {"last_item": "xyz"}
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_internal_update_returns_false_for_nonexistent_task(
+ app_context, login_as
+) -> None:
+ """Test that updating non-existent task returns False."""
+ login_as("admin")
+
+ command = InternalUpdateTaskCommand(
+ task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
+ properties={"is_abortable": True},
+ )
+ result = command.run()
+
+ assert result is False
+
+
+def test_internal_update_returns_false_when_nothing_to_update(
+ app_context, get_user, login_as
+) -> None:
+ """Test that passing no properties or payload returns False early."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_empty",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # No properties or payload provided
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties=None,
+ payload=None,
+ )
+ result = command.run()
+
+ assert result is False
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_internal_update_does_not_change_status(
+ app_context, get_user, login_as
+) -> None:
+ """Test that internal update leaves status unchanged (safe for concurrent abort)."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_status",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Update properties - status should not change
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties={"progress_percent": 0.75},
+ )
+ result = command.run()
+
+ assert result is True
+
+ # Verify status unchanged
+ db.session.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_internal_update_replaces_entire_properties(
+ app_context, get_user, login_as
+) -> None:
+ """Test that internal update replaces properties entirely (no merge)."""
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="internal_update_replace",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ properties={"is_abortable": True, "timeout": 300},
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # Replace with new properties (caller is responsible for merging if needed)
+ command = InternalUpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties={"error_message": "new_value"},
+ )
+ result = command.run()
+
+ assert result is True
+
+ # Verify entire replacement occurred
+ db.session.refresh(task)
+ # The caller should have merged if they wanted to preserve is_abortable
+ assert task.properties_dict == {"error_message": "new_value"}
+ assert "is_abortable" not in task.properties_dict
+ assert "timeout" not in task.properties_dict
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+# =============================================================================
+# InternalStatusTransitionCommand Tests
+# =============================================================================
+
+
+def test_status_transition_atomic_compare_and_swap(
+ app_context, get_user, login_as
+) -> None:
+ """Test atomic conditional status transitions with comprehensive scenarios.
+
+ Covers: success case, failure case, list of expected statuses, properties update,
+ ended_at timestamp, and string status values.
+ """
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="status_transition_comprehensive",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ db.session.commit()
+
+ try:
+ # 1. SUCCESS CASE: PENDING → IN_PROGRESS (expected matches)
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.IN_PROGRESS,
+ expected_status=TaskStatus.PENDING,
+ ).run()
+ assert result is True
+ db.session.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value
+
+ # 2. FAILURE CASE: Try wrong expected status (should fail, status unchanged)
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.SUCCESS,
+ expected_status=TaskStatus.PENDING, # Wrong! Current is IN_PROGRESS
+ ).run()
+ assert result is False
+ db.session.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value # Unchanged
+
+ # 3. LIST OF EXPECTED: Transition with multiple acceptable source statuses
+ task.set_status(TaskStatus.ABORTING)
+ db.session.commit()
+
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.FAILURE,
+ expected_status=[TaskStatus.IN_PROGRESS, TaskStatus.ABORTING],
+ properties={"error_message": "Test error"},
+ ).run()
+ assert result is True
+ db.session.refresh(task)
+ assert task.status == TaskStatus.FAILURE.value
+ assert task.properties_dict.get("error_message") == "Test error"
+
+ # 4. ENDED_AT: Reset to IN_PROGRESS and test ended_at timestamp
+ task.set_status(TaskStatus.IN_PROGRESS)
+ task.ended_at = None
+ db.session.commit()
+ assert task.ended_at is None
+
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.SUCCESS,
+ expected_status=TaskStatus.IN_PROGRESS,
+ set_ended_at=True,
+ ).run()
+ assert result is True
+ db.session.refresh(task)
+ assert task.status == TaskStatus.SUCCESS.value
+ assert task.ended_at is not None
+
+ # 5. STRING VALUES: Reset and test string status values
+ task.set_status(TaskStatus.PENDING)
+ db.session.commit()
+
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status="in_progress",
+ expected_status="pending",
+ ).run()
+ assert result is True
+ db.session.refresh(task)
+ assert task.status == "in_progress"
+
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_status_transition_prevents_race_condition(
+ app_context, get_user, login_as
+) -> None:
+ """Test that conditional update prevents overwriting concurrent abort.
+
+ This is the key race condition fix: if task is aborted concurrently,
+ the executor's attempt to set SUCCESS should fail (return False),
+ preserving the ABORTING state.
+ """
+ admin = get_user("admin")
+ login_as("admin")
+
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="status_transition_race",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Simulate concurrent abort: directly set ABORTING in DB
+ # (as if CancelTaskCommand ran in another process)
+ task.set_status(TaskStatus.ABORTING)
+ db.session.commit()
+
+ # Executor tries to set SUCCESS (expecting IN_PROGRESS) - stale expectation
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.SUCCESS,
+ expected_status=TaskStatus.IN_PROGRESS,
+ ).run()
+
+ # Should fail - task was aborted concurrently
+ assert result is False
+
+ # Verify ABORTING is preserved (not overwritten to SUCCESS)
+ db.session.refresh(task)
+ assert task.status == TaskStatus.ABORTING.value
+
+ # Verify correct transition from ABORTING still works
+ result = InternalStatusTransitionCommand(
+ task_uuid=task.uuid,
+ new_status=TaskStatus.ABORTED,
+ expected_status=TaskStatus.ABORTING,
+ set_ended_at=True,
+ ).run()
+ assert result is True
+ db.session.refresh(task)
+ assert task.status == TaskStatus.ABORTED.value
+
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_status_transition_nonexistent_task(app_context, login_as) -> None:
+ """Test that transitioning non-existent task returns False."""
+ login_as("admin")
+
+ result = InternalStatusTransitionCommand(
+ task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
+ new_status=TaskStatus.IN_PROGRESS,
+ expected_status=TaskStatus.PENDING,
+ ).run()
+
+ assert result is False
diff --git a/tests/integration_tests/tasks/commands/test_prune.py b/tests/integration_tests/tasks/commands/test_prune.py
new file mode 100644
index 000000000000..0706319951d3
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/test_prune.py
@@ -0,0 +1,258 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from freezegun import freeze_time
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset import db
+from superset.commands.tasks import TaskPruneCommand
+from superset.daos.tasks import TaskDAO
+from superset.models.tasks import Task
+
+
+@freeze_time("2024-02-15")
+@patch("superset.tasks.utils.get_current_user")
+def test_prune_tasks_success(mock_get_user, app_context, get_user, login_as) -> None:
+ """Test successful pruning of old completed tasks"""
+ login_as("admin")
+ admin = get_user("admin")
+ mock_get_user.return_value = admin.username
+
+ # Create old completed tasks (35 days ago = Jan 11, 2024)
+ old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
+ task_ids = []
+ for i in range(3):
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=f"prune_task_{i}",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.SUCCESS)
+ task.ended_at = old_date
+ task_ids.append(task.id)
+
+ # Create a recent task (5 days ago = Feb 10, 2024) that should NOT be deleted
+ recent_date = datetime(2024, 2, 10, tzinfo=timezone.utc)
+ recent_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="recent_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ recent_task.created_by = admin
+ recent_task.set_status(TaskStatus.SUCCESS)
+ recent_task.ended_at = recent_date
+ recent_task_id = recent_task.id
+
+ db.session.commit()
+
+ try:
+ # Prune tasks older than 30 days
+ command = TaskPruneCommand(retention_period_days=30)
+ command.run()
+
+ # Verify old tasks are deleted
+ for task_id in task_ids:
+ assert db.session.get(Task, task_id) is None
+
+ # Verify recent task is NOT deleted
+ assert db.session.get(Task, recent_task_id) is not None
+ finally:
+ # Cleanup remaining tasks
+ for task_id in task_ids:
+ existing = db.session.get(Task, task_id)
+ if existing:
+ db.session.delete(existing)
+ if db.session.get(Task, recent_task_id):
+ db.session.delete(db.session.get(Task, recent_task_id))
+ db.session.commit()
+
+
+@freeze_time("2024-02-15")
+@patch("superset.tasks.utils.get_current_user")
+def test_prune_tasks_with_max_rows(
+ mock_get_user, app_context, get_user, login_as
+) -> None:
+ """Test pruning with max_rows_per_run limit"""
+ login_as("admin")
+ admin = get_user("admin")
+ mock_get_user.return_value = admin.username
+
+ # Create old completed tasks (35 days ago = Jan 11, 2024)
+ task_ids = []
+ for i in range(5):
+ # Different ages for ordering (older tasks have smaller hour values)
+ old_date = datetime(2024, 1, 11, i, tzinfo=timezone.utc)
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key=f"max_rows_task_{i}",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.SUCCESS)
+ task.ended_at = old_date
+ task_ids.append(task.id)
+
+ db.session.commit()
+
+ try:
+ # Prune with max_rows_per_run=2 (should only delete 2 oldest)
+ command = TaskPruneCommand(retention_period_days=30, max_rows_per_run=2)
+ command.run()
+
+ # Count remaining tasks
+ remaining = sum(
+ 1 for task_id in task_ids if db.session.get(Task, task_id) is not None
+ )
+ assert remaining == 3 # 5 - 2 = 3 remaining
+ finally:
+ # Cleanup remaining tasks
+ for task_id in task_ids:
+ existing = db.session.get(Task, task_id)
+ if existing:
+ db.session.delete(existing)
+ db.session.commit()
+
+
+@freeze_time("2024-02-15")
+@patch("superset.tasks.utils.get_current_user")
+def test_prune_does_not_delete_pending_tasks(
+ mock_get_user, app_context, get_user, login_as
+) -> None:
+ """Test that pruning does not delete pending or in-progress tasks"""
+ login_as("admin")
+ admin = get_user("admin")
+ mock_get_user.return_value = admin.username
+
+ pending_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="pending_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ pending_task.created_by = admin
+ # Keep as PENDING (no ended_at)
+
+ in_progress_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="in_progress_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ in_progress_task.created_by = admin
+ in_progress_task.set_status(TaskStatus.IN_PROGRESS)
+ # No ended_at for in-progress tasks
+
+ db.session.commit()
+
+ try:
+ # Prune tasks older than 30 days
+ command = TaskPruneCommand(retention_period_days=30)
+ command.run()
+
+ # Verify non-completed tasks are NOT deleted
+ assert db.session.get(Task, pending_task.id) is not None
+ assert db.session.get(Task, in_progress_task.id) is not None
+ finally:
+ # Cleanup
+ for task in [pending_task, in_progress_task]:
+ existing = db.session.get(Task, task.id)
+ if existing:
+ db.session.delete(existing)
+ db.session.commit()
+
+
+@freeze_time("2024-02-15")
+@patch("superset.tasks.utils.get_current_user")
+def test_prune_deletes_all_completed_statuses(
+ mock_get_user, app_context, get_user, login_as
+) -> None:
+ """Test pruning deletes SUCCESS, FAILURE, and ABORTED tasks"""
+ login_as("admin")
+ admin = get_user("admin")
+ mock_get_user.return_value = admin.username
+
+ old_date = datetime(2024, 1, 11, tzinfo=timezone.utc)
+
+ # Create tasks with different completed statuses
+ success_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="success_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ success_task.created_by = admin
+ success_task.set_status(TaskStatus.SUCCESS)
+ success_task.ended_at = old_date
+ success_task_id = success_task.id
+
+ failure_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="failure_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ failure_task.created_by = admin
+ failure_task.set_status(TaskStatus.FAILURE)
+ failure_task.ended_at = old_date
+ failure_task_id = failure_task.id
+
+ aborted_task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="aborted_task",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ aborted_task.created_by = admin
+ aborted_task.set_status(TaskStatus.ABORTED)
+ aborted_task.ended_at = old_date
+ aborted_task_id = aborted_task.id
+
+ db.session.commit()
+ task_ids = [success_task_id, failure_task_id, aborted_task_id]
+
+ try:
+ # Prune tasks older than 30 days
+ command = TaskPruneCommand(retention_period_days=30)
+ command.run()
+
+ # Verify all completed tasks are deleted
+ for task_id in task_ids:
+ assert db.session.get(Task, task_id) is None
+ except AssertionError:
+ # Cleanup if test fails
+ for task_id in task_ids:
+ existing = db.session.get(Task, task_id)
+ if existing:
+ db.session.delete(existing)
+ db.session.commit()
+ raise
+
+
+def test_prune_no_tasks_to_delete(app_context, login_as) -> None:
+ """Test pruning when no old tasks exist"""
+ login_as("admin")
+
+ # Don't create any tasks - should handle gracefully
+ command = TaskPruneCommand(retention_period_days=30)
+ command.run() # Should not raise any errors
diff --git a/tests/integration_tests/tasks/commands/test_submit.py b/tests/integration_tests/tasks/commands/test_submit.py
new file mode 100644
index 000000000000..a6a7f6f31711
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/test_submit.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import pytest
+from superset_core.api.tasks import TaskStatus
+
+from superset import db
+from superset.commands.tasks import SubmitTaskCommand
+from superset.commands.tasks.exceptions import (
+ TaskInvalidError,
+)
+
+
+def test_submit_task_success(app_context, login_as, get_user) -> None:
+ """Test successful task submission"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ command = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "test-key",
+ "task_name": "Test Task",
+ "user_id": admin.id,
+ }
+ )
+
+ try:
+ result = command.run()
+
+ # Verify task was created
+ assert result.task_type == "test-type"
+ assert result.task_key == "test-key"
+ assert result.task_name == "Test Task"
+ assert result.status == TaskStatus.PENDING.value
+ assert result.payload == "{}"
+
+ # Verify in database
+ db.session.refresh(result)
+ assert result.id is not None
+ assert result.uuid is not None
+ finally:
+ # Cleanup
+ db.session.delete(result)
+ db.session.commit()
+
+
+def test_submit_task_with_all_fields(app_context, login_as, get_user) -> None:
+ """Test task submission with all optional fields"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ command = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "test-key-full",
+ "task_name": "Test Task Full",
+ "user_id": admin.id,
+ "payload": {"key": "value"},
+ "properties": {"execution_mode": "async", "timeout": 300},
+ }
+ )
+
+ try:
+ result = command.run()
+
+ # Verify all fields were set
+ assert result.task_type == "test-type"
+ assert result.task_key == "test-key-full"
+ assert result.task_name == "Test Task Full"
+ assert result.user_id == admin.id
+ assert result.payload_dict == {"key": "value"}
+ assert result.properties_dict.get("execution_mode") == "async"
+ assert result.properties_dict.get("timeout") == 300
+ finally:
+ # Cleanup
+ db.session.delete(result)
+ db.session.commit()
+
+
+def test_submit_task_missing_task_type(app_context, login_as) -> None:
+ """Test submission fails when task_type is missing"""
+ login_as("admin")
+
+ command = SubmitTaskCommand(data={})
+
+ with pytest.raises(TaskInvalidError) as exc_info:
+ command.run()
+
+ assert len(exc_info.value._exceptions) == 1
+ assert "task_type" in exc_info.value._exceptions[0].field_name
+
+
+def test_submit_task_joins_existing(app_context, login_as, get_user) -> None:
+ """Test that submitting with duplicate key joins existing task"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ # Create first task
+ command1 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "shared-key",
+ "task_name": "First Task",
+ "user_id": admin.id,
+ }
+ )
+ task1 = command1.run()
+
+ try:
+ # Submit second task with same task_key and type
+ command2 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "shared-key",
+ "task_name": "Second Task",
+ "user_id": admin.id,
+ }
+ )
+
+ # Should return existing task, not create new one
+ task2 = command2.run()
+ assert task2.id == task1.id
+ assert task2.uuid == task1.uuid
+ finally:
+ # Cleanup
+ db.session.delete(task1)
+ db.session.commit()
+
+
+def test_submit_task_without_task_key(app_context, login_as, get_user) -> None:
+ """Test task submission without task_key (command generates UUID)"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ command = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_name": "Test Task No ID",
+ "user_id": admin.id,
+ }
+ )
+
+ try:
+ result = command.run()
+
+ # Verify task was created and command generated a task_key
+ assert result.task_type == "test-type"
+ assert result.task_name == "Test Task No ID"
+ assert result.task_key is not None # Command generated UUID
+ assert result.uuid is not None
+ finally:
+ # Cleanup
+ db.session.delete(result)
+ db.session.commit()
+
+
+def test_submit_task_run_with_info_returns_is_new_true(
+ app_context, login_as, get_user
+) -> None:
+ """Test run_with_info returns is_new=True for new task"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ command = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "unique-key-is-new",
+ "task_name": "Test Task",
+ "user_id": admin.id,
+ }
+ )
+
+ try:
+ task, is_new = command.run_with_info()
+
+ assert is_new is True
+ assert task.task_key == "unique-key-is-new"
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_submit_task_run_with_info_returns_is_new_false(
+ app_context, login_as, get_user
+) -> None:
+ """Test run_with_info returns is_new=False when joining existing task"""
+ login_as("admin")
+ admin = get_user("admin")
+
+ # Create first task
+ command1 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "shared-key-is-new",
+ "task_name": "First Task",
+ "user_id": admin.id,
+ }
+ )
+ task1, is_new1 = command1.run_with_info()
+ assert is_new1 is True
+
+ try:
+ # Submit second task with same key
+ command2 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "shared-key-is-new",
+ "task_name": "Second Task",
+ "user_id": admin.id,
+ }
+ )
+ task2, is_new2 = command2.run_with_info()
+
+ # Should return existing task with is_new=False
+ assert is_new2 is False
+ assert task2.id == task1.id
+ assert task2.uuid == task1.uuid
+ finally:
+ # Cleanup
+ db.session.delete(task1)
+ db.session.commit()
diff --git a/tests/integration_tests/tasks/commands/test_update.py b/tests/integration_tests/tasks/commands/test_update.py
new file mode 100644
index 000000000000..8ace6360efcd
--- /dev/null
+++ b/tests/integration_tests/tasks/commands/test_update.py
@@ -0,0 +1,260 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from uuid import UUID
+
+import pytest
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset import db
+from superset.commands.tasks import UpdateTaskCommand
+from superset.commands.tasks.exceptions import (
+ TaskForbiddenError,
+ TaskNotFoundError,
+)
+from superset.daos.tasks import TaskDAO
+
+
+def test_update_task_success(app_context, get_user, login_as) -> None:
+ """Test successful task update"""
+ admin = get_user("admin")
+ login_as("admin")
+
+ # Create a task using DAO
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="update_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Update the task status
+ command = UpdateTaskCommand(
+ task_uuid=task.uuid,
+ status=TaskStatus.SUCCESS.value,
+ )
+ result = command.run()
+
+ # Verify update
+ assert result.uuid == task.uuid
+ assert result.status == TaskStatus.SUCCESS.value
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.status == TaskStatus.SUCCESS.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_update_task_not_found(app_context, login_as) -> None:
+ """Test update fails when task not found"""
+ login_as("admin")
+
+ command = UpdateTaskCommand(
+ task_uuid=UUID("00000000-0000-0000-0000-000000000000"),
+ status=TaskStatus.SUCCESS.value,
+ )
+
+ with pytest.raises(TaskNotFoundError):
+ command.run()
+
+
+def test_update_task_forbidden(app_context, get_user, login_as) -> None:
+ """Test update fails when user doesn't own task (via base filter)"""
+ gamma = get_user("gamma")
+ login_as("gamma")
+
+ # Create a task owned by gamma (non-admin) using DAO
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="forbidden_test",
+ scope=TaskScope.PRIVATE,
+ user_id=gamma.id,
+ )
+ task.created_by = gamma
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Login as alpha user (different non-admin, non-owner)
+ login_as("alpha")
+
+ # Try to update gamma's task as alpha user
+ command = UpdateTaskCommand(
+ task_uuid=task.uuid,
+ status=TaskStatus.SUCCESS.value,
+ )
+
+ # Should raise ForbiddenError because ownership check fails
+ with pytest.raises(TaskForbiddenError):
+ command.run()
+
+ # Verify task was NOT updated
+ db.session.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_update_task_payload(app_context, get_user, login_as) -> None:
+ """Test updating task payload"""
+ admin = get_user("admin")
+ login_as("admin")
+
+ # Create a task using DAO
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="payload_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ payload={"initial": "data"},
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Update payload
+ command = UpdateTaskCommand(
+ task_uuid=task.uuid,
+ payload={"progress": 50, "message": "halfway"},
+ )
+ result = command.run()
+
+ # Verify payload was updated
+ assert result.uuid == task.uuid
+ payload = result.payload_dict
+ assert payload["progress"] == 50
+ assert payload["message"] == "halfway"
+
+ # Verify in database
+ db.session.refresh(task)
+ task_payload = task.payload_dict
+ assert task_payload["progress"] == 50
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_update_all_supported_fields(app_context, get_user, login_as) -> None:
+ """Test updating all supported task fields
+ (status, error, progress, abortable, timeout)"""
+ admin = get_user("admin")
+ login_as("admin")
+
+ # Create a task with initial execution_mode and timeout in properties
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="all_fields_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ properties={"execution_mode": "async", "timeout": 300},
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Update all field types at once
+ command = UpdateTaskCommand(
+ task_uuid=task.uuid,
+ status=TaskStatus.FAILURE.value,
+ properties={
+ "error_message": "Task failed due to error",
+ "progress_percent": 0.75,
+ "progress_current": 75,
+ "progress_total": 100,
+ "is_abortable": True,
+ },
+ )
+ result = command.run()
+
+ # Verify all fields were updated
+ assert result.uuid == task.uuid
+ assert result.status == TaskStatus.FAILURE.value
+ assert result.properties_dict.get("error_message") == "Task failed due to error"
+ assert result.properties_dict.get("progress_percent") == 0.75
+ assert result.properties_dict.get("progress_current") == 75
+ assert result.properties_dict.get("progress_total") == 100
+ assert result.properties_dict.get("is_abortable") is True
+ assert result.properties_dict.get("execution_mode") == "async"
+ assert result.properties_dict.get("timeout") == 300
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.status == TaskStatus.FAILURE.value
+ assert task.properties_dict.get("error_message") == "Task failed due to error"
+ assert task.properties_dict.get("progress_percent") == 0.75
+ assert task.properties_dict.get("progress_current") == 75
+ assert task.properties_dict.get("progress_total") == 100
+ assert task.properties_dict.get("is_abortable") is True
+ assert task.properties_dict.get("execution_mode") == "async"
+ assert task.properties_dict.get("timeout") == 300
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_update_task_skip_security_check(app_context, get_user, login_as) -> None:
+ """Test skip_security_check allows updating any task"""
+ admin = get_user("admin")
+ login_as("admin")
+
+ # Create a task owned by admin
+ task = TaskDAO.create_task(
+ task_type="test_type",
+ task_key="skip_security_test",
+ scope=TaskScope.PRIVATE,
+ user_id=admin.id,
+ )
+ task.created_by = admin
+ task.set_status(TaskStatus.IN_PROGRESS)
+ db.session.commit()
+
+ try:
+ # Login as gamma user (non-owner)
+ login_as("gamma")
+
+ # With skip_security_check=True, should succeed even though gamma doesn't own it
+ command = UpdateTaskCommand(
+ task_uuid=task.uuid,
+ properties={"progress_percent": 0.75},
+ skip_security_check=True,
+ )
+ result = command.run()
+
+ # Verify update succeeded
+ assert result.uuid == task.uuid
+ assert result.properties_dict.get("progress_percent") == 0.75
+
+ # Verify in database
+ db.session.refresh(task)
+ assert task.properties_dict.get("progress_percent") == 0.75
+ finally:
+ # Cleanup
+ db.session.delete(task)
+ db.session.commit()
diff --git a/tests/integration_tests/tasks/test_event_handlers.py b/tests/integration_tests/tasks/test_event_handlers.py
new file mode 100644
index 000000000000..7c7aacd15ed0
--- /dev/null
+++ b/tests/integration_tests/tasks/test_event_handlers.py
@@ -0,0 +1,415 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""End-to-end integration tests for task event handlers (abort and cleanup)
+
+These tests verify that abort and cleanup handlers work correctly through
+the full task execution path using real @task decorated functions executed
+via the Celery executor (synchronously via .apply()).
+"""
+
+from __future__ import annotations
+
+import uuid
+from typing import Any
+
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset.commands.tasks.cancel import CancelTaskCommand
+from superset.daos.tasks import TaskDAO
+from superset.extensions import db
+from superset.models.tasks import Task
+from superset.tasks.ambient_context import get_context
+from superset.tasks.context import TaskContext
+from superset.tasks.registry import TaskRegistry
+from superset.tasks.scheduler import execute_task
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.constants import ADMIN_USERNAME
+
+# Module-level state to track handler calls across test executions
+# (Since decorated functions are defined at module level)
+_handler_state: dict[str, Any] = {}
+
+
+def _reset_handler_state():
+ """Reset handler state before each test."""
+ global _handler_state
+ _handler_state = {
+ "cleanup_called": False,
+ "abort_called": False,
+ "cleanup_order": [],
+ "abort_order": [],
+ "cleanup_data": {},
+ }
+
+
+def cleanup_test_task() -> None:
+ """Task that registers a cleanup handler."""
+ ctx = get_context()
+
+ @ctx.on_cleanup
+ def handle_cleanup() -> None:
+ _handler_state["cleanup_called"] = True
+
+ # Simulate some work
+ ctx.update_task(progress=1.0)
+
+
+def abort_test_task() -> None:
+ """Task that registers an abort handler."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def handle_abort() -> None:
+ _handler_state["abort_called"] = True
+
+
+def both_handlers_task() -> None:
+ """Task that registers both abort and cleanup handlers."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def handle_abort() -> None:
+ _handler_state["abort_called"] = True
+ _handler_state["abort_order"].append("abort")
+
+ @ctx.on_cleanup
+ def handle_cleanup() -> None:
+ _handler_state["cleanup_called"] = True
+ _handler_state["cleanup_order"].append("cleanup")
+
+
+def multiple_cleanup_handlers_task() -> None:
+ """Task that registers multiple cleanup handlers."""
+ ctx = get_context()
+
+ @ctx.on_cleanup
+ def cleanup_first() -> None:
+ _handler_state["cleanup_order"].append("first")
+
+ @ctx.on_cleanup
+ def cleanup_second() -> None:
+ _handler_state["cleanup_order"].append("second")
+
+ @ctx.on_cleanup
+ def cleanup_third() -> None:
+ _handler_state["cleanup_order"].append("third")
+
+
+def cleanup_with_data_task() -> None:
+ """Task that uses cleanup handler to clean up partial work."""
+ ctx = get_context()
+
+ # Simulate partial work in module-level state
+ _handler_state["cleanup_data"]["temp_key"] = "temp_value"
+
+ @ctx.on_cleanup
+ def handle_cleanup() -> None:
+ # Clean up the partial work
+ _handler_state["cleanup_data"].clear()
+ _handler_state["cleanup_called"] = True
+
+
+def _register_test_tasks() -> None:
+ """Register test task functions if not already registered.
+
+ Called in setUp() to ensure tasks are registered regardless of
+ whether other tests have cleared the registry.
+ """
+ registrations = [
+ ("test_cleanup_task", cleanup_test_task),
+ ("test_abort_task", abort_test_task),
+ ("test_both_handlers_task", both_handlers_task),
+ ("test_multiple_cleanup_task", multiple_cleanup_handlers_task),
+ ("test_cleanup_with_data", cleanup_with_data_task),
+ ]
+ for name, func in registrations:
+ if not TaskRegistry.is_registered(name):
+ TaskRegistry.register(name, func)
+
+
+class TestCleanupHandlers(SupersetTestCase):
+ """E2E tests for on_cleanup functionality using Celery executor."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+ _register_test_tasks()
+ _reset_handler_state()
+
+ def test_cleanup_handler_fires_on_success(self):
+ """Test cleanup handler runs when task completes successfully."""
+ # Create task entry directly
+ task_obj = TaskDAO.create_task(
+ task_type="test_cleanup_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Cleanup",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Execute task synchronously through Celery executor
+ # Use str(uuid) since Celery serializes args as JSON strings
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
+ )
+
+ # Verify task completed successfully
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.SUCCESS.value
+
+ # Verify cleanup handler was called
+ assert _handler_state["cleanup_called"]
+
+ def test_multiple_cleanup_handlers_in_lifo_order(self):
+ """Test multiple cleanup handlers execute in LIFO order."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_multiple_cleanup_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Multiple Cleanup",
+ scope=TaskScope.SYSTEM,
+ )
+
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_multiple_cleanup_task", (), {}]
+ )
+
+ assert result.successful()
+
+ # Handlers should execute in LIFO order (last registered first)
+ assert _handler_state["cleanup_order"] == ["third", "second", "first"]
+
+ def test_cleanup_handler_cleans_up_partial_work(self):
+ """Test cleanup handler can clean up partial work."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_cleanup_with_data",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Cleanup Data",
+ scope=TaskScope.SYSTEM,
+ )
+
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_cleanup_with_data", (), {}]
+ )
+
+ assert result.successful()
+ assert _handler_state["cleanup_called"]
+ # Cleanup handler should have cleared the data
+ assert len(_handler_state["cleanup_data"]) == 0
+
+
+class TestAbortHandlers(SupersetTestCase):
+ """E2E tests for on_abort functionality."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+ _register_test_tasks()
+ _reset_handler_state()
+
+ def test_abort_handler_fires_when_task_aborting(self):
+ """Test abort handler runs when task is in ABORTING state during cleanup."""
+ # Create task entry
+ task_obj = TaskDAO.create_task(
+ task_type="test_abort_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Abort",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Manually set to IN_PROGRESS and then ABORTING to simulate abort
+ task_obj.status = TaskStatus.IN_PROGRESS.value
+ task_obj.update_properties({"is_abortable": True})
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ # Refresh to get the updated task
+ db.session.refresh(task_obj)
+
+ # Create context (simulating what executor does)
+ ctx = TaskContext(task_obj)
+
+ # Register abort handler
+ @ctx.on_abort
+ def handle_abort():
+ _handler_state["abort_called"] = True
+
+ # Set status to ABORTING (simulating CancelTaskCommand)
+ task_obj.status = TaskStatus.ABORTING.value
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ # Run cleanup (simulating executor's finally block)
+ ctx._run_cleanup()
+
+ # Verify abort handler was called
+ assert _handler_state["abort_called"]
+
+ def test_both_handlers_fire_on_abort(self):
+ """Test both abort and cleanup handlers run when task is aborted."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_both_handlers_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Both Handlers",
+ scope=TaskScope.SYSTEM,
+ )
+
+ task_obj.status = TaskStatus.IN_PROGRESS.value
+ task_obj.update_properties({"is_abortable": True})
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ # Refresh to get the updated task
+ db.session.refresh(task_obj)
+
+ ctx = TaskContext(task_obj)
+
+ @ctx.on_abort
+ def handle_abort():
+ _handler_state["abort_called"] = True
+ _handler_state["abort_order"].append("abort")
+
+ @ctx.on_cleanup
+ def handle_cleanup():
+ _handler_state["cleanup_called"] = True
+ _handler_state["cleanup_order"].append("cleanup")
+
+ # Set to ABORTING
+ task_obj.status = TaskStatus.ABORTING.value
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ ctx._run_cleanup()
+
+ # Both should have been called
+ assert _handler_state["abort_called"]
+ assert _handler_state["cleanup_called"]
+
+ def test_abort_handler_not_called_on_success(self):
+ """Test abort handler doesn't run when task succeeds."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_abort_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test No Abort on Success",
+ scope=TaskScope.SYSTEM,
+ )
+
+ task_obj.status = TaskStatus.SUCCESS.value
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ # Refresh to get the updated task
+ db.session.refresh(task_obj)
+
+ ctx = TaskContext(task_obj)
+
+ @ctx.on_abort
+ def handle_abort():
+ _handler_state["abort_called"] = True
+
+ @ctx.on_cleanup
+ def handle_cleanup():
+ _handler_state["cleanup_called"] = True
+
+ ctx._run_cleanup()
+
+ # Abort handler should NOT be called
+ assert not _handler_state["abort_called"]
+ # Cleanup handler should still be called
+ assert _handler_state["cleanup_called"]
+
+
+class TestTaskContextMethods(SupersetTestCase):
+ """Tests for TaskContext public methods."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+
+ def test_on_abort_marks_task_abortable(self):
+ """Test that registering an on_abort handler marks task as abortable."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_abortable_flag",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Abortable",
+ scope=TaskScope.SYSTEM,
+ )
+
+ assert task_obj.properties_dict.get("is_abortable") is not True
+
+ ctx = TaskContext(task_obj)
+
+ @ctx.on_abort
+ def handle_abort():
+ pass
+
+ db.session.expire_all()
+ task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
+ assert task_obj.properties_dict.get("is_abortable") is True
+
+
+class TestAbortBeforeExecution(SupersetTestCase):
+ """Tests for aborting tasks before they start executing."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+ _register_test_tasks()
+
+ def test_abort_pending_task(self):
+ """Test that pending tasks can be aborted directly."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_abort_before_start",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Before Start",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Cancel immediately (task is still PENDING)
+ CancelTaskCommand(task_obj.uuid, force=True).run()
+
+ db.session.expire_all()
+ task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
+ assert task_obj.status == TaskStatus.ABORTED.value
+
+ def test_executor_skips_aborted_task(self):
+ """Test that executor skips tasks already aborted before execution."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_cleanup_task",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Skip Aborted",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Abort the task before execution
+ task_obj.status = TaskStatus.ABORTED.value
+ db.session.merge(task_obj)
+ db.session.commit()
+
+ _reset_handler_state()
+
+ # Try to execute - should skip
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_cleanup_task", (), {}]
+ )
+
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.ABORTED.value
+ # Cleanup handler should NOT have been called (task was skipped)
+ assert not _handler_state["cleanup_called"]
diff --git a/tests/integration_tests/tasks/test_sync_join_wait.py b/tests/integration_tests/tasks/test_sync_join_wait.py
new file mode 100644
index 000000000000..9379efca1c75
--- /dev/null
+++ b/tests/integration_tests/tasks/test_sync_join_wait.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Integration tests for sync join-and-wait functionality in GTF."""
+
+import time
+
+from superset_core.api.tasks import TaskStatus
+
+from superset import db
+from superset.commands.tasks import SubmitTaskCommand
+from superset.daos.tasks import TaskDAO
+from superset.tasks.manager import TaskManager
+
+
+def test_submit_task_distinguishes_new_vs_existing(
+ app_context, login_as, get_user
+) -> None:
+ """
+ Test that SubmitTaskCommand.run_with_info() correctly returns is_new flag.
+ """
+ login_as("admin")
+ admin = get_user("admin")
+
+ # First submission - should be new
+ task1, is_new1 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "distinguish-key",
+ "task_name": "First Task",
+ "user_id": admin.id,
+ }
+ ).run_with_info()
+
+ assert is_new1 is True
+
+ try:
+ # Second submission with same key - should join existing
+ task2, is_new2 = SubmitTaskCommand(
+ data={
+ "task_type": "test-type",
+ "task_key": "distinguish-key",
+ "task_name": "Second Task",
+ "user_id": admin.id,
+ }
+ ).run_with_info()
+
+ assert is_new2 is False
+ assert task2.uuid == task1.uuid
+
+ finally:
+ # Cleanup
+ db.session.delete(task1)
+ db.session.commit()
+
+
+def test_terminal_states_recognized_correctly(app_context) -> None:
+ """
+ Test that TaskManager.TERMINAL_STATES contains the expected values.
+ """
+ assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
+ assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
+ assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
+ assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
+
+ # Non-terminal states should not be in the set
+ assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
+ assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
+ assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
+
+
+def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
+ """
+ Test that wait_for_completion raises TimeoutError on timeout.
+ """
+ from unittest.mock import patch
+
+ import pytest
+
+ login_as("admin")
+ admin = get_user("admin")
+
+ # Create a pending task (won't complete)
+ task, _ = SubmitTaskCommand(
+ data={
+ "task_type": "test-timeout",
+ "task_key": "timeout-key",
+ "task_name": "Timeout Task",
+ "user_id": admin.id,
+ }
+ ).run_with_info()
+
+ try:
+ # Force polling mode by mocking signal_cache as None
+ with patch("superset.tasks.manager.cache_manager") as mock_cache_manager:
+ mock_cache_manager.signal_cache = None
+ with pytest.raises(TimeoutError):
+ TaskManager.wait_for_completion(
+ task.uuid,
+ timeout=0.2,
+ poll_interval=0.05,
+ )
+ finally:
+ db.session.delete(task)
+ db.session.commit()
+
+
+def test_wait_returns_immediately_for_terminal_task(
+ app_context, login_as, get_user
+) -> None:
+ """
+ Test that wait_for_completion returns immediately if task is already terminal.
+ """
+ login_as("admin")
+ admin = get_user("admin")
+
+ # Create and immediately complete a task
+ task, _ = SubmitTaskCommand(
+ data={
+ "task_type": "test-immediate",
+ "task_key": "immediate-key",
+ "task_name": "Immediate Task",
+ "user_id": admin.id,
+ }
+ ).run_with_info()
+
+ TaskDAO.update(task, {"status": TaskStatus.SUCCESS.value})
+ db.session.commit()
+
+ try:
+ start = time.time()
+ result = TaskManager.wait_for_completion(
+ task.uuid,
+ timeout=5.0,
+ poll_interval=0.5,
+ )
+ elapsed = time.time() - start
+
+ assert result.status == TaskStatus.SUCCESS.value
+ # Should return almost immediately since task is already terminal
+ assert elapsed < 0.2
+ finally:
+ db.session.delete(task)
+ db.session.commit()
diff --git a/tests/integration_tests/tasks/test_throttling.py b/tests/integration_tests/tasks/test_throttling.py
new file mode 100644
index 000000000000..a0e77ba041b0
--- /dev/null
+++ b/tests/integration_tests/tasks/test_throttling.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration tests for TaskContext update_task throttling.
+
+Tests verify:
+1. Final state is persisted correctly via cleanup flush
+2. Throttled updates are deferred, timer writes latest pending update
+"""
+
+from __future__ import annotations
+
+import time
+import uuid
+
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset.daos.tasks import TaskDAO
+from superset.extensions import db
+from superset.models.tasks import Task
+from superset.tasks.ambient_context import get_context
+from superset.tasks.registry import TaskRegistry
+from superset.tasks.scheduler import execute_task
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.constants import ADMIN_USERNAME
+
+
+def task_with_throttled_updates() -> None:
+ """Task with rapid progress and payload updates (exercises throttling)."""
+ ctx = get_context()
+
+ # Rapid-fire updates within throttle window
+ for i in range(10):
+ ctx.update_task(progress=(i + 1, 10), payload={"step": i + 1})
+
+
+def _register_test_tasks() -> None:
+ """Register test task functions if not already registered.
+
+ Called in setUp() to ensure tasks are registered regardless of
+ whether other tests have cleared the registry.
+ """
+ if not TaskRegistry.is_registered("test_throttle_combined"):
+ TaskRegistry.register("test_throttle_combined", task_with_throttled_updates)
+
+
+class TestUpdateTaskThrottling(SupersetTestCase):
+ """Integration test for update_task() throttling behavior."""
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+ _register_test_tasks()
+
+ def test_throttled_updates_persisted_on_cleanup(self) -> None:
+ """Final state should be persisted regardless of throttling.
+
+ Verifies the core invariant: cleanup flush ensures final state is persisted.
+ """
+ task_obj = TaskDAO.create_task(
+ task_type="test_throttle_combined",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Throttled Updates",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Use str(uuid) since Celery serializes args as JSON strings
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_throttle_combined", (), {}]
+ )
+
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.SUCCESS.value
+
+ # Verify final state is persisted
+ db.session.expire_all()
+ task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
+
+ # Progress: 10/10 = 100%
+ props = task_obj.properties_dict
+ assert props.get("progress_current") == 10
+ assert props.get("progress_total") == 10
+ assert props.get("progress_percent") == 1.0
+
+ # Payload: final step
+ payload = task_obj.payload_dict
+ assert payload.get("step") == 10
+
+ def test_throttle_behavior(self) -> None:
+ """Test complete throttle behavior: immediate write, deferral, and timer.
+
+ Verifies:
+ 1. First update writes immediately
+ 2. Second and third updates within throttle window are deferred
+ 3. Deferred timer fires and writes the LATEST pending update (third)
+ """
+ from flask import current_app
+
+ from superset.commands.tasks.submit import SubmitTaskCommand
+ from superset.tasks.context import TaskContext
+
+ # Get throttle interval from config (default: 2 seconds)
+ throttle_interval = current_app.config["TASK_PROGRESS_UPDATE_THROTTLE_INTERVAL"]
+
+ # Create task
+ task_obj = SubmitTaskCommand(
+ data={
+ "task_type": "test_throttle_behavior",
+ "task_key": f"test_key_{uuid.uuid4().hex[:8]}",
+ "task_name": "Test Throttle Behavior",
+ "scope": TaskScope.SYSTEM,
+ }
+ ).run()
+ task_uuid = task_obj.uuid
+
+ # Get fresh task for context
+ fresh_task = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
+ assert fresh_task is not None
+ ctx = TaskContext(fresh_task)
+
+ try:
+ # === Step 1: First update - writes immediately ===
+ ctx.update_task(progress=0.1, payload={"step": 1})
+
+ db.session.expire_all()
+ task_step1 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
+ assert task_step1 is not None
+ assert task_step1.properties_dict.get("progress_percent") == 0.1
+ assert task_step1.payload_dict.get("step") == 1
+
+ # === Step 2: Second update - deferred (within throttle window) ===
+ ctx.update_task(progress=0.5, payload={"step": 2})
+
+ # === Step 3: Third update - also deferred, overwrites second in cache ===
+ ctx.update_task(progress=0.7, payload={"step": 3})
+
+ # Verify in-memory cache has LATEST update (third)
+ assert ctx._properties_cache.get("progress_percent") == 0.7
+ assert ctx._payload_cache.get("step") == 3
+
+ # Verify DB still has first update (both second and third deferred)
+ db.session.expire_all()
+ task_step2 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
+ assert task_step2 is not None
+ assert task_step2.properties_dict.get("progress_percent") == 0.1
+ assert task_step2.payload_dict.get("step") == 1
+
+ # === Step 4: Wait for deferred timer to fire ===
+ time.sleep(throttle_interval + 0.5)
+
+ # Verify timer fired and wrote the LATEST update (third, not second)
+ db.session.expire_all()
+ task_step3 = TaskDAO.find_one_or_none(uuid=task_uuid, skip_base_filter=True)
+ assert task_step3 is not None
+ assert task_step3.properties_dict.get("progress_percent") == 0.7
+ assert task_step3.payload_dict.get("step") == 3
+
+ finally:
+ ctx._cancel_deferred_flush_timer()
diff --git a/tests/integration_tests/tasks/test_timeout.py b/tests/integration_tests/tasks/test_timeout.py
new file mode 100644
index 000000000000..efd92ffd80e8
--- /dev/null
+++ b/tests/integration_tests/tasks/test_timeout.py
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integration tests for GTF timeout handling.
+
+Uses module-level task functions with manual registry (like test_event_handlers.py)
+to avoid mypy issues with the @task decorator's complex generic types.
+
+NOTE: Tests that use background threads (timeout/abort handlers) are skipped in
+SQLite environments because SQLite connections cannot be shared across threads.
+"""
+
+from __future__ import annotations
+
+import time
+import uuid
+from typing import Any
+
+import pytest
+from superset_core.api.tasks import TaskScope, TaskStatus
+
+from superset.commands.tasks.cancel import CancelTaskCommand
+from superset.daos.tasks import TaskDAO
+from superset.extensions import db
+from superset.models.tasks import Task
+from superset.tasks.ambient_context import get_context
+from superset.tasks.registry import TaskRegistry
+from superset.tasks.scheduler import execute_task
+from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.constants import ADMIN_USERNAME
+
+
+def _skip_if_sqlite() -> None:
+ """Skip test if running with SQLite database.
+
+ SQLite connections cannot be shared across threads, which breaks
+ timeout tests that use background threads for abort handlers.
+ Must be called from within a test method (with app context).
+ """
+ if "sqlite" in db.engine.url.drivername:
+ pytest.skip("SQLite connections cannot be shared across threads")
+
+
+# Module-level state to track handler calls
+_handler_state: dict[str, Any] = {}
+
+
+def _reset_handler_state() -> None:
+ """Reset handler state before each test."""
+ global _handler_state
+ _handler_state = {
+ "abort_called": False,
+ "handler_exception": None,
+ }
+
+
+def timeout_abortable_task() -> None:
+ """Task with abort handler that exits when aborted."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def on_abort() -> None:
+ _handler_state["abort_called"] = True
+
+ # Poll for abort signal
+ for _ in range(50):
+ if _handler_state["abort_called"]:
+ return
+ time.sleep(0.1)
+
+
+def timeout_handler_fails_task() -> None:
+ """Task with abort handler that throws an exception."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def on_abort() -> None:
+ _handler_state["abort_called"] = True
+ raise ValueError("Handler crashed!")
+
+ # Sleep longer than timeout
+ time.sleep(5)
+
+
+def simple_task_with_abort() -> None:
+ """Simple task with abort handler for testing."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def on_abort() -> None:
+ pass
+
+
+def quick_task_with_abort() -> None:
+ """Quick task that completes before timeout."""
+ ctx = get_context()
+
+ @ctx.on_abort
+ def on_abort() -> None:
+ pass
+
+ time.sleep(0.2)
+
+
+def _register_test_tasks() -> None:
+ """Register test task functions if not already registered.
+
+ Called in setUp() to ensure tasks are registered regardless of
+ whether other tests have cleared the registry.
+ """
+ registrations = [
+ ("test_timeout_abortable", timeout_abortable_task),
+ ("test_timeout_handler_fails", timeout_handler_fails_task),
+ ("test_timeout_simple", simple_task_with_abort),
+ ("test_timeout_quick", quick_task_with_abort),
+ ]
+ for name, func in registrations:
+ if not TaskRegistry.is_registered(name):
+ TaskRegistry.register(name, func)
+
+
+class TestTimeoutHandling(SupersetTestCase):
+ """E2E tests for task timeout functionality."""
+
+ def setUp(self) -> None:
+ """Set up test fixtures."""
+ super().setUp()
+ self.login(ADMIN_USERNAME)
+ _register_test_tasks()
+ _reset_handler_state()
+
+ def test_timeout_with_abort_handler_results_in_timed_out_status(self) -> None:
+ """Task with timeout and abort handler should end with TIMED_OUT status."""
+ _skip_if_sqlite()
+
+ # Create task with timeout
+ task_obj = TaskDAO.create_task(
+ task_type="test_timeout_abortable",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Timeout",
+ scope=TaskScope.SYSTEM,
+ properties={"timeout": 1}, # 1 second timeout
+ )
+
+ # Execute task via Celery executor (synchronously)
+ # Use str(uuid) since Celery serializes args as JSON strings
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_timeout_abortable", (), {}]
+ )
+
+ # Verify execution completed
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.TIMED_OUT.value
+
+ # Verify abort handler was called
+ assert _handler_state["abort_called"]
+
+ def test_user_abort_results_in_aborted_status(self) -> None:
+ """User-initiated abort on pending task should result in ABORTED."""
+ # Create task (pending state)
+ task_obj = TaskDAO.create_task(
+ task_type="test_timeout_simple",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Abort Task",
+ scope=TaskScope.SYSTEM,
+ )
+
+ # Cancel before execution (pending task abort)
+ CancelTaskCommand(task_obj.uuid, force=True).run()
+
+ # Refresh from DB
+ db.session.expire_all()
+ task_obj = db.session.query(Task).filter_by(uuid=task_obj.uuid).first()
+ assert task_obj.status == TaskStatus.ABORTED.value
+
+ def test_no_timeout_when_not_configured(self) -> None:
+ """Task without timeout should run to completion regardless of duration."""
+ task_obj = TaskDAO.create_task(
+ task_type="test_timeout_quick",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test No Timeout",
+ scope=TaskScope.SYSTEM,
+ # No timeout property
+ )
+
+ # Use str(uuid) since Celery serializes args as JSON strings
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_timeout_quick", (), {}]
+ )
+
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.SUCCESS.value
+
+ def test_abort_handler_exception_results_in_failure(self) -> None:
+ """If abort handler throws during timeout, task should be FAILURE."""
+ _skip_if_sqlite()
+
+ task_obj = TaskDAO.create_task(
+ task_type="test_timeout_handler_fails",
+ task_key=f"test_key_{uuid.uuid4().hex[:8]}",
+ task_name="Test Handler Fails",
+ scope=TaskScope.SYSTEM,
+ properties={"timeout": 1}, # 1 second timeout
+ )
+
+ # Use str(uuid) since Celery serializes args as JSON strings
+ result = execute_task.apply(
+ args=[str(task_obj.uuid), "test_timeout_handler_fails", (), {}]
+ )
+
+ assert result.successful()
+ assert result.result["status"] == TaskStatus.FAILURE.value
+ assert _handler_state["abort_called"]
diff --git a/tests/unit_tests/daos/test_tasks.py b/tests/unit_tests/daos/test_tasks.py
new file mode 100644
index 000000000000..f8f3bdc073a9
--- /dev/null
+++ b/tests/unit_tests/daos/test_tasks.py
@@ -0,0 +1,420 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections.abc import Iterator
+from uuid import UUID
+
+import pytest
+from sqlalchemy.orm.session import Session
+from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
+
+from superset.commands.tasks.exceptions import TaskNotAbortableError
+from superset.models.tasks import Task
+from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key
+
+# Test constants
+TASK_UUID = UUID("e7765491-40c1-4f35-a4f5-06308e79310e")
+TASK_ID = 42
+TEST_TASK_TYPE = "test_type"
+TEST_TASK_KEY = "test-key"
+TEST_USER_ID = 1
+
+
+def create_task(
+ session: Session,
+ *,
+ task_id: int | None = None,
+ task_uuid: UUID | None = None,
+ task_key: str = TEST_TASK_KEY,
+ task_type: str = TEST_TASK_TYPE,
+ scope: TaskScope = TaskScope.PRIVATE,
+ status: TaskStatus = TaskStatus.PENDING,
+ user_id: int | None = TEST_USER_ID,
+ properties: TaskProperties | None = None,
+ use_finished_dedup_key: bool = False,
+) -> Task:
+ """Helper to create a task with sensible defaults for testing."""
+ if use_finished_dedup_key:
+ dedup_key = get_finished_dedup_key(task_uuid or TASK_UUID)
+ else:
+ dedup_key = get_active_dedup_key(
+ scope=scope,
+ task_type=task_type,
+ task_key=task_key,
+ user_id=user_id,
+ )
+
+ task = Task(
+ task_type=task_type,
+ task_key=task_key,
+ scope=scope.value,
+ status=status.value,
+ dedup_key=dedup_key,
+ user_id=user_id,
+ )
+ if task_id is not None:
+ task.id = task_id
+ if task_uuid:
+ task.uuid = task_uuid
+ if properties:
+ task.update_properties(properties)
+
+ session.add(task)
+ session.flush()
+ return task
+
+
+@pytest.fixture
+def session_with_task(session: Session) -> Iterator[Session]:
+ """Create a session with Task and TaskSubscriber tables."""
+ from superset.models.task_subscribers import TaskSubscriber
+
+ engine = session.get_bind()
+ Task.metadata.create_all(engine)
+ TaskSubscriber.metadata.create_all(engine)
+
+ yield session
+ session.rollback()
+
+
+def test_find_by_task_key_active(session_with_task: Session) -> None:
+ """Test finding active task by task_key"""
+ from superset.daos.tasks import TaskDAO
+
+ create_task(session_with_task)
+
+ result = TaskDAO.find_by_task_key(
+ task_type=TEST_TASK_TYPE,
+ task_key=TEST_TASK_KEY,
+ scope=TaskScope.PRIVATE,
+ user_id=TEST_USER_ID,
+ )
+
+ assert result is not None
+ assert result.task_key == TEST_TASK_KEY
+ assert result.task_type == TEST_TASK_TYPE
+ assert result.status == TaskStatus.PENDING.value
+
+
+def test_find_by_task_key_not_found(session_with_task: Session) -> None:
+ """Test finding task by task_key returns None when not found"""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.find_by_task_key(
+ task_type=TEST_TASK_TYPE,
+ task_key="nonexistent-key",
+ scope=TaskScope.PRIVATE,
+ user_id=TEST_USER_ID,
+ )
+
+ assert result is None
+
+
+def test_find_by_task_key_finished_not_found(session_with_task: Session) -> None:
+ """Test that find_by_task_key returns None for finished tasks.
+
+ Finished tasks have a different dedup_key format (UUID-based),
+ so they won't be found by the active task lookup.
+ """
+ from superset.daos.tasks import TaskDAO
+
+ create_task(
+ session_with_task,
+ task_key="finished-key",
+ status=TaskStatus.SUCCESS,
+ use_finished_dedup_key=True,
+ task_uuid=TASK_UUID,
+ )
+
+ # Should not find SUCCESS task via active lookup
+ result = TaskDAO.find_by_task_key(
+ task_type=TEST_TASK_TYPE,
+ task_key="finished-key",
+ scope=TaskScope.PRIVATE,
+ user_id=TEST_USER_ID,
+ )
+ assert result is None
+
+
+def test_create_task_success(session_with_task: Session) -> None:
+ """Test successful task creation."""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.create_task(
+ task_type=TEST_TASK_TYPE,
+ task_key=TEST_TASK_KEY,
+ scope=TaskScope.PRIVATE,
+ user_id=TEST_USER_ID,
+ )
+
+ assert result is not None
+ assert result.task_key == TEST_TASK_KEY
+ assert result.task_type == TEST_TASK_TYPE
+ assert result.status == TaskStatus.PENDING.value
+ assert isinstance(result, Task)
+
+
+def test_create_task_with_user_id(session_with_task: Session) -> None:
+ """Test task creation with explicit user_id."""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.create_task(
+ task_type=TEST_TASK_TYPE,
+ task_key="user-task",
+ scope=TaskScope.PRIVATE,
+ user_id=42,
+ )
+
+ assert result is not None
+ assert result.user_id == 42
+ # Creator should be auto-subscribed
+ assert len(result.subscribers) == 1
+ assert result.subscribers[0].user_id == 42
+
+
+def test_create_task_with_properties(session_with_task: Session) -> None:
+ """Test task creation with properties."""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.create_task(
+ task_type=TEST_TASK_TYPE,
+ task_key="props-task",
+ scope=TaskScope.PRIVATE,
+ user_id=TEST_USER_ID,
+ properties={"timeout": 300},
+ )
+
+ assert result is not None
+ assert result.properties_dict.get("timeout") == 300
+
+
+def test_abort_task_pending_success(session_with_task: Session) -> None:
+ """Test successful abort of pending task - goes directly to ABORTED"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="pending-task",
+ status=TaskStatus.PENDING,
+ )
+
+ result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+ assert result is not None
+ assert result.status == TaskStatus.ABORTED.value
+
+
+def test_abort_task_in_progress_abortable(session_with_task: Session) -> None:
+ """Test abort of in-progress task with abort handler.
+
+ Should transition to ABORTING status.
+ """
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="abortable-task",
+ status=TaskStatus.IN_PROGRESS,
+ properties={"is_abortable": True},
+ )
+
+ result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+ assert result is not None
+ # Should set status to ABORTING, not ABORTED
+ assert result.status == TaskStatus.ABORTING.value
+
+
+def test_abort_task_in_progress_not_abortable(session_with_task: Session) -> None:
+ """Test abort of in-progress task without abort handler - raises error"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="non-abortable-task",
+ status=TaskStatus.IN_PROGRESS,
+ properties={"is_abortable": False},
+ )
+
+ with pytest.raises(TaskNotAbortableError):
+ TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+
+def test_abort_task_in_progress_is_abortable_none(session_with_task: Session) -> None:
+ """Test abort of in-progress task with is_abortable not set - raises error"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="no-abortable-prop-task",
+ status=TaskStatus.IN_PROGRESS,
+ # Empty properties - no is_abortable key
+ )
+
+ with pytest.raises(TaskNotAbortableError):
+ TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+
+def test_abort_task_already_aborting(session_with_task: Session) -> None:
+ """Test abort of already aborting task - idempotent success"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="aborting-task",
+ status=TaskStatus.ABORTING,
+ )
+
+ result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+ # Idempotent - returns task without error
+ assert result is not None
+ assert result.status == TaskStatus.ABORTING.value
+
+
+def test_abort_task_not_found(session_with_task: Session) -> None:
+ """Test abort fails when task not found"""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.abort_task(UUID("00000000-0000-0000-0000-000000000000"))
+
+ assert result is None
+
+
+def test_abort_task_already_finished(session_with_task: Session) -> None:
+ """Test abort fails when task already finished"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="finished-task",
+ status=TaskStatus.SUCCESS,
+ use_finished_dedup_key=True,
+ task_uuid=TASK_UUID,
+ )
+
+ result = TaskDAO.abort_task(task.uuid, skip_base_filter=True)
+
+ assert result is None
+
+
+def test_add_subscriber(session_with_task: Session) -> None:
+ """Test adding a subscriber to a task"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="shared-task",
+ scope=TaskScope.SHARED,
+ user_id=None,
+ )
+
+ # Add subscriber
+ result = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
+ assert result is True
+
+ # Verify subscriber was added
+ session_with_task.refresh(task)
+ assert len(task.subscribers) == 1
+ assert task.subscribers[0].user_id == TEST_USER_ID
+
+
+def test_add_subscriber_idempotent(session_with_task: Session) -> None:
+ """Test adding same subscriber twice is idempotent"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="shared-task-2",
+ scope=TaskScope.SHARED,
+ user_id=None,
+ )
+
+ # Add subscriber twice
+ result1 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
+ result2 = TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
+
+ assert result1 is True
+ assert result2 is False # Already subscribed
+
+ # Verify only one subscriber
+ session_with_task.refresh(task)
+ assert len(task.subscribers) == 1
+
+
+def test_remove_subscriber(session_with_task: Session) -> None:
+ """Test removing a subscriber from a task"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="shared-task-3",
+ scope=TaskScope.SHARED,
+ user_id=None,
+ )
+
+ TaskDAO.add_subscriber(task.id, user_id=TEST_USER_ID)
+ session_with_task.refresh(task)
+ assert len(task.subscribers) == 1
+
+ # Remove subscriber
+ result = TaskDAO.remove_subscriber(task.id, user_id=TEST_USER_ID)
+
+ assert result is not None
+ assert len(result.subscribers) == 0
+
+
+def test_remove_subscriber_not_subscribed(session_with_task: Session) -> None:
+ """Test removing non-existent subscriber returns None"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_key="shared-task-4",
+ scope=TaskScope.SHARED,
+ user_id=None,
+ )
+
+ # Try to remove non-existent subscriber
+ result = TaskDAO.remove_subscriber(task.id, user_id=999)
+
+ assert result is None
+
+
+def test_get_status(session_with_task: Session) -> None:
+ """Test get_status returns status string when task found by UUID"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_uuid=TASK_UUID,
+ task_key="status-task",
+ status=TaskStatus.IN_PROGRESS,
+ )
+
+ result = TaskDAO.get_status(task.uuid)
+
+ assert result == TaskStatus.IN_PROGRESS.value
+
+
+def test_get_status_not_found(session_with_task: Session) -> None:
+ """Test get_status returns None when task not found"""
+ from superset.daos.tasks import TaskDAO
+
+ result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
+
+ assert result is None
diff --git a/tests/unit_tests/distributed_lock/distributed_lock_tests.py b/tests/unit_tests/distributed_lock/distributed_lock_tests.py
index 398fb8683d92..3b22c3adc43c 100644
--- a/tests/unit_tests/distributed_lock/distributed_lock_tests.py
+++ b/tests/unit_tests/distributed_lock/distributed_lock_tests.py
@@ -18,17 +18,21 @@
# pylint: disable=invalid-name
from typing import Any
+from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
+# Force module loading before tests run so patches work correctly
+import superset.commands.distributed_lock.acquire as acquire_module
+import superset.commands.distributed_lock.release as release_module
from superset import db
-from superset.distributed_lock import KeyValueDistributedLock
+from superset.distributed_lock import DistributedLock
from superset.distributed_lock.types import LockValue
from superset.distributed_lock.utils import get_key
-from superset.exceptions import CreateKeyValueDistributedLockFailedException
+from superset.exceptions import AcquireDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
LOCK_VALUE: LockValue = {"value": True}
@@ -56,9 +60,9 @@ def _get_other_session() -> Session:
return SessionMaker()
-def test_key_value_distributed_lock_happy_path() -> None:
+def test_distributed_lock_kv_happy_path() -> None:
"""
- Test successfully acquiring and returning the distributed lock.
+ Test successfully acquiring and returning the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -66,24 +70,29 @@ def test_key_value_distributed_lock_happy_path() -> None:
"""
session = _get_other_session()
- with freeze_time("2021-01-01"):
- assert _get_lock(MAIN_KEY, session) is None
+ # Ensure Redis is not configured so KV backend is used
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=None),
+ patch.object(release_module, "get_redis_client", return_value=None),
+ ):
+ with freeze_time("2021-01-01"):
+ assert _get_lock(MAIN_KEY, session) is None
- with KeyValueDistributedLock("ns", a=1, b=2) as key:
- assert key == MAIN_KEY
- assert _get_lock(key, session) == LOCK_VALUE
- assert _get_lock(OTHER_KEY, session) is None
+ with DistributedLock("ns", a=1, b=2) as key:
+ assert key == MAIN_KEY
+ assert _get_lock(key, session) == LOCK_VALUE
+ assert _get_lock(OTHER_KEY, session) is None
- with pytest.raises(CreateKeyValueDistributedLockFailedException):
- with KeyValueDistributedLock("ns", a=1, b=2):
- pass
+ with pytest.raises(AcquireDistributedLockFailedException):
+ with DistributedLock("ns", a=1, b=2):
+ pass
- assert _get_lock(MAIN_KEY, session) is None
+ assert _get_lock(MAIN_KEY, session) is None
-def test_key_value_distributed_lock_expired() -> None:
+def test_distributed_lock_kv_expired() -> None:
"""
- Test expiration of the distributed lock
+ Test expiration of the distributed lock via KV backend.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
@@ -91,11 +100,112 @@ def test_key_value_distributed_lock_expired() -> None:
"""
session = _get_other_session()
- with freeze_time("2021-01-01"):
- assert _get_lock(MAIN_KEY, session) is None
- with KeyValueDistributedLock("ns", a=1, b=2):
- assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
- with freeze_time("2022-01-01"):
- assert _get_lock(MAIN_KEY, session) is None
-
- assert _get_lock(MAIN_KEY, session) is None
+ # Ensure Redis is not configured so KV backend is used
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=None),
+ patch.object(release_module, "get_redis_client", return_value=None),
+ ):
+ with freeze_time("2021-01-01"):
+ assert _get_lock(MAIN_KEY, session) is None
+ with DistributedLock("ns", a=1, b=2):
+ assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
+ with freeze_time("2022-01-01"):
+ assert _get_lock(MAIN_KEY, session) is None
+
+ assert _get_lock(MAIN_KEY, session) is None
+
+
+def test_distributed_lock_uses_redis_when_configured() -> None:
+ """Test that DistributedLock uses Redis backend when configured."""
+ mock_redis = MagicMock()
+ mock_redis.set.return_value = True # Lock acquired
+
+ # Use patch.object to patch on already-imported modules
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
+ patch.object(release_module, "get_redis_client", return_value=mock_redis),
+ ):
+ with DistributedLock("test_redis", key="value") as lock_key:
+ assert lock_key is not None
+ # Verify SET NX EX was called
+ mock_redis.set.assert_called_once()
+ call_args = mock_redis.set.call_args
+ assert call_args.kwargs["nx"] is True
+ assert "ex" in call_args.kwargs
+
+ # Verify DELETE was called on exit
+ mock_redis.delete.assert_called_once()
+
+
+def test_distributed_lock_redis_already_taken() -> None:
+ """Test Redis lock fails when already held."""
+ mock_redis = MagicMock()
+ mock_redis.set.return_value = None # Lock not acquired (already taken)
+
+ with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
+ with pytest.raises(AcquireDistributedLockFailedException):
+ with DistributedLock("test_redis", key="value"):
+ pass
+
+
+def test_distributed_lock_redis_connection_error() -> None:
+ """Test Redis connection error raises exception (fail fast)."""
+ import redis
+
+ mock_redis = MagicMock()
+ mock_redis.set.side_effect = redis.RedisError("Connection failed")
+
+ with patch.object(acquire_module, "get_redis_client", return_value=mock_redis):
+ with pytest.raises(AcquireDistributedLockFailedException):
+ with DistributedLock("test_redis", key="value"):
+ pass
+
+
+def test_distributed_lock_custom_ttl() -> None:
+ """Test Redis lock with custom TTL."""
+ mock_redis = MagicMock()
+ mock_redis.set.return_value = True
+
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
+ patch.object(release_module, "get_redis_client", return_value=mock_redis),
+ ):
+ with DistributedLock("test", ttl_seconds=60, key="value"):
+ call_args = mock_redis.set.call_args
+ assert call_args.kwargs["ex"] == 60 # Custom TTL
+
+
+def test_distributed_lock_default_ttl(app_context: None) -> None:
+ """Test Redis lock uses default TTL when not specified."""
+ from superset.commands.distributed_lock.base import get_default_lock_ttl
+
+ mock_redis = MagicMock()
+ mock_redis.set.return_value = True
+
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=mock_redis),
+ patch.object(release_module, "get_redis_client", return_value=mock_redis),
+ ):
+ with DistributedLock("test", key="value"):
+ call_args = mock_redis.set.call_args
+ assert call_args.kwargs["ex"] == get_default_lock_ttl()
+
+
+def test_distributed_lock_fallback_to_kv_when_redis_not_configured() -> None:
+ """Test falls back to KV lock when Redis not configured."""
+ session = _get_other_session()
+ test_key = get_key("test_fallback", key="value")
+
+ with (
+ patch.object(acquire_module, "get_redis_client", return_value=None),
+ patch.object(release_module, "get_redis_client", return_value=None),
+ ):
+ with freeze_time("2021-01-01"):
+ # When Redis is not configured, should use KV backend
+ with DistributedLock("test_fallback", key="value") as lock_key:
+ assert lock_key == test_key
+ # Verify lock exists in KV store
+ assert _get_lock(test_key, session) == LOCK_VALUE
+
+ # Lock should be released
+ assert _get_lock(test_key, session) is None
diff --git a/tests/unit_tests/tasks/test_decorators.py b/tests/unit_tests/tasks/test_decorators.py
new file mode 100644
index 000000000000..e998fbeeadac
--- /dev/null
+++ b/tests/unit_tests/tasks/test_decorators.py
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for task decorators"""
+
+from unittest.mock import MagicMock, patch
+from uuid import UUID
+
+import pytest
+from superset_core.api.tasks import TaskOptions, TaskScope
+
+from superset.commands.tasks.exceptions import GlobalTaskFrameworkDisabledError
+from superset.tasks.decorators import task, TaskWrapper
+from superset.tasks.registry import TaskRegistry
+
+TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
+
+
+class TestTaskDecoratorFeatureFlag:
+ """Tests for @task decorator feature flag behavior"""
+
+ def setup_method(self):
+ """Clear task registry before each test"""
+ TaskRegistry._tasks.clear()
+
+ @patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
+ def test_decorator_succeeds_when_gtf_disabled(self, mock_feature_flag):
+ """Test that @task decorator can be applied even when GTF is disabled.
+
+ This enables safe module imports during app startup or Celery autodiscovery.
+ """
+
+ # Decoration should succeed - no error raised
+ @task(name="test_gtf_disabled_decorator")
+ def my_task() -> None:
+ pass
+
+ assert isinstance(my_task, TaskWrapper)
+ assert my_task.name == "test_gtf_disabled_decorator"
+
+ @patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
+ def test_call_raises_error_when_gtf_disabled(self, mock_feature_flag):
+ """Test that calling a task raises GlobalTaskFrameworkDisabledError
+ when GTF is disabled."""
+
+ @task(name="test_gtf_disabled_call")
+ def my_task() -> None:
+ pass
+
+ with pytest.raises(GlobalTaskFrameworkDisabledError):
+ my_task()
+
+ @patch("superset.tasks.decorators.is_feature_enabled", return_value=False)
+ def test_schedule_raises_error_when_gtf_disabled(self, mock_feature_flag):
+ """Test that scheduling a task raises GlobalTaskFrameworkDisabledError
+ when GTF is disabled."""
+
+ @task(name="test_gtf_disabled_schedule")
+ def my_task() -> None:
+ pass
+
+ with pytest.raises(GlobalTaskFrameworkDisabledError):
+ my_task.schedule()
+
+
+class TestTaskDecorator:
+ """Tests for @task decorator"""
+
+ def test_decorator_basic(self):
+ """Test basic decorator usage without options"""
+
+ @task(name="test_task")
+ def my_task(arg1: int, arg2: str) -> None:
+ pass
+
+ assert isinstance(my_task, TaskWrapper)
+ assert my_task.name == "test_task"
+ assert my_task.scope == TaskScope.PRIVATE
+
+ def test_decorator_without_parentheses(self):
+ """Test decorator usage without parentheses"""
+
+ @task
+ def my_no_parens_task(arg1: int, arg2: str) -> None:
+ pass
+
+ assert isinstance(my_no_parens_task, TaskWrapper)
+ assert my_no_parens_task.name == "my_no_parens_task" # Uses function name
+ assert my_no_parens_task.scope == TaskScope.PRIVATE
+
+ def test_decorator_with_default_scope_private(self):
+ """Test decorator with explicit PRIVATE scope"""
+
+ @task(name="private_task", scope=TaskScope.PRIVATE)
+ def my_private_task(arg1: int) -> None:
+ pass
+
+ assert my_private_task.scope == TaskScope.PRIVATE
+
+ def test_decorator_with_default_scope_shared(self):
+ """Test decorator with SHARED scope"""
+
+ @task(name="shared_task", scope=TaskScope.SHARED)
+ def my_shared_task(arg1: int) -> None:
+ pass
+
+ assert my_shared_task.scope == TaskScope.SHARED
+
+ def test_decorator_with_default_scope_system(self):
+ """Test decorator with SYSTEM scope"""
+
+ @task(name="system_task", scope=TaskScope.SYSTEM)
+ def my_system_task() -> None:
+ pass
+
+ assert my_system_task.scope == TaskScope.SYSTEM
+
+ def test_decorator_forbids_ctx_parameter(self):
+ """Test decorator rejects functions with ctx parameter"""
+
+ with pytest.raises(TypeError, match="must not define 'ctx'"):
+
+ @task(name="bad_task")
+ def bad_task(ctx, arg1: int) -> None: # noqa: ARG001
+ pass
+
+ def test_decorator_forbids_options_parameter(self):
+ """Test decorator rejects functions with options parameter"""
+
+ with pytest.raises(TypeError, match="must not define.*'options'"):
+
+ @task(name="bad_task")
+ def bad_task(options, arg1: int) -> None: # noqa: ARG001
+ pass
+
+
+class TestTaskWrapperMergeOptions:
+ """Tests for TaskWrapper._merge_options()"""
+
+ def setup_method(self):
+ """Clear task registry before each test"""
+ TaskRegistry._tasks.clear()
+
+ def test_merge_options_no_override(self):
+ """Test merging with no override returns defaults"""
+
+ @task(name="test_merge_no_override_unique")
+ def merge_task_1() -> None:
+ pass
+
+ # Set default options for testing
+ merge_task_1.default_options = TaskOptions(
+ task_key="default_key",
+ task_name="Default Name",
+ )
+
+ merged = merge_task_1._merge_options(None)
+ assert merged.task_key == "default_key"
+ assert merged.task_name == "Default Name"
+
+ def test_merge_options_override_task_key(self):
+ """Test overriding task_key at call time"""
+
+ @task(name="test_merge_override_key_unique")
+ def merge_task_2() -> None:
+ pass
+
+ # Set default options for testing
+ merge_task_2.default_options = TaskOptions(task_key="default_key")
+
+ override = TaskOptions(task_key="override_key")
+ merged = merge_task_2._merge_options(override)
+ assert merged.task_key == "override_key"
+
+ def test_merge_options_override_task_name(self):
+ """Test overriding task_name at call time"""
+
+ @task(name="test_merge_override_name_unique")
+ def merge_task_3() -> None:
+ pass
+
+ # Set default options for testing
+ merge_task_3.default_options = TaskOptions(task_name="Default Name")
+
+ override = TaskOptions(task_name="Override Name")
+ merged = merge_task_3._merge_options(override)
+ assert merged.task_name == "Override Name"
+
+ def test_merge_options_override_all(self):
+ """Test overriding all options at call time"""
+
+ @task(name="test_merge_override_all_unique")
+ def merge_task_4() -> None:
+ pass
+
+ # Set default options for testing
+ merge_task_4.default_options = TaskOptions(
+ task_key="default_key",
+ task_name="Default Name",
+ )
+
+ override = TaskOptions(
+ task_key="override_key",
+ task_name="Override Name",
+ )
+ merged = merge_task_4._merge_options(override)
+ assert merged.task_key == "override_key"
+ assert merged.task_name == "Override Name"
+
+
+class TestTaskWrapperSchedule:
+ """Tests for TaskWrapper.schedule() with scope"""
+
+ def setup_method(self):
+ """Clear task registry before each test"""
+ TaskRegistry._tasks.clear()
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_uses_default_scope(self, mock_submit):
+ """Test schedule() uses decorator's default scope"""
+ mock_submit.return_value = MagicMock()
+
+ @task(name="test_schedule_default_unique", scope=TaskScope.SHARED)
+ def schedule_task_1(arg1: int) -> None:
+ pass
+
+ # Shared tasks require explicit task_key
+ schedule_task_1.schedule(123, options=TaskOptions(task_key="test_key"))
+
+ # Verify TaskManager.submit_task was called with correct scope
+ mock_submit.assert_called_once()
+ call_args = mock_submit.call_args
+ assert call_args[1]["scope"] == TaskScope.SHARED
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_uses_private_scope_by_default(self, mock_submit):
+ """Test schedule() uses PRIVATE scope when no scope specified"""
+ mock_submit.return_value = MagicMock()
+
+ @task(name="test_schedule_override_unique")
+ def schedule_task_2(arg1: int) -> None:
+ pass
+
+ schedule_task_2.schedule(123)
+
+ # Verify PRIVATE scope was used (default)
+ mock_submit.assert_called_once()
+ call_args = mock_submit.call_args
+ assert call_args[1]["scope"] == TaskScope.PRIVATE
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_with_custom_options(self, mock_submit):
+ """Test schedule() with custom task options"""
+ mock_submit.return_value = MagicMock()
+
+ @task(name="test_schedule_custom_unique", scope=TaskScope.SYSTEM)
+ def schedule_task_3(arg1: int) -> None:
+ pass
+
+ # Use custom task key and name
+ schedule_task_3.schedule(
+ 123,
+ options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
+ )
+
+ # Verify scope from decorator and options from call time
+ mock_submit.assert_called_once()
+ call_args = mock_submit.call_args
+ assert call_args[1]["scope"] == TaskScope.SYSTEM
+ assert call_args[1]["task_key"] == "custom_key"
+ assert call_args[1]["task_name"] == "Custom Task Name"
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_with_no_decorator_options(self, mock_submit):
+ """Test schedule() uses default PRIVATE scope when no options provided"""
+ mock_submit.return_value = MagicMock()
+
+ @task(name="test_schedule_no_options_unique")
+ def schedule_task_4(arg1: int) -> None:
+ pass
+
+ schedule_task_4.schedule(123)
+
+ # Verify default PRIVATE scope
+ mock_submit.assert_called_once()
+ call_args = mock_submit.call_args
+ assert call_args[1]["scope"] == TaskScope.PRIVATE
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_shared_task_requires_task_key(self, mock_submit):
+ """Test shared task schedule() requires explicit task_key"""
+
+ @task(name="test_shared_requires_key", scope=TaskScope.SHARED)
+ def shared_task(arg1: int) -> None:
+ pass
+
+ # Should raise ValueError when no task_key provided
+ with pytest.raises(
+ ValueError,
+ match="Shared task.*requires an explicit task_key.*for deduplication",
+ ):
+ shared_task.schedule(123)
+
+ # Should work with task_key provided
+ mock_submit.return_value = MagicMock()
+ shared_task.schedule(123, options=TaskOptions(task_key="valid_key"))
+ mock_submit.assert_called_once()
+
+ @patch("superset.tasks.decorators.TaskManager.submit_task")
+ def test_schedule_private_task_allows_no_task_key(self, mock_submit):
+ """Test private task schedule() works without task_key"""
+ mock_submit.return_value = MagicMock()
+
+ @task(name="test_private_no_key", scope=TaskScope.PRIVATE)
+ def private_task(arg1: int) -> None:
+ pass
+
+ # Should work without task_key (generates random UUID)
+ private_task.schedule(123)
+ mock_submit.assert_called_once()
+
+
+class TestTaskWrapperCall:
+ """Tests for TaskWrapper.__call__() with scope"""
+
+ def setup_method(self):
+ """Clear task registry before each test"""
+ TaskRegistry._tasks.clear()
+
+ @patch("superset.commands.tasks.update.UpdateTaskCommand.run")
+ @patch("superset.daos.tasks.TaskDAO.find_one_or_none")
+ @patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
+ def test_call_uses_default_scope(
+ self, mock_submit_run_with_info, mock_find, mock_update_run
+ ):
+ """Test direct call uses decorator's default scope"""
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.status = "in_progress"
+ mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
+ mock_update_run.return_value = mock_task
+ mock_find.return_value = mock_task # Mock the subsequent find call
+
+ @task(name="test_call_default_unique", scope=TaskScope.SHARED)
+ def call_task_1(arg1: int) -> None:
+ pass
+
+ # Shared tasks require explicit task_key
+ call_task_1(123, options=TaskOptions(task_key="test_key"))
+
+ # Verify SubmitTaskCommand.run_with_info was called
+ mock_submit_run_with_info.assert_called_once()
+
+ @patch("superset.utils.core.get_user_id")
+ @patch("superset.commands.tasks.update.UpdateTaskCommand.run")
+ @patch("superset.daos.tasks.TaskDAO.find_one_or_none")
+ @patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
+ def test_call_uses_private_scope_by_default(
+ self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
+ ):
+ """Test direct call uses PRIVATE scope when no scope specified"""
+ mock_get_user_id.return_value = 1
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.status = "in_progress"
+ mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
+ mock_update_run.return_value = mock_task
+ mock_find.return_value = mock_task # Mock the subsequent find call
+
+ @task(name="test_call_private_default_unique")
+ def call_task_2(arg1: int) -> None:
+ pass
+
+ call_task_2(123)
+
+ # Verify SubmitTaskCommand.run_with_info was called
+ mock_submit_run_with_info.assert_called_once()
+
+ @patch("superset.commands.tasks.update.UpdateTaskCommand.run")
+ @patch("superset.daos.tasks.TaskDAO.find_one_or_none")
+ @patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
+ def test_call_with_custom_options(
+ self, mock_submit_run_with_info, mock_find, mock_update_run
+ ):
+ """Test direct call with custom task options"""
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.status = "in_progress"
+ mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
+ mock_update_run.return_value = mock_task
+ mock_find.return_value = mock_task # Mock the subsequent find call
+
+ @task(name="test_call_custom_unique", scope=TaskScope.SYSTEM)
+ def call_task_3(arg1: int) -> None:
+ pass
+
+ # Use custom task key and name
+ call_task_3(
+ 123,
+ options=TaskOptions(task_key="custom_key", task_name="Custom Task Name"),
+ )
+
+ # Verify SubmitTaskCommand.run_with_info was called
+ mock_submit_run_with_info.assert_called_once()
+
+ def test_call_shared_task_requires_task_key(self):
+ """Test shared task direct call requires explicit task_key"""
+
+ @task(name="test_shared_call_requires_key", scope=TaskScope.SHARED)
+ def shared_task(arg1: int) -> None:
+ pass
+
+ # Should raise ValueError when no task_key provided
+ with pytest.raises(
+ ValueError,
+ match="Shared task.*requires an explicit task_key.*for deduplication",
+ ):
+ shared_task(123)
+
+ @patch("superset.commands.tasks.update.UpdateTaskCommand.run")
+ @patch("superset.daos.tasks.TaskDAO.find_one_or_none")
+ @patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
+ def test_call_shared_task_works_with_task_key(
+ self, mock_submit_run_with_info, mock_find, mock_update_run
+ ):
+ """Test shared task direct call works with task_key"""
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.status = "in_progress"
+ mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
+ mock_update_run.return_value = mock_task
+ mock_find.return_value = mock_task
+
+ @task(name="test_shared_call_with_key", scope=TaskScope.SHARED)
+ def shared_task(arg1: int) -> None:
+ pass
+
+ # Should work with task_key provided
+ shared_task(123, options=TaskOptions(task_key="valid_key"))
+ mock_submit_run_with_info.assert_called_once()
+
+ @patch("superset.utils.core.get_user_id")
+ @patch("superset.commands.tasks.update.UpdateTaskCommand.run")
+ @patch("superset.daos.tasks.TaskDAO.find_one_or_none")
+ @patch("superset.commands.tasks.submit.SubmitTaskCommand.run_with_info")
+ def test_call_private_task_allows_no_task_key(
+ self, mock_submit_run_with_info, mock_find, mock_update_run, mock_get_user_id
+ ):
+ """Test private task direct call works without task_key"""
+ mock_get_user_id.return_value = 1
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.status = "in_progress"
+ mock_submit_run_with_info.return_value = (mock_task, True) # (task, is_new)
+ mock_update_run.return_value = mock_task
+ mock_find.return_value = mock_task
+
+ @task(name="test_private_call_no_key", scope=TaskScope.PRIVATE)
+ def private_task(arg1: int) -> None:
+ pass
+
+ # Should work without task_key (generates random UUID)
+ private_task(123)
+ mock_submit_run_with_info.assert_called_once()
diff --git a/tests/unit_tests/tasks/test_handlers.py b/tests/unit_tests/tasks/test_handlers.py
new file mode 100644
index 000000000000..1da4b4da93a8
--- /dev/null
+++ b/tests/unit_tests/tasks/test_handlers.py
@@ -0,0 +1,677 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for GTF handlers (abort, cleanup) and related Task model behavior."""
+
+import time
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, Mock, patch
+from uuid import UUID
+
+import pytest
+from freezegun import freeze_time
+from superset_core.api.tasks import TaskStatus
+
+from superset.tasks.context import TaskContext
+
+TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
+
+
+@pytest.fixture
+def mock_task():
+ """Create a mock task for testing."""
+ task = MagicMock()
+ task.uuid = TEST_UUID
+ task.status = TaskStatus.PENDING.value
+ return task
+
+
+@pytest.fixture
+def mock_task_dao(mock_task):
+ """Mock TaskDAO to return our test task."""
+ with patch("superset.daos.tasks.TaskDAO") as mock_dao:
+ mock_dao.find_one_or_none.return_value = mock_task
+ yield mock_dao
+
+
+@pytest.fixture
+def mock_update_command():
+ """Mock UpdateTaskCommand to avoid database operations."""
+ with patch("superset.commands.tasks.update.UpdateTaskCommand") as mock_cmd:
+ mock_cmd.return_value.run.return_value = None
+ yield mock_cmd
+
+
+@pytest.fixture
+def mock_flask_app():
+ """Create a properly configured mock Flask app."""
+ mock_app = MagicMock()
+ mock_app.config = {
+ "TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
+ }
+ # Make app_context() return a proper context manager
+ mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
+ mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
+ # Use regular Mock (not MagicMock) for _get_current_object to avoid
+ # AsyncMockMixin creating unawaited coroutines in Python 3.10+
+ mock_app._get_current_object = Mock(return_value=mock_app)
+ return mock_app
+
+
+@pytest.fixture
+def task_context(mock_task, mock_task_dao, mock_update_command, mock_flask_app):
+ """Create TaskContext with mocked dependencies."""
+ # Ensure mock_task has properties_dict and payload_dict (TaskContext accesses them)
+ mock_task.properties_dict = {"is_abortable": False}
+ mock_task.payload_dict = {}
+
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ # Configure current_app mock
+ mock_current_app.config = mock_flask_app.config
+ # Use regular Mock (not MagicMock) for _get_current_object to avoid
+ # AsyncMockMixin creating unawaited coroutines in Python 3.10+
+ mock_current_app._get_current_object = Mock(return_value=mock_flask_app)
+
+ ctx = TaskContext(mock_task)
+
+ yield ctx
+
+ # Cleanup: stop polling if started
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+
+class TestTaskStatusEnum:
+ """Test TaskStatus enum values."""
+
+ def test_aborting_status_exists(self):
+ """Test that ABORTING status is defined."""
+ assert hasattr(TaskStatus, "ABORTING")
+ assert TaskStatus.ABORTING.value == "aborting"
+
+ def test_all_statuses_present(self):
+ """Test all expected statuses are present."""
+ expected_statuses = [
+ "pending",
+ "in_progress",
+ "success",
+ "failure",
+ "aborting",
+ "aborted",
+ ]
+ actual_statuses = [s.value for s in TaskStatus]
+
+ for status in expected_statuses:
+ assert status in actual_statuses, f"Missing status: {status}"
+
+
+class TestTaskAbortProperties:
+ """Test Task model abort-related properties via status and properties accessor."""
+
+ def test_aborting_status(self):
+ """Test ABORTING status check."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.status = TaskStatus.ABORTING.value
+
+ assert task.status == TaskStatus.ABORTING.value
+
+ def test_is_abortable_in_properties(self):
+ """Test is_abortable is accessible via properties."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.update_properties({"is_abortable": True})
+
+ assert task.properties_dict.get("is_abortable") is True
+
+ def test_is_abortable_default_none(self):
+ """Test is_abortable defaults to None for new tasks."""
+ from superset.models.tasks import Task
+
+ task = Task()
+
+ assert task.properties_dict.get("is_abortable") is None
+
+
+class TestTaskSetStatus:
+ """Test Task.set_status behavior for abort states."""
+
+ def test_set_status_in_progress_sets_is_abortable_false(self):
+ """Test that transitioning to IN_PROGRESS sets is_abortable to False."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.uuid = "test-uuid"
+ # Default is None
+
+ task.set_status(TaskStatus.IN_PROGRESS)
+
+ assert task.properties_dict.get("is_abortable") is False
+ assert task.started_at is not None
+
+ def test_set_status_in_progress_preserves_existing_is_abortable(self):
+ """Test that re-setting IN_PROGRESS doesn't override is_abortable."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.uuid = "test-uuid"
+ task.update_properties(
+ {"is_abortable": True}
+ ) # Already set by handler registration
+ task.started_at = datetime.now(timezone.utc) # Already started
+
+ task.set_status(TaskStatus.IN_PROGRESS)
+
+ # Should not override since started_at is already set
+ assert task.properties_dict.get("is_abortable") is True
+
+ def test_set_status_aborting_does_not_set_ended_at(self):
+ """Test that ABORTING status does not set ended_at."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.uuid = "test-uuid"
+ task.started_at = datetime.now(timezone.utc)
+
+ task.status = TaskStatus.ABORTING.value
+
+ assert task.ended_at is None
+
+ def test_set_status_aborted_sets_ended_at(self):
+ """Test that ABORTED status sets ended_at."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.uuid = "test-uuid"
+ task.started_at = datetime.now(timezone.utc)
+
+ task.set_status(TaskStatus.ABORTED)
+
+ assert task.ended_at is not None
+
+
+class TestTaskDuration:
+ """Test Task duration_seconds property with different states."""
+
+ def test_duration_seconds_finished_task(self):
+ """Test duration for finished task returns actual duration."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.status = TaskStatus.SUCCESS.value # Must be finished to use ended_at
+ task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
+ task.ended_at = datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc)
+
+ # Should use ended_at - started_at = 30 seconds
+ assert task.duration_seconds == 30.0
+
+ @freeze_time("2024-01-01 10:00:30")
+ def test_duration_seconds_running_task(self):
+ """Test duration for running task returns time since start."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.started_at = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
+ task.ended_at = None
+
+ # 30 seconds since start
+ assert task.duration_seconds == 30.0
+
+ @freeze_time("2024-01-01 10:00:15")
+ def test_duration_seconds_pending_task(self):
+ """Test duration for pending task returns queue time."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
+ task.started_at = None
+ task.ended_at = None
+
+ # 15 seconds since creation
+ assert task.duration_seconds == 15.0
+
+ def test_duration_seconds_no_timestamps(self):
+ """Test duration returns None when no timestamps available."""
+ from superset.models.tasks import Task
+
+ task = Task()
+ task.created_on = None
+ task.started_at = None
+ task.ended_at = None
+
+ assert task.duration_seconds is None
+
+
+class TestAbortHandlerRegistration:
+ """Test abort handler registration and is_abortable flag."""
+
+ def test_on_abort_registers_handler(self, task_context):
+ """Test that on_abort registers a handler."""
+ handler_called = False
+
+ @task_context.on_abort
+ def handle_abort():
+ nonlocal handler_called
+ handler_called = True
+
+ assert len(task_context._abort_handlers) == 1
+ assert not handler_called
+
+ @patch("superset.tasks.context.current_app")
+ def test_on_abort_sets_abortable(self, mock_app):
+ """Test on_abort sets is_abortable to True on first handler."""
+ mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
+ mock_app._get_current_object = Mock(return_value=mock_app)
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.properties_dict = {"is_abortable": False}
+ mock_task.payload_dict = {}
+
+ with (
+ patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
+ patch.object(TaskContext, "start_abort_polling"),
+ ):
+ ctx = TaskContext(mock_task)
+
+ @ctx.on_abort
+ def handler():
+ pass
+
+ mock_set_abortable.assert_called_once()
+
+ @patch("superset.tasks.context.current_app")
+ def test_on_abort_only_sets_abortable_once(self, mock_app):
+ """Test on_abort only calls _set_abortable for first handler."""
+ mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 1.0}
+ mock_app._get_current_object = Mock(return_value=mock_app)
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.properties_dict = {"is_abortable": False}
+ mock_task.payload_dict = {}
+
+ with (
+ patch.object(TaskContext, "_set_abortable") as mock_set_abortable,
+ patch.object(TaskContext, "start_abort_polling"),
+ ):
+ ctx = TaskContext(mock_task)
+
+ @ctx.on_abort
+ def handler1():
+ pass
+
+ @ctx.on_abort
+ def handler2():
+ pass
+
+ # Should only be called once for first handler
+ assert mock_set_abortable.call_count == 1
+
+ def test_abort_handlers_completed_initially_false(self):
+ """Test abort_handlers_completed is False initially."""
+ mock_task = MagicMock()
+ mock_task.uuid = TEST_UUID
+ mock_task.properties_dict = {}
+ mock_task.payload_dict = {}
+
+ with patch("superset.tasks.context.current_app") as mock_app:
+ mock_app._get_current_object = Mock(return_value=mock_app)
+ ctx = TaskContext(mock_task)
+ assert ctx.abort_handlers_completed is False
+
+
+class TestAbortPolling:
+ """Test abort detection polling behavior."""
+
+ def test_on_abort_starts_polling_automatically(self, task_context):
+ """Test that registering first handler starts abort listener."""
+ assert task_context._abort_listener is None
+
+ @task_context.on_abort
+ def handle_abort():
+ pass
+
+ assert task_context._abort_listener is not None
+
+ def test_stop_abort_polling(self, task_context):
+ """Test that stop_abort_polling stops the abort listener."""
+
+ @task_context.on_abort
+ def handle_abort():
+ pass
+
+ assert task_context._abort_listener is not None
+
+ task_context.stop_abort_polling()
+
+ assert task_context._abort_listener is None
+
+ def test_start_abort_polling_only_once(self, task_context):
+ """Test that start_abort_polling is idempotent."""
+ task_context.start_abort_polling(interval=0.1)
+ first_listener = task_context._abort_listener
+
+ # Try to start again
+ task_context.start_abort_polling(interval=0.1)
+ second_listener = task_context._abort_listener
+
+ # Should be the same listener
+ assert first_listener is second_listener
+
+ def test_on_abort_with_custom_interval(self, task_context):
+ """Test that custom interval can be set via start_abort_polling."""
+ with patch("superset.tasks.context.current_app") as mock_app:
+ mock_app.config = {"TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1}
+ mock_app._get_current_object = Mock(return_value=mock_app)
+
+ @task_context.on_abort
+ def handle_abort():
+ pass
+
+ # Override with custom interval
+ task_context.stop_abort_polling()
+ task_context.start_abort_polling(interval=0.05)
+
+ assert task_context._abort_listener is not None
+
+ def test_polling_stops_after_abort_detected(self, task_context, mock_task):
+ """Test that abort is detected and handlers are triggered."""
+
+ @task_context.on_abort
+ def handle_abort():
+ pass
+
+ # Trigger abort
+ mock_task.status = TaskStatus.ABORTED.value
+
+ # Wait for detection
+ time.sleep(0.3)
+
+ # Abort should have been detected
+ assert task_context._abort_detected is True
+
+
+class TestAbortHandlerExecution:
+ """Test abort handler execution behavior."""
+
+ def test_on_abort_handler_fires_when_task_aborted(self, task_context, mock_task):
+ """Test that abort handler fires automatically when task is aborted."""
+ abort_called = False
+
+ @task_context.on_abort
+ def handle_abort():
+ nonlocal abort_called
+ abort_called = True
+
+ # Simulate task being aborted
+ mock_task.status = TaskStatus.ABORTED.value
+
+ # Wait for polling to detect abort (max 0.3s with 0.1s interval)
+ time.sleep(0.3)
+
+ assert abort_called
+ assert task_context._abort_detected
+
+ def test_on_abort_not_called_on_success(self, task_context, mock_task):
+ """Test that abort handlers don't run on success."""
+ abort_called = False
+
+ @task_context.on_abort
+ def handle_abort():
+ nonlocal abort_called
+ abort_called = True
+
+ # Keep task in success state
+ mock_task.status = TaskStatus.SUCCESS.value
+
+ # Wait and verify handler not called
+ time.sleep(0.3)
+
+ assert not abort_called
+
+ def test_multiple_abort_handlers(self, task_context, mock_task):
+ """Test that all abort handlers execute in LIFO order."""
+ calls = []
+
+ @task_context.on_abort
+ def handler1():
+ calls.append(1)
+
+ @task_context.on_abort
+ def handler2():
+ calls.append(2)
+
+ # Trigger abort
+ mock_task.status = TaskStatus.ABORTED.value
+
+ # Wait for detection
+ time.sleep(0.3)
+
+ # LIFO order: handler2 runs first
+ assert calls == [2, 1]
+
+ def test_abort_handler_exception_doesnt_fail_task(self, task_context, mock_task):
+ """Test that exception in abort handler is logged but doesn't fail task."""
+ handler2_called = False
+
+ @task_context.on_abort
+ def bad_handler():
+ raise ValueError("Handler error")
+
+ @task_context.on_abort
+ def good_handler():
+ nonlocal handler2_called
+ handler2_called = True
+
+ # Trigger abort
+ mock_task.status = TaskStatus.ABORTED.value
+
+ # Wait for detection
+ time.sleep(0.3)
+
+ # Second handler should still run despite first handler failing
+ assert handler2_called
+
+
+class TestBestEffortHandlerExecution:
+ """Test that all handlers execute even when some fail (best-effort)."""
+
+ def test_all_abort_handlers_run_even_if_all_fail(self, task_context, mock_task):
+ """Test all abort handlers execute even if every one raises an exception."""
+ calls = []
+
+ @task_context.on_abort
+ def handler1():
+ calls.append(1)
+ raise ValueError("Handler 1 failed")
+
+ @task_context.on_abort
+ def handler2():
+ calls.append(2)
+ raise RuntimeError("Handler 2 failed")
+
+ @task_context.on_abort
+ def handler3():
+ calls.append(3)
+ raise TypeError("Handler 3 failed")
+
+ # Trigger abort handlers directly (simulating abort detection)
+ task_context._trigger_abort_handlers()
+
+ # All handlers should have been called (LIFO order: 3, 2, 1)
+ assert calls == [3, 2, 1]
+
+ # Failures should be collected (abort handlers don't write to DB)
+ assert len(task_context._handler_failures) == 3
+ failure_types = [
+ type(ex).__name__ for _, ex, _ in task_context._handler_failures
+ ]
+ assert "TypeError" in failure_types
+ assert "RuntimeError" in failure_types
+ assert "ValueError" in failure_types
+
+ def test_all_cleanup_handlers_run_even_if_all_fail(self, task_context, mock_task):
+ """Test all cleanup handlers execute even if every one raises an exception."""
+ calls = []
+ captured_failures = []
+
+ # Mock _write_handler_failures_to_db to capture failures before clearing
+ original_write = task_context._write_handler_failures_to_db
+
+ def mock_write():
+ captured_failures.extend(task_context._handler_failures)
+ original_write()
+
+ task_context._write_handler_failures_to_db = mock_write
+
+ @task_context.on_cleanup
+ def cleanup1():
+ calls.append(1)
+ raise ValueError("Cleanup 1 failed")
+
+ @task_context.on_cleanup
+ def cleanup2():
+ calls.append(2)
+ raise RuntimeError("Cleanup 2 failed")
+
+ @task_context.on_cleanup
+ def cleanup3():
+ calls.append(3)
+ raise TypeError("Cleanup 3 failed")
+
+ # Set task to SUCCESS (not aborting) so only cleanup handlers run
+ mock_task.status = TaskStatus.SUCCESS.value
+
+ # Run cleanup
+ task_context._run_cleanup()
+
+ # All handlers should have been called (LIFO order: 3, 2, 1)
+ assert calls == [3, 2, 1]
+
+ # Failures should have been captured before clearing
+ assert len(captured_failures) == 3
+ failure_types = [type(ex).__name__ for _, ex, _ in captured_failures]
+ assert "TypeError" in failure_types
+ assert "RuntimeError" in failure_types
+ assert "ValueError" in failure_types
+
+ def test_mixed_abort_and_cleanup_failures_all_collected(
+ self, task_context, mock_task
+ ):
+ """Test abort and cleanup handler failures are collected together."""
+ calls = []
+ captured_failures = []
+
+ # Mock _write_handler_failures_to_db to capture failures before clearing
+ original_write = task_context._write_handler_failures_to_db
+
+ def mock_write():
+ captured_failures.extend(task_context._handler_failures)
+ original_write()
+
+ task_context._write_handler_failures_to_db = mock_write
+
+ @task_context.on_abort
+ def abort1():
+ calls.append("abort1")
+ raise ValueError("Abort 1 failed")
+
+ @task_context.on_abort
+ def abort2():
+ calls.append("abort2")
+ raise RuntimeError("Abort 2 failed")
+
+ @task_context.on_cleanup
+ def cleanup1():
+ calls.append("cleanup1")
+ raise TypeError("Cleanup 1 failed")
+
+ @task_context.on_cleanup
+ def cleanup2():
+ calls.append("cleanup2")
+ raise KeyError("Cleanup 2 failed")
+
+ # Set task to ABORTING so both abort and cleanup handlers run
+ mock_task.status = TaskStatus.ABORTING.value
+
+ # Run cleanup (which triggers abort handlers first, then cleanup handlers)
+ task_context._run_cleanup()
+
+ # All handlers should have been called
+ # Abort handlers run first (LIFO: abort2, abort1)
+ # Then cleanup handlers (LIFO: cleanup2, cleanup1)
+ assert calls == ["abort2", "abort1", "cleanup2", "cleanup1"]
+
+ # All 4 failures should have been captured
+ assert len(captured_failures) == 4
+
+ # Verify handler types are recorded correctly
+ handler_types = [htype for htype, _, _ in captured_failures]
+ assert handler_types.count("abort") == 2
+ assert handler_types.count("cleanup") == 2
+
+
+class TestCleanupHandlers:
+ """Test cleanup handler behavior."""
+
+ def test_cleanup_triggers_abort_handlers_if_not_detected(
+ self, task_context, mock_task
+ ):
+ """Test that _run_cleanup triggers abort handlers if task ended aborted."""
+ abort_called = False
+
+ @task_context.on_abort
+ def handle_abort():
+ nonlocal abort_called
+ abort_called = True
+
+ # Set task as aborted but don't let polling detect it
+ mock_task.status = TaskStatus.ABORTED.value
+ task_context._abort_detected = False
+
+ # Immediately run cleanup (simulating task ending before poll)
+ task_context._run_cleanup()
+
+ assert abort_called
+
+ def test_cleanup_doesnt_duplicate_abort_handlers(self, task_context, mock_task):
+ """Test that abort handlers only run once even if called from cleanup."""
+ call_count = 0
+
+ @task_context.on_abort
+ def handle_abort():
+ nonlocal call_count
+ call_count += 1
+
+ # Trigger abort via polling
+ mock_task.status = TaskStatus.ABORTED.value
+ time.sleep(0.3)
+
+ # Handlers should have been called once
+ assert call_count == 1
+ assert task_context._abort_detected is True
+
+ # Run cleanup - handlers should NOT be called again
+ task_context._run_cleanup()
+
+ assert call_count == 1 # Still 1, not 2
diff --git a/tests/unit_tests/tasks/test_manager.py b/tests/unit_tests/tasks/test_manager.py
new file mode 100644
index 000000000000..13997a7f113e
--- /dev/null
+++ b/tests/unit_tests/tasks/test_manager.py
@@ -0,0 +1,462 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for TaskManager pub/sub functionality"""
+
+import threading
+import time
+from unittest.mock import MagicMock, patch
+
+import redis
+
+from superset.tasks.manager import AbortListener, TaskManager
+
+
+class TestAbortListener:
+ """Tests for AbortListener class"""
+
+ def test_stop_sets_event(self):
+ """Test that stop() sets the stop event"""
+ stop_event = threading.Event()
+ thread = MagicMock(spec=threading.Thread)
+ thread.is_alive.return_value = False
+
+ listener = AbortListener("test-uuid", thread, stop_event)
+
+ assert not stop_event.is_set()
+ listener.stop()
+ assert stop_event.is_set()
+
+ def test_stop_closes_pubsub(self):
+ """Test that stop() closes the pub/sub connection"""
+ stop_event = threading.Event()
+ thread = MagicMock(spec=threading.Thread)
+ thread.is_alive.return_value = False
+ pubsub = MagicMock()
+
+ listener = AbortListener("test-uuid", thread, stop_event, pubsub)
+ listener.stop()
+
+ pubsub.unsubscribe.assert_called_once()
+ pubsub.close.assert_called_once()
+
+ def test_stop_joins_thread(self):
+ """Test that stop() joins the listener thread"""
+ stop_event = threading.Event()
+ thread = MagicMock(spec=threading.Thread)
+ thread.is_alive.return_value = True
+
+ listener = AbortListener("test-uuid", thread, stop_event)
+ listener.stop()
+
+ thread.join.assert_called_once_with(timeout=2.0)
+
+
+class TestTaskManagerInitApp:
+ """Tests for TaskManager.init_app()"""
+
+ def setup_method(self):
+ """Reset TaskManager state before each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def teardown_method(self):
+ """Reset TaskManager state after each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def test_init_app_sets_channel_prefixes(self):
+ """Test init_app reads channel prefixes from config"""
+ app = MagicMock()
+ app.config.get.side_effect = lambda key, default=None: {
+ "TASKS_ABORT_CHANNEL_PREFIX": "custom:abort:",
+ "TASKS_COMPLETION_CHANNEL_PREFIX": "custom:complete:",
+ }.get(key, default)
+
+ TaskManager.init_app(app)
+
+ assert TaskManager._initialized is True
+ assert TaskManager._channel_prefix == "custom:abort:"
+ assert TaskManager._completion_channel_prefix == "custom:complete:"
+
+ def test_init_app_skips_if_already_initialized(self):
+ """Test init_app is idempotent"""
+ TaskManager._initialized = True
+
+ app = MagicMock()
+ TaskManager.init_app(app)
+
+ # Should not call app.config.get since already initialized
+ app.config.get.assert_not_called()
+
+
+class TestTaskManagerPubSub:
+ """Tests for TaskManager pub/sub methods"""
+
+ def setup_method(self):
+ """Reset TaskManager state before each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def teardown_method(self):
+ """Reset TaskManager state after each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_is_pubsub_available_no_redis(self, mock_cache_manager):
+ """Test is_pubsub_available returns False when Redis not configured"""
+ mock_cache_manager.signal_cache = None
+ assert TaskManager.is_pubsub_available() is False
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_is_pubsub_available_with_redis(self, mock_cache_manager):
+ """Test is_pubsub_available returns True when Redis is configured"""
+ mock_cache_manager.signal_cache = MagicMock()
+ assert TaskManager.is_pubsub_available() is True
+
+ def test_get_abort_channel(self):
+ """Test get_abort_channel returns correct channel name"""
+ task_uuid = "abc-123-def-456"
+ channel = TaskManager.get_abort_channel(task_uuid)
+ assert channel == "gtf:abort:abc-123-def-456"
+
+ def test_get_abort_channel_custom_prefix(self):
+ """Test get_abort_channel with custom prefix"""
+ TaskManager._channel_prefix = "custom:prefix:"
+ task_uuid = "test-uuid"
+ channel = TaskManager.get_abort_channel(task_uuid)
+ assert channel == "custom:prefix:test-uuid"
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_abort_no_redis(self, mock_cache_manager):
+ """Test publish_abort returns False when Redis not available"""
+ mock_cache_manager.signal_cache = None
+ result = TaskManager.publish_abort("test-uuid")
+ assert result is False
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_abort_success(self, mock_cache_manager):
+ """Test publish_abort publishes message successfully"""
+ mock_redis = MagicMock()
+ mock_redis.publish.return_value = 1 # One subscriber
+ mock_cache_manager.signal_cache = mock_redis
+
+ result = TaskManager.publish_abort("test-uuid")
+
+ assert result is True
+ mock_redis.publish.assert_called_once_with("gtf:abort:test-uuid", "abort")
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_abort_redis_error(self, mock_cache_manager):
+ """Test publish_abort handles Redis errors gracefully"""
+ mock_redis = MagicMock()
+ mock_redis.publish.side_effect = redis.RedisError("Connection lost")
+ mock_cache_manager.signal_cache = mock_redis
+
+ result = TaskManager.publish_abort("test-uuid")
+
+ assert result is False
+
+
+class TestTaskManagerListenForAbort:
+ """Tests for TaskManager.listen_for_abort()"""
+
+ def setup_method(self):
+ """Reset TaskManager state before each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def teardown_method(self):
+ """Reset TaskManager state after each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_listen_for_abort_no_redis_uses_polling(self, mock_cache_manager):
+ """Test listen_for_abort falls back to polling when Redis unavailable"""
+ mock_cache_manager.signal_cache = None
+ callback = MagicMock()
+
+ with patch.object(TaskManager, "_poll_for_abort", return_value=None):
+ listener = TaskManager.listen_for_abort(
+ task_uuid="test-uuid",
+ callback=callback,
+ poll_interval=1.0,
+ app=None,
+ )
+
+ # Give thread time to start
+ time.sleep(0.1)
+ listener.stop()
+
+ # Should use polling since no Redis
+ assert listener._pubsub is None
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_listen_for_abort_with_redis_uses_pubsub(self, mock_cache_manager):
+ """Test listen_for_abort uses pub/sub when Redis available"""
+ mock_redis = MagicMock()
+ mock_pubsub = MagicMock()
+ mock_redis.pubsub.return_value = mock_pubsub
+ mock_cache_manager.signal_cache = mock_redis
+
+ callback = MagicMock()
+
+ with patch.object(TaskManager, "_listen_pubsub", return_value=None):
+ listener = TaskManager.listen_for_abort(
+ task_uuid="test-uuid",
+ callback=callback,
+ poll_interval=1.0,
+ app=None,
+ )
+
+ # Give thread time to start
+ time.sleep(0.1)
+ listener.stop()
+
+ # Should subscribe to channel
+ mock_pubsub.subscribe.assert_called_once_with("gtf:abort:test-uuid")
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_listen_for_abort_redis_subscribe_failure_raises(self, mock_cache_manager):
+ """Test listen_for_abort raises exception on subscribe failure
+ when Redis configured"""
+ import pytest
+
+ mock_redis = MagicMock()
+ mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
+ mock_cache_manager.signal_cache = mock_redis
+
+ callback = MagicMock()
+
+ # With fail-fast behavior, Redis subscribe failure raises exception
+ with pytest.raises(redis.RedisError, match="Connection failed"):
+ TaskManager.listen_for_abort(
+ task_uuid="test-uuid",
+ callback=callback,
+ poll_interval=1.0,
+ app=None,
+ )
+
+
+class TestTaskManagerCompletion:
+ """Tests for TaskManager completion pub/sub and wait_for_completion"""
+
+ def setup_method(self):
+ """Reset TaskManager state before each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def teardown_method(self):
+ """Reset TaskManager state after each test"""
+ TaskManager._initialized = False
+ TaskManager._channel_prefix = "gtf:abort:"
+ TaskManager._completion_channel_prefix = "gtf:complete:"
+
+ def test_get_completion_channel(self):
+ """Test get_completion_channel returns correct channel name"""
+ task_uuid = "abc-123-def-456"
+ channel = TaskManager.get_completion_channel(task_uuid)
+ assert channel == "gtf:complete:abc-123-def-456"
+
+ def test_get_completion_channel_custom_prefix(self):
+ """Test get_completion_channel with custom prefix"""
+ TaskManager._completion_channel_prefix = "custom:complete:"
+ task_uuid = "test-uuid"
+ channel = TaskManager.get_completion_channel(task_uuid)
+ assert channel == "custom:complete:test-uuid"
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_completion_no_redis(self, mock_cache_manager):
+ """Test publish_completion returns False when Redis not available"""
+ mock_cache_manager.signal_cache = None
+ result = TaskManager.publish_completion("test-uuid", "success")
+ assert result is False
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_completion_success(self, mock_cache_manager):
+ """Test publish_completion publishes message successfully"""
+ mock_redis = MagicMock()
+ mock_redis.publish.return_value = 1 # One subscriber
+ mock_cache_manager.signal_cache = mock_redis
+
+ result = TaskManager.publish_completion("test-uuid", "success")
+
+ assert result is True
+ mock_redis.publish.assert_called_once_with("gtf:complete:test-uuid", "success")
+
+ @patch("superset.tasks.manager.cache_manager")
+ def test_publish_completion_redis_error(self, mock_cache_manager):
+ """Test publish_completion handles Redis errors gracefully"""
+ mock_redis = MagicMock()
+ mock_redis.publish.side_effect = redis.RedisError("Connection lost")
+ mock_cache_manager.signal_cache = mock_redis
+
+ result = TaskManager.publish_completion("test-uuid", "success")
+
+ assert result is False
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_task_not_found(self, mock_dao, mock_cache_manager):
+ """Test wait_for_completion raises ValueError for missing task"""
+ import pytest
+
+ mock_cache_manager.signal_cache = None
+ mock_dao.find_one_or_none.return_value = None
+
+ with pytest.raises(ValueError, match="not found"):
+ TaskManager.wait_for_completion("nonexistent-uuid")
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_already_complete(self, mock_dao, mock_cache_manager):
+ """Test wait_for_completion returns immediately for terminal state"""
+ mock_cache_manager.signal_cache = None
+ mock_task = MagicMock()
+ mock_task.uuid = "test-uuid"
+ mock_task.status = "success"
+ mock_dao.find_one_or_none.return_value = mock_task
+
+ result = TaskManager.wait_for_completion("test-uuid")
+
+ assert result == mock_task
+ # Should only call find_one_or_none once (initial check)
+ mock_dao.find_one_or_none.assert_called_once()
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_timeout(self, mock_dao, mock_cache_manager):
+ """Test wait_for_completion raises TimeoutError when timeout expires"""
+ import pytest
+
+ mock_cache_manager.signal_cache = None
+ mock_task = MagicMock()
+ mock_task.uuid = "test-uuid"
+ mock_task.status = "in_progress" # Never completes
+ mock_dao.find_one_or_none.return_value = mock_task
+
+ with pytest.raises(TimeoutError, match="Timeout waiting"):
+ TaskManager.wait_for_completion("test-uuid", timeout=0.1)
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_polling_success(self, mock_dao, mock_cache_manager):
+ """Test wait_for_completion returns when task completes via polling"""
+ mock_cache_manager.signal_cache = None
+ mock_task_pending = MagicMock()
+ mock_task_pending.uuid = "test-uuid"
+ mock_task_pending.status = "pending"
+
+ mock_task_complete = MagicMock()
+ mock_task_complete.uuid = "test-uuid"
+ mock_task_complete.status = "success"
+
+ # First call returns pending, second returns complete
+ mock_dao.find_one_or_none.side_effect = [
+ mock_task_pending,
+ mock_task_complete,
+ ]
+
+ result = TaskManager.wait_for_completion(
+ "test-uuid",
+ timeout=5.0,
+ poll_interval=0.1,
+ )
+
+ assert result.status == "success"
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_with_pubsub(self, mock_dao, mock_cache_manager):
+ """Test wait_for_completion uses pub/sub when Redis available"""
+ mock_task_pending = MagicMock()
+ mock_task_pending.uuid = "test-uuid"
+ mock_task_pending.status = "pending"
+
+ mock_task_complete = MagicMock()
+ mock_task_complete.uuid = "test-uuid"
+ mock_task_complete.status = "success"
+
+ # First call returns pending, second returns complete
+ mock_dao.find_one_or_none.side_effect = [
+ mock_task_pending,
+ mock_task_complete,
+ ]
+
+ # Set up mock Redis with pub/sub
+ mock_redis = MagicMock()
+ mock_pubsub = MagicMock()
+ # Simulate receiving a completion message
+ mock_pubsub.get_message.return_value = {
+ "type": "message",
+ "data": "success",
+ }
+ mock_redis.pubsub.return_value = mock_pubsub
+ mock_cache_manager.signal_cache = mock_redis
+
+ result = TaskManager.wait_for_completion(
+ "test-uuid",
+ timeout=5.0,
+ )
+
+ assert result.status == "success"
+ # Should have subscribed to completion channel
+ mock_pubsub.subscribe.assert_called_once_with("gtf:complete:test-uuid")
+ # Should have cleaned up
+ mock_pubsub.unsubscribe.assert_called_once()
+ mock_pubsub.close.assert_called_once()
+
+ @patch("superset.tasks.manager.cache_manager")
+ @patch("superset.daos.tasks.TaskDAO")
+ def test_wait_for_completion_pubsub_error_raises(
+ self, mock_dao, mock_cache_manager
+ ):
+ """Test wait_for_completion raises exception on Redis error when
+ Redis configured"""
+ import pytest
+
+ mock_task_pending = MagicMock()
+ mock_task_pending.uuid = "test-uuid"
+ mock_task_pending.status = "pending"
+
+ mock_dao.find_one_or_none.return_value = mock_task_pending
+
+ # Set up mock Redis that fails
+ mock_redis = MagicMock()
+ mock_redis.pubsub.side_effect = redis.RedisError("Connection failed")
+ mock_cache_manager.signal_cache = mock_redis
+
+ # With fail-fast behavior, Redis error is raised instead of falling back
+ with pytest.raises(redis.RedisError, match="Connection failed"):
+ TaskManager.wait_for_completion(
+ "test-uuid",
+ timeout=5.0,
+ poll_interval=0.1,
+ )
+
+ def test_terminal_states_constant(self):
+ """Test TERMINAL_STATES contains expected values"""
+ expected = {"success", "failure", "aborted", "timed_out"}
+ assert TaskManager.TERMINAL_STATES == expected
diff --git a/tests/unit_tests/tasks/test_timeout.py b/tests/unit_tests/tasks/test_timeout.py
new file mode 100644
index 000000000000..ef8d5f9d7616
--- /dev/null
+++ b/tests/unit_tests/tasks/test_timeout.py
@@ -0,0 +1,612 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for GTF timeout handling."""
+
+import time
+from unittest.mock import MagicMock, patch
+from uuid import UUID
+
+import pytest
+from superset_core.api.tasks import TaskOptions, TaskScope
+
+from superset.tasks.context import TaskContext
+from superset.tasks.decorators import TaskWrapper
+
+TEST_UUID = UUID("b8b61b7b-1cd3-4a31-a74a-0a95341afc06")
+
+# =============================================================================
+# Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def mock_flask_app():
+ """Create a properly configured mock Flask app."""
+ mock_app = MagicMock()
+ mock_app.config = {
+ "TASK_ABORT_POLLING_DEFAULT_INTERVAL": 0.1,
+ }
+ # Make app_context() return a proper context manager
+ mock_app.app_context.return_value.__enter__ = MagicMock(return_value=None)
+ mock_app.app_context.return_value.__exit__ = MagicMock(return_value=None)
+ return mock_app
+
+
+@pytest.fixture
+def mock_task_abortable():
+ """Create a mock task that is abortable."""
+ task = MagicMock()
+ task.uuid = TEST_UUID
+ task.status = "in_progress"
+ task.properties_dict = {"is_abortable": True}
+ task.payload_dict = {}
+ # Set real values for dedup_key generation (used by UpdateTaskCommand lock)
+ task.scope = "shared"
+ task.task_type = "test_task"
+ task.task_key = "test_key"
+ task.user_id = 1
+ return task
+
+
+@pytest.fixture
+def mock_task_not_abortable():
+ """Create a mock task that is NOT abortable."""
+ task = MagicMock()
+ task.uuid = TEST_UUID
+ task.status = "in_progress"
+ task.properties_dict = {} # No is_abortable means it's not abortable
+ task.payload_dict = {}
+ # Set real values for dedup_key generation (used by UpdateTaskCommand lock)
+ task.scope = "shared"
+ task.task_type = "test_task"
+ task.task_key = "test_key"
+ task.user_id = 1
+ return task
+
+
+@pytest.fixture
+def task_context_for_timeout(mock_flask_app, mock_task_abortable):
+ """Create TaskContext with mocked dependencies for timeout tests."""
+ # Ensure mock_task has required attributes for TaskContext
+ mock_task_abortable.payload_dict = {}
+
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ # Configure current_app mock
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+
+ # Configure TaskDAO mock
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ yield ctx
+
+ # Cleanup: stop timers if started
+ ctx.stop_timeout_timer()
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+
+# =============================================================================
+# TaskWrapper._merge_options Timeout Tests
+# =============================================================================
+
+
+class TestTimeoutMerging:
+ """Test timeout merging behavior in TaskWrapper._merge_options."""
+
+ def test_merge_options_decorator_timeout_used_when_no_override(self):
+ """Test that decorator timeout is used when no override is provided."""
+
+ def dummy_func():
+ pass
+
+ wrapper = TaskWrapper(
+ name="test_task",
+ func=dummy_func,
+ default_options=TaskOptions(),
+ scope=TaskScope.PRIVATE,
+ default_timeout=300, # 5-minute default
+ )
+
+ merged = wrapper._merge_options(None)
+ assert merged.timeout == 300
+
+ def test_merge_options_override_timeout_takes_precedence(self):
+ """Test that TaskOptions timeout overrides decorator default."""
+
+ def dummy_func():
+ pass
+
+ wrapper = TaskWrapper(
+ name="test_task",
+ func=dummy_func,
+ default_options=TaskOptions(),
+ scope=TaskScope.PRIVATE,
+ default_timeout=300, # 5-minute default
+ )
+
+ override = TaskOptions(timeout=600) # 10-minute override
+ merged = wrapper._merge_options(override)
+ assert merged.timeout == 600
+
+ def test_merge_options_no_timeout_when_not_configured(self):
+ """Test that no timeout is set when not configured anywhere."""
+
+ def dummy_func():
+ pass
+
+ wrapper = TaskWrapper(
+ name="test_task",
+ func=dummy_func,
+ default_options=TaskOptions(),
+ scope=TaskScope.PRIVATE,
+ default_timeout=None, # No default timeout
+ )
+
+ merged = wrapper._merge_options(None)
+ assert merged.timeout is None
+
+ def test_merge_options_override_with_other_options_preserves_timeout(self):
+ """Test that setting other options doesn't lose decorator timeout."""
+
+ def dummy_func():
+ pass
+
+ wrapper = TaskWrapper(
+ name="test_task",
+ func=dummy_func,
+ default_options=TaskOptions(),
+ scope=TaskScope.PRIVATE,
+ default_timeout=300,
+ )
+
+ # Override only task_key, not timeout
+ override = TaskOptions(task_key="my-key")
+ merged = wrapper._merge_options(override)
+
+ # Should keep decorator timeout since override.timeout is None
+ assert merged.timeout == 300
+ assert merged.task_key == "my-key"
+
+
+# =============================================================================
+# TaskContext Timeout Timer Tests
+# =============================================================================
+
+
+class TestTimeoutTimer:
+ """Test TaskContext timeout timer behavior."""
+
+ def test_start_timeout_timer_sets_timer(self, task_context_for_timeout):
+ """Test that start_timeout_timer creates a timer."""
+ ctx = task_context_for_timeout
+
+ assert ctx._timeout_timer is None
+
+ ctx.start_timeout_timer(10)
+
+ assert ctx._timeout_timer is not None
+ assert ctx._timeout_triggered is False
+
+ def test_start_timeout_timer_is_idempotent(self, task_context_for_timeout):
+ """Test that starting timer twice doesn't create duplicate timers."""
+ ctx = task_context_for_timeout
+
+ ctx.start_timeout_timer(10)
+ first_timer = ctx._timeout_timer
+
+ ctx.start_timeout_timer(20) # Try to start again
+ second_timer = ctx._timeout_timer
+
+ assert first_timer is second_timer
+
+ def test_stop_timeout_timer_cancels_timer(self, task_context_for_timeout):
+ """Test that stop_timeout_timer cancels the timer."""
+ ctx = task_context_for_timeout
+
+ ctx.start_timeout_timer(10)
+ assert ctx._timeout_timer is not None
+
+ ctx.stop_timeout_timer()
+
+ assert ctx._timeout_timer is None
+
+ def test_stop_timeout_timer_safe_when_no_timer(self, task_context_for_timeout):
+ """Test that stop_timeout_timer doesn't fail when no timer exists."""
+ ctx = task_context_for_timeout
+
+ assert ctx._timeout_timer is None
+ ctx.stop_timeout_timer() # Should not raise
+ assert ctx._timeout_timer is None
+
+ def test_timeout_triggered_property_initially_false(self, task_context_for_timeout):
+ """Test that timeout_triggered is False initially."""
+ ctx = task_context_for_timeout
+ assert ctx.timeout_triggered is False
+
+ def test_cleanup_stops_timeout_timer(self, task_context_for_timeout):
+ """Test that _run_cleanup stops the timeout timer."""
+ ctx = task_context_for_timeout
+
+ ctx.start_timeout_timer(10)
+ assert ctx._timeout_timer is not None
+
+ ctx._run_cleanup()
+
+ assert ctx._timeout_timer is None
+
+
+class TestTimeoutTrigger:
+ """Test timeout trigger behavior when timer fires."""
+
+ def test_timeout_triggers_abort_when_abortable(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that timeout triggers abort handlers when task is abortable."""
+ abort_called = False
+
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch(
+ "superset.commands.tasks.update.UpdateTaskCommand"
+ ) as mock_update_cmd,
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ nonlocal abort_called
+ abort_called = True
+
+ # Start short timeout
+ ctx.start_timeout_timer(1)
+
+ # Wait for timeout to fire
+ time.sleep(1.5)
+
+ # Abort handler should have been called
+ assert abort_called
+ assert ctx._timeout_triggered
+ assert ctx._abort_detected
+
+ # Verify UpdateTaskCommand was called with ABORTING status
+ mock_update_cmd.assert_called()
+ call_kwargs = mock_update_cmd.call_args[1]
+ assert call_kwargs.get("status") == "aborting"
+
+ # Cleanup
+ ctx.stop_timeout_timer()
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+ def test_timeout_logs_warning_when_not_abortable(
+ self, mock_flask_app, mock_task_not_abortable
+ ):
+ """Test that timeout logs warning when task has no abort handler."""
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.tasks.context.logger") as mock_logger,
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_not_abortable
+
+ ctx = TaskContext(mock_task_not_abortable)
+ ctx._app = mock_flask_app
+
+ # No abort handler registered
+
+ # Start short timeout
+ ctx.start_timeout_timer(1)
+
+ # Wait for timeout to fire
+ time.sleep(1.5)
+
+ # Should have logged warning
+ mock_logger.warning.assert_called()
+ warning_call = mock_logger.warning.call_args
+ assert "no abort handler" in warning_call[0][0].lower()
+ assert ctx._timeout_triggered
+ assert not ctx._abort_detected # No abort since no handler
+
+ # Cleanup
+ ctx.stop_timeout_timer()
+
+ def test_timeout_does_not_trigger_if_already_aborting(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that timeout doesn't re-trigger abort if already aborting."""
+ abort_count = 0
+
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.commands.tasks.update.UpdateTaskCommand"),
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ nonlocal abort_count
+ abort_count += 1
+
+ # Pre-set abort detected
+ ctx._abort_detected = True
+
+ # Start short timeout
+ ctx.start_timeout_timer(1)
+
+ # Wait for timeout to fire
+ time.sleep(1.5)
+
+ # Handler should NOT have been called since already aborting
+ assert abort_count == 0
+
+ # Cleanup
+ ctx.stop_timeout_timer()
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+
+# =============================================================================
+# Task Decorator Timeout Tests
+# =============================================================================
+
+
+class TestTaskDecoratorTimeout:
+ """Test @task decorator timeout parameter."""
+
+ def test_task_decorator_accepts_timeout(self):
+ """Test that @task decorator accepts timeout parameter."""
+ from superset.tasks.decorators import task
+ from superset.tasks.registry import TaskRegistry
+
+ @task(name="test_timeout_task_1", timeout=300)
+ def timeout_test_task_1():
+ pass
+
+ assert isinstance(timeout_test_task_1, TaskWrapper)
+ assert timeout_test_task_1.default_timeout == 300
+
+ # Cleanup registry
+ TaskRegistry._tasks.pop("test_timeout_task_1", None)
+
+ def test_task_decorator_without_timeout(self):
+ """Test that @task decorator works without timeout."""
+ from superset.tasks.decorators import task
+ from superset.tasks.registry import TaskRegistry
+
+ @task(name="test_timeout_task_2")
+ def timeout_test_task_2():
+ pass
+
+ assert isinstance(timeout_test_task_2, TaskWrapper)
+ assert timeout_test_task_2.default_timeout is None
+
+ # Cleanup registry
+ TaskRegistry._tasks.pop("test_timeout_task_2", None)
+
+ def test_task_decorator_with_all_params(self):
+ """Test that @task decorator accepts all parameters together."""
+ from superset.tasks.decorators import task
+ from superset.tasks.registry import TaskRegistry
+
+ @task(name="test_timeout_task_3", scope=TaskScope.SHARED, timeout=600)
+ def timeout_test_task_3():
+ pass
+
+ assert timeout_test_task_3.name == "test_timeout_task_3"
+ assert timeout_test_task_3.scope == TaskScope.SHARED
+ assert timeout_test_task_3.default_timeout == 600
+
+ # Cleanup registry
+ TaskRegistry._tasks.pop("test_timeout_task_3", None)
+
+
+# =============================================================================
+# Timeout Terminal State Tests
+# =============================================================================
+
+
+class TestTimeoutTerminalState:
+ """Test timeout transitions to correct terminal state (TIMED_OUT vs FAILURE)."""
+
+ def test_timeout_triggered_flag_set_on_timeout(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that timeout_triggered flag is set when timeout fires."""
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.commands.tasks.update.UpdateTaskCommand"),
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ pass
+
+ # Initially not triggered
+ assert ctx.timeout_triggered is False
+
+ # Start short timeout
+ ctx.start_timeout_timer(1)
+
+ # Wait for timeout to fire
+ time.sleep(1.5)
+
+ # Should be set after timeout
+ assert ctx.timeout_triggered is True
+
+ # Cleanup
+ ctx.stop_timeout_timer()
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+ def test_user_abort_does_not_set_timeout_triggered(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that user abort doesn't set timeout_triggered flag."""
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.commands.tasks.update.UpdateTaskCommand"),
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ pass
+
+ # Simulate user abort (not timeout)
+ ctx._on_abort_detected()
+
+ # timeout_triggered should still be False
+ assert ctx.timeout_triggered is False
+ # But abort_detected should be True
+ assert ctx._abort_detected is True
+
+ # Cleanup
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+ def test_abort_handlers_completed_tracks_success(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that abort_handlers_completed flag tracks successful
+ handler execution."""
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.commands.tasks.update.UpdateTaskCommand"),
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ pass # Successful handler
+
+ # Initially not completed
+ assert ctx.abort_handlers_completed is False
+
+ # Trigger abort handlers
+ ctx._trigger_abort_handlers()
+
+ # Should be marked as completed
+ assert ctx.abort_handlers_completed is True
+
+ # Cleanup
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
+
+ def test_abort_handlers_completed_false_on_exception(
+ self, mock_flask_app, mock_task_abortable
+ ):
+ """Test that abort_handlers_completed is False when handler throws."""
+ with (
+ patch("superset.tasks.context.current_app") as mock_current_app,
+ patch("superset.daos.tasks.TaskDAO") as mock_dao,
+ patch("superset.commands.tasks.update.UpdateTaskCommand"),
+ patch("superset.tasks.manager.cache_manager") as mock_cache_manager,
+ ):
+ # Disable Redis by making signal_cache return None
+ mock_cache_manager.signal_cache = None
+
+ mock_current_app.config = mock_flask_app.config
+ mock_current_app._get_current_object.return_value = mock_flask_app
+ mock_dao.find_one_or_none.return_value = mock_task_abortable
+
+ ctx = TaskContext(mock_task_abortable)
+ ctx._app = mock_flask_app
+
+ @ctx.on_abort
+ def handle_abort():
+ raise ValueError("Handler failed")
+
+ # Initially not completed
+ assert ctx.abort_handlers_completed is False
+
+ # Trigger abort handlers (will catch the exception internally)
+ ctx._trigger_abort_handlers()
+
+ # Should NOT be marked as completed since handler threw
+ assert ctx.abort_handlers_completed is False
+
+ # Cleanup
+ if ctx._abort_listener:
+ ctx.stop_abort_polling()
diff --git a/tests/unit_tests/tasks/test_utils.py b/tests/unit_tests/tasks/test_utils.py
index d4b24c66661b..bd5fb8282449 100644
--- a/tests/unit_tests/tasks/test_utils.py
+++ b/tests/unit_tests/tasks/test_utils.py
@@ -22,9 +22,19 @@
import pytest
from flask_appbuilder.security.sqla.models import User
+from superset_core.api.tasks import TaskScope
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.types import Executor, ExecutorType, FixedExecutor
+from superset.tasks.utils import (
+ error_update,
+ get_active_dedup_key,
+ get_finished_dedup_key,
+ parse_properties,
+ progress_update,
+ serialize_properties,
+)
+from superset.utils.hashing import hash_from_str
FIXED_USER_ID = 1234
FIXED_USERNAME = "admin"
@@ -330,3 +340,242 @@ def test_get_executor(
)
assert executor_type == expected_executor_type
assert executor == expected_executor
+
+
+@pytest.mark.parametrize(
+ "scope,task_type,task_key,user_id,expected_composite_key",
+ [
+ # Private tasks with TaskScope enum
+ (
+ TaskScope.PRIVATE,
+ "sql_execution",
+ "chart_123",
+ 42,
+ "private|sql_execution|chart_123|42",
+ ),
+ (
+ TaskScope.PRIVATE,
+ "thumbnail_gen",
+ "dash_456",
+ 100,
+ "private|thumbnail_gen|dash_456|100",
+ ),
+ # Private tasks with string scope
+ (
+ "private",
+ "api_call",
+ "endpoint_789",
+ 200,
+ "private|api_call|endpoint_789|200",
+ ),
+ # Shared tasks with TaskScope enum
+ (
+ TaskScope.SHARED,
+ "report_gen",
+ "monthly_report",
+ None,
+ "shared|report_gen|monthly_report",
+ ),
+ (
+ TaskScope.SHARED,
+ "export_csv",
+ "large_export",
+ 999, # user_id should be ignored for shared
+ "shared|export_csv|large_export",
+ ),
+ # Shared tasks with string scope
+ (
+ "shared",
+ "batch_process",
+ "batch_001",
+ 123, # user_id should be ignored for shared
+ "shared|batch_process|batch_001",
+ ),
+ # System tasks with TaskScope enum
+ (
+ TaskScope.SYSTEM,
+ "cleanup_task",
+ "daily_cleanup",
+ None,
+ "system|cleanup_task|daily_cleanup",
+ ),
+ (
+ TaskScope.SYSTEM,
+ "db_migration",
+ "version_123",
+ 1, # user_id should be ignored for system
+ "system|db_migration|version_123",
+ ),
+ # System tasks with string scope
+ (
+ "system",
+ "maintenance",
+ "nightly_job",
+ 2, # user_id should be ignored for system
+ "system|maintenance|nightly_job",
+ ),
+ ],
+)
+def test_get_active_dedup_key(
+ scope, task_type, task_key, user_id, expected_composite_key, app_context
+):
+ """Test get_active_dedup_key generates a hash of the composite key.
+
+ The function hashes the composite key using the configured HASH_ALGORITHM
+ to produce a fixed-length dedup_key for database storage. The result is
+ truncated to 64 chars to fit the database column.
+ """
+ result = get_active_dedup_key(scope, task_type, task_key, user_id)
+
+ # The result should be a hash of the expected composite key, truncated to 64 chars
+ expected_hash = hash_from_str(expected_composite_key)[:64]
+ assert result == expected_hash
+ assert len(result) <= 64
+
+
+def test_get_active_dedup_key_private_requires_user_id():
+ """Test that private tasks require explicit user_id parameter."""
+ with pytest.raises(ValueError, match="user_id required for private tasks"):
+ get_active_dedup_key(TaskScope.PRIVATE, "test_type", "test_key")
+
+
+def test_get_finished_dedup_key():
+ """Test that finished tasks use UUID as dedup_key"""
+ test_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
+ result = get_finished_dedup_key(test_uuid)
+ assert result == test_uuid
+
+
+@pytest.mark.parametrize(
+ "progress,expected",
+ [
+ # Float (percentage) progress
+ (0.5, {"progress_percent": 0.5}),
+ (0.0, {"progress_percent": 0.0}),
+ (1.0, {"progress_percent": 1.0}),
+ (0.25, {"progress_percent": 0.25}),
+ # Int (count only) progress
+ (42, {"progress_current": 42}),
+ (0, {"progress_current": 0}),
+ (1000, {"progress_current": 1000}),
+ # Tuple (current, total) progress with auto-computed percentage
+ (
+ (50, 100),
+ {"progress_current": 50, "progress_total": 100, "progress_percent": 0.5},
+ ),
+ (
+ (25, 100),
+ {"progress_current": 25, "progress_total": 100, "progress_percent": 0.25},
+ ),
+ (
+ (100, 100),
+ {"progress_current": 100, "progress_total": 100, "progress_percent": 1.0},
+ ),
+ # Tuple with zero total (no percentage computed)
+ ((10, 0), {"progress_current": 10, "progress_total": 0}),
+ ((0, 0), {"progress_current": 0, "progress_total": 0}),
+ ],
+)
+def test_progress_update(progress, expected):
+ """Test progress_update returns correct TaskProperties dict."""
+ result = progress_update(progress)
+ assert result == expected
+
+
+def test_error_update():
+ """Test error_update captures exception details."""
+ try:
+ raise ValueError("Test error message")
+ except ValueError as e:
+ result = error_update(e)
+
+ assert result["error_message"] == "Test error message"
+ assert result["exception_type"] == "ValueError"
+ assert "stack_trace" in result
+ assert "ValueError" in result["stack_trace"]
+
+
+def test_error_update_custom_exception():
+ """Test error_update with custom exception class."""
+
+ class CustomError(Exception):
+ pass
+
+ try:
+ raise CustomError("Custom error")
+ except CustomError as e:
+ result = error_update(e)
+
+ assert result["error_message"] == "Custom error"
+ assert result["exception_type"] == "CustomError"
+
+
+@pytest.mark.parametrize(
+ "json_str,expected",
+ [
+ # Valid JSON
+ (
+ '{"is_abortable": true, "progress_percent": 0.5}',
+ {"is_abortable": True, "progress_percent": 0.5},
+ ),
+ (
+ '{"error_message": "Something failed"}',
+ {"error_message": "Something failed"},
+ ),
+ (
+ '{"progress_current": 50, "progress_total": 100}',
+ {"progress_current": 50, "progress_total": 100},
+ ),
+ # Empty/None cases
+ ("", {}),
+ (None, {}),
+ # Invalid JSON returns empty dict
+ ("not valid json", {}),
+ ("{broken", {}),
+ # Unknown keys are preserved (forward compatibility)
+ (
+ '{"is_abortable": true, "future_field": "value"}',
+ {"is_abortable": True, "future_field": "value"},
+ ),
+ ],
+)
+def test_parse_properties(json_str, expected):
+ """Test parse_properties parses JSON to TaskProperties dict."""
+ result = parse_properties(json_str)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "props,expected_contains",
+ [
+ # Full properties
+ (
+ {"is_abortable": True, "progress_percent": 0.5},
+ {"is_abortable": True, "progress_percent": 0.5},
+ ),
+ # Empty dict
+ ({}, {}),
+ # Sparse properties
+ ({"is_abortable": True}, {"is_abortable": True}),
+ ({"error_message": "fail"}, {"error_message": "fail"}),
+ ],
+)
+def test_serialize_properties(props, expected_contains):
+ """Test serialize_properties converts TaskProperties to JSON."""
+ from superset.utils import json
+
+ result = serialize_properties(props)
+ parsed = json.loads(result)
+ assert parsed == expected_contains
+
+
+def test_properties_roundtrip():
+ """Test that serialize -> parse roundtrip preserves data."""
+ original = {
+ "is_abortable": True,
+ "progress_percent": 0.75,
+ "error_message": "Test error",
+ }
+ serialized = serialize_properties(original)
+ parsed = parse_properties(serialized)
+ assert parsed == original
diff --git a/tests/unit_tests/utils/json_tests.py b/tests/unit_tests/utils/json_tests.py
index 33565774077f..89c0a5a16574 100644
--- a/tests/unit_tests/utils/json_tests.py
+++ b/tests/unit_tests/utils/json_tests.py
@@ -54,7 +54,7 @@ def test_json_loads_exception():
def test_json_loads_encoding():
- unicode_data = b'{"a": "\u0073\u0074\u0072"}'
+ unicode_data = rb'{"a": "\u0073\u0074\u0072"}'
data = json.loads(unicode_data)
assert data["a"] == "str"
utf16_data = b'\xff\xfe{\x00"\x00a\x00"\x00:\x00 \x00"\x00s\x00t\x00r\x00"\x00}\x00'
diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py
index 33a0c0c26630..08b7cc9c6e79 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -119,7 +119,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
was revoked), the invalid token should be deleted and the exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
- mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
+ mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -149,7 +149,7 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
- mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
+ mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
@@ -175,7 +175,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
- mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
+ mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",