diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 1c67350c4dfc1..fd45c73355c92 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -114,6 +114,7 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa # Hook deriving from the DBApiHook to still have access to the field in it's constructor self.__schema = schema self.log_sql = log_sql + self.running_query_ids: list[str] = [] def get_conn(self): """Returns a connection object""" @@ -244,6 +245,7 @@ def run( :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the ALL SQL expressions if handler was provided. """ + self.running_query_ids = [] scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: @@ -264,6 +266,7 @@ def run( results = [] for sql_statement in sql: self._run_command(cur, sql_statement, parameters) + self._update_query_ids(cur) if handler is not None: result = handler(cur) @@ -294,6 +297,22 @@ def _run_command(self, cur, sql_statement, parameters): if cur.rowcount >= 0: self.log.info("Rows affected: %s", cur.rowcount) + def _update_query_ids(self, cursor) -> None: + """ + Adds query ids to list + :param cur: current cursor after run + :return: + """ + return None + + def kill_query(self, query_id) -> Any: + """ + Stops query with certain identifier + :param query_id: identifier of the query + :return: + """ + raise NotImplementedError + def set_autocommit(self, conn, autocommit): """Sets the autocommit flag on the connection""" if not self.supports_autocommit and autocommit: diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 8dee6ed968e57..c779c3d7fe94f 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -197,9 +197,8 @@ def __init__( def execute(self, context): self.log.info("Executing: %s", self.sql) - hook = self.get_db_hook() if self.do_xcom_push: - output = hook.run( + output = self._hook.run( sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, @@ -208,7 +207,7 @@ def execute(self, context): return_last=self.return_last, ) else: - output = hook.run( + output = self._hook.run( sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, @@ -221,6 +220,16 @@ def execute(self, context): return output + def on_kill(self) -> None: + for query_id in self._hook.running_query_ids.copy(): + self.log.info("Stopping query: %s", query_id) + try: + self._hook.kill_query(query_id) + except NotImplementedError: + self.log.info("Method '.kill()' is not implemented for ", self._hook.__class__.__name__) + except Exception as e: + self.log.info("The query '%s' can not be killed due to %s", query_id, str(e)) + def prepare_template(self) -> None: """Parse template file for attribute parameters.""" if isinstance(self.parameters, str): diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index e2b9101789653..ff01dca1f02fc 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -24,6 +24,7 @@ from databricks.sql.utils import ParamEscaper +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator @@ -115,7 +116,8 @@ def __init__( **hook_params, } - def get_db_hook(self) -> DatabricksSqlHook: + @cached_property + def _hook(self) -> DatabricksSqlHook: return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params) def _process_output(self, schema, results): diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 825b28ad60a54..fab3a8132eedd 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -196,6 +196,16 @@ def run( return_last=return_last, ) + def _update_query_ids(self, cursor) -> None: + self.running_query_ids.append(cursor.stats["queryId"]) + + def kill_query(self, query_id) -> Any: + result = super().run( + sql=f"CALL system.runtime.kill_query(query_id => '{query_id}',message => 'Job killed by user');", + handler=list, + ) + return result + def insert_rows( self, table: str, diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 138025a455652..24ac99a6b8b64 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -18,7 +18,6 @@ from __future__ import annotations import os -from contextlib import closing from functools import wraps from io import StringIO from pathlib import Path @@ -27,7 +26,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from snowflake import connector -from snowflake.connector import DictCursor, SnowflakeConnection, util_text +from snowflake.connector import SnowflakeConnection, util_text from snowflake.sqlalchemy import URL from sqlalchemy import create_engine @@ -178,7 +177,7 @@ def __init__(self, *args, **kwargs) -> None: self.schema = kwargs.pop("schema", None) self.authenticator = kwargs.pop("authenticator", None) self.session_parameters = kwargs.pop("session_parameters", None) - self.query_ids: list[str] = [] + self.running_query_ids: list[str] = [] def _get_field(self, extra_dict, field_name): backcompat_prefix = "extra__snowflake__" @@ -321,6 +320,21 @@ def set_autocommit(self, conn, autocommit: Any) -> None: def get_autocommit(self, conn): return getattr(conn, "autocommit_mode", False) + @staticmethod + def split_sql_string(sql: str) -> list[str]: + split_statements_tuple = util_text.split_statements(StringIO(sql)) + return [sql_string for sql_string, _ in split_statements_tuple if sql_string] + + def _update_query_ids(self, cursor) -> None: + self.running_query_ids.append(cursor.sfqid) + + def kill_query(self, query_id) -> Any: + result = super().run( + sql=f"CALL system$cancel_query('{query_id}');", + handler=list, + ) + return result + def run( self, sql: str | Iterable[str], @@ -330,64 +344,11 @@ def run( split_statements: bool = True, return_last: bool = True, ) -> Any | list[Any] | None: - """ - Runs a command or a list of commands. Pass a list of sql - statements to the sql parameter to get them to execute - sequentially. The variable execution_info is returned so that - it can be used in the Operators to modify the behavior - depending on the result of the query (i.e fail the operator - if the copy has processed 0 files) - - :param sql: the sql string to be executed with possibly multiple statements, - or a list of sql statements to execute - :param autocommit: What to set the connection's autocommit setting to - before executing the query. - :param parameters: The parameters to render the SQL query with. - :param handler: The result handler which is called with the result of each statement. - :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return result for only last statement or for all after split - :return: return only result of the LAST SQL expression if handler was provided. - """ - self.query_ids = [] - - scalar_return_last = isinstance(sql, str) and return_last - if isinstance(sql, str): - if split_statements: - split_statements_tuple = util_text.split_statements(StringIO(sql)) - sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] - else: - sql = [self.strip_sql_string(sql)] - - if sql: - self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) - else: - raise ValueError("List of SQL statements is empty") - - with closing(self.get_conn()) as conn: - self.set_autocommit(conn, autocommit) - - # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here - with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] - results = [] - for sql_statement in sql: - self._run_command(cur, sql_statement, parameters) - - if handler is not None: - result = handler(cur) - results.append(result) - - query_id = cur.sfqid - self.log.info("Rows affected: %s", cur.rowcount) - self.log.info("Snowflake query id: %s", query_id) - self.query_ids.append(query_id) - - # If autocommit was set to False or db does not support autocommit, we do a manual commit. - if not self.get_autocommit(conn): - conn.commit() - - if handler is None: - return None - elif scalar_return_last: - return results[-1] - else: - return results + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, + ) diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index cf7835ef65518..c171beb3a663c 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -191,7 +191,7 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: list[str] = [] + self.running_query_ids: list[str] = [] class SnowflakeValueCheckOperator(SQLValueCheckOperator): @@ -257,7 +257,7 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: list[str] = [] + self.running_query_ids: list[str] = [] class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator): @@ -336,4 +336,4 @@ def __init__( self.schema = schema self.authenticator = authenticator self.session_parameters = session_parameters - self.query_ids: list[str] = [] + self.running_query_ids: list[str] = [] diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 63c75446ead19..74b7ba5775ca3 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -211,6 +211,16 @@ def run( return_last=return_last, ) + def _update_query_ids(self, cursor) -> None: + self.running_query_ids.append(cursor.stats["queryId"]) + + def kill_query(self, query_id) -> Any: + result = super().run( + sql=f"CALL system.runtime.kill_query(query_id => '{query_id}',message => 'Job killed by user');", + handler=list, + ) + return result + def insert_rows( self, table: str, diff --git a/airflow/providers/trino/operators/trino.py b/airflow/providers/trino/operators/trino.py index fc2f496ff8118..f9004e99af9c9 100644 --- a/airflow/providers/trino/operators/trino.py +++ b/airflow/providers/trino/operators/trino.py @@ -21,10 +21,7 @@ import warnings from typing import Any, Sequence -from trino.exceptions import TrinoQueryError - from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -from airflow.providers.trino.hooks.trino import TrinoHook class TrinoOperator(SQLExecuteQueryOperator): @@ -58,18 +55,3 @@ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs: Any) -> No DeprecationWarning, stacklevel=2, ) - - def on_kill(self) -> None: - if self._hook is not None and isinstance(self._hook, TrinoHook): - query_id = "'" + self._hook.query_id + "'" - try: - self.log.info("Stopping query run with queryId - %s", self._hook.query_id) - self._hook.run( - sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job " - f"killed by " - f"user');", - handler=list, - ) - except TrinoQueryError as e: - self.log.info(str(e)) - self.log.info("Trino query (%s) terminated", query_id) diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 5c323aa8434b4..cf23883fd5720 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -18,7 +18,6 @@ import unittest from unittest import mock -from unittest.mock import MagicMock from parameterized import parameterized @@ -28,18 +27,14 @@ class TestRedshiftSQLOperator(unittest.TestCase): @parameterized.expand([(True, ("a", "b")), (False, ("c", "d"))]) - @mock.patch("airflow.providers.amazon.aws.operators.redshift_sql.RedshiftSQLOperator.get_db_hook") - def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook): - hook = MagicMock() - mock_run = hook.run - mock_get_hook.return_value = hook - sql = MagicMock() + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_redshift_operator(self, test_autocommit, test_parameters, mock_hook): operator = RedshiftSQLOperator( - task_id="test", sql=sql, autocommit=test_autocommit, parameters=test_parameters + task_id="test", sql="SELECT 1", autocommit=test_autocommit, parameters=test_parameters ) operator.execute(None) - mock_run.assert_called_once_with( - sql=sql, + mock_hook.run.assert_called_once_with( + sql="SELECT 1", autocommit=test_autocommit, parameters=test_parameters, handler=fetch_all_handler, diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 3741a93ed36ac..d3c96a2b9ee1e 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -66,12 +66,12 @@ def _construct_operator(self, sql, **kwargs): dag=dag, ) - @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") - def test_do_xcom_push(self, mock_get_db_hook): + @mock.patch.object(SQLExecuteQueryOperator, "_hook") + def test_do_xcom_push(self, mock_hook): operator = self._construct_operator("SELECT 1;", do_xcom_push=True) operator.execute(context=MagicMock()) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql="SELECT 1;", autocommit=False, handler=fetch_all_handler, @@ -80,12 +80,12 @@ def test_do_xcom_push(self, mock_get_db_hook): split_statements=False, ) - @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") - def test_dont_xcom_push(self, mock_get_db_hook): + @mock.patch.object(SQLExecuteQueryOperator, "_hook") + def test_dont_xcom_push(self, mock_hook): operator = self._construct_operator("SELECT 1;", do_xcom_push=False) operator.execute(context=MagicMock()) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql="SELECT 1;", autocommit=False, parameters=None, diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index 6f1d00c999a96..ebd5136b06ad4 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -25,11 +25,11 @@ class TestExasol(unittest.TestCase): - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_overwrite_autocommit(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_overwrite_autocommit(self, mock_hook): operator = ExasolOperator(task_id="TEST", sql="SELECT 1", autocommit=True) operator.execute({}) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql="SELECT 1", autocommit=True, parameters=None, @@ -38,11 +38,11 @@ def test_overwrite_autocommit(self, mock_get_db_hook): split_statements=False, ) - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_pass_parameters(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_pass_parameters(self, mock_hook): operator = ExasolOperator(task_id="TEST", sql="SELECT {value!s}", parameters={"value": 1}) operator.execute({}) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql="SELECT {value!s}", autocommit=False, parameters={"value": 1}, diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index b9339ce584850..5cab48215f7d2 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -28,12 +28,12 @@ class TestJdbcOperator(unittest.TestCase): def setUp(self): self.kwargs = dict(sql="sql", task_id="test_jdbc_operator", dag=None) - @patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_execute_do_push(self, mock_get_db_hook): + @patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_execute_do_push(self, mock_hook): jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True) jdbc_operator.execute(context={}) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql=jdbc_operator.sql, autocommit=jdbc_operator.autocommit, handler=fetch_all_handler, @@ -42,12 +42,12 @@ def test_execute_do_push(self, mock_get_db_hook): split_statements=False, ) - @patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_execute_dont_push(self, mock_get_db_hook): + @patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_execute_dont_push(self, mock_hook): jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False) jdbc_operator.execute(context={}) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql=jdbc_operator.sql, autocommit=jdbc_operator.autocommit, parameters=jdbc_operator.parameters, diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index ba5e82df3999b..0ed303a2c250e 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -25,8 +25,8 @@ class TestOracleOperator(unittest.TestCase): - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_execute(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_execute(self, mock_hook): sql = "SELECT * FROM test_table" oracle_conn_id = "oracle_default" parameters = {"parameter": "value"} @@ -42,7 +42,7 @@ def test_execute(self, mock_get_db_hook): task_id=task_id, ) operator.execute(context=context) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql=sql, autocommit=autocommit, parameters=parameters, diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index 87c661043a643..7d253a5bebfc1 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -550,7 +550,7 @@ def test_run_storing_query_ids_extra(self, mock_conn, sql, expected_sql, expecte hook.run(sql, parameters=mock_params) cur.execute.assert_has_calls([mock.call(query, mock_params) for query in expected_sql]) - assert hook.query_ids == expected_query_ids + assert hook.running_query_ids == expected_query_ids cur.close.assert_called() @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first") diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 7cd2d040cfbc7..23b6a6921dd70 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -44,8 +44,8 @@ def setUp(self): dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_snowflake_operator(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_snowflake_operator(self, mock_hook): sql = """ CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index 3caa2dad8ee6c..d491e79acabf3 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -27,8 +27,8 @@ class TestTrinoOperator(unittest.TestCase): - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_execute(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_execute(self, mock_hook): """Asserts that the run method is called when a TrinoOperator task is executed""" op = TrinoOperator( @@ -39,7 +39,7 @@ def test_execute(self, mock_get_db_hook): ) op.execute(None) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql="SELECT 1;", autocommit=False, handler=list, diff --git a/tests/providers/vertica/operators/test_vertica.py b/tests/providers/vertica/operators/test_vertica.py index 836f324cb8289..1562a422ba9e5 100644 --- a/tests/providers/vertica/operators/test_vertica.py +++ b/tests/providers/vertica/operators/test_vertica.py @@ -25,12 +25,12 @@ class TestVerticaOperator(unittest.TestCase): - @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") - def test_execute(self, mock_get_db_hook): + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook") + def test_execute(self, mock_hook): sql = "select a, b, c" op = VerticaOperator(task_id="test_task_id", sql=sql) op.execute(None) - mock_get_db_hook.return_value.run.assert_called_once_with( + mock_hook.run.assert_called_once_with( sql=sql, autocommit=False, handler=fetch_all_handler,