Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def run(
results = []
for sql_statement in sql_list:
self._run_command(cur, sql_statement, parameters)
self._post_run_hook(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
Expand Down Expand Up @@ -373,6 +374,10 @@ def _run_command(self, cur, sql_statement, parameters):
if cur.rowcount >= 0:
self.log.info("Rows affected: %s", cur.rowcount)

def _post_run_hook(self, cur, sql_statement, parameters) -> None:
"""This method is run after every statement execution"""
return None

def set_autocommit(self, conn, autocommit):
"""Sets the autocommit flag on the connection"""
if not self.supports_autocommit and autocommit:
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class PrestoHook(DbApiHook):
hook_name = "Presto"
placeholder = "?"

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.query_ids: list[str] = []

def get_conn(self) -> Connection:
"""Returns a connection object"""
db = self.get_connection(self.presto_conn_id) # type: ignore[attr-defined]
Expand Down Expand Up @@ -187,6 +191,7 @@ def run(
split_statements: bool = False,
return_last: bool = True,
) -> Any | list[Any] | None:
self.query_ids = []
return super().run(
sql=sql,
autocommit=autocommit,
Expand All @@ -196,6 +201,9 @@ def run(
return_last=return_last,
)

def _post_run_hook(self, cur, sql_statement, parameters) -> None:
self.query_ids.append(cur.stats["queryId"])

def insert_rows(
self,
table: str,
Expand Down
82 changes: 17 additions & 65 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import os
from contextlib import closing
from functools import wraps
from io import StringIO
from pathlib import Path
Expand All @@ -27,12 +26,12 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.connector import SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine

from airflow import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.strings import to_boolean


Expand Down Expand Up @@ -321,6 +320,11 @@ def set_autocommit(self, conn, autocommit: Any) -> None:
def get_autocommit(self, conn):
return getattr(conn, "autocommit_mode", False)

@staticmethod
def split_sql_string(sql: str) -> list[str]:
split_statements_tuple = util_text.split_statements(StringIO(sql))
return [sql_string for sql_string, _ in split_statements_tuple if sql_string]

def run(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method still implemented here to keep Snowflake's split_statements: bool = True default parameter

self,
sql: str | Iterable[str],
Expand All @@ -330,67 +334,15 @@ def run(
split_statements: bool = True,
return_last: bool = True,
) -> Any | list[Any] | None:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially. The variable execution_info is returned so that
it can be used in the Operators to modify the behavior
depending on the result of the query (i.e fail the operator
if the copy has processed 0 files)

:param sql: the sql string to be executed with possibly multiple statements,
or a list of sql statements to execute
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
self.query_ids = []
return super().run(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=handler,
split_statements=split_statements,
return_last=return_last,
)

if isinstance(sql, str):
if split_statements:
split_statements_tuple = util_text.split_statements(StringIO(sql))
sql_list: Iterable[str] = [
sql_string for sql_string, _ in split_statements_tuple if sql_string
]
else:
sql_list = [self.strip_sql_string(sql)]
else:
sql_list = sql

if sql_list:
self.log.debug("Executing following statements against Snowflake DB: %s", sql_list)
else:
raise ValueError("List of SQL statements is empty")

with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)

# SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here
with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var]
results = []
for sql_statement in sql_list:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
results.append(result)

query_id = cur.sfqid
self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", query_id)
self.query_ids.append(query_id)

# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None
elif return_single_query_results(sql, return_last, split_statements):
return results[-1]
else:
return results
def _post_run_hook(self, cur, sql_statement, parameters) -> None:
self.query_ids.append(cur.sfqid)
8 changes: 8 additions & 0 deletions airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class TrinoHook(DbApiHook):
placeholder = "?"
_test_connection_sql = "select 1"

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.query_ids: list[str] = []

def get_conn(self) -> Connection:
"""Returns a connection object"""
db = self.get_connection(self.trino_conn_id) # type: ignore[attr-defined]
Expand Down Expand Up @@ -202,6 +206,7 @@ def run(
split_statements: bool = False,
return_last: bool = True,
) -> Any | list[Any] | None:
self.query_ids = []
return super().run(
sql=sql,
autocommit=autocommit,
Expand All @@ -211,6 +216,9 @@ def run(
return_last=return_last,
)

def _post_run_hook(self, cur, sql_statement, parameters) -> None:
self.query_ids.append(cur.stats["queryId"])

def insert_rows(
self,
table: str,
Expand Down