From f614f4738c990dee1a3dfce6ef8631c52b325ee2 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Thu, 14 Jul 2022 15:29:14 +0300 Subject: [PATCH 1/7] Types change only --- airflow/operators/sql.py | 2 +- airflow/providers/amazon/aws/operators/redshift_sql.py | 4 ++-- airflow/providers/amazon/aws/transfers/redshift_to_s3.py | 2 +- airflow/providers/amazon/aws/transfers/s3_to_redshift.py | 4 ++-- airflow/providers/apache/drill/operators/drill.py | 2 +- airflow/providers/google/cloud/operators/bigquery.py | 2 +- airflow/providers/google/cloud/operators/cloud_sql.py | 6 +++--- airflow/providers/google/suite/transfers/sql_to_sheets.py | 2 +- airflow/providers/microsoft/mssql/operators/mssql.py | 4 ++-- airflow/providers/mysql/operators/mysql.py | 6 +++--- airflow/providers/neo4j/operators/neo4j.py | 2 +- airflow/providers/oracle/operators/oracle.py | 6 +++--- airflow/providers/postgres/operators/postgres.py | 6 +++--- airflow/providers/vertica/operators/vertica.py | 4 ++-- 14 files changed, 26 insertions(+), 26 deletions(-) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index cb8b664875971..fbce9b85a1c97 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -496,7 +496,7 @@ def __init__( follow_task_ids_if_false: List[str], conn_id: str = "default_conn_id", database: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(conn_id=conn_id, database=database, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py index c7ad77acb5341..aa324b40bc0b6 100644 --- a/airflow/providers/amazon/aws/operators/redshift_sql.py +++ b/airflow/providers/amazon/aws/operators/redshift_sql.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook @@ -55,7 +55,7 @@ def __init__( *, sql: Union[str, Iterable[str]], redshift_conn_id: str = 'redshift_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, **kwargs, ) -> None: diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 0bfdda44e78d6..f5a3cd9bf6547 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -93,7 +93,7 @@ def __init__( unload_options: Optional[List] = None, autocommit: bool = False, include_header: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, table_as_file_name: bool = True, # Set to True by default for not breaking current workflows **kwargs, ) -> None: diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 014e23ec070f2..747c97e1f1212 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -16,7 +16,7 @@ # under the License. import warnings -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -140,7 +140,7 @@ def execute(self, context: 'Context') -> None: copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options) - sql: Union[list, str] + sql: Union[str, Iterable[str]] if self.method == 'REPLACE': sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"] diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 791ed546c34fe..ce77332fc87a2 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -52,7 +52,7 @@ def __init__( *, sql: str, drill_conn_id: str = 'drill_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index b7e12dabce641..550c3174068ec 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -550,7 +550,7 @@ def operator_extra_links(self): def __init__( self, *, - sql: Union[str, Iterable], + sql: Union[str, Iterable[str]], destination_dataset_table: Optional[str] = None, write_disposition: str = 'WRITE_EMPTY', allow_large_results: bool = False, diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index e90a21e296727..a5d40c2ff4889 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud SQL operators.""" -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from googleapiclient.errors import HttpError @@ -1054,9 +1054,9 @@ class CloudSQLExecuteQueryOperator(BaseOperator): def __init__( self, *, - sql: Union[List[str], str], + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Union[Dict, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, gcp_conn_id: str = 'google_cloud_default', gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', **kwargs, diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py index 8384868199b6f..8626fb7227182 100644 --- a/airflow/providers/google/suite/transfers/sql_to_sheets.py +++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py @@ -68,7 +68,7 @@ def __init__( sql: str, spreadsheet_id: str, sql_conn_id: str, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, database: Optional[str] = None, spreadsheet_range: str = "Sheet1", gcp_conn_id: str = "google_cloud_default", diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index 5a5738eb6878c..3a6434704f1fa 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -57,9 +57,9 @@ class MsSqlOperator(BaseOperator): def __init__( self, *, - sql: str, + sql: Union[str, Iterable[str]], mssql_conn_id: str = 'mssql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, database: Optional[str] = None, **kwargs, diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index d51a97e6fe38a..975586cd52991 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. import ast -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.mysql.hooks.mysql import MySqlHook @@ -59,9 +59,9 @@ class MySqlOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], mysql_conn_id: str = 'mysql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, database: Optional[str] = None, **kwargs, diff --git a/airflow/providers/neo4j/operators/neo4j.py b/airflow/providers/neo4j/operators/neo4j.py index b61f0734f0841..82939ad7790f2 100644 --- a/airflow/providers/neo4j/operators/neo4j.py +++ b/airflow/providers/neo4j/operators/neo4j.py @@ -44,7 +44,7 @@ def __init__( *, sql: str, neo4j_conn_id: str = 'neo4j_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index b60d4b6e89100..969d69728bbaf 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -50,9 +50,9 @@ class OracleOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], oracle_conn_id: str = 'oracle_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, **kwargs, ) -> None: @@ -98,7 +98,7 @@ def __init__( self.procedure = procedure self.parameters = parameters - def execute(self, context: 'Context') -> Optional[Union[List, Dict]]: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.procedure) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters) diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index e0238aa88204b..7a787498d2dd4 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from psycopg2.sql import SQL, Identifier @@ -53,10 +53,10 @@ class PostgresOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], postgres_conn_id: str = 'postgres_default', autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, database: Optional[str] = None, runtime_parameters: Optional[Mapping] = None, **kwargs, diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py index 3a30e0ee2723e..7a804f8ed738e 100644 --- a/airflow/providers/vertica/operators/vertica.py +++ b/airflow/providers/vertica/operators/vertica.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, List, Sequence, Union +from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union from airflow.models import BaseOperator from airflow.providers.vertica.hooks.vertica import VerticaHook @@ -40,7 +40,7 @@ class VerticaOperator(BaseOperator): ui_color = '#b4e0ff' def __init__( - self, *, sql: Union[str, List[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any + self, *, sql: Union[str, Iterable[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any ) -> None: super().__init__(**kwargs) self.vertica_conn_id = vertica_conn_id From 368420cbe9d6eb80d95fc2e4b7c6984771ad0c84 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Thu, 14 Jul 2022 15:59:28 +0300 Subject: [PATCH 2/7] Unify DbApiHook.run() method --- airflow/dependencies.json | 0 .../providers/apache/drill/operators/drill.py | 6 +- airflow/providers/apache/drill/provider.yaml | 1 - airflow/providers/apache/pinot/hooks/pinot.py | 6 +- airflow/providers/common/sql/hooks/sql.py | 56 +++++++++--- airflow/providers/common/sql/provider.yaml | 3 +- .../databricks/hooks/databricks_sql.py | 66 ++++++-------- .../databricks/operators/databricks_sql.py | 12 ++- airflow/providers/exasol/hooks/exasol.py | 52 +++++------ airflow/providers/exasol/operators/exasol.py | 6 +- airflow/providers/jdbc/operators/jdbc.py | 14 +-- airflow/providers/presto/hooks/presto.py | 35 ++++--- .../providers/snowflake/hooks/snowflake.py | 33 +++---- .../snowflake/operators/snowflake.py | 19 ++-- airflow/providers/sqlite/operators/sqlite.py | 6 +- airflow/providers/trino/hooks/trino.py | 91 +++++++------------ generated/provider_dependencies.json | 7 +- .../providers/common/sql/hooks/test_dbapi.py | 2 +- .../databricks/hooks/test_databricks_sql.py | 11 ++- .../operators/test_databricks_sql.py | 9 +- tests/providers/exasol/hooks/test_exasol.py | 12 +-- tests/providers/jdbc/operators/test_jdbc.py | 3 +- tests/providers/oracle/hooks/test_oracle.py | 8 +- 23 files changed, 221 insertions(+), 237 deletions(-) create mode 100644 airflow/dependencies.json diff --git a/airflow/dependencies.json b/airflow/dependencies.json new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index ce77332fc87a2..44a0a1e2b6583 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -17,8 +17,6 @@ # under the License. from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union -import sqlparse - from airflow.models import BaseOperator from airflow.providers.apache.drill.hooks.drill import DrillHook @@ -64,6 +62,4 @@ def __init__( def execute(self, context: 'Context'): self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id) self.hook = DrillHook(drill_conn_id=self.drill_conn_id) - sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True)) - no_term_sql = [s[:-1] for s in sql if s[-1] == ';'] - self.hook.run(no_term_sql, parameters=self.parameters) + self.hook.run(DrillHook.split_sql_string(self.sql), parameters=self.parameters) diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml index 33235850b8515..0e26ae5186e3c 100644 --- a/airflow/providers/apache/drill/provider.yaml +++ b/airflow/providers/apache/drill/provider.yaml @@ -34,7 +34,6 @@ dependencies: - apache-airflow>=2.2.0 - apache-airflow-providers-common-sql - sqlalchemy-drill>=1.1.0 - - sqlparse>=0.4.1 integrations: - integration-name: Apache Drill diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index fa31b9f33d18f..794646e46dc62 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -18,7 +18,7 @@ import os import subprocess -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Mapping, Optional, Union from pinotdb import connect @@ -275,7 +275,7 @@ def get_uri(self) -> str: endpoint = conn.extra_dejson.get('endpoint', 'query/sql') return f'{conn_type}://{host}/{endpoint}' - def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """ Executes the sql and returns a set of records. @@ -287,7 +287,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera cur.execute(sql) return cur.fetchall() - def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """ Executes the sql and returns the first resulting row. diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index efd4a9dcfed3b..0ae0d1457ef0a 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -17,8 +17,9 @@ import warnings from contextlib import closing from datetime import datetime -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union +import sqlparse from sqlalchemy import create_engine from typing_extensions import Protocol @@ -27,6 +28,17 @@ from airflow.providers_manager import ProvidersManager from airflow.utils.module_loading import import_string +if TYPE_CHECKING: + from sqlalchemy.engine import CursorResult + + +def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]: + """Handler for DbApiHook.run() to return results""" + if cursor.returns_rows: + return cursor.fetchall() + else: + return None + def _backported_get_hook(connection, *, hook_params=None): """Return hook based on conn_type @@ -201,7 +213,29 @@ def get_first(self, sql, parameters=None): cur.execute(sql) return cur.fetchone() - def run(self, sql, autocommit=False, parameters=None, handler=None): + @staticmethod + def strip_sql_string(sql: str) -> str: + return sql.strip().rstrip(';') + + @staticmethod + def split_sql_string(sql: str) -> List[str]: + """ + Splits string into multiple SQL expressions + + :param sql: SQL string potentially consisting of multiple expressions + :return: list of individual expressions + """ + splits = sqlparse.split(sqlparse.format(sql, strip_comments=True)) + statements = [s.rstrip(';') for s in splits if s.endswith(';')] + return statements + + def run( + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, + ) -> Optional[list]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -213,14 +247,13 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. - :return: query results if handler was provided. + :return: return only result of the ALL SQL expressions if handler was provided. """ - scalar = isinstance(sql, str) - if scalar: + if isinstance(sql, str): sql = [sql] if sql: - self.log.debug("Executing %d statements", len(sql)) + self.log.debug("Executing following statements against DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") @@ -232,22 +265,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): results = [] for sql_statement in sql: self._run_command(cur, sql_statement, parameters) + if handler is not None: result = handler(cur) results.append(result) - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() if handler is None: return None - - if scalar: - return results[0] - - return results + else: + return results def _run_command(self, cur, sql_statement, parameters): """Runs a statement using an already open cursor.""" diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml index a277f327ccac5..39c8d483e41b1 100644 --- a/airflow/providers/common/sql/provider.yaml +++ b/airflow/providers/common/sql/provider.yaml @@ -24,7 +24,8 @@ description: | versions: - 1.0.0 -dependencies: [] +dependencies: + - sqlparse>=0.4.2 additional-extras: - name: pandas diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 6c5800170d255..3d9ace8b2319d 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. -import re from contextlib import closing from copy import copy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union from databricks import sql # type: ignore[attr-defined] from databricks.sql.client import Connection # type: ignore[attr-defined] @@ -139,19 +138,13 @@ def get_conn(self) -> Connection: ) return self._sql_conn - @staticmethod - def maybe_split_sql_string(sql: str) -> List[str]: - """ - Splits strings consisting of multiple SQL expressions into an - TODO: do we need something more sophisticated? - - :param sql: SQL string potentially consisting of multiple expressions - :return: list of individual expressions - """ - splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""] - return splits - - def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None): + def run( + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, + ) -> Optional[List[Tuple[Any, Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -163,41 +156,36 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. - :return: query results. + :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): - sql = self.maybe_split_sql_string(sql) + sql = self.split_sql_string(sql) if sql: - self.log.debug("Executing %d statements", len(sql)) + self.log.debug("Executing following statements against Databricks DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") - conn = None + results = [] for sql_statement in sql: # when using AAD tokens, it could expire if previous query run longer than token lifetime - conn = self.get_conn() - with closing(conn.cursor()) as cur: - self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - schema = cur.description - results = [] - if handler is not None: - cur = handler(cur) - for row in cur: - self.log.debug("Statement results: %s", row) - results.append(row) - - self.log.info("Rows affected: %s", cur.rowcount) - if conn: - conn.close() + with closing(self.get_conn()) as conn: + self.set_autocommit(conn, autocommit) + + with closing(conn.cursor()) as cur: + self._run_command(cur, sql_statement, parameters) + + if handler is not None: + result = handler(cur) + schema = cur.description + results.append((schema, result)) + self._sql_conn = None - # Return only result of the last SQL expression - return schema, results + if handler is None: + return None + else: + return results def test_connection(self): """Test the Databricks SQL connection by running a simple query.""" diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 9e6298bc21263..ad4add30fe8e6 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -20,12 +20,13 @@ import csv import json -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook if TYPE_CHECKING: @@ -71,11 +72,11 @@ class DatabricksSqlOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], databricks_conn_id: str = DatabricksSqlHook.default_conn_name, http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, session_configuration=None, http_headers: Optional[List[Tuple[str, str]]] = None, catalog: Optional[str] = None, @@ -147,10 +148,11 @@ def _format_output(self, schema, results): else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - def execute(self, context: 'Context') -> Any: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.sql) hook = self._get_hook() - schema, results = hook.run(self.sql, parameters=self.parameters) + response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler) + schema, results = cast(List[Tuple[Any, Any]], response)[0] # self.log.info('Schema: %s', schema) # self.log.info('Results: %s', results) self._format_output(schema, results) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 784c57cde0661..a2e3f0452320c 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -17,7 +17,7 @@ # under the License. from contextlib import closing -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import pandas as pd import pyexasol @@ -64,9 +64,7 @@ def get_conn(self) -> ExaConnection: conn = pyexasol.connect(**conn_args) return conn - def get_pandas_df( - self, sql: Union[str, list], parameters: Optional[dict] = None, **kwargs - ) -> pd.DataFrame: + def get_pandas_df(self, sql: str, parameters: Optional[dict] = None, **kwargs) -> pd.DataFrame: """ Executes the sql and returns a pandas dataframe @@ -79,9 +77,7 @@ def get_pandas_df( df = conn.export_to_pandas(sql, query_params=parameters, **kwargs) return df - def get_records( - self, sql: Union[str, list], parameters: Optional[dict] = None - ) -> List[Union[dict, Tuple[Any, ...]]]: + def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union[dict, Tuple[Any, ...]]]: """ Executes the sql and returns a set of records. @@ -93,7 +89,7 @@ def get_records( with closing(conn.execute(sql, parameters)) as cur: return cur.fetchall() - def get_first(self, sql: Union[str, list], parameters: Optional[dict] = None) -> Optional[Any]: + def get_first(self, sql: str, parameters: Optional[dict] = None) -> Optional[Any]: """ Executes the sql and returns the first resulting row. @@ -133,7 +129,11 @@ def export_to_file( self.log.info("Data saved to %s", filename) def run( - self, sql: Union[str, list], autocommit: bool = False, parameters: Optional[dict] = None, handler=None + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, ) -> Optional[list]: """ Runs a command or a list of commands. Pass a list of sql @@ -146,38 +146,36 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): - sql = [sql] + sql = self.split_sql_string(sql) if sql: - self.log.debug("Executing %d statements against Exasol DB", len(sql)) + self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - - for query in sql: - self.log.info(query) - with closing(conn.execute(query, parameters)) as cur: - results = [] - + self.set_autocommit(conn, autocommit) + results = [] + for sql_statement in sql: + with closing(conn.execute(sql_statement, parameters)) as cur: + self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: - cur = handler(cur) + result = handler(cur) + results.append(result) - for row in cur: - self.log.info("Statement execution info - %s", row) - results.append(row) + self.log.info("Rows affected: %s", cur.rowcount) - self.log.info(cur.row_count) - # If autocommit was set to False for db that supports autocommit, - # or if db does not support autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return results + if handler is None: + return None + else: + return results def set_autocommit(self, conn, autocommit: bool) -> None: """ diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index eecf44885ec42..33d4ab55c64f0 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.exasol.hooks.exasol import ExasolHook @@ -46,10 +46,10 @@ class ExasolOperator(BaseOperator): def __init__( self, *, - sql: str, + sql: Union[str, Iterable[str]], exasol_conn_id: str = 'exasol_default', autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, schema: Optional[str] = None, **kwargs, ) -> None: diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 2c023d9afe9ba..6b38366b41e81 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -16,20 +16,16 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.jdbc.hooks.jdbc import JdbcHook if TYPE_CHECKING: from airflow.utils.context import Context -def fetch_all_handler(cursor): - """Handler for DbApiHook.run() to return results""" - return cursor.fetchall() - - class JdbcOperator(BaseOperator): """ Executes sql code in a database using jdbc driver. @@ -57,10 +53,10 @@ class JdbcOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], jdbc_conn_id: str = 'jdbc_default', autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -70,7 +66,7 @@ def __init__( self.autocommit = autocommit self.hook = None - def execute(self, context: 'Context') -> None: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.sql) hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 22afc71577341..70eb2d1c9088d 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -18,7 +18,7 @@ import json import os import warnings -from typing import Any, Callable, Iterable, Optional, overload +from typing import Any, Callable, Iterable, Mapping, Optional, Union, overload import prestodb from prestodb.exceptions import DatabaseError @@ -142,10 +142,6 @@ def get_isolation_level(self) -> Any: isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None): """Get a set of records from Presto @@ -169,7 +165,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str sql = hql try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -196,7 +192,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = sql = hql try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -226,7 +222,7 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise PrestoException(e) @@ -241,32 +237,32 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) @overload def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, - ) -> None: + ) -> Optional[list]: """Execute the statement against Presto. Can be used to create views.""" @overload def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, hql: str = "", - ) -> None: + ) -> Optional[list]: """:sphinx-autoapi-skip:""" def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, hql: str = "", - ) -> None: + ) -> Optional[list]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -276,7 +272,10 @@ def run( ) sql = hql - return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler) + if isinstance(sql, str): + sql = self.strip_sql_string(sql) + + return super().run(sql=sql, autocommit=autocommit, parameters=parameters, handler=handler) def insert_rows( self, diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 3dee0989ed210..fd0bcbdce856a 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -19,7 +19,7 @@ from contextlib import closing from io import StringIO from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -286,9 +286,9 @@ def get_autocommit(self, conn): def run( self, - sql: Union[str, list], + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, ): """ @@ -305,6 +305,7 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] @@ -313,7 +314,7 @@ def run( sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] if sql: - self.log.debug("Executing %d statements against Snowflake DB", len(sql)) + self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") @@ -322,33 +323,27 @@ def run( # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] - + results = [] for sql_statement in sql: + self._run_command(cur, sql_statement, parameters) - self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - - execution_info = [] if handler is not None: - cur = handler(cur) - for row in cur: - self.log.info("Statement execution info - %s", row) - execution_info.append(row) + result = handler(cur) + results.append(result) query_id = cur.sfqid self.log.info("Rows affected: %s", cur.rowcount) self.log.info("Snowflake query id: %s", query_id) self.query_ids.append(query_id) - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return execution_info + if handler is None: + return None + else: + return results def test_connection(self): """Test the Snowflake connection by running a simple query.""" diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 086c1d6fd5bee..dd996cc526b84 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Sequence, SupportsAbs +from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union from airflow.models import BaseOperator from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -83,9 +84,9 @@ class SnowflakeOperator(BaseOperator): def __init__( self, *, - sql: Any, + sql: Union[str, Iterable[str]], snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -113,11 +114,11 @@ def __init__( def get_db_hook(self) -> SnowflakeHook: return get_db_hook(self) - def execute(self, context: Any) -> None: + def execute(self, context: Any): """Run query on snowflake""" self.log.info('Executing: %s', self.sql) hook = self.get_db_hook() - execution_info = hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + execution_info = hook.run(self.sql, self.autocommit, self.parameters, fetch_all_handler) self.query_ids = hook.query_ids if self.do_xcom_push: @@ -186,9 +187,9 @@ class SnowflakeCheckOperator(SQLCheckOperator): def __init__( self, *, - sql: Any, + sql: str, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -257,7 +258,7 @@ def __init__( pass_value: Any, tolerance: Any = None, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -334,7 +335,7 @@ def __init__( date_filter_column: str = 'ds', days_back: SupportsAbs[int] = -7, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py index 7ef97ca2963f4..ef20d760c0519 100644 --- a/airflow/providers/sqlite/operators/sqlite.py +++ b/airflow/providers/sqlite/operators/sqlite.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.sqlite.hooks.sqlite import SqliteHook @@ -45,9 +45,9 @@ class SqliteOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], sqlite_conn_id: str = 'sqlite_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 9170e19a547ae..f0c966b4c9796 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -19,10 +19,8 @@ import os import warnings from contextlib import closing -from itertools import chain -from typing import Any, Callable, Iterable, Optional, Tuple, overload +from typing import Any, Callable, Iterable, Mapping, Optional, Union, overload -import sqlparse import trino from trino.exceptions import DatabaseError from trino.transaction import IsolationLevel @@ -150,12 +148,8 @@ def get_isolation_level(self) -> Any: isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None): + def get_records(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None): """Get a set of records from Trino :param sql: SQL statement to be executed. @@ -163,10 +157,14 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): """ @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): + def get_records( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ): """:sphinx-autoapi-skip:""" - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): + def get_records( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ): """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -177,12 +175,12 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str sql = hql try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: + def get_first(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """Returns only the first row, regardless of how many rows the query returns. :param sql: SQL statement to be executed. @@ -190,10 +188,14 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: """ @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: + def get_first( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ) -> Any: """:sphinx-autoapi-skip:""" - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: + def get_first( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ) -> Any: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -204,13 +206,13 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = sql = hql try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) @overload def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs ): # type: ignore[override] """Get a pandas dataframe from a sql query. @@ -220,12 +222,12 @@ def get_pandas_df( @overload def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs ): # type: ignore[override] """:sphinx-autoapi-skip:""" def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs ): # type: ignore[override] """:sphinx-autoapi-skip:""" if hql: @@ -240,7 +242,7 @@ def get_pandas_df( cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise TrinoException(e) @@ -255,32 +257,32 @@ def get_pandas_df( @overload def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, - ) -> None: + ) -> Optional[list]: """Execute the statement against Trino. Can be used to create views.""" @overload def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, hql: str = "", - ) -> None: + ) -> Optional[list]: """:sphinx-autoapi-skip:""" def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, hql: str = "", - ): + ) -> Optional[list]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -289,38 +291,11 @@ def run( stacklevel=2, ) sql = hql - scalar = isinstance(sql, str) - - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - - if scalar: - sql = sqlparse.split(sql) - - with closing(conn.cursor()) as cur: - results = [] - for sql_statement in sql: - self._run_command(cur, self._strip_sql(sql_statement), parameters) - self.query_id = cur.stats["queryId"] - if handler is not None: - result = handler(cur) - results.append(result) - - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. - if not self.get_autocommit(conn): - conn.commit() - - self.log.info("Query Execution Result: %s", str(list(chain.from_iterable(results)))) - - if handler is None: - return None - if scalar: - return results[0] + if isinstance(sql, str): + sql = self.strip_sql_string(sql) - return results + return super().run(sql=sql, autocommit=autocommit, parameters=parameters, handler=handler) def insert_rows( self, diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 203fa5e22aff1..743a73d0da786 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -62,8 +62,7 @@ "deps": [ "apache-airflow-providers-common-sql", "apache-airflow>=2.2.0", - "sqlalchemy-drill>=1.1.0", - "sqlparse>=0.4.1" + "sqlalchemy-drill>=1.1.0" ], "cross-providers-deps": [ "common.sql" @@ -191,7 +190,9 @@ "cross-providers-deps": [] }, "common.sql": { - "deps": [], + "deps": [ + "sqlparse>=0.4.2" + ], "cross-providers-deps": [] }, "databricks": { diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index a44fa57e075a7..05466477ad2ff 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -367,7 +367,7 @@ def handler(cur): result = self.db_hook.run(sql, parameters=param, handler=handler) assert called == 1 assert self.conn.commit.called - assert result == obj + assert result == [obj] def test_run_with_handler_multiple(self): sql = ['SQL', 'SQL'] diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 05952beabd5b8..1b845acf6eed3 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -23,6 +23,7 @@ import pytest from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.utils.session import provide_session @@ -72,17 +73,17 @@ def test_query(self, mock_requests, mock_conn): test_schema = [(field,) for field in test_fields] conn = mock_conn.return_value - cur = mock.MagicMock(rowcount=0) + cur = mock.MagicMock(rowcount=0, description=test_schema) + cur.fetchall.return_value = [] conn.cursor.return_value = cur - type(cur).description = mock.PropertyMock(return_value=test_schema) - query = "select * from test.test" - schema, results = self.hook.run(sql=query) + query = "select * from test.test;" + schema, results = self.hook.run(sql=query, handler=fetch_all_handler)[0] assert schema == test_schema assert results == [] - cur.execute.assert_has_calls([mock.call(q) for q in [query]]) + cur.execute.assert_has_calls([mock.call(q) for q in [query.rstrip(';')]]) cur.close.assert_called() def test_no_query(self): diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 783fa520a79cc..6775ff83988ce 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -25,6 +25,7 @@ from databricks.sql.types import Row from airflow import AirflowException +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.operators.databricks_sql import ( DatabricksCopyIntoOperator, DatabricksSqlOperator, @@ -47,7 +48,7 @@ def test_exec_success(self, db_mock_class): db_mock = db_mock_class.return_value mock_schema = [('id',), ('value',)] mock_results = [Row(id=1, value='value1')] - db_mock.run.return_value = (mock_schema, mock_results) + db_mock.run.return_value = [(mock_schema, mock_results)] results = op.execute(None) @@ -61,7 +62,7 @@ def test_exec_success(self, db_mock_class): catalog=None, schema=None, ) - db_mock.run.assert_called_once_with(sql, parameters=None) + db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) @mock.patch('airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook') def test_exec_write_file(self, db_mock_class): @@ -74,7 +75,7 @@ def test_exec_write_file(self, db_mock_class): db_mock = db_mock_class.return_value mock_schema = [('id',), ('value',)] mock_results = [Row(id=1, value='value1')] - db_mock.run.return_value = (mock_schema, mock_results) + db_mock.run.return_value = [(mock_schema, mock_results)] try: op.execute(None) @@ -92,7 +93,7 @@ def test_exec_write_file(self, db_mock_class): catalog=None, schema=None, ) - db_mock.run.assert_called_once_with(sql, parameters=None) + db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) class TestDatabricksSqlCopyIntoOperator(unittest.TestCase): diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py index bf85465629c43..92a14696e8328 100644 --- a/tests/providers/exasol/hooks/test_exasol.py +++ b/tests/providers/exasol/hooks/test_exasol.py @@ -95,29 +95,29 @@ def test_get_autocommit(self): assert not self.db_hook.get_autocommit(self.conn) def test_run_without_autocommit(self): - sql = 'SQL' + sql = 'SELECT 1;' setattr(self.conn, 'attr', {'autocommit': False}) # Default autocommit setting should be False. # Testing default autocommit value as well as run() behavior. self.db_hook.run(sql, autocommit=False) self.conn.set_autocommit.assert_called_once_with(False) - self.conn.execute.assert_called_once_with(sql, None) + self.conn.execute.assert_called_once_with(sql.rstrip(';'), None) self.conn.commit.assert_called_once() def test_run_with_autocommit(self): - sql = 'SQL' + sql = 'SELECT 1;' self.db_hook.run(sql, autocommit=True) self.conn.set_autocommit.assert_called_once_with(True) - self.conn.execute.assert_called_once_with(sql, None) + self.conn.execute.assert_called_once_with(sql.rstrip(';'), None) self.conn.commit.assert_not_called() def test_run_with_parameters(self): - sql = 'SQL' + sql = 'SELECT 1;' parameters = ('param1', 'param2') self.db_hook.run(sql, autocommit=True, parameters=parameters) self.conn.set_autocommit.assert_called_once_with(True) - self.conn.execute.assert_called_once_with(sql, parameters) + self.conn.execute.assert_called_once_with(sql.rstrip(';'), parameters) self.conn.commit.assert_not_called() def test_run_multi_queries(self): diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index 812d60fd4552b..9168674c566a8 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -19,7 +19,8 @@ import unittest from unittest.mock import patch -from airflow.providers.jdbc.operators.jdbc import JdbcOperator, fetch_all_handler +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.jdbc.operators.jdbc import JdbcOperator class TestJdbcOperator(unittest.TestCase): diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index d33dbf79f721a..8c2ed9d5ae256 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -269,7 +269,7 @@ def getvalue(self): self.cur.bindvars = None result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')] - assert result == parameters + assert result == [parameters] def test_callproc_dict(self): parameters = {"a": 1, "b": 2, "c": 3} @@ -281,7 +281,7 @@ def getvalue(self): self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] - assert result == parameters + assert result == [parameters] def test_callproc_list(self): parameters = [1, 2, 3] @@ -293,7 +293,7 @@ def getvalue(self): self.cur.bindvars = list(map(bindvar, parameters)) result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] - assert result == parameters + assert result == [parameters] def test_callproc_out_param(self): parameters = [1, int, float, bool, str] @@ -307,7 +307,7 @@ def bindvar(value): result = self.db_hook.callproc('proc', True, parameters) expected = [1, 0, 0.0, False, ''] assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] - assert result == expected + assert result == [expected] def test_test_connection_use_dual_table(self): status, message = self.db_hook.test_connection() From a184b3456227d69b7af8d1cad3ed99e23042afcf Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Fri, 15 Jul 2022 01:29:13 +0300 Subject: [PATCH 3/7] Add `split_statements` param --- airflow/providers/apache/drill/operators/drill.py | 2 +- airflow/providers/common/sql/hooks/sql.py | 7 ++++++- airflow/providers/databricks/hooks/databricks_sql.py | 7 ++++++- airflow/providers/exasol/hooks/exasol.py | 7 ++++++- airflow/providers/presto/hooks/presto.py | 11 ++++++++++- airflow/providers/snowflake/hooks/snowflake.py | 12 ++++++++---- airflow/providers/trino/hooks/trino.py | 11 ++++++++++- 7 files changed, 47 insertions(+), 10 deletions(-) diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 44a0a1e2b6583..6dad45cc3c323 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -62,4 +62,4 @@ def __init__( def execute(self, context: 'Context'): self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id) self.hook = DrillHook(drill_conn_id=self.drill_conn_id) - self.hook.run(DrillHook.split_sql_string(self.sql), parameters=self.parameters) + self.hook.run(self.sql, parameters=self.parameters, split_statements=True) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 0ae0d1457ef0a..81111a8c25b91 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -235,6 +235,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, ) -> Optional[list]: """ Runs a command or a list of commands. Pass a list of sql @@ -247,10 +248,14 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately :return: return only result of the ALL SQL expressions if handler was provided. """ if isinstance(sql, str): - sql = [sql] + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [sql] if sql: self.log.debug("Executing following statements against DB: %s", list(sql)) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 3d9ace8b2319d..9c637cbe6fb7e 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -144,6 +144,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = True, ) -> Optional[List[Tuple[Any, Any]]]: """ Runs a command or a list of commands. Pass a list of sql @@ -156,10 +157,14 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): - sql = self.split_sql_string(sql) + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [sql] if sql: self.log.debug("Executing following statements against Databricks DB: %s", list(sql)) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index a2e3f0452320c..d3550556de5d1 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -134,6 +134,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, ) -> Optional[list]: """ Runs a command or a list of commands. Pass a list of sql @@ -146,10 +147,14 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): - sql = self.split_sql_string(sql) + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [sql] if sql: self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 70eb2d1c9088d..972c096387379 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -241,6 +241,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, ) -> Optional[list]: """Execute the statement against Presto. Can be used to create views.""" @@ -251,6 +252,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, hql: str = "", ) -> Optional[list]: """:sphinx-autoapi-skip:""" @@ -261,6 +263,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, hql: str = "", ) -> Optional[list]: """:sphinx-autoapi-skip:""" @@ -275,7 +278,13 @@ def run( if isinstance(sql, str): sql = self.strip_sql_string(sql) - return super().run(sql=sql, autocommit=autocommit, parameters=parameters, handler=handler) + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + ) def insert_rows( self, diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index fd0bcbdce856a..a65d85cb4dea8 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -24,8 +24,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from snowflake import connector -from snowflake.connector import DictCursor, SnowflakeConnection -from snowflake.connector.util_text import split_statements +from snowflake.connector import DictCursor, SnowflakeConnection, util_text from snowflake.sqlalchemy import URL from sqlalchemy import create_engine @@ -290,6 +289,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = True, ): """ Runs a command or a list of commands. Pass a list of sql @@ -305,13 +305,17 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] if isinstance(sql, str): - split_statements_tuple = split_statements(StringIO(sql)) - sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + if split_statements: + split_statements_tuple = util_text.split_statements(StringIO(sql)) + sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + else: + sql = [sql] if sql: self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index f0c966b4c9796..0c252e9613e5f 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -261,6 +261,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, ) -> Optional[list]: """Execute the statement against Trino. Can be used to create views.""" @@ -271,6 +272,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, hql: str = "", ) -> Optional[list]: """:sphinx-autoapi-skip:""" @@ -281,6 +283,7 @@ def run( autocommit: bool = False, parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, hql: str = "", ) -> Optional[list]: """:sphinx-autoapi-skip:""" @@ -295,7 +298,13 @@ def run( if isinstance(sql, str): sql = self.strip_sql_string(sql) - return super().run(sql=sql, autocommit=autocommit, parameters=parameters, handler=handler) + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + ) def insert_rows( self, From f9e1461829b2af7c8cd737bad8a2ad8445442d95 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Mon, 18 Jul 2022 01:38:45 +0300 Subject: [PATCH 4/7] Add `return_last` param --- airflow/providers/common/sql/hooks/sql.py | 6 +++++- airflow/providers/databricks/hooks/databricks_sql.py | 6 +++++- airflow/providers/exasol/hooks/exasol.py | 6 +++++- airflow/providers/presto/hooks/presto.py | 11 +++++++---- airflow/providers/snowflake/hooks/snowflake.py | 6 +++++- airflow/providers/trino/hooks/trino.py | 11 +++++++---- 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 81111a8c25b91..f6cf83c62db67 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -236,7 +236,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, - ) -> Optional[list]: + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -249,6 +250,7 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return handler result for only last statement or for all :return: return only result of the ALL SQL expressions if handler was provided. """ if isinstance(sql, str): @@ -281,6 +283,8 @@ def run( if handler is None: return None + elif return_last: + return results[-1] else: return results diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 9c637cbe6fb7e..655ab31a4333c 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -145,7 +145,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = True, - ) -> Optional[List[Tuple[Any, Any]]]: + return_last: bool = True, + ) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -158,6 +159,7 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return handler result for only last statement or for all :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): @@ -189,6 +191,8 @@ def run( if handler is None: return None + elif return_last: + return results[-1] else: return results diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index d3550556de5d1..5743e309e0f5d 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -135,7 +135,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, - ) -> Optional[list]: + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -148,6 +149,7 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return handler result for only last statement or for all :return: return only result of the LAST SQL expression if handler was provided. """ if isinstance(sql, str): @@ -179,6 +181,8 @@ def run( if handler is None: return None + elif return_last: + return results[-1] else: return results diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 972c096387379..51fd26e44ebdd 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -18,7 +18,7 @@ import json import os import warnings -from typing import Any, Callable, Iterable, Mapping, Optional, Union, overload +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload import prestodb from prestodb.exceptions import DatabaseError @@ -242,7 +242,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, - ) -> Optional[list]: + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """Execute the statement against Presto. Can be used to create views.""" @overload @@ -253,8 +254,9 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> Optional[list]: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" def run( @@ -264,8 +266,9 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> Optional[list]: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index a65d85cb4dea8..0582a31610984 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -290,7 +290,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = True, - ): + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -306,6 +307,7 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return handler result for only last statement or for all :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] @@ -346,6 +348,8 @@ def run( if handler is None: return None + elif return_last: + return results[-1] else: return results diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 0c252e9613e5f..30bafa42b0ac4 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -19,7 +19,7 @@ import os import warnings from contextlib import closing -from typing import Any, Callable, Iterable, Mapping, Optional, Union, overload +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload import trino from trino.exceptions import DatabaseError @@ -262,7 +262,8 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, - ) -> Optional[list]: + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """Execute the statement against Trino. Can be used to create views.""" @overload @@ -273,8 +274,9 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> Optional[list]: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" def run( @@ -284,8 +286,9 @@ def run( parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> Optional[list]: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( From 6fd8745377318f750a4338f2d982676af82dc183 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Tue, 19 Jul 2022 00:40:39 +0300 Subject: [PATCH 5/7] Fix tests --- airflow/providers/common/sql/hooks/sql.py | 5 +++-- airflow/providers/databricks/hooks/databricks_sql.py | 5 +++-- airflow/providers/exasol/hooks/exasol.py | 5 +++-- airflow/providers/snowflake/hooks/snowflake.py | 5 +++-- tests/providers/common/sql/hooks/test_dbapi.py | 2 +- .../databricks/hooks/test_databricks_sql.py | 2 +- tests/providers/exasol/hooks/test_exasol.py | 12 ++++++------ tests/providers/oracle/hooks/test_oracle.py | 8 ++++---- 8 files changed, 24 insertions(+), 20 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index f6cf83c62db67..ec396fbf7adad 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -250,9 +250,10 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return handler result for only last statement or for all + :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the ALL SQL expressions if handler was provided. """ + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -283,7 +284,7 @@ def run( if handler is None: return None - elif return_last: + elif scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 655ab31a4333c..2adc3a78dfbde 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -159,9 +159,10 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return handler result for only last statement or for all + :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided. """ + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -191,7 +192,7 @@ def run( if handler is None: return None - elif return_last: + elif scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 5743e309e0f5d..aabd93101a8ce 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -149,9 +149,10 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return handler result for only last statement or for all + :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided. """ + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -181,7 +182,7 @@ def run( if handler is None: return None - elif return_last: + elif scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 0582a31610984..7a808e3d01b35 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -307,11 +307,12 @@ def run( :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into statements and run separately - :param return_last: Whether to return handler result for only last statement or for all + :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: split_statements_tuple = util_text.split_statements(StringIO(sql)) @@ -348,7 +349,7 @@ def run( if handler is None: return None - elif return_last: + elif scalar_return_last: return results[-1] else: return results diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index 05466477ad2ff..a44fa57e075a7 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -367,7 +367,7 @@ def handler(cur): result = self.db_hook.run(sql, parameters=param, handler=handler) assert called == 1 assert self.conn.commit.called - assert result == [obj] + assert result == obj def test_run_with_handler_multiple(self): sql = ['SQL', 'SQL'] diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 1b845acf6eed3..d70d203e33f66 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -78,7 +78,7 @@ def test_query(self, mock_requests, mock_conn): conn.cursor.return_value = cur query = "select * from test.test;" - schema, results = self.hook.run(sql=query, handler=fetch_all_handler)[0] + schema, results = self.hook.run(sql=query, handler=fetch_all_handler) assert schema == test_schema assert results == [] diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py index 92a14696e8328..bf85465629c43 100644 --- a/tests/providers/exasol/hooks/test_exasol.py +++ b/tests/providers/exasol/hooks/test_exasol.py @@ -95,29 +95,29 @@ def test_get_autocommit(self): assert not self.db_hook.get_autocommit(self.conn) def test_run_without_autocommit(self): - sql = 'SELECT 1;' + sql = 'SQL' setattr(self.conn, 'attr', {'autocommit': False}) # Default autocommit setting should be False. # Testing default autocommit value as well as run() behavior. self.db_hook.run(sql, autocommit=False) self.conn.set_autocommit.assert_called_once_with(False) - self.conn.execute.assert_called_once_with(sql.rstrip(';'), None) + self.conn.execute.assert_called_once_with(sql, None) self.conn.commit.assert_called_once() def test_run_with_autocommit(self): - sql = 'SELECT 1;' + sql = 'SQL' self.db_hook.run(sql, autocommit=True) self.conn.set_autocommit.assert_called_once_with(True) - self.conn.execute.assert_called_once_with(sql.rstrip(';'), None) + self.conn.execute.assert_called_once_with(sql, None) self.conn.commit.assert_not_called() def test_run_with_parameters(self): - sql = 'SELECT 1;' + sql = 'SQL' parameters = ('param1', 'param2') self.db_hook.run(sql, autocommit=True, parameters=parameters) self.conn.set_autocommit.assert_called_once_with(True) - self.conn.execute.assert_called_once_with(sql.rstrip(';'), parameters) + self.conn.execute.assert_called_once_with(sql, parameters) self.conn.commit.assert_not_called() def test_run_multi_queries(self): diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 8c2ed9d5ae256..d33dbf79f721a 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -269,7 +269,7 @@ def getvalue(self): self.cur.bindvars = None result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')] - assert result == [parameters] + assert result == parameters def test_callproc_dict(self): parameters = {"a": 1, "b": 2, "c": 3} @@ -281,7 +281,7 @@ def getvalue(self): self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] - assert result == [parameters] + assert result == parameters def test_callproc_list(self): parameters = [1, 2, 3] @@ -293,7 +293,7 @@ def getvalue(self): self.cur.bindvars = list(map(bindvar, parameters)) result = self.db_hook.callproc('proc', True, parameters) assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] - assert result == [parameters] + assert result == parameters def test_callproc_out_param(self): parameters = [1, int, float, bool, str] @@ -307,7 +307,7 @@ def bindvar(value): result = self.db_hook.callproc('proc', True, parameters) expected = [1, 0, 0.0, False, ''] assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] - assert result == [expected] + assert result == expected def test_test_connection_use_dual_table(self): status, message = self.db_hook.test_connection() From dcc2debc0b38442e4357b0f48186301a6b4f73b8 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Tue, 19 Jul 2022 19:07:27 +0300 Subject: [PATCH 6/7] Review suggestions --- airflow/dependencies.json | 0 airflow/providers/common/sql/hooks/sql.py | 2 +- airflow/providers/databricks/hooks/databricks_sql.py | 2 +- airflow/providers/exasol/hooks/exasol.py | 2 +- airflow/providers/presto/hooks/presto.py | 4 +--- airflow/providers/snowflake/hooks/snowflake.py | 2 +- airflow/providers/trino/hooks/trino.py | 4 +--- 7 files changed, 6 insertions(+), 10 deletions(-) delete mode 100644 airflow/dependencies.json diff --git a/airflow/dependencies.json b/airflow/dependencies.json deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index ec396fbf7adad..e6687fa938a44 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -258,7 +258,7 @@ def run( if split_statements: sql = self.split_sql_string(sql) else: - sql = [sql] + sql = [self.strip_sql_string(sql)] if sql: self.log.debug("Executing following statements against DB: %s", list(sql)) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 2adc3a78dfbde..7a888438e9dc4 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -167,7 +167,7 @@ def run( if split_statements: sql = self.split_sql_string(sql) else: - sql = [sql] + sql = [self.strip_sql_string(sql)] if sql: self.log.debug("Executing following statements against Databricks DB: %s", list(sql)) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index aabd93101a8ce..537f2fcb6d2b7 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -157,7 +157,7 @@ def run( if split_statements: sql = self.split_sql_string(sql) else: - sql = [sql] + sql = [self.strip_sql_string(sql)] if sql: self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 51fd26e44ebdd..709a378a8d89f 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -278,15 +278,13 @@ def run( ) sql = hql - if isinstance(sql, str): - sql = self.strip_sql_string(sql) - return super().run( sql=sql, autocommit=autocommit, parameters=parameters, handler=handler, split_statements=split_statements, + return_last=return_last, ) def insert_rows( diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 7a808e3d01b35..21655aaf5b1fd 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -318,7 +318,7 @@ def run( split_statements_tuple = util_text.split_statements(StringIO(sql)) sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] else: - sql = [sql] + sql = [self.strip_sql_string(sql)] if sql: self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 30bafa42b0ac4..d8ac5148de07a 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -298,15 +298,13 @@ def run( ) sql = hql - if isinstance(sql, str): - sql = self.strip_sql_string(sql) - return super().run( sql=sql, autocommit=autocommit, parameters=parameters, handler=handler, split_statements=split_statements, + return_last=return_last, ) def insert_rows( From c3cdcd0ee0961e8ad90a797210633321a560272c Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Thu, 21 Jul 2022 00:29:21 +0300 Subject: [PATCH 7/7] Fix tests --- tests/providers/oracle/hooks/test_oracle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index d33dbf79f721a..254514bc9fe86 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -268,7 +268,7 @@ def getvalue(self): self.cur.bindvars = None result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END')] assert result == parameters def test_callproc_dict(self): @@ -280,7 +280,7 @@ def getvalue(self): self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END', parameters)] assert result == parameters def test_callproc_list(self): @@ -292,7 +292,7 @@ def getvalue(self): self.cur.bindvars = list(map(bindvar, parameters)) result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END', parameters)] assert result == parameters def test_callproc_out_param(self): @@ -306,7 +306,7 @@ def bindvar(value): self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters] result = self.db_hook.callproc('proc', True, parameters) expected = [1, 0, 0.0, False, ''] - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END', expected)] assert result == expected def test_test_connection_use_dual_table(self):