Skip to content

Commit

Permalink
fix(celery cache warmup): add auth and use warm_up_cache endpoint (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
nytai authored Aug 30, 2022
1 parent b354f22 commit 04dd8d4
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 181 deletions.
10 changes: 10 additions & 0 deletions docker/pythonpath_dev/superset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:

RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")

CACHE_CONFIG = {
"CACHE_TYPE": "redis",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_KEY_PREFIX": "superset_",
"CACHE_REDIS_HOST": REDIS_HOST,
"CACHE_REDIS_PORT": REDIS_PORT,
"CACHE_REDIS_DB": REDIS_RESULTS_DB,
}
DATA_CACHE_CONFIG = CACHE_CONFIG


class CeleryConfig(object):
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
Expand Down
98 changes: 44 additions & 54 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional, Union
from urllib import request
from urllib.error import URLError

from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
from sqlalchemy import and_, func

from superset import app, db
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime
from superset.views.utils import build_extra_filters
from superset.utils.machine_auth import MachineAuthProvider

logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)


def get_form_data(
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.
When a dashboard has `default_filters` they need to be added as extra
filters in the GET request for charts.
"""
form_data: Dict[str, Any] = {"slice_id": chart_id}

if dashboard is None or not dashboard.json_metadata:
return form_data

json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
if not default_filters:
return form_data

filter_scopes = json_metadata.get("filter_scopes", {})
layout = json.loads(dashboard.position_json or "{}")
if (
isinstance(layout, dict)
and isinstance(filter_scopes, dict)
and isinstance(default_filters, dict)
):
extra_filters = build_extra_filters(
layout, filter_scopes, default_filters, chart_id
)
if extra_filters:
form_data["extra_filters"] = extra_filters

return form_data


def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str:
"""Return external URL for warming up a given chart/table cache."""
with app.test_request_context():
baseurl = (
"{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**app.config)
)
return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}"
if dashboard:
url += f"&dashboard_id={dashboard.id}"
return url


class Strategy: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -179,8 +142,7 @@ def get_urls(self) -> List[str]:
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
urls.append(get_url(chart, form_data_with_filters))
urls.append(get_url(chart, dashboard))

return urls

Expand Down Expand Up @@ -253,6 +215,30 @@ def get_urls(self) -> List[str]:
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]


@celery_app.task(name="fetch_url")
def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
logger.info("Fetching %s", url)
req = request.Request(url, headers=headers)
response = request.urlopen( # pylint: disable=consider-using-with
req, timeout=600
)
logger.info("Fetched %s, status code: %s", url, response.code)
if response.code == 200:
result = {"success": url, "response": response.read().decode("utf-8")}
else:
result = {"error": url, "status_code": response.code}
logger.error("Error fetching %s, status code: %s", url, response.code)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": url, "exception": str(err)}
return result


@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -282,14 +268,18 @@ def cache_warmup(
logger.exception(message)
return message

results: Dict[str, List[str]] = {"success": [], "errors": []}
user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"])
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {"Cookie": f"session={cookies.get('session', '')}"}

results: Dict[str, List[str]] = {"scheduled": [], "errors": []}
for url in strategy.get_urls():
try:
logger.info("Fetching %s", url)
request.urlopen(url) # pylint: disable=consider-using-with
results["success"].append(url)
except URLError:
logger.exception("Error warming up cache!")
logger.info("Scheduling %s", url)
fetch_url.delay(url, headers)
results["scheduled"].append(url)
except SchedulingError:
logger.exception("Error scheduling fetch_url: %s", url)
results["errors"].append(url)

return results
141 changes: 14 additions & 127 deletions tests/integration_tests/strategy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import (
DashboardTagsStrategy,
get_form_data,
TopNDashboardsStrategy,
)
from superset.utils.urls import get_url_host

from .base_tests import SupersetTestCase
from .dashboard_utils import create_dashboard, create_slice, create_table_metadata
Expand All @@ -49,7 +49,6 @@
load_unicode_data,
)

URL_PREFIX = "http://0.0.0.0:8081"

mock_positions = {
"DASHBOARD_VERSION_KEY": "v2",
Expand All @@ -69,128 +68,6 @@


class TestCacheWarmUp(SupersetTestCase):
def test_get_form_data_chart_only(self):
chart_id = 1
result = get_form_data(chart_id, None)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)

def test_get_form_data_no_dashboard_metadata(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = None
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)

def test_get_form_data_immune_slice(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"filter_scopes": {
str(filter_box_id): {
"name": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
"default_filters": json.dumps(
{str(filter_box_id): {"name": ["Alice", "Bob"]}}
),
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)

def test_get_form_data_no_default_filters(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = json.dumps({})
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)

def test_get_form_data_immune_fields(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [{"col": "name", "op": "in", "val": ["Alice", "Bob"]}],
}
self.assertEqual(result, expected)

def test_get_form_data_no_extra_filters(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{str(filter_box_id): {"__time_range": "100 years ago : today"}}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)

def test_get_form_data(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
)
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [
{"col": "name", "op": "in", "val": ["Alice", "Bob"]},
{"col": "__time_range", "op": "==", "val": "100 years ago : today"},
],
}
self.assertEqual(result, expected)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_top_n_dashboards_strategy(self):
# create a top visited dashboard
Expand All @@ -202,7 +79,12 @@ def test_top_n_dashboards_strategy(self):

strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls())
expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
expected = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
for slc in dash.slices
]
)
self.assertEqual(result, expected)

def reset_tag(self, tag):
Expand All @@ -228,7 +110,12 @@ def test_dashboard_tags(self):
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tag1_urls = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
for slc in dash.slices
]
)
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
)
Expand All @@ -248,7 +135,7 @@ def test_dashboard_tags(self):
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
slc = dash.slices[0]
tag2_urls = [f"{URL_PREFIX}{slc.url}"]
tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
object_id = slc.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
Expand Down
2 changes: 2 additions & 0 deletions tests/integration_tests/superset_test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"DRILL_TO_DETAIL": True,
}

WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"


def GET_FEATURE_FLAGS_FUNC(ff):
ff_copy = copy(ff)
Expand Down

0 comments on commit 04dd8d4

Please sign in to comment.