Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/admin_docs/configuration/cache.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,39 @@ instead requires a cachelib object.

See [Async Queries via Celery](/admin-docs/configuration/async-queries-celery) for details.

## Celery beat

Superset has a Celery task that will periodically warm up the cache based on different strategies.
To use it, add the following to your `superset_config.py`:

```python
from celery.schedules import crontab
from superset.config import CeleryConfig

# User that will be used to authenticate and render dashboards for cache warmup
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"

# Extend the default CeleryConfig to add cache warmup schedule
class CustomCeleryConfig(CeleryConfig):
beat_schedule = {
**CeleryConfig.beat_schedule,
'cache-warmup-hourly': {
'task': 'cache-warmup',
'schedule': crontab(minute=0, hour='*'), # hourly
'kwargs': {
'strategy_name': 'top_n_dashboards',
'top_n': 5,
'since': '7 days ago',
},
},
}

CELERY_CONFIG = CustomCeleryConfig
```

This will cache the top 5 most popular dashboards every hour. For other
strategies, check the `superset/tasks/cache.py` file.

## Caching Thumbnails

This is an optional feature that can be turned on by activating its [feature flag](/admin-docs/configuration/configuring-superset#feature-flags) on config:
Expand Down
5 changes: 5 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,11 @@ class D3TimeFormat(TypedDict, total=False):
}
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())

# Cache warmup user — must be set explicitly before enabling the cache-warmup
# Celery task. Intentionally defaults to None so operators pick a dedicated
# least-privilege user rather than inadvertently running warmup as "admin".
SUPERSET_CACHE_WARMUP_USER: str | None = None

# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())
Expand Down
6 changes: 5 additions & 1 deletion superset/mcp_service/screenshot/webdriver_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ def _health_check_driver(self, pooled_driver: PooledWebDriver) -> bool:
def _destroy_driver(self, pooled_driver: PooledWebDriver) -> None:
"""Safely destroy a WebDriver instance"""
try:
WebDriverSelenium.destroy(pooled_driver.driver)
try:
pooled_driver.driver.close()
except Exception: # pylint: disable=broad-except # noqa: S110
pass
pooled_driver.driver.quit()
self._stats["destroyed"] += 1
logger.debug("Destroyed WebDriver instance")
except Exception as e:
Expand Down
205 changes: 73 additions & 132 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,65 +17,46 @@
from __future__ import annotations

import logging
from typing import Any, Optional, TypedDict, Union
from urllib import request
from urllib.error import URLError
from typing import Any, Optional, Union

from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
from flask import current_app
from selenium.common.exceptions import WebDriverException
from sqlalchemy import and_, func
from sqlalchemy.orm import selectinload

from superset import db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.utils import fetch_csrf_token, get_executor
from superset.utils import json
from superset.utils.date_parser import parse_human_datetime
from superset.utils.machine_auth import MachineAuthProvider
from superset.utils.urls import get_url_path, is_secure_url
from superset.utils.webdriver import WebDriverSelenium

logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)


class CacheWarmupPayload(TypedDict, total=False):
chart_id: int
dashboard_id: int | None


class CacheWarmupTask(TypedDict):
payload: CacheWarmupPayload
username: str | None


def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> CacheWarmupTask:
"""Return task for warming up a given chart/table cache."""
executors = current_app.config["CACHE_WARMUP_EXECUTORS"]
payload: CacheWarmupPayload = {"chart_id": chart.id}
if dashboard:
payload["dashboard_id"] = dashboard.id

username: str | None
try:
executor = get_executor(executors, chart)
username = executor[1]
except (ExecutorNotFoundError, InvalidExecutorError):
username = None

return {"payload": payload, "username": username}
def get_dash_url(dashboard: Dashboard) -> str:
"""Return external URL for warming up a given dashboard cache."""
with current_app.test_request_context():
baseurl = (
# when running this as an async task, drop the request context with
# app.test_request_context()
current_app.config.get("WEBDRIVER_BASEURL")
or "{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**current_app.config)
)
return f"{baseurl.rstrip('/')}{dashboard.url}"


class Strategy: # pylint: disable=too-few-public-methods
"""
A cache warm up strategy.

Each strategy defines a `get_tasks` method that returns a list of tasks to
send to the `/api/v1/chart/warm_up_cache` endpoint.
Each strategy defines a `get_urls` method that returns a list of dashboard URLs to
warm up using WebDriver.

Strategies can be configured in `superset/config.py`:

Expand All @@ -96,15 +77,16 @@ class Strategy: # pylint: disable=too-few-public-methods
def __init__(self) -> None:
pass

def get_tasks(self) -> list[CacheWarmupTask]:
raise NotImplementedError("Subclasses must implement get_tasks!")
def get_urls(self) -> list[str]:
raise NotImplementedError("Subclasses must implement get_urls!")


class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
"""
Warm up all charts.
Warm up all published dashboards.

This is a dummy strategy that will fetch all charts. Can be configured by:
This is a dummy strategy that will fetch all published dashboards.
Can be configured by:

beat_schedule = {
'cache-warmup-hourly': {
Expand All @@ -118,8 +100,16 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods

name = "dummy"

def get_tasks(self) -> list[CacheWarmupTask]:
return [get_task(chart) for chart in db.session.query(Slice).all()]
def get_urls(self) -> list[str]:
# Use selectinload to avoid N+1 queries when checking dashboard.slices
dashboards = (
db.session.query(Dashboard)
.options(selectinload(Dashboard.slices))
.filter(Dashboard.published.is_(True))
.all()
)

return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]


class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -147,7 +137,7 @@ def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None:
self.top_n = top_n
self.since = parse_human_datetime(since) if since else None

def get_tasks(self) -> list[CacheWarmupTask]:
def get_urls(self) -> list[str]:
records = (
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
Expand All @@ -161,11 +151,7 @@ def get_tasks(self) -> list[CacheWarmupTask]:
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)

return [
get_task(chart, dashboard)
for dashboard in dashboards
for chart in dashboard.slices
]
return [get_dash_url(dashboard) for dashboard in dashboards]


class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand All @@ -190,8 +176,8 @@ def __init__(self, tags: Optional[list[str]] = None) -> None:
super().__init__()
self.tags = tags or []

def get_tasks(self) -> list[CacheWarmupTask]:
tasks = []
def get_urls(self) -> list[str]:
urls = []
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]

Expand All @@ -211,81 +197,22 @@ def get_tasks(self) -> list[CacheWarmupTask]:
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
tasks.append(get_task(chart))

# add charts that are tagged
tagged_objects = (
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
tasks.append(get_task(chart))
urls.append(get_dash_url(dashboard))

return tasks
return urls


strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]


@celery_app.task(name="fetch_url")
def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
url = get_url_path("ChartRestApi.warm_up_cache")

if is_secure_url(url):
logger.info("URL '%s' is secure. Adding Referer header.", url)
headers.update({"Referer": url})

# Fetch CSRF token for API request
headers.update(fetch_csrf_token(headers))

logger.info("Fetching %s with payload %s", url, data)
req = request.Request( # noqa: S310
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
)
response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310
req, timeout=600
)
logger.info(
"Fetched %s with payload %s, status code: %s", url, data, response.code
)
if response.code == 200:
result = {"success": data, "response": response.read().decode("utf-8")}
else:
result = {"error": data, "status_code": response.code}
logger.error(
"Error fetching %s with payload %s, status code: %s",
url,
data,
response.code,
)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": data, "exception": str(err)}
return result


@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
) -> Union[dict[str, list[str]], str]:
"""
Warm up cache.

This task periodically hits charts to warm up the cache.
This task periodically hits dashboards to warm up the cache.

"""
logger.info("Loading strategy")
Expand All @@ -307,25 +234,39 @@ def cache_warmup(
logger.exception(message)
return message

results: dict[str, list[str]] = {"scheduled": [], "errors": []}
for task in strategy.get_tasks():
username = task["username"]
payload = json.dumps(task["payload"])
if username:
results: dict[str, list[str]] = {"success": [], "errors": []}

warmup_username = current_app.config.get("SUPERSET_CACHE_WARMUP_USER")
if not warmup_username:
message = (
"SUPERSET_CACHE_WARMUP_USER is not configured. Set it to a dedicated "
"least-privilege user with access to the dashboards you want warmed up."
)
logger.error(message)
return message

user = security_manager.find_user(username=warmup_username)
if not user:
message = (
f"Cache warmup user '{warmup_username}' not found. Please configure "
"SUPERSET_CACHE_WARMUP_USER with a valid username."
)
logger.error(message)
return message

wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user)

try:
for url in strategy.get_urls():
try:
user = security_manager.get_user_by_username(username)
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {
"Cookie": "session=%s" % cookies.get("session", ""),
"Content-Type": "application/json",
}
logger.info("Scheduling %s", payload)
fetch_url.delay(payload, headers)
results["scheduled"].append(payload)
except SchedulingError:
logger.exception("Error scheduling fetch_url for payload: %s", payload)
results["errors"].append(payload)
else:
logger.warning("Executor not found for %s", payload)
logger.info("Fetching %s", url)
wd.get_screenshot(url, "grid-container")
Comment thread
rusackas marked this conversation as resolved.
results["success"].append(url)
except (WebDriverException, Exception) as ex: # noqa: BLE001
logger.exception("Error warming up cache for %s: %s", url, ex)
results["errors"].append(url)
finally:
# Ensure WebDriver is properly cleaned up
wd.destroy()

return results
Loading
Loading