Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
# Hook deriving from the DBApiHook to still have access to the field in it's constructor
self.__schema = schema
self.log_sql = log_sql
self.running_query_ids: list[str] = []

def get_conn(self):
"""Returns a connection object"""
Expand Down Expand Up @@ -244,6 +245,7 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the ALL SQL expressions if handler was provided.
"""
self.running_query_ids = []
scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
Expand All @@ -264,6 +266,7 @@ def run(
results = []
for sql_statement in sql:
self._run_command(cur, sql_statement, parameters)
self._update_query_ids(cur)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, the query is done, so what is the utility of updating the query id list here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I checked it for Snowflake.
photo_2022-11-13_00-53-28

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have access to any Trino and Presto instances to check it also for these DBs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think that we need to address this issue, where the queries aren't actually running and therefore will not be killed, before adding this interface.


if handler is not None:
result = handler(cur)
Expand Down Expand Up @@ -294,6 +297,22 @@ def _run_command(self, cur, sql_statement, parameters):
if cur.rowcount >= 0:
self.log.info("Rows affected: %s", cur.rowcount)

def _update_query_ids(self, cursor) -> None:
"""
Adds query ids to list
:param cur: current cursor after run
:return:
"""
return None

def kill_query(self, query_id) -> Any:
"""
Stops query with certain identifier
:param query_id: identifier of the query
:return:
"""
raise NotImplementedError

def set_autocommit(self, conn, autocommit):
"""Sets the autocommit flag on the connection"""
if not self.supports_autocommit and autocommit:
Expand Down
15 changes: 12 additions & 3 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,8 @@ def __init__(

def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
if self.do_xcom_push:
output = hook.run(
output = self._hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
Expand All @@ -208,7 +207,7 @@ def execute(self, context):
return_last=self.return_last,
)
else:
output = hook.run(
output = self._hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
Expand All @@ -221,6 +220,16 @@ def execute(self, context):

return output

def on_kill(self) -> None:
for query_id in self._hook.running_query_ids.copy():
self.log.info("Stopping query: %s", query_id)
try:
self._hook.kill_query(query_id)
except NotImplementedError:
self.log.info("Method '.kill()' is not implemented for ", self._hook.__class__.__name__)
except Exception as e:
self.log.info("The query '%s' can not be killed due to %s", query_id, str(e))

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
if isinstance(self.parameters, str):
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from databricks.sql.utils import ParamEscaper

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
Expand Down Expand Up @@ -115,7 +116,8 @@ def __init__(
**hook_params,
}

def get_db_hook(self) -> DatabricksSqlHook:
@cached_property
def _hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)

def _process_output(self, schema, results):
Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def run(
return_last=return_last,
)

def _update_query_ids(self, cursor) -> None:
self.running_query_ids.append(cursor.stats["queryId"])

def kill_query(self, query_id) -> Any:
result = super().run(
sql=f"CALL system.runtime.kill_query(query_id => '{query_id}',message => 'Job killed by user');",
handler=list,
)
return result

def insert_rows(
self,
table: str,
Expand Down
89 changes: 25 additions & 64 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import os
from contextlib import closing
from functools import wraps
from io import StringIO
from pathlib import Path
Expand All @@ -27,7 +26,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.connector import SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine

Expand Down Expand Up @@ -178,7 +177,7 @@ def __init__(self, *args, **kwargs) -> None:
self.schema = kwargs.pop("schema", None)
self.authenticator = kwargs.pop("authenticator", None)
self.session_parameters = kwargs.pop("session_parameters", None)
self.query_ids: list[str] = []
self.running_query_ids: list[str] = []

def _get_field(self, extra_dict, field_name):
backcompat_prefix = "extra__snowflake__"
Expand Down Expand Up @@ -321,6 +320,21 @@ def set_autocommit(self, conn, autocommit: Any) -> None:
def get_autocommit(self, conn):
return getattr(conn, "autocommit_mode", False)

@staticmethod
def split_sql_string(sql: str) -> list[str]:
split_statements_tuple = util_text.split_statements(StringIO(sql))
return [sql_string for sql_string, _ in split_statements_tuple if sql_string]

def _update_query_ids(self, cursor) -> None:
self.running_query_ids.append(cursor.sfqid)

def kill_query(self, query_id) -> Any:
result = super().run(
sql=f"CALL system$cancel_query('{query_id}');",
handler=list,
)
return result

def run(
self,
sql: str | Iterable[str],
Expand All @@ -330,64 +344,11 @@ def run(
split_statements: bool = True,
return_last: bool = True,
) -> Any | list[Any] | None:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially. The variable execution_info is returned so that
it can be used in the Operators to modify the behavior
depending on the result of the query (i.e fail the operator
if the copy has processed 0 files)

:param sql: the sql string to be executed with possibly multiple statements,
or a list of sql statements to execute
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
self.query_ids = []

scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
split_statements_tuple = util_text.split_statements(StringIO(sql))
sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string]
else:
sql = [self.strip_sql_string(sql)]

if sql:
self.log.debug("Executing following statements against Snowflake DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")

with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)

# SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here
with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var]
results = []
for sql_statement in sql:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
results.append(result)

query_id = cur.sfqid
self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", query_id)
self.query_ids.append(query_id)

# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None
elif scalar_return_last:
return results[-1]
else:
return results
return super().run(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=handler,
split_statements=split_statements,
return_last=return_last,
)
6 changes: 3 additions & 3 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []
self.running_query_ids: list[str] = []


class SnowflakeValueCheckOperator(SQLValueCheckOperator):
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []
self.running_query_ids: list[str] = []


class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
Expand Down Expand Up @@ -336,4 +336,4 @@ def __init__(
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []
self.running_query_ids: list[str] = []
10 changes: 10 additions & 0 deletions airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def run(
return_last=return_last,
)

def _update_query_ids(self, cursor) -> None:
self.running_query_ids.append(cursor.stats["queryId"])

def kill_query(self, query_id) -> Any:
result = super().run(
sql=f"CALL system.runtime.kill_query(query_id => '{query_id}',message => 'Job killed by user');",
handler=list,
)
return result

def insert_rows(
self,
table: str,
Expand Down
18 changes: 0 additions & 18 deletions airflow/providers/trino/operators/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
import warnings
from typing import Any, Sequence

from trino.exceptions import TrinoQueryError

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.trino.hooks.trino import TrinoHook


class TrinoOperator(SQLExecuteQueryOperator):
Expand Down Expand Up @@ -58,18 +55,3 @@ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs: Any) -> No
DeprecationWarning,
stacklevel=2,
)

def on_kill(self) -> None:
if self._hook is not None and isinstance(self._hook, TrinoHook):
query_id = "'" + self._hook.query_id + "'"
try:
self.log.info("Stopping query run with queryId - %s", self._hook.query_id)
self._hook.run(
sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job "
f"killed by "
f"user');",
handler=list,
)
except TrinoQueryError as e:
self.log.info(str(e))
self.log.info("Trino query (%s) terminated", query_id)
15 changes: 5 additions & 10 deletions tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import unittest
from unittest import mock
from unittest.mock import MagicMock

from parameterized import parameterized

Expand All @@ -28,18 +27,14 @@

class TestRedshiftSQLOperator(unittest.TestCase):
@parameterized.expand([(True, ("a", "b")), (False, ("c", "d"))])
@mock.patch("airflow.providers.amazon.aws.operators.redshift_sql.RedshiftSQLOperator.get_db_hook")
def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook):
hook = MagicMock()
mock_run = hook.run
mock_get_hook.return_value = hook
sql = MagicMock()
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator._hook")
def test_redshift_operator(self, test_autocommit, test_parameters, mock_hook):
operator = RedshiftSQLOperator(
task_id="test", sql=sql, autocommit=test_autocommit, parameters=test_parameters
task_id="test", sql="SELECT 1", autocommit=test_autocommit, parameters=test_parameters
)
operator.execute(None)
mock_run.assert_called_once_with(
sql=sql,
mock_hook.run.assert_called_once_with(
sql="SELECT 1",
autocommit=test_autocommit,
parameters=test_parameters,
handler=fetch_all_handler,
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def _construct_operator(self, sql, **kwargs):
dag=dag,
)

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_do_xcom_push(self, mock_get_db_hook):
@mock.patch.object(SQLExecuteQueryOperator, "_hook")
def test_do_xcom_push(self, mock_hook):
operator = self._construct_operator("SELECT 1;", do_xcom_push=True)
operator.execute(context=MagicMock())

mock_get_db_hook.return_value.run.assert_called_once_with(
mock_hook.run.assert_called_once_with(
sql="SELECT 1;",
autocommit=False,
handler=fetch_all_handler,
Expand All @@ -80,12 +80,12 @@ def test_do_xcom_push(self, mock_get_db_hook):
split_statements=False,
)

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_dont_xcom_push(self, mock_get_db_hook):
@mock.patch.object(SQLExecuteQueryOperator, "_hook")
def test_dont_xcom_push(self, mock_hook):
operator = self._construct_operator("SELECT 1;", do_xcom_push=False)
operator.execute(context=MagicMock())

mock_get_db_hook.return_value.run.assert_called_once_with(
mock_hook.run.assert_called_once_with(
sql="SELECT 1;",
autocommit=False,
parameters=None,
Expand Down
Loading