diff --git a/superset/config.py b/superset/config.py index d36fe30f6965..7b033689d3b2 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1015,6 +1015,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods "superset.tasks.scheduler", "superset.tasks.thumbnails", "superset.tasks.cache", + "superset.tasks.slack", ) result_backend = "db+sqlite:///celery_results.sqlite" worker_prefetch_multiplier = 1 @@ -1046,6 +1047,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # "schedule": crontab(minute="*", hour="*"), # "kwargs": {"retention_period_days": 180}, # }, + # Uncomment to enable Slack channel cache warm-up + # "slack.cache_channels": { + # "task": "slack.cache_channels", + # "schedule": crontab(minute="0", hour="*"), + # }, } @@ -1488,6 +1494,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # noq # Slack API token for the superset reports, either string or callable SLACK_API_TOKEN: Callable[[], str] | str | None = None SLACK_PROXY = None +SLACK_CACHE_TIMEOUT = int(timedelta(days=1).total_seconds()) # The webdriver to use for generating reports. Use one of the following # firefox diff --git a/superset/tasks/slack.py b/superset/tasks/slack.py new file mode 100644 index 000000000000..0b35a721bb51 --- /dev/null +++ b/superset/tasks/slack.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask import current_app + +from superset.extensions import celery_app +from superset.utils.slack import get_channels + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="slack.cache_channels") +def cache_channels() -> None: + try: + get_channels( + force=True, cache_timeout=current_app.config["SLACK_CACHE_TIMEOUT"] + ) + except Exception as ex: + logger.exception("An error occurred while caching Slack channels: %s", ex) + raise diff --git a/superset/utils/slack.py b/superset/utils/slack.py index 8125a3ac4019..34d48bef21b1 100644 --- a/superset/utils/slack.py +++ b/superset/utils/slack.py @@ -17,7 +17,7 @@ import logging -from typing import Any, Optional +from typing import Callable, Optional from flask import current_app from slack_sdk import WebClient @@ -60,7 +60,7 @@ def get_slack_client() -> WebClient: key="slack_conversations_list", cache=cache_manager.cache, ) -def get_channels(limit: int, extra_params: dict[str, Any]) -> list[SlackChannelSchema]: +def get_channels() -> list[SlackChannelSchema]: """ Retrieves a list of all conversations accessible by the bot from the Slack API, and caches results (to avoid rate limits). @@ -71,11 +71,12 @@ def get_channels(limit: int, extra_params: dict[str, Any]) -> list[SlackChannelS client = get_slack_client() channel_schema = SlackChannelSchema() channels: list[SlackChannelSchema] = [] + extra_params = {"types": ",".join(SlackChannelTypes)} cursor = None while True: response = client.conversations_list( - limit=limit, cursor=cursor, exclude_archived=True, **extra_params + limit=999, cursor=cursor, exclude_archived=True, **extra_params ) channels.extend( channel_schema.load(channel) for channel in response.data["channels"] @@ -89,7 +90,6 @@ def get_channels(limit: int, extra_params: dict[str, Any]) -> list[SlackChannelS def get_channels_with_search( search_string: str = "", - limit: int = 999, types: Optional[list[SlackChannelTypes]] = None, exact_match: bool = False, force: bool = False, @@ -99,18 +99,25 @@ def get_channels_with_search( all channels and filter them ourselves This will search by slack name or id """ - extra_params = {} - extra_params["types"] = ",".join(types) if types else None try: channels = get_channels( - limit=limit, - extra_params=extra_params, force=force, - cache_timeout=86400, + cache_timeout=current_app.config["SLACK_CACHE_TIMEOUT"], ) except (SlackClientError, SlackApiError) as ex: raise SupersetException(f"Failed to list channels: {ex}") from ex + if types and not len(types) == len(SlackChannelTypes): + conditions: list[Callable[[SlackChannelSchema], bool]] = [] + if SlackChannelTypes.PUBLIC in types: + conditions.append(lambda channel: not channel["is_private"]) + if SlackChannelTypes.PRIVATE in types: + conditions.append(lambda channel: channel["is_private"]) + + channels = [ + channel for channel in channels if any(cond(channel) for cond in conditions) + ] + # The search string can be multiple channels separated by commas if search_string: search_array = recipients_string_to_list(search_string) diff --git a/tests/unit_tests/utils/slack_test.py b/tests/unit_tests/utils/slack_test.py index ed7a82c220c7..024d6cf96ee0 100644 --- a/tests/unit_tests/utils/slack_test.py +++ b/tests/unit_tests/utils/slack_test.py @@ -17,7 +17,7 @@ import pytest -from superset.utils.slack import get_channels_with_search +from superset.utils.slack import get_channels_with_search, SlackChannelTypes class MockResponse: @@ -150,15 +150,35 @@ def test_handle_slack_client_error_listing_channels(self, mocker): The server responded with: missing scope: channels:read""" ) - def test_filter_channels_by_specified_types(self, mocker): + @pytest.mark.parametrize( + "types, expected_channel_ids", + [ + ([SlackChannelTypes.PUBLIC], {"public_channel_id"}), + ([SlackChannelTypes.PRIVATE], {"private_channel_id"}), + ( + [SlackChannelTypes.PUBLIC, SlackChannelTypes.PRIVATE], + {"public_channel_id", "private_channel_id"}, + ), + ([], {"public_channel_id", "private_channel_id"}), + ], + ) + def test_filter_channels_by_specified_types( + self, types: list[SlackChannelTypes], expected_channel_ids: set[str], mocker + ): mock_data = { "channels": [ { - "id": "C12345", - "name": "general", + "id": "public_channel_id", + "name": "open", "is_member": False, "is_private": False, }, + { + "id": "private_channel_id", + "name": "secret", + "is_member": False, + "is_private": True, + }, ], "response_metadata": {"next_cursor": None}, } @@ -168,15 +188,8 @@ def test_filter_channels_by_specified_types(self, mocker): mock_client.conversations_list.return_value = mock_response_instance mocker.patch("superset.utils.slack.get_slack_client", return_value=mock_client) - result = get_channels_with_search(types=["public"]) - assert result == [ - { - "id": "C12345", - "name": "general", - "is_member": False, - "is_private": False, - } - ] + result = get_channels_with_search(types=types) + assert {channel["id"] for channel in result} == expected_channel_ids def test_handle_pagination_multiple_pages(self, mocker): mock_data_page1 = {