diff --git a/airflow/providers/slack/transfers/sql_to_slack.py b/airflow/providers/slack/transfers/sql_to_slack.py index bdd1cddd2b62c..cf5c01b22c9cf 100644 --- a/airflow/providers/slack/transfers/sql_to_slack.py +++ b/airflow/providers/slack/transfers/sql_to_slack.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Iterable, Mapping, Sequence from pandas import DataFrame @@ -25,13 +26,59 @@ from airflow.hooks.base import BaseHook from airflow.models import BaseOperator from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.slack.hooks.slack import SlackHook from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook +from airflow.providers.slack.utils import parse_filename if TYPE_CHECKING: from airflow.utils.context import Context -class SqlToSlackOperator(BaseOperator): +class BaseSqlToSlackOperator(BaseOperator): + """ + Operator implements base sql methods for SQL to Slack Transfer operators. + + :param sql: The SQL query to be executed + :param sql_conn_id: reference to a specific DB-API Connection. + :param sql_hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + :param parameters: The parameters to pass to the SQL query. + """ + + def __init__( + self, + *, + sql: str, + sql_conn_id: str, + sql_hook_params: dict | None = None, + parameters: Iterable | Mapping | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.sql_conn_id = sql_conn_id + self.sql_hook_params = sql_hook_params + self.sql = sql + self.parameters = parameters + + def _get_hook(self) -> DbApiHook: + self.log.debug("Get connection for %s", self.sql_conn_id) + conn = BaseHook.get_connection(self.sql_conn_id) + hook = conn.get_hook(hook_params=self.sql_hook_params) + if not callable(getattr(hook, "get_pandas_df", None)): + raise AirflowException( + "This hook is not supported. The hook class must have get_pandas_df method." + ) + return hook + + def _get_query_results(self) -> DataFrame: + sql_hook = self._get_hook() + + self.log.info("Running SQL query: %s", self.sql) + df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters) + return df + + +class SqlToSlackOperator(BaseSqlToSlackOperator): """ Executes an SQL statement in a given SQL connection and sends the results to Slack. The results of the query are rendered into the 'slack_message' parameter as a Pandas dataframe using a JINJA variable called @@ -79,12 +126,10 @@ def __init__( **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__( + sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, parameters=parameters, **kwargs + ) - self.sql_conn_id = sql_conn_id - self.sql_hook_params = sql_hook_params - self.sql = sql - self.parameters = parameters self.slack_conn_id = slack_conn_id self.slack_webhook_token = slack_webhook_token self.slack_channel = slack_channel @@ -97,23 +142,6 @@ def __init__( "SqlToSlackOperator requires either a `slack_conn_id` or a `slack_webhook_token` argument" ) - def _get_hook(self) -> DbApiHook: - self.log.debug("Get connection for %s", self.sql_conn_id) - conn = BaseHook.get_connection(self.sql_conn_id) - hook = conn.get_hook(hook_params=self.sql_hook_params) - if not callable(getattr(hook, "get_pandas_df", None)): - raise AirflowException( - "This hook is not supported. The hook class must have get_pandas_df method." - ) - return hook - - def _get_query_results(self) -> DataFrame: - sql_hook = self._get_hook() - - self.log.info("Running SQL query: %s", self.sql) - df = sql_hook.get_pandas_df(self.sql, parameters=self.parameters) - return df - def _render_and_send_slack_message(self, context, df) -> None: # Put the dataframe into the context and render the JINJA template fields context[self.results_df_name] = df @@ -157,3 +185,115 @@ def execute(self, context: Context) -> None: self._render_and_send_slack_message(context, df) self.log.debug("Finished sending SQL data to Slack") + + +class SqlToSlackApiFileOperator(BaseSqlToSlackOperator): + """ + Executes an SQL statement in a given SQL connection and sends the results to Slack API as file. + + :param sql: The SQL query to be executed + :param sql_conn_id: reference to a specific DB-API Connection. + :param slack_conn_id: :ref:`Slack API Connection `. + :param slack_filename: Filename for display in slack. + Should contain supported extension which referenced to ``SUPPORTED_FILE_FORMATS``. + It is also possible to set compression in extension: + ``filename.csv.gzip``, ``filename.json.zip``, etc. + :param sql_hook_params: Extra config params to be passed to the underlying hook. + Should match the desired hook constructor params. + :param parameters: The parameters to pass to the SQL query. + :param slack_channels: Comma-separated list of channel names or IDs where the file will be shared. + If omitting this parameter, then file will send to workspace. + :param slack_initial_comment: The message text introducing the file in specified ``slack_channels``. + :param slack_title: Title of file. + :param df_kwargs: Keyword arguments forwarded to ``pandas.DataFrame.to_{format}()`` method. + + Example: + .. code-block:: python + + SqlToSlackApiFileOperator( + task_id="sql_to_slack", + sql="SELECT 1 a, 2 b, 3 c", + sql_conn_id="sql-connection", + slack_conn_id="slack-api-connection", + slack_filename="awesome.json.gz", + slack_channels="#random,#general", + slack_initial_comment="Awesome load to compressed multiline JSON.", + df_kwargs={ + "orient": "records", + "lines": True, + }, + ) + """ + + template_fields: Sequence[str] = ( + "sql", + "slack_channels", + "slack_filename", + "slack_initial_comment", + "slack_title", + ) + template_ext: Sequence[str] = (".sql", ".jinja", ".j2") + template_fields_renderers = {"sql": "sql", "slack_message": "jinja"} + + SUPPORTED_FILE_FORMATS: Sequence[str] = ("csv", "json", "html") + + def __init__( + self, + *, + sql: str, + sql_conn_id: str, + sql_hook_params: dict | None = None, + parameters: Iterable | Mapping | None = None, + slack_conn_id: str, + slack_filename: str, + slack_channels: str | Sequence[str] | None = None, + slack_initial_comment: str | None = None, + slack_title: str | None = None, + df_kwargs: dict | None = None, + **kwargs, + ): + super().__init__( + sql=sql, sql_conn_id=sql_conn_id, sql_hook_params=sql_hook_params, parameters=parameters, **kwargs + ) + self.slack_conn_id = slack_conn_id + self.slack_filename = slack_filename + self.slack_channels = slack_channels + self.slack_initial_comment = slack_initial_comment + self.slack_title = slack_title + self.df_kwargs = df_kwargs or {} + + def execute(self, context: Context) -> None: + # Parse file format from filename + output_file_format, _ = parse_filename( + filename=self.slack_filename, + supported_file_formats=self.SUPPORTED_FILE_FORMATS, + ) + + slack_hook = SlackHook(slack_conn_id=self.slack_conn_id) + with NamedTemporaryFile(mode="w+", suffix=f"_{self.slack_filename}") as fp: + # tempfile.NamedTemporaryFile used only for create and remove temporary file, + # pandas will open file in correct mode itself depend on file type. + # So we close file descriptor here for avoid incidentally write anything. + fp.close() + + output_file_name = fp.name + output_file_format = output_file_format.upper() + df_result = self._get_query_results() + if output_file_format == "CSV": + df_result.to_csv(output_file_name, **self.df_kwargs) + elif output_file_format == "JSON": + df_result.to_json(output_file_name, **self.df_kwargs) + elif output_file_format == "HTML": + df_result.to_html(output_file_name, **self.df_kwargs) + else: + # Not expected that this error happen. This only possible + # if SUPPORTED_FILE_FORMATS extended and no actual implementation for specific format. + raise AirflowException(f"Unexpected output file format: {output_file_format}") + + slack_hook.send_file( + channels=self.slack_channels, + file=output_file_name, + filename=self.slack_filename, + initial_comment=self.slack_initial_comment, + title=self.slack_title, + ) diff --git a/airflow/providers/slack/utils/__init__.py b/airflow/providers/slack/utils/__init__.py index dda12656d48ea..1071de6299c27 100644 --- a/airflow/providers/slack/utils/__init__.py +++ b/airflow/providers/slack/utils/__init__.py @@ -17,7 +17,7 @@ from __future__ import annotations import warnings -from typing import Any +from typing import Any, Sequence from airflow.utils.types import NOTSET @@ -77,3 +77,41 @@ def getint(self, field, default: Any = NOTSET) -> Any: if value != default: value = int(value) return value + + +def parse_filename( + filename: str, supported_file_formats: Sequence[str], fallback: str | None = None +) -> tuple[str, str | None]: + """ + Parse filetype and compression from given filename. + :param filename: filename to parse. + :param supported_file_formats: list of supported file extensions. + :param fallback: fallback to given file format. + :returns: filetype and compression (if specified) + """ + if not filename: + raise ValueError("Expected 'filename' parameter is missing.") + if fallback and fallback not in supported_file_formats: + raise ValueError(f"Invalid fallback value {fallback!r}, expected one of {supported_file_formats}.") + + parts = filename.rsplit(".", 2) + try: + if len(parts) == 1: + raise ValueError(f"No file extension specified in filename {filename!r}.") + if parts[-1] in supported_file_formats: + return parts[-1], None + elif len(parts) == 2: + raise ValueError( + f"Unsupported file format {parts[-1]!r}, expected one of {supported_file_formats}." + ) + else: + if parts[-2] not in supported_file_formats: + raise ValueError( + f"Unsupported file format '{parts[-2]}.{parts[-1]}', " + f"expected one of {supported_file_formats} with compression extension." + ) + return parts[-2], parts[-1] + except ValueError as ex: + if fallback: + return fallback, None + raise ex from None diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py b/tests/providers/slack/transfers/test_sql_to_slack.py index 307469460b4fd..23efa895a2214 100644 --- a/tests/providers/slack/transfers/test_sql_to_slack.py +++ b/tests/providers/slack/transfers/test_sql_to_slack.py @@ -23,7 +23,11 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, Connection -from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator +from airflow.providers.slack.transfers.sql_to_slack import ( + BaseSqlToSlackOperator, + SqlToSlackApiFileOperator, + SqlToSlackOperator, +) from airflow.utils import timezone TEST_DAG_ID = "sql_to_slack_unit_test" @@ -31,6 +35,77 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) +class TestBaseSqlToSlackOperator: + def setup_method(self): + self.default_op_kwargs = { + "sql": "SELECT 1", + "sql_conn_id": "test-sql-conn-id", + "sql_hook_params": None, + "parameters": None, + } + + def test_execute_not_implemented(self): + """Test that no base implementation for ``BaseSqlToSlackOperator.execute()``.""" + op = BaseSqlToSlackOperator(task_id="test_base_not_implements", **self.default_op_kwargs) + with pytest.raises(NotImplementedError): + op.execute(mock.MagicMock()) + + @mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_hook") + @pytest.mark.parametrize("conn_type", ["postgres", "snowflake"]) + @pytest.mark.parametrize("sql_hook_params", [None, {"foo": "bar"}]) + def test_get_hook(self, mock_get_hook, mock_get_conn, conn_type, sql_hook_params): + class SomeDummyHook: + """Hook which implements ``get_pandas_df`` method""" + + def get_pandas_df(self): + pass + + expected_hook = SomeDummyHook() + mock_get_conn.return_value = Connection(conn_id=f"test_connection_{conn_type}", conn_type=conn_type) + mock_get_hook.return_value = expected_hook + op_kwargs = { + **self.default_op_kwargs, + "sql_hook_params": sql_hook_params, + } + op = BaseSqlToSlackOperator(task_id="test_get_hook", **op_kwargs) + hook = op._get_hook() + mock_get_hook.assert_called_once_with(hook_params=sql_hook_params) + assert hook == expected_hook + + @mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_hook") + def test_get_not_supported_hook(self, mock_get_hook, mock_get_conn): + class SomeDummyHook: + """Hook which not implemented ``get_pandas_df`` method""" + + mock_get_conn.return_value = Connection(conn_id="test_connection", conn_type="test_connection") + mock_get_hook.return_value = SomeDummyHook() + op = BaseSqlToSlackOperator(task_id="test_get_not_supported_hook", **self.default_op_kwargs) + error_message = r"This hook is not supported. The hook class must have get_pandas_df method\." + with pytest.raises(AirflowException, match=error_message): + op._get_hook() + + @mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_hook") + @pytest.mark.parametrize("sql", ["SELECT 42", "SELECT 1 FROM DUMMY WHERE col = ?"]) + @pytest.mark.parametrize("parameters", [None, {"col": "spam-egg"}]) + def test_get_query_results(self, mock_op_get_hook, sql, parameters): + test_df = pd.DataFrame({"a": "1", "b": "2"}, index=[0, 1]) + mock_get_pandas_df = mock.MagicMock(return_value=test_df) + mock_hook = mock.MagicMock() + mock_hook.get_pandas_df = mock_get_pandas_df + mock_op_get_hook.return_value = mock_hook + op_kwargs = { + **self.default_op_kwargs, + "sql": sql, + "parameters": parameters, + } + op = BaseSqlToSlackOperator(task_id="test_get_query_results", **op_kwargs) + df = op._get_query_results() + mock_get_pandas_df.assert_called_once_with(sql, parameters=parameters) + assert df is test_df + + class TestSqlToSlackOperator: def setup_method(self): self.example_dag = DAG(TEST_DAG_ID, start_date=DEFAULT_DATE) @@ -215,3 +290,90 @@ def test_hook_params_snowflake(self, mock_get_conn): assert hook.database == "database" assert hook.role == "role" assert hook.schema == "schema" + + +class TestSqlToSlackApiFileOperator: + def setup_method(self): + self.default_op_kwargs = { + "sql": "SELECT 1", + "sql_conn_id": "test-sql-conn-id", + "slack_conn_id": "test-slack-conn-id", + "sql_hook_params": None, + "parameters": None, + } + + @mock.patch("airflow.providers.slack.transfers.sql_to_slack.BaseSqlToSlackOperator._get_query_results") + @mock.patch("airflow.providers.slack.transfers.sql_to_slack.SlackHook") + @pytest.mark.parametrize( + "filename,df_method", + [ + ("awesome.json", "to_json"), + ("awesome.json.zip", "to_json"), + ("awesome.csv", "to_csv"), + ("awesome.csv.xz", "to_csv"), + ("awesome.html", "to_html"), + ], + ) + @pytest.mark.parametrize("df_kwargs", [None, {}, {"foo": "bar"}]) + @pytest.mark.parametrize("channels", ["#random", "#random,#general", None]) + @pytest.mark.parametrize("initial_comment", [None, "Test Comment"]) + @pytest.mark.parametrize("title", [None, "Test File Title"]) + def test_send_file( + self, + mock_slack_hook_cls, + mock_get_query_results, + filename, + df_method, + df_kwargs, + channels, + initial_comment, + title, + ): + # Mock Hook + mock_send_file = mock.MagicMock() + mock_slack_hook_cls.return_value.send_file = mock_send_file + + # Mock returns pandas.DataFrame and expected method + mock_df = mock.MagicMock() + mock_df_output_method = mock.MagicMock() + setattr(mock_df, df_method, mock_df_output_method) + mock_get_query_results.return_value = mock_df + + op_kwargs = { + **self.default_op_kwargs, + "slack_conn_id": "expected-test-slack-conn-id", + "slack_filename": filename, + "slack_channels": channels, + "slack_initial_comment": initial_comment, + "slack_title": title, + "df_kwargs": df_kwargs, + } + op = SqlToSlackApiFileOperator(task_id="test_send_file", **op_kwargs) + op.execute(mock.MagicMock()) + + mock_slack_hook_cls.assert_called_once_with(slack_conn_id="expected-test-slack-conn-id") + mock_get_query_results.assert_called_once_with() + mock_df_output_method.assert_called_once_with(mock.ANY, **(df_kwargs or {})) + mock_send_file.assert_called_once_with( + channels=channels, + filename=filename, + initial_comment=initial_comment, + title=title, + file=mock.ANY, + ) + + @pytest.mark.parametrize( + "filename", + [ + "foo.parquet", + "bat.parquet.snappy", + "spam.xml", + "egg.xlsx", + ], + ) + def test_unsupported_format(self, filename): + op = SqlToSlackApiFileOperator( + task_id="test_send_file", slack_filename=filename, **self.default_op_kwargs + ) + with pytest.raises(ValueError): + op.execute(mock.MagicMock()) diff --git a/tests/providers/slack/utils/test_utils.py b/tests/providers/slack/utils/test_utils.py index d794c80f60c10..bff3dbc658e80 100644 --- a/tests/providers/slack/utils/test_utils.py +++ b/tests/providers/slack/utils/test_utils.py @@ -18,7 +18,7 @@ import pytest -from airflow.providers.slack.utils import ConnectionExtraConfig +from airflow.providers.slack.utils import ConnectionExtraConfig, parse_filename class TestConnectionExtra: @@ -92,3 +92,55 @@ def test_get_parse_int(self): ) assert extra_config.getint("int_arg_1") == 42 assert extra_config.getint("int_arg_2") == 9000 + + +class TestParseFilename: + SUPPORTED_FORMAT = ("so", "dll", "exe", "sh") + + def test_error_parse_without_extension(self): + with pytest.raises(ValueError, match="No file extension specified in filename"): + assert parse_filename("Untitled File", self.SUPPORTED_FORMAT) + + @pytest.mark.parametrize( + "filename,expected_format", + [ + ("libc.so", "so"), + ("kernel32.dll", "dll"), + ("xxx.mp4.exe", "exe"), + ("init.sh", "sh"), + ], + ) + def test_parse_first_level(self, filename, expected_format): + assert parse_filename(filename, self.SUPPORTED_FORMAT) == (expected_format, None) + + @pytest.mark.parametrize("filename", ["New File.txt", "cats-memes.mp4"]) + def test_error_parse_first_level(self, filename): + with pytest.raises(ValueError, match="Unsupported file format"): + assert parse_filename(filename, self.SUPPORTED_FORMAT) + + @pytest.mark.parametrize( + "filename,expected", + [ + ("libc.so.6", ("so", "6")), + ("kernel32.dll.zip", ("dll", "zip")), + ("explorer.exe.7z", ("exe", "7z")), + ("init.sh.gz", ("sh", "gz")), + ], + ) + def test_parse_second_level(self, filename, expected): + assert parse_filename(filename, self.SUPPORTED_FORMAT) == expected + + @pytest.mark.parametrize("filename", ["example.so.tar.gz", "w.i.e.r.d"]) + def test_error_parse_second_level(self, filename): + with pytest.raises(ValueError, match="Unsupported file format.*with compression extension."): + assert parse_filename(filename, self.SUPPORTED_FORMAT) + + @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt", "example.so.tar.gz"]) + @pytest.mark.parametrize("fallback", SUPPORTED_FORMAT) + def test_fallback(self, filename, fallback): + assert parse_filename(filename, self.SUPPORTED_FORMAT, fallback) == (fallback, None) + + @pytest.mark.parametrize("filename", ["Untitled File", "New File.txt", "example.so.tar.gz"]) + def test_wrong_fallback(self, filename): + with pytest.raises(ValueError, match="Invalid fallback value"): + assert parse_filename(filename, self.SUPPORTED_FORMAT, "mp4")