From 67ab31718981acb940e0a84cef349fbb4bb80991 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Mon, 20 Jun 2022 15:20:13 +0530 Subject: [PATCH] Add on_kill() to kill Trino query if Airflow task is killed --- airflow/providers/trino/hooks/trino.py | 2 ++ airflow/providers/trino/operators/trino.py | 17 +++++++++++++++++ tests/system/providers/trino/example_trino.py | 6 +++--- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 781236ae34a66..99ce9eb2df165 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -93,6 +93,7 @@ class TrinoHook(DbApiHook): default_conn_name = 'trino_default' conn_type = 'trino' hook_name = 'Trino' + query_id = '' def get_conn(self) -> Connection: """Returns a connection object""" @@ -301,6 +302,7 @@ def run( results = [] for sql_statement in sql: self._run_command(cur, self._strip_sql(sql_statement), parameters) + self.query_id = cur.stats["queryId"] if handler is not None: result = handler(cur) results.append(result) diff --git a/airflow/providers/trino/operators/trino.py b/airflow/providers/trino/operators/trino.py index 7992eae4edeae..dd305f0e8b104 100644 --- a/airflow/providers/trino/operators/trino.py +++ b/airflow/providers/trino/operators/trino.py @@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union +from trino.exceptions import TrinoQueryError + from airflow.models import BaseOperator from airflow.providers.trino.hooks.trino import TrinoHook @@ -79,3 +81,18 @@ def execute(self, context: 'Context') -> None: self.hook.run( sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, handler=self.handler ) + + def on_kill(self) -> None: + if self.hook is not None and isinstance(self.hook, TrinoHook): + query_id = "'" + self.hook.query_id + "'" + try: + self.log.info("Stopping query run with queryId - %s", self.hook.query_id) + self.hook.run( + sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job " + f"killed by " + f"user');", + handler=list, + ) + except TrinoQueryError as e: + self.log.info(str(e)) + self.log.info("Trino query (%s) terminated", query_id) diff --git a/tests/system/providers/trino/example_trino.py b/tests/system/providers/trino/example_trino.py index 9895b6c249be6..f7916d7c66b30 100644 --- a/tests/system/providers/trino/example_trino.py +++ b/tests/system/providers/trino/example_trino.py @@ -45,7 +45,7 @@ ) trino_create_table = TrinoOperator( task_id="trino_create_table", - sql=f"""CREATE TABLE {SCHEMA}.{TABLE}( + sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}( cityid bigint, cityname varchar )""", @@ -60,9 +60,9 @@ trino_multiple_queries = TrinoOperator( task_id="trino_multiple_queries", - sql=f"""CREATE TABLE {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar); + sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar); INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose'); - CREATE TABLE {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar); + CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar); INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""", handler=list, )