diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index cb8b664875971..fbce9b85a1c97 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -496,7 +496,7 @@ def __init__( follow_task_ids_if_false: List[str], conn_id: str = "default_conn_id", database: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(conn_id=conn_id, database=database, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py index c7ad77acb5341..aa324b40bc0b6 100644 --- a/airflow/providers/amazon/aws/operators/redshift_sql.py +++ b/airflow/providers/amazon/aws/operators/redshift_sql.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook @@ -55,7 +55,7 @@ def __init__( *, sql: Union[str, Iterable[str]], redshift_conn_id: str = 'redshift_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, **kwargs, ) -> None: diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 0bfdda44e78d6..f5a3cd9bf6547 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -93,7 +93,7 @@ def __init__( unload_options: Optional[List] = None, autocommit: bool = False, include_header: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, table_as_file_name: bool = True, # Set to True by default for not breaking current workflows **kwargs, ) -> None: diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 014e23ec070f2..747c97e1f1212 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -16,7 +16,7 @@ # under the License. import warnings -from typing import TYPE_CHECKING, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -140,7 +140,7 @@ def execute(self, context: 'Context') -> None: copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options) - sql: Union[list, str] + sql: Union[str, Iterable[str]] if self.method == 'REPLACE': sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"] diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 791ed546c34fe..6dad45cc3c323 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -17,8 +17,6 @@ # under the License. from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union -import sqlparse - from airflow.models import BaseOperator from airflow.providers.apache.drill.hooks.drill import DrillHook @@ -52,7 +50,7 @@ def __init__( *, sql: str, drill_conn_id: str = 'drill_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -64,6 +62,4 @@ def __init__( def execute(self, context: 'Context'): self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id) self.hook = DrillHook(drill_conn_id=self.drill_conn_id) - sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True)) - no_term_sql = [s[:-1] for s in sql if s[-1] == ';'] - self.hook.run(no_term_sql, parameters=self.parameters) + self.hook.run(self.sql, parameters=self.parameters, split_statements=True) diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml index 33235850b8515..0e26ae5186e3c 100644 --- a/airflow/providers/apache/drill/provider.yaml +++ b/airflow/providers/apache/drill/provider.yaml @@ -34,7 +34,6 @@ dependencies: - apache-airflow>=2.2.0 - apache-airflow-providers-common-sql - sqlalchemy-drill>=1.1.0 - - sqlparse>=0.4.1 integrations: - integration-name: Apache Drill diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index fa31b9f33d18f..794646e46dc62 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -18,7 +18,7 @@ import os import subprocess -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Mapping, Optional, Union from pinotdb import connect @@ -275,7 +275,7 @@ def get_uri(self) -> str: endpoint = conn.extra_dejson.get('endpoint', 'query/sql') return f'{conn_type}://{host}/{endpoint}' - def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """ Executes the sql and returns a set of records. @@ -287,7 +287,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera cur.execute(sql) return cur.fetchall() - def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any: + def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """ Executes the sql and returns the first resulting row. diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index efd4a9dcfed3b..e6687fa938a44 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -17,8 +17,9 @@ import warnings from contextlib import closing from datetime import datetime -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union +import sqlparse from sqlalchemy import create_engine from typing_extensions import Protocol @@ -27,6 +28,17 @@ from airflow.providers_manager import ProvidersManager from airflow.utils.module_loading import import_string +if TYPE_CHECKING: + from sqlalchemy.engine import CursorResult + + +def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]: + """Handler for DbApiHook.run() to return results""" + if cursor.returns_rows: + return cursor.fetchall() + else: + return None + def _backported_get_hook(connection, *, hook_params=None): """Return hook based on conn_type @@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None): cur.execute(sql) return cur.fetchone() - def run(self, sql, autocommit=False, parameters=None, handler=None): + @staticmethod + def strip_sql_string(sql: str) -> str: + return sql.strip().rstrip(';') + + @staticmethod + def split_sql_string(sql: str) -> List[str]: + """ + Splits string into multiple SQL expressions + + :param sql: SQL string potentially consisting of multiple expressions + :return: list of individual expressions + """ + splits = sqlparse.split(sqlparse.format(sql, strip_comments=True)) + statements = [s.rstrip(';') for s in splits if s.endswith(';')] + return statements + + def run( + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. - :return: query results if handler was provided. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the ALL SQL expressions if handler was provided. """ - scalar = isinstance(sql, str) - if scalar: - sql = [sql] + scalar_return_last = isinstance(sql, str) and return_last + if isinstance(sql, str): + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements", len(sql)) + self.log.debug("Executing following statements against DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") @@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None): results = [] for sql_statement in sql: self._run_command(cur, sql_statement, parameters) + if handler is not None: result = handler(cur) results.append(result) - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() if handler is None: return None - - if scalar: - return results[0] - - return results + elif scalar_return_last: + return results[-1] + else: + return results def _run_command(self, cur, sql_statement, parameters): """Runs a statement using an already open cursor.""" diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml index a277f327ccac5..39c8d483e41b1 100644 --- a/airflow/providers/common/sql/provider.yaml +++ b/airflow/providers/common/sql/provider.yaml @@ -24,7 +24,8 @@ description: | versions: - 1.0.0 -dependencies: [] +dependencies: + - sqlparse>=0.4.2 additional-extras: - name: pandas diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 6c5800170d255..7a888438e9dc4 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. -import re from contextlib import closing from copy import copy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union from databricks import sql # type: ignore[attr-defined] from databricks.sql.client import Connection # type: ignore[attr-defined] @@ -139,19 +138,15 @@ def get_conn(self) -> Connection: ) return self._sql_conn - @staticmethod - def maybe_split_sql_string(sql: str) -> List[str]: - """ - Splits strings consisting of multiple SQL expressions into an - TODO: do we need something more sophisticated? - - :param sql: SQL string potentially consisting of multiple expressions - :return: list of individual expressions - """ - splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""] - return splits - - def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None): + def run( + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, + split_statements: bool = True, + return_last: bool = True, + ) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -163,41 +158,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. - :return: query results. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - sql = self.maybe_split_sql_string(sql) + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements", len(sql)) + self.log.debug("Executing following statements against Databricks DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") - conn = None + results = [] for sql_statement in sql: # when using AAD tokens, it could expire if previous query run longer than token lifetime - conn = self.get_conn() - with closing(conn.cursor()) as cur: - self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - schema = cur.description - results = [] - if handler is not None: - cur = handler(cur) - for row in cur: - self.log.debug("Statement results: %s", row) - results.append(row) - - self.log.info("Rows affected: %s", cur.rowcount) - if conn: - conn.close() + with closing(self.get_conn()) as conn: + self.set_autocommit(conn, autocommit) + + with closing(conn.cursor()) as cur: + self._run_command(cur, sql_statement, parameters) + + if handler is not None: + result = handler(cur) + schema = cur.description + results.append((schema, result)) + self._sql_conn = None - # Return only result of the last SQL expression - return schema, results + if handler is None: + return None + elif scalar_return_last: + return results[-1] + else: + return results def test_connection(self): """Test the Databricks SQL connection by running a simple query.""" diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 9e6298bc21263..ad4add30fe8e6 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -20,12 +20,13 @@ import csv import json -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook if TYPE_CHECKING: @@ -71,11 +72,11 @@ class DatabricksSqlOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], databricks_conn_id: str = DatabricksSqlHook.default_conn_name, http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, session_configuration=None, http_headers: Optional[List[Tuple[str, str]]] = None, catalog: Optional[str] = None, @@ -147,10 +148,11 @@ def _format_output(self, schema, results): else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - def execute(self, context: 'Context') -> Any: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.sql) hook = self._get_hook() - schema, results = hook.run(self.sql, parameters=self.parameters) + response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler) + schema, results = cast(List[Tuple[Any, Any]], response)[0] # self.log.info('Schema: %s', schema) # self.log.info('Results: %s', results) self._format_output(schema, results) diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 784c57cde0661..537f2fcb6d2b7 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -17,7 +17,7 @@ # under the License. from contextlib import closing -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import pandas as pd import pyexasol @@ -64,9 +64,7 @@ def get_conn(self) -> ExaConnection: conn = pyexasol.connect(**conn_args) return conn - def get_pandas_df( - self, sql: Union[str, list], parameters: Optional[dict] = None, **kwargs - ) -> pd.DataFrame: + def get_pandas_df(self, sql: str, parameters: Optional[dict] = None, **kwargs) -> pd.DataFrame: """ Executes the sql and returns a pandas dataframe @@ -79,9 +77,7 @@ def get_pandas_df( df = conn.export_to_pandas(sql, query_params=parameters, **kwargs) return df - def get_records( - self, sql: Union[str, list], parameters: Optional[dict] = None - ) -> List[Union[dict, Tuple[Any, ...]]]: + def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union[dict, Tuple[Any, ...]]]: """ Executes the sql and returns a set of records. @@ -93,7 +89,7 @@ def get_records( with closing(conn.execute(sql, parameters)) as cur: return cur.fetchall() - def get_first(self, sql: Union[str, list], parameters: Optional[dict] = None) -> Optional[Any]: + def get_first(self, sql: str, parameters: Optional[dict] = None) -> Optional[Any]: """ Executes the sql and returns the first resulting row. @@ -133,8 +129,14 @@ def export_to_file( self.log.info("Data saved to %s", filename) def run( - self, sql: Union[str, list], autocommit: bool = False, parameters: Optional[dict] = None, handler=None - ) -> Optional[list]: + self, + sql: Union[str, Iterable[str]], + autocommit: bool = False, + parameters: Optional[Union[Iterable, Mapping]] = None, + handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -146,38 +148,44 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - sql = [sql] + if split_statements: + sql = self.split_sql_string(sql) + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements against Exasol DB", len(sql)) + self.log.debug("Executing following statements against Exasol DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - - for query in sql: - self.log.info(query) - with closing(conn.execute(query, parameters)) as cur: - results = [] - + self.set_autocommit(conn, autocommit) + results = [] + for sql_statement in sql: + with closing(conn.execute(sql_statement, parameters)) as cur: + self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: - cur = handler(cur) + result = handler(cur) + results.append(result) - for row in cur: - self.log.info("Statement execution info - %s", row) - results.append(row) + self.log.info("Rows affected: %s", cur.rowcount) - self.log.info(cur.row_count) - # If autocommit was set to False for db that supports autocommit, - # or if db does not support autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return results + if handler is None: + return None + elif scalar_return_last: + return results[-1] + else: + return results def set_autocommit(self, conn, autocommit: bool) -> None: """ diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index eecf44885ec42..33d4ab55c64f0 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.exasol.hooks.exasol import ExasolHook @@ -46,10 +46,10 @@ class ExasolOperator(BaseOperator): def __init__( self, *, - sql: str, + sql: Union[str, Iterable[str]], exasol_conn_id: str = 'exasol_default', autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, schema: Optional[str] = None, **kwargs, ) -> None: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index b7e12dabce641..550c3174068ec 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -550,7 +550,7 @@ def operator_extra_links(self): def __init__( self, *, - sql: Union[str, Iterable], + sql: Union[str, Iterable[str]], destination_dataset_table: Optional[str] = None, write_disposition: str = 'WRITE_EMPTY', allow_large_results: bool = False, diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index e90a21e296727..a5d40c2ff4889 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. """This module contains Google Cloud SQL operators.""" -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from googleapiclient.errors import HttpError @@ -1054,9 +1054,9 @@ class CloudSQLExecuteQueryOperator(BaseOperator): def __init__( self, *, - sql: Union[List[str], str], + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Union[Dict, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, gcp_conn_id: str = 'google_cloud_default', gcp_cloudsql_conn_id: str = 'google_cloud_sql_default', **kwargs, diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py index 8384868199b6f..8626fb7227182 100644 --- a/airflow/providers/google/suite/transfers/sql_to_sheets.py +++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py @@ -68,7 +68,7 @@ def __init__( sql: str, spreadsheet_id: str, sql_conn_id: str, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, database: Optional[str] = None, spreadsheet_range: str = "Sheet1", gcp_conn_id: str = "google_cloud_default", diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index 2c023d9afe9ba..6b38366b41e81 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -16,20 +16,16 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.jdbc.hooks.jdbc import JdbcHook if TYPE_CHECKING: from airflow.utils.context import Context -def fetch_all_handler(cursor): - """Handler for DbApiHook.run() to return results""" - return cursor.fetchall() - - class JdbcOperator(BaseOperator): """ Executes sql code in a database using jdbc driver. @@ -57,10 +53,10 @@ class JdbcOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], jdbc_conn_id: str = 'jdbc_default', autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -70,7 +66,7 @@ def __init__( self.autocommit = autocommit self.hook = None - def execute(self, context: 'Context') -> None: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.sql) hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler) diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index 5a5738eb6878c..3a6434704f1fa 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -57,9 +57,9 @@ class MsSqlOperator(BaseOperator): def __init__( self, *, - sql: str, + sql: Union[str, Iterable[str]], mssql_conn_id: str = 'mssql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, database: Optional[str] = None, **kwargs, diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index d51a97e6fe38a..975586cd52991 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. import ast -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.mysql.hooks.mysql import MySqlHook @@ -59,9 +59,9 @@ class MySqlOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], mysql_conn_id: str = 'mysql_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, database: Optional[str] = None, **kwargs, diff --git a/airflow/providers/neo4j/operators/neo4j.py b/airflow/providers/neo4j/operators/neo4j.py index b61f0734f0841..82939ad7790f2 100644 --- a/airflow/providers/neo4j/operators/neo4j.py +++ b/airflow/providers/neo4j/operators/neo4j.py @@ -44,7 +44,7 @@ def __init__( *, sql: str, neo4j_conn_id: str = 'neo4j_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index b60d4b6e89100..969d69728bbaf 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -50,9 +50,9 @@ class OracleOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], oracle_conn_id: str = 'oracle_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = False, **kwargs, ) -> None: @@ -98,7 +98,7 @@ def __init__( self.procedure = procedure self.parameters = parameters - def execute(self, context: 'Context') -> Optional[Union[List, Dict]]: + def execute(self, context: 'Context'): self.log.info('Executing: %s', self.procedure) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters) diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index e0238aa88204b..7a787498d2dd4 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union from psycopg2.sql import SQL, Identifier @@ -53,10 +53,10 @@ class PostgresOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], postgres_conn_id: str = 'postgres_default', autocommit: bool = False, - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, database: Optional[str] = None, runtime_parameters: Optional[Mapping] = None, **kwargs, diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 22afc71577341..709a378a8d89f 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -18,7 +18,7 @@ import json import os import warnings -from typing import Any, Callable, Iterable, Optional, overload +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload import prestodb from prestodb.exceptions import DatabaseError @@ -142,10 +142,6 @@ def get_isolation_level(self) -> Any: isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - @overload def get_records(self, sql: str = "", parameters: Optional[dict] = None): """Get a set of records from Presto @@ -169,7 +165,7 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str sql = hql try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -196,7 +192,7 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = sql = hql try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise PrestoException(e) @@ -226,7 +222,7 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise PrestoException(e) @@ -241,32 +237,38 @@ def get_pandas_df(self, sql: str = "", parameters=None, hql: str = "", **kwargs) @overload def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, - ) -> None: + split_statements: bool = False, + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """Execute the statement against Presto. Can be used to create views.""" @overload def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> None: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" def run( self, - sql: str = "", + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> None: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -276,7 +278,14 @@ def run( ) sql = hql - return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler) + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, + ) def insert_rows( self, diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 3dee0989ed210..21655aaf5b1fd 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -19,13 +19,12 @@ from contextlib import closing from io import StringIO from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from snowflake import connector -from snowflake.connector import DictCursor, SnowflakeConnection -from snowflake.connector.util_text import split_statements +from snowflake.connector import DictCursor, SnowflakeConnection, util_text from snowflake.sqlalchemy import URL from sqlalchemy import create_engine @@ -286,11 +285,13 @@ def get_autocommit(self, conn): def run( self, - sql: Union[str, list], + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, - ): + split_statements: bool = True, + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -305,15 +306,22 @@ def run( before executing the query. :param parameters: The parameters to render the SQL query with. :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: return only result of the LAST SQL expression if handler was provided. """ self.query_ids = [] + scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): - split_statements_tuple = split_statements(StringIO(sql)) - sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + if split_statements: + split_statements_tuple = util_text.split_statements(StringIO(sql)) + sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string] + else: + sql = [self.strip_sql_string(sql)] if sql: - self.log.debug("Executing %d statements against Snowflake DB", len(sql)) + self.log.debug("Executing following statements against Snowflake DB: %s", list(sql)) else: raise ValueError("List of SQL statements is empty") @@ -322,33 +330,29 @@ def run( # SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var] - + results = [] for sql_statement in sql: + self._run_command(cur, sql_statement, parameters) - self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - if parameters: - cur.execute(sql_statement, parameters) - else: - cur.execute(sql_statement) - - execution_info = [] if handler is not None: - cur = handler(cur) - for row in cur: - self.log.info("Statement execution info - %s", row) - execution_info.append(row) + result = handler(cur) + results.append(result) query_id = cur.sfqid self.log.info("Rows affected: %s", cur.rowcount) self.log.info("Snowflake query id: %s", query_id) self.query_ids.append(query_id) - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. + # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): conn.commit() - return execution_info + if handler is None: + return None + elif scalar_return_last: + return results[-1] + else: + return results def test_connection(self): """Test the Snowflake connection by running a simple query.""" diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index 086c1d6fd5bee..dd996cc526b84 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -15,10 +15,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Sequence, SupportsAbs +from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union from airflow.models import BaseOperator from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -83,9 +84,9 @@ class SnowflakeOperator(BaseOperator): def __init__( self, *, - sql: Any, + sql: Union[str, Iterable[str]], snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -113,11 +114,11 @@ def __init__( def get_db_hook(self) -> SnowflakeHook: return get_db_hook(self) - def execute(self, context: Any) -> None: + def execute(self, context: Any): """Run query on snowflake""" self.log.info('Executing: %s', self.sql) hook = self.get_db_hook() - execution_info = hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + execution_info = hook.run(self.sql, self.autocommit, self.parameters, fetch_all_handler) self.query_ids = hook.query_ids if self.do_xcom_push: @@ -186,9 +187,9 @@ class SnowflakeCheckOperator(SQLCheckOperator): def __init__( self, *, - sql: Any, + sql: str, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -257,7 +258,7 @@ def __init__( pass_value: Any, tolerance: Any = None, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, @@ -334,7 +335,7 @@ def __init__( date_filter_column: str = 'ds', days_back: SupportsAbs[int] = -7, snowflake_conn_id: str = 'snowflake_default', - parameters: Optional[dict] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, autocommit: bool = True, do_xcom_push: bool = True, warehouse: Optional[str] = None, diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py index 7ef97ca2963f4..ef20d760c0519 100644 --- a/airflow/providers/sqlite/operators/sqlite.py +++ b/airflow/providers/sqlite/operators/sqlite.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, Union from airflow.models import BaseOperator from airflow.providers.sqlite.hooks.sqlite import SqliteHook @@ -45,9 +45,9 @@ class SqliteOperator(BaseOperator): def __init__( self, *, - sql: Union[str, List[str]], + sql: Union[str, Iterable[str]], sqlite_conn_id: str = 'sqlite_default', - parameters: Optional[Union[Mapping, Iterable]] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py index 9170e19a547ae..d8ac5148de07a 100644 --- a/airflow/providers/trino/hooks/trino.py +++ b/airflow/providers/trino/hooks/trino.py @@ -19,10 +19,8 @@ import os import warnings from contextlib import closing -from itertools import chain -from typing import Any, Callable, Iterable, Optional, Tuple, overload +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload -import sqlparse import trino from trino.exceptions import DatabaseError from trino.transaction import IsolationLevel @@ -150,12 +148,8 @@ def get_isolation_level(self) -> Any: isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) - @staticmethod - def _strip_sql(sql: str) -> str: - return sql.strip().rstrip(';') - @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None): + def get_records(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None): """Get a set of records from Trino :param sql: SQL statement to be executed. @@ -163,10 +157,14 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None): """ @overload - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): + def get_records( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ): """:sphinx-autoapi-skip:""" - def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""): + def get_records( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ): """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -177,12 +175,12 @@ def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str sql = hql try: - return super().get_records(self._strip_sql(sql), parameters) + return super().get_records(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: + def get_first(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None) -> Any: """Returns only the first row, regardless of how many rows the query returns. :param sql: SQL statement to be executed. @@ -190,10 +188,14 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any: """ @overload - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: + def get_first( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ) -> Any: """:sphinx-autoapi-skip:""" - def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any: + def get_first( + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "" + ) -> Any: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -204,13 +206,13 @@ def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = sql = hql try: - return super().get_first(self._strip_sql(sql), parameters) + return super().get_first(self.strip_sql_string(sql), parameters) except DatabaseError as e: raise TrinoException(e) @overload def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs ): # type: ignore[override] """Get a pandas dataframe from a sql query. @@ -220,12 +222,12 @@ def get_pandas_df( @overload def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs ): # type: ignore[override] """:sphinx-autoapi-skip:""" def get_pandas_df( - self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs + self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs ): # type: ignore[override] """:sphinx-autoapi-skip:""" if hql: @@ -240,7 +242,7 @@ def get_pandas_df( cursor = self.get_cursor() try: - cursor.execute(self._strip_sql(sql), parameters) + cursor.execute(self.strip_sql_string(sql), parameters) data = cursor.fetchall() except DatabaseError as e: raise TrinoException(e) @@ -255,32 +257,38 @@ def get_pandas_df( @overload def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, - ) -> None: + split_statements: bool = False, + return_last: bool = True, + ) -> Optional[Union[Any, List[Any]]]: """Execute the statement against Trino. Can be used to create views.""" @overload def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, hql: str = "", - ) -> None: + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" def run( self, - sql, + sql: Union[str, Iterable[str]], autocommit: bool = False, - parameters: Optional[Tuple] = None, + parameters: Optional[Union[Iterable, Mapping]] = None, handler: Optional[Callable] = None, + split_statements: bool = False, + return_last: bool = True, hql: str = "", - ): + ) -> Optional[Union[Any, List[Any]]]: """:sphinx-autoapi-skip:""" if hql: warnings.warn( @@ -289,38 +297,15 @@ def run( stacklevel=2, ) sql = hql - scalar = isinstance(sql, str) - - with closing(self.get_conn()) as conn: - if self.supports_autocommit: - self.set_autocommit(conn, autocommit) - if scalar: - sql = sqlparse.split(sql) - - with closing(conn.cursor()) as cur: - results = [] - for sql_statement in sql: - self._run_command(cur, self._strip_sql(sql_statement), parameters) - self.query_id = cur.stats["queryId"] - if handler is not None: - result = handler(cur) - results.append(result) - - # If autocommit was set to False for db that supports autocommit, - # or if db does not supports autocommit, we do a manual commit. - if not self.get_autocommit(conn): - conn.commit() - - self.log.info("Query Execution Result: %s", str(list(chain.from_iterable(results)))) - - if handler is None: - return None - - if scalar: - return results[0] - - return results + return super().run( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=split_statements, + return_last=return_last, + ) def insert_rows( self, diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py index 3a30e0ee2723e..7a804f8ed738e 100644 --- a/airflow/providers/vertica/operators/vertica.py +++ b/airflow/providers/vertica/operators/vertica.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, List, Sequence, Union +from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union from airflow.models import BaseOperator from airflow.providers.vertica.hooks.vertica import VerticaHook @@ -40,7 +40,7 @@ class VerticaOperator(BaseOperator): ui_color = '#b4e0ff' def __init__( - self, *, sql: Union[str, List[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any + self, *, sql: Union[str, Iterable[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any ) -> None: super().__init__(**kwargs) self.vertica_conn_id = vertica_conn_id diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 203fa5e22aff1..743a73d0da786 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -62,8 +62,7 @@ "deps": [ "apache-airflow-providers-common-sql", "apache-airflow>=2.2.0", - "sqlalchemy-drill>=1.1.0", - "sqlparse>=0.4.1" + "sqlalchemy-drill>=1.1.0" ], "cross-providers-deps": [ "common.sql" @@ -191,7 +190,9 @@ "cross-providers-deps": [] }, "common.sql": { - "deps": [], + "deps": [ + "sqlparse>=0.4.2" + ], "cross-providers-deps": [] }, "databricks": { diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index 05952beabd5b8..d70d203e33f66 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -23,6 +23,7 @@ import pytest from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.utils.session import provide_session @@ -72,17 +73,17 @@ def test_query(self, mock_requests, mock_conn): test_schema = [(field,) for field in test_fields] conn = mock_conn.return_value - cur = mock.MagicMock(rowcount=0) + cur = mock.MagicMock(rowcount=0, description=test_schema) + cur.fetchall.return_value = [] conn.cursor.return_value = cur - type(cur).description = mock.PropertyMock(return_value=test_schema) - query = "select * from test.test" - schema, results = self.hook.run(sql=query) + query = "select * from test.test;" + schema, results = self.hook.run(sql=query, handler=fetch_all_handler) assert schema == test_schema assert results == [] - cur.execute.assert_has_calls([mock.call(q) for q in [query]]) + cur.execute.assert_has_calls([mock.call(q) for q in [query.rstrip(';')]]) cur.close.assert_called() def test_no_query(self): diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 783fa520a79cc..6775ff83988ce 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -25,6 +25,7 @@ from databricks.sql.types import Row from airflow import AirflowException +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.databricks.operators.databricks_sql import ( DatabricksCopyIntoOperator, DatabricksSqlOperator, @@ -47,7 +48,7 @@ def test_exec_success(self, db_mock_class): db_mock = db_mock_class.return_value mock_schema = [('id',), ('value',)] mock_results = [Row(id=1, value='value1')] - db_mock.run.return_value = (mock_schema, mock_results) + db_mock.run.return_value = [(mock_schema, mock_results)] results = op.execute(None) @@ -61,7 +62,7 @@ def test_exec_success(self, db_mock_class): catalog=None, schema=None, ) - db_mock.run.assert_called_once_with(sql, parameters=None) + db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) @mock.patch('airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook') def test_exec_write_file(self, db_mock_class): @@ -74,7 +75,7 @@ def test_exec_write_file(self, db_mock_class): db_mock = db_mock_class.return_value mock_schema = [('id',), ('value',)] mock_results = [Row(id=1, value='value1')] - db_mock.run.return_value = (mock_schema, mock_results) + db_mock.run.return_value = [(mock_schema, mock_results)] try: op.execute(None) @@ -92,7 +93,7 @@ def test_exec_write_file(self, db_mock_class): catalog=None, schema=None, ) - db_mock.run.assert_called_once_with(sql, parameters=None) + db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) class TestDatabricksSqlCopyIntoOperator(unittest.TestCase): diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index 812d60fd4552b..9168674c566a8 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -19,7 +19,8 @@ import unittest from unittest.mock import patch -from airflow.providers.jdbc.operators.jdbc import JdbcOperator, fetch_all_handler +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.jdbc.operators.jdbc import JdbcOperator class TestJdbcOperator(unittest.TestCase): diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index d33dbf79f721a..254514bc9fe86 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -268,7 +268,7 @@ def getvalue(self): self.cur.bindvars = None result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END')] assert result == parameters def test_callproc_dict(self): @@ -280,7 +280,7 @@ def getvalue(self): self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END', parameters)] assert result == parameters def test_callproc_list(self): @@ -292,7 +292,7 @@ def getvalue(self): self.cur.bindvars = list(map(bindvar, parameters)) result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END', parameters)] assert result == parameters def test_callproc_out_param(self): @@ -306,7 +306,7 @@ def bindvar(value): self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters] result = self.db_hook.callproc('proc', True, parameters) expected = [1, 0, 0.0, False, ''] - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END', expected)] assert result == expected def test_test_connection_use_dual_table(self):