diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 76260612a435a..81e804c145ed6 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -337,6 +337,7 @@ def run( results = [] for sql_statement in sql_list: self._run_command(cur, sql_statement, parameters) + self._post_run_hook(cur, sql_statement, parameters) if handler is not None: result = handler(cur) @@ -373,6 +374,10 @@ def _run_command(self, cur, sql_statement, parameters): if cur.rowcount >= 0: self.log.info("Rows affected: %s", cur.rowcount) + def _post_run_hook(self, cur, sql_statement, parameters) -> None: + """This method is run after every statement execution""" + return None + def set_autocommit(self, conn, autocommit): """Sets the autocommit flag on the connection""" if not self.supports_autocommit and autocommit: diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 825b28ad60a54..a23e5326ff6fd 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -90,6 +90,10 @@ class PrestoHook(DbApiHook): hook_name = "Presto" placeholder = "?" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.query_ids: list[str] = [] + def get_conn(self) -> Connection: """Returns a connection object""" db = self.get_connection(self.presto_conn_id) # type: ignore[attr-defined] @@ -187,6 +191,7 @@ def run( split_statements: bool = False, return_last: bool = True, ) -> Any | list[Any] | None: + self.query_ids = [] return super().run( sql=sql, autocommit=autocommit, @@ -196,6 +201,9 @@ def run( return_last=return_last, ) + def _post_run_hook(self, cur, sql_statement, parameters) -> None: + self.query_ids.append(cur.stats["queryId"]) + def insert_rows( self, table: str, diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 79bb6893180a9..4475ac0e32552 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -18,7 +18,6 @@ from __future__ import annotations import os -from contextlib import closing from functools import wraps from io import StringIO from pathlib import Path @@ -27,12 +26,12 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from snowflake import connector -from snowflake.connector import DictCursor, SnowflakeConnection, util_text +from snowflake.connector import SnowflakeConnection, util_text from snowflake.sqlalchemy import URL from sqlalchemy import create_engine from airflow import AirflowException -from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.strings import to_boolean @@ -321,6 +320,11 @@ def set_autocommit(self, conn, autocommit: Any) -> None: def get_autocommit(self, conn): return getattr(conn, "autocommit_mode", False) + @staticmethod + def split_sql_string(sql: str) -> list[str]: + split_statements_tuple = util_text.split_statements(StringIO(sql)) + return [sql_string for sql_string, _ in split_statements_tuple if sql_string] + def run( self, sql: str | Iterable[str], @@ -330,67 +334,15 @@ def run( split_statements: bool = True, return_last: bool = True, ) -> Any | list[Any] | None: - """ - Runs a command or a list of commands. Pass a list of sql - statements to the sql parameter to get them to execute - sequentially. The variable execution_info is returned so that - it can be used in the Operators to modify the behavior - depending on the result of the query (i.e fail the operator - if the copy has processed 0 files) - - :param sql: the sql string to be executed with possibly multiple statements, - or a list of sql statements to execute - :param autocommit: What to set the connection's autocommit setting to - before executing the query. - :param parameters: The parameters to render the SQL query with. - :param handler: The result handler which is called with the result of each statement. - :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return result for only last statement or for all after split - :return: return only result of the LAST SQL expression if handler was provided. - """ self.query_ids = [] + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, + ) - if isinstance(sql, str): - if split_statements: - split_statements_tuple = util_text.split_statements(StringIO(sql)) - sql_list: Iterable[str] = [ - sql_string for sql_string, _ in split_statements_tuple if sql_string - ] - else: - sql_list = [self.strip_sql_string(sql)] - else: - sql_list = sql - - if sql_list: - self.log.debug("Executing following statements against Snowflake DB: %s", sql_list) - else: - raise ValueError("List of SQL statements is empty") - - with closing(self.get_conn()) as conn: - self.set_autocommit(conn, autocommit) - - # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here - with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] - results = [] - for sql_statement in sql_list: - self._run_command(cur, sql_statement, parameters) - - if handler is not None: - result = handler(cur) - results.append(result) - - query_id = cur.sfqid - self.log.info("Rows affected: %s", cur.rowcount) - self.log.info("Snowflake query id: %s", query_id) - self.query_ids.append(query_id) - - # If autocommit was set to False or db does not support autocommit, we do a manual commit. - if not self.get_autocommit(conn): - conn.commit() - - if handler is None: - return None - elif return_single_query_results(sql, return_last, split_statements): - return results[-1] - else: - return results + def _post_run_hook(self, cur, sql_statement, parameters) -> None: + self.query_ids.append(cur.sfqid) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 63c75446ead19..8bae878cf0ba3 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -92,6 +92,10 @@ class TrinoHook(DbApiHook): placeholder = "?" _test_connection_sql = "select 1" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.query_ids: list[str] = [] + def get_conn(self) -> Connection: """Returns a connection object""" db = self.get_connection(self.trino_conn_id) # type: ignore[attr-defined] @@ -202,6 +206,7 @@ def run( split_statements: bool = False, return_last: bool = True, ) -> Any | list[Any] | None: + self.query_ids = [] return super().run( sql=sql, autocommit=autocommit, @@ -211,6 +216,9 @@ def run( return_last=return_last, ) + def _post_run_hook(self, cur, sql_statement, parameters) -> None: + self.query_ids.append(cur.stats["queryId"]) + def insert_rows( self, table: str,