diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py index d5002e37488fb..77fd125aeeeb5 100644 --- a/airflow/providers/amazon/aws/operators/redshift_sql.py +++ b/airflow/providers/amazon/aws/operators/redshift_sql.py @@ -16,17 +16,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.www import utils as wwwutils -if TYPE_CHECKING: - from airflow.utils.context import Context - -class RedshiftSQLOperator(BaseOperator): +class RedshiftSQLOperator(SQLExecuteQueryOperator): """ Executes SQL Statements against an Amazon Redshift cluster @@ -54,29 +51,11 @@ class RedshiftSQLOperator(BaseOperator): "sql": "postgresql" if "postgresql" in wwwutils.get_attr_renderer() else "sql" } - def __init__( - self, - *, - sql: str | Iterable[str], - redshift_conn_id: str = 'redshift_default', - parameters: Iterable | Mapping | None = None, - autocommit: bool = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.redshift_conn_id = redshift_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - - def get_hook(self) -> RedshiftSQLHook: - """Create and return RedshiftSQLHook. - :return RedshiftSQLHook: A RedshiftSQLHook instance. - """ - return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) - - def execute(self, context: Context) -> None: - """Execute a statement against Amazon Redshift""" - self.log.info("Executing statement: %s", self.sql) - hook = self.get_hook() - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + def __init__(self, *, redshift_conn_id: str = 'redshift_default', **kwargs) -> None: + super().__init__(conn_id=redshift_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index eeca76e47a6a6..ad645239a3472 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -47,7 +47,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - boto3>=1.15.0 # watchtower 3 has been released end Jan and introduced breaking change across the board that might # change logging behaviour: diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 1aad91967b136..d70f35ecfe043 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -17,16 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.apache.drill.hooks.drill import DrillHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class DrillOperator(BaseOperator): +class DrillOperator(SQLExecuteQueryOperator): """ Executes the provided SQL in the identified Drill environment. @@ -47,21 +44,11 @@ class DrillOperator(BaseOperator): template_ext: Sequence[str] = ('.sql',) ui_color = '#ededed' - def __init__( - self, - *, - sql: str, - drill_conn_id: str = 'drill_default', - parameters: Iterable | Mapping | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.sql = sql - self.drill_conn_id = drill_conn_id - self.parameters = parameters - self.hook: DrillHook | None = None - - def execute(self, context: Context): - self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id) - self.hook = DrillHook(drill_conn_id=self.drill_conn_id) - self.hook.run(self.sql, parameters=self.parameters, split_statements=True) + def __init__(self, *, drill_conn_id: str = 'drill_default', **kwargs) -> None: + super().__init__(conn_id=drill_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml index 1d6c58bad943a..c94df8458c24e 100644 --- a/airflow/providers/apache/drill/provider.yaml +++ b/airflow/providers/apache/drill/provider.yaml @@ -34,7 +34,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - sqlalchemy-drill>=1.1.0 integrations: diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 536b72fbff8e8..f287fb6d86962 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -17,8 +17,9 @@ # under the License. from __future__ import annotations +import ast import re -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, SupportsAbs from packaging.version import Version @@ -26,7 +27,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, SkipMixin -from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook +from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook, fetch_all_handler from airflow.version import version if TYPE_CHECKING: @@ -94,7 +95,9 @@ class BaseSQLOperator(BaseOperator): The provided method is .get_db_hook(). The default behavior will try to retrieve the DB hook based on connection type. - You can custom the behavior by overriding the .get_db_hook() method. + You can customize the behavior by overriding the .get_db_hook() method. + + :param conn_id: reference to a specific database """ def __init__( @@ -162,6 +165,78 @@ def get_db_hook(self) -> DbApiHook: return self._hook +class SQLExecuteQueryOperator(BaseSQLOperator): + """ + Executes SQL code in a specific database + :param sql: the SQL code or string pointing to a template file to be executed (templated). + File must have a '.sql' extensions. + :param autocommit: (optional) if True, each command is automatically committed (default: False). + :param parameters: (optional) the parameters to render the SQL query with. + :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). + :param split_statements: (optional) if split single SQL string into statements (default: False). + :param return_last: (optional) if return the result of only last statement (default: True). + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SQLExecuteQueryOperator` + """ + + template_fields: Sequence[str] = ('sql', 'parameters') + template_ext: Sequence[str] = ('.sql', '.json') + template_fields_renderers = {"sql": "sql", "parameters": "json"} + ui_color = '#cdaaed' + + def __init__( + self, + *, + sql: str | list[str], + autocommit: bool = False, + parameters: Mapping | Iterable | None = None, + handler: Callable[[Any], Any] = fetch_all_handler, + split_statements: bool = False, + return_last: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.handler = handler + self.split_statements = split_statements + self.return_last = return_last + + def execute(self, context): + self.log.info('Executing: %s', self.sql) + hook = self.get_db_hook() + if self.do_xcom_push: + output = hook.run( + sql=self.sql, + autocommit=self.autocommit, + parameters=self.parameters, + handler=self.handler, + split_statements=self.split_statements, + return_last=self.return_last, + ) + else: + output = hook.run( + sql=self.sql, + autocommit=self.autocommit, + parameters=self.parameters, + split_statements=self.split_statements, + ) + + if hasattr(self, '_process_output'): + for out in output: + self._process_output(*out) + + return output + + def prepare_template(self) -> None: + """Parse template file for attribute parameters.""" + if isinstance(self.parameters, str): + self.parameters = ast.literal_eval(self.parameters) + + class SQLColumnCheckOperator(BaseSQLOperator): """ Performs one or more of the templated checks in the column_checks dictionary. diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml index 3086abae1b297..bed15a0bd2842 100644 --- a/airflow/providers/common/sql/provider.yaml +++ b/airflow/providers/common/sql/provider.yaml @@ -22,6 +22,7 @@ description: | `Common SQL Provider `__ versions: + - 1.3.0 - 1.2.0 - 1.1.0 - 1.0.0 diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 3ee7d60878666..c98d23d172d3b 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -20,20 +20,20 @@ import csv import json -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, Tuple, cast +from typing import TYPE_CHECKING, Any, Sequence from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook if TYPE_CHECKING: from airflow.utils.context import Context -class DatabricksSqlOperator(BaseOperator): +class DatabricksSqlOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Databricks SQL endpoint or a Databricks cluster @@ -80,11 +80,9 @@ class DatabricksSqlOperator(BaseOperator): def __init__( self, *, - sql: str | Iterable[str], databricks_conn_id: str = DatabricksSqlHook.default_conn_name, http_path: str | None = None, sql_endpoint_name: str | None = None, - parameters: Iterable | Mapping | None = None, session_configuration=None, http_headers: list[tuple[str, str]] | None = None, catalog: str | None = None, @@ -96,37 +94,31 @@ def __init__( client_parameters: dict[str, Any] | None = None, **kwargs, ) -> None: - """Creates a new ``DatabricksSqlOperator``.""" - super().__init__(**kwargs) + super().__init__(conn_id=databricks_conn_id, **kwargs) self.databricks_conn_id = databricks_conn_id - self.sql = sql - self._http_path = http_path - self._sql_endpoint_name = sql_endpoint_name self._output_path = output_path self._output_format = output_format self._csv_params = csv_params - self.parameters = parameters - self.do_xcom_push = do_xcom_push - self.session_config = session_configuration - self.http_headers = http_headers - self.catalog = catalog - self.schema = schema - self.client_parameters = client_parameters or {} - def _get_hook(self) -> DatabricksSqlHook: - return DatabricksSqlHook( - self.databricks_conn_id, - http_path=self._http_path, - session_configuration=self.session_config, - sql_endpoint_name=self._sql_endpoint_name, - http_headers=self.http_headers, - catalog=self.catalog, - schema=self.schema, - caller="DatabricksSqlOperator", - **self.client_parameters, - ) + client_parameters = {} if client_parameters is None else client_parameters + hook_params = kwargs.pop('hook_params', {}) - def _format_output(self, schema, results): + self.hook_params = { + 'http_path': http_path, + 'session_configuration': session_configuration, + 'sql_endpoint_name': sql_endpoint_name, + 'http_headers': http_headers, + 'catalog': catalog, + 'schema': schema, + 'caller': "DatabricksSqlOperator", + **client_parameters, + **hook_params, + } + + def get_db_hook(self) -> DatabricksSqlHook: + return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params) + + def _process_output(self, schema, results): if not self._output_path: return if not self._output_format: @@ -157,17 +149,6 @@ def _format_output(self, schema, results): else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - def execute(self, context: Context): - self.log.info('Executing: %s', self.sql) - hook = self._get_hook() - response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler) - schema, results = cast(List[Tuple[Any, Any]], response)[0] - # self.log.info('Schema: %s', schema) - # self.log.info('Results: %s', results) - self._format_output(schema, results) - if self.do_xcom_push: - return results - COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index d72f882feb6c8..ad25de857e960 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -41,7 +41,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - requests>=2.27,<3 - databricks-sql-connector>=2.0.0, <3.0.0 - aiohttp>=3.6.3, <4 diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index e251413f03a21..6ae28a5f44f28 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -17,16 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.exasol.hooks.exasol import ExasolHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class ExasolOperator(BaseOperator): +class ExasolOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Exasol database @@ -46,23 +43,16 @@ class ExasolOperator(BaseOperator): ui_color = '#ededed' def __init__( - self, - *, - sql: str | Iterable[str], - exasol_conn_id: str = 'exasol_default', - autocommit: bool = False, - parameters: Iterable | Mapping | None = None, - schema: str | None = None, - **kwargs, + self, *, exasol_conn_id: str = 'exasol_default', schema: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.exasol_conn_id = exasol_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - self.schema = schema - - def execute(self, context: Context) -> None: - self.log.info('Executing: %s', self.sql) - hook = ExasolHook(exasol_conn_id=self.exasol_conn_id, schema=self.schema) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + if schema is not None: + hook_params = kwargs.pop('hook_params', {}) + kwargs['hook_params'] = {'schema': schema, **hook_params} + + super().__init__(conn_id=exasol_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/exasol/provider.yaml b/airflow/providers/exasol/provider.yaml index 370a9d821d557..02c37aa3b8e77 100644 --- a/airflow/providers/exasol/provider.yaml +++ b/airflow/providers/exasol/provider.yaml @@ -38,7 +38,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - pyexasol>=0.5.1 - pandas>=0.17.1 diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py index a744d8fa7b478..f8f645856db45 100644 --- a/airflow/providers/jdbc/operators/jdbc.py +++ b/airflow/providers/jdbc/operators/jdbc.py @@ -17,17 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.common.sql.hooks.sql import fetch_all_handler -from airflow.providers.jdbc.hooks.jdbc import JdbcHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class JdbcOperator(BaseOperator): +class JdbcOperator(SQLExecuteQueryOperator): """ Executes sql code in a database using jdbc driver. @@ -51,28 +47,11 @@ class JdbcOperator(BaseOperator): template_fields_renderers = {'sql': 'sql'} ui_color = '#ededed' - def __init__( - self, - *, - sql: str | Iterable[str], - jdbc_conn_id: str = 'jdbc_default', - autocommit: bool = False, - parameters: Iterable | Mapping | None = None, - handler: Callable[[Any], Any] = fetch_all_handler, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.parameters = parameters - self.sql = sql - self.jdbc_conn_id = jdbc_conn_id - self.autocommit = autocommit - self.handler = handler - self.hook = None - - def execute(self, context: Context): - self.log.info('Executing: %s', self.sql) - hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id) - if self.do_xcom_push: - return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=self.handler) - else: - return hook.run(self.sql, self.autocommit, parameters=self.parameters) + def __init__(self, *, jdbc_conn_id: str = 'jdbc_default', **kwargs) -> None: + super().__init__(conn_id=jdbc_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/jdbc/provider.yaml b/airflow/providers/jdbc/provider.yaml index bb0ed61f9753e..30c564e100d89 100644 --- a/airflow/providers/jdbc/provider.yaml +++ b/airflow/providers/jdbc/provider.yaml @@ -37,7 +37,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - jaydebeapi>=1.1.1 integrations: diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index b33002a9e8fbf..df5b1a4800b6b 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -17,19 +17,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.www import utils as wwwutils -if TYPE_CHECKING: - from airflow.providers.common.sql.hooks.sql import DbApiHook - from airflow.utils.context import Context - -class MsSqlOperator(BaseOperator): +class MsSqlOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Microsoft SQL database @@ -57,42 +52,17 @@ class MsSqlOperator(BaseOperator): ui_color = '#ededed' def __init__( - self, - *, - sql: str | Iterable[str], - mssql_conn_id: str = 'mssql_default', - parameters: Iterable | Mapping | None = None, - autocommit: bool = False, - database: str | None = None, - **kwargs, + self, *, mssql_conn_id: str = 'mssql_default', database: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.mssql_conn_id = mssql_conn_id - self.sql = sql - self.parameters = parameters - self.autocommit = autocommit - self.database = database - self._hook: MsSqlHook | DbApiHook | None = None - - def get_hook(self) -> MsSqlHook | DbApiHook | None: - """ - Will retrieve hook as determined by :meth:`~.Connection.get_hook` if one is defined, and - :class:`~.MsSqlHook` otherwise. - - For example, if the connection ``conn_type`` is ``'odbc'``, :class:`~.OdbcHook` will be used. - """ - if not self._hook: - conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id) - try: - self._hook = conn.get_hook() - self._hook.schema = self.database # type: ignore[union-attr] - except AirflowException: - self._hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id, schema=self.database) - return self._hook + if database is not None: + hook_params = kwargs.pop('hook_params', {}) + kwargs['hook_params'] = {'schema': database, **hook_params} - def execute(self, context: Context) -> None: - self.log.info('Executing: %s', self.sql) - hook = self.get_hook() - hook.run( # type: ignore[union-attr] - sql=self.sql, autocommit=self.autocommit, parameters=self.parameters + super().__init__(conn_id=mssql_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, ) diff --git a/airflow/providers/microsoft/mssql/provider.yaml b/airflow/providers/microsoft/mssql/provider.yaml index 61f7f8cd81f42..9d3c9a4feb379 100644 --- a/airflow/providers/microsoft/mssql/provider.yaml +++ b/airflow/providers/microsoft/mssql/provider.yaml @@ -38,7 +38,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - pymssql>=2.1.5; platform_machine != "aarch64" integrations: diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index da919561db73d..0c2b482e0b62d 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -17,18 +17,14 @@ # under the License. from __future__ import annotations -import ast -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.www import utils as wwwutils -if TYPE_CHECKING: - from airflow.utils.context import Context - -class MySqlOperator(BaseOperator): +class MySqlOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific MySQL database @@ -59,28 +55,17 @@ class MySqlOperator(BaseOperator): ui_color = '#ededed' def __init__( - self, - *, - sql: str | Iterable[str], - mysql_conn_id: str = 'mysql_default', - parameters: Iterable | Mapping | None = None, - autocommit: bool = False, - database: str | None = None, - **kwargs, + self, *, mysql_conn_id: str = 'mysql_default', database: str | None = None, **kwargs ) -> None: - super().__init__(**kwargs) - self.mysql_conn_id = mysql_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - self.database = database - - def prepare_template(self) -> None: - """Parse template file for attribute parameters.""" - if isinstance(self.parameters, str): - self.parameters = ast.literal_eval(self.parameters) + if database is not None: + hook_params = kwargs.pop('hook_params', {}) + kwargs['hook_params'] = {'schema': database, **hook_params} - def execute(self, context: Context) -> None: - self.log.info('Executing: %s', self.sql) - hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + super().__init__(conn_id=mysql_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml index 69759d4e6da88..965234657b901 100644 --- a/airflow/providers/mysql/provider.yaml +++ b/airflow/providers/mysql/provider.yaml @@ -40,7 +40,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - mysql-connector-python>=8.0.11; platform_machine != "aarch64" - mysqlclient>=1.3.6; platform_machine != "aarch64" diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index e31f564653467..a98f93b9528f1 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -17,16 +17,18 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import TYPE_CHECKING, Sequence from airflow.models import BaseOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.oracle.hooks.oracle import OracleHook if TYPE_CHECKING: from airflow.utils.context import Context -class OracleOperator(BaseOperator): +class OracleOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Oracle database. @@ -49,26 +51,14 @@ class OracleOperator(BaseOperator): template_fields_renderers = {'sql': 'sql'} ui_color = '#ededed' - def __init__( - self, - *, - sql: str | Iterable[str], - oracle_conn_id: str = 'oracle_default', - parameters: Iterable | Mapping | None = None, - autocommit: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.oracle_conn_id = oracle_conn_id - self.sql = sql - self.autocommit = autocommit - self.parameters = parameters - - def execute(self, context: Context) -> None: - self.log.info('Executing: %s', self.sql) - hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - if self.sql: - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + def __init__(self, *, oracle_conn_id: str = 'oracle_default', **kwargs) -> None: + super().__init__(conn_id=oracle_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) class OracleStoredProcedureOperator(BaseOperator): diff --git a/airflow/providers/oracle/provider.yaml b/airflow/providers/oracle/provider.yaml index 4acdd7c1eb0a8..07419afcb5557 100644 --- a/airflow/providers/oracle/provider.yaml +++ b/airflow/providers/oracle/provider.yaml @@ -40,7 +40,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - oracledb>=1.0.0 integrations: diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index e12ee58d8ce45..1ac574da9fecd 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -17,19 +17,16 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence +import warnings +from typing import Mapping, Sequence from psycopg2.sql import SQL, Identifier -from airflow.models import BaseOperator -from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.www import utils as wwwutils -if TYPE_CHECKING: - from airflow.utils.context import Context - -class PostgresOperator(BaseOperator): +class PostgresOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Postgres database @@ -55,42 +52,42 @@ class PostgresOperator(BaseOperator): def __init__( self, *, - sql: str | Iterable[str], postgres_conn_id: str = 'postgres_default', - autocommit: bool = False, - parameters: Iterable | Mapping | None = None, database: str | None = None, runtime_parameters: Mapping | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.sql = sql - self.postgres_conn_id = postgres_conn_id - self.autocommit = autocommit - self.parameters = parameters - self.database = database - self.runtime_parameters = runtime_parameters - self.hook: PostgresHook | None = None + if database is not None: + hook_params = kwargs.pop('hook_params', {}) + kwargs['hook_params'] = {'schema': database, **hook_params} + + if runtime_parameters: + sql = kwargs.pop('sql') + parameters = kwargs.pop('parameters', {}) - def execute(self, context: Context): - self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) - if self.runtime_parameters: final_sql = [] sql_param = {} - for param in self.runtime_parameters: + for param in runtime_parameters: set_param_sql = f"SET {{}} TO %({param})s;" dynamic_sql = SQL(set_param_sql).format(Identifier(f"{param}")) final_sql.append(dynamic_sql) - for param, val in self.runtime_parameters.items(): + for param, val in runtime_parameters.items(): sql_param.update({f"{param}": f"{val}"}) - if self.parameters: - sql_param.update(self.parameters) - if isinstance(self.sql, str): - final_sql.append(SQL(self.sql)) + if parameters: + sql_param.update(parameters) + if isinstance(sql, str): + final_sql.append(SQL(sql)) else: - final_sql.extend(list(map(SQL, self.sql))) - self.hook.run(final_sql, self.autocommit, parameters=sql_param) - else: - self.hook.run(self.sql, self.autocommit, parameters=self.parameters) - for output in self.hook.conn.notices: - self.log.info(output) + final_sql.extend(list(map(SQL, sql))) + + kwargs['sql'] = final_sql + kwargs['parameters'] = sql_param + + super().__init__(conn_id=postgres_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'schema': }`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/postgres/provider.yaml b/airflow/providers/postgres/provider.yaml index af155d7486228..ebf65ab7d23fe 100644 --- a/airflow/providers/postgres/provider.yaml +++ b/airflow/providers/postgres/provider.yaml @@ -42,7 +42,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - psycopg2>=2.8.0 integrations: diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py index a7a7309aa34d7..e686e3fd6ff02 100644 --- a/airflow/providers/snowflake/operators/snowflake.py +++ b/airflow/providers/snowflake/operators/snowflake.py @@ -17,37 +17,18 @@ # under the License. from __future__ import annotations -from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs +import warnings +from typing import Any, Iterable, Mapping, Sequence, SupportsAbs -from airflow.models import BaseOperator -from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( SQLCheckOperator, + SQLExecuteQueryOperator, SQLIntervalCheckOperator, SQLValueCheckOperator, ) -from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook -def get_db_hook(self) -> SnowflakeHook: - """ - Create and return SnowflakeHook. - - :return: a SnowflakeHook instance. - :rtype: SnowflakeHook - """ - return SnowflakeHook( - snowflake_conn_id=self.snowflake_conn_id, - warehouse=self.warehouse, - database=self.database, - role=self.role, - schema=self.schema, - authenticator=self.authenticator, - session_parameters=self.session_parameters, - ) - - -class SnowflakeOperator(BaseOperator): +class SnowflakeOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Snowflake database @@ -80,8 +61,6 @@ class SnowflakeOperator(BaseOperator): through native Okta. :param session_parameters: You can set session-level parameters at the time you connect to Snowflake - :param handler: A Python callable that will act on cursor result. - By default, it will use ``fetchall`` """ template_fields: Sequence[str] = ('sql',) @@ -92,48 +71,37 @@ class SnowflakeOperator(BaseOperator): def __init__( self, *, - sql: str | Iterable[str], snowflake_conn_id: str = 'snowflake_default', - parameters: Iterable | Mapping | None = None, - autocommit: bool = True, - do_xcom_push: bool = True, warehouse: str | None = None, database: str | None = None, role: str | None = None, schema: str | None = None, authenticator: str | None = None, session_parameters: dict | None = None, - handler: Callable | None = None, **kwargs, ) -> None: - super().__init__(**kwargs) - self.snowflake_conn_id = snowflake_conn_id - self.sql = sql - self.autocommit = autocommit - self.do_xcom_push = do_xcom_push - self.parameters = parameters - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters - self.query_ids: list[str] = [] - self.handler = handler - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) - - def execute(self, context: Any): - """Run query on snowflake""" - self.log.info('Executing: %s', self.sql) - hook = self.get_db_hook() - handler = self.handler or fetch_all_handler - execution_info = hook.run(self.sql, self.autocommit, self.parameters, handler) - self.query_ids = hook.query_ids - - if self.do_xcom_push: - return execution_info + if any([warehouse, database, role, schema, authenticator, session_parameters]): + hook_params = kwargs.pop('hook_params', {}) + kwargs['hook_params'] = { + 'warehouse': warehouse, + 'database': database, + 'role': role, + 'schema': schema, + 'authenticator': authenticator, + 'session_parameters': session_parameters, + **hook_params, + } + + super().__init__(conn_id=snowflake_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`. + Also, you can provide `hook_params={'warehouse': , 'database': , + 'role': , 'schema': , 'authenticator': , + 'session_parameters': }`.""", + DeprecationWarning, + stacklevel=2, + ) class SnowflakeCheckOperator(SQLCheckOperator): @@ -225,9 +193,6 @@ def __init__( self.session_parameters = session_parameters self.query_ids: list[str] = [] - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) - class SnowflakeValueCheckOperator(SQLValueCheckOperator): """ @@ -294,9 +259,6 @@ def __init__( self.session_parameters = session_parameters self.query_ids: list[str] = [] - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) - class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator): """ @@ -375,6 +337,3 @@ def __init__( self.authenticator = authenticator self.session_parameters = session_parameters self.query_ids: list[str] = [] - - def get_db_hook(self) -> SnowflakeHook: - return get_db_hook(self) diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml index 5672a037249a6..98f640a20ed97 100644 --- a/airflow/providers/snowflake/provider.yaml +++ b/airflow/providers/snowflake/provider.yaml @@ -46,7 +46,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - snowflake-connector-python>=2.4.1 - snowflake-sqlalchemy>=1.1.0 diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py index c382d53ee3fa3..6d8d1497dca4c 100644 --- a/airflow/providers/sqlite/operators/sqlite.py +++ b/airflow/providers/sqlite/operators/sqlite.py @@ -17,13 +17,13 @@ # under the License. from __future__ import annotations -from typing import Any, Iterable, Mapping, Sequence +import warnings +from typing import Sequence -from airflow.models import BaseOperator -from airflow.providers.sqlite.hooks.sqlite import SqliteHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -class SqliteOperator(BaseOperator): +class SqliteOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Sqlite database @@ -44,20 +44,11 @@ class SqliteOperator(BaseOperator): template_fields_renderers = {'sql': 'sql'} ui_color = '#cdaaed' - def __init__( - self, - *, - sql: str | Iterable[str], - sqlite_conn_id: str = 'sqlite_default', - parameters: Iterable | Mapping | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.sqlite_conn_id = sqlite_conn_id - self.sql = sql - self.parameters = parameters or [] - - def execute(self, context: Mapping[Any, Any]) -> None: - self.log.info('Executing: %s', self.sql) - hook = SqliteHook(sqlite_conn_id=self.sqlite_conn_id) - hook.run(self.sql, parameters=self.parameters) + def __init__(self, *, sqlite_conn_id: str = 'sqlite_default', **kwargs) -> None: + super().__init__(conn_id=sqlite_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/sqlite/provider.yaml b/airflow/providers/sqlite/provider.yaml index 18bd98582edfa..f07f36db551e2 100644 --- a/airflow/providers/sqlite/provider.yaml +++ b/airflow/providers/sqlite/provider.yaml @@ -37,7 +37,7 @@ versions: - 1.0.0 dependencies: - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 integrations: - integration-name: SQLite diff --git a/airflow/providers/trino/operators/trino.py b/airflow/providers/trino/operators/trino.py index 7ed5b5e68f95b..40741b048e5a0 100644 --- a/airflow/providers/trino/operators/trino.py +++ b/airflow/providers/trino/operators/trino.py @@ -18,18 +18,16 @@ """This module contains the Trino operator.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Sequence +import warnings +from typing import Any, Sequence from trino.exceptions import TrinoQueryError -from airflow.models import BaseOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.trino.hooks.trino import TrinoHook -if TYPE_CHECKING: - from airflow.utils.context import Context - -class TrinoOperator(BaseOperator): +class TrinoOperator(SQLExecuteQueryOperator): """ Executes sql code using a specific Trino query Engine. @@ -52,43 +50,21 @@ class TrinoOperator(BaseOperator): template_ext: Sequence[str] = ('.sql',) ui_color = '#ededed' - def __init__( - self, - *, - sql: str | list[str], - trino_conn_id: str = "trino_default", - autocommit: bool = False, - parameters: tuple | None = None, - handler: Callable | None = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.sql = sql - self.trino_conn_id = trino_conn_id - self.hook: TrinoHook | None = None - self.autocommit = autocommit - self.parameters = parameters - self.handler = handler - - def get_hook(self) -> TrinoHook: - """Get Trino hook""" - return TrinoHook( - trino_conn_id=self.trino_conn_id, - ) - - def execute(self, context: Context) -> None: - """Execute Trino SQL""" - self.hook = self.get_hook() - self.hook.run( - sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, handler=self.handler + def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs: Any) -> None: + super().__init__(conn_id=trino_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, ) def on_kill(self) -> None: - if self.hook is not None and isinstance(self.hook, TrinoHook): - query_id = "'" + self.hook.query_id + "'" + if self._hook is not None and isinstance(self._hook, TrinoHook): + query_id = "'" + self._hook.query_id + "'" try: - self.log.info("Stopping query run with queryId - %s", self.hook.query_id) - self.hook.run( + self.log.info("Stopping query run with queryId - %s", self._hook.query_id) + self._hook.run( sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job " f"killed by " f"user');", diff --git a/airflow/providers/trino/provider.yaml b/airflow/providers/trino/provider.yaml index 46b1bf58d57d4..57da2ba254ee9 100644 --- a/airflow/providers/trino/provider.yaml +++ b/airflow/providers/trino/provider.yaml @@ -39,7 +39,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - pandas>=0.17.1 - trino>=0.301.0 diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py index 50463a7771ece..f9825435ede1a 100644 --- a/airflow/providers/vertica/operators/vertica.py +++ b/airflow/providers/vertica/operators/vertica.py @@ -17,16 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Sequence +import warnings +from typing import Any, Sequence -from airflow.models import BaseOperator -from airflow.providers.vertica.hooks.vertica import VerticaHook +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -if TYPE_CHECKING: - from airflow.utils.context import Context - -class VerticaOperator(BaseOperator): +class VerticaOperator(SQLExecuteQueryOperator): """ Executes sql code in a specific Vertica database. @@ -41,14 +38,11 @@ class VerticaOperator(BaseOperator): template_fields_renderers = {'sql': 'sql'} ui_color = '#b4e0ff' - def __init__( - self, *, sql: str | Iterable[str], vertica_conn_id: str = 'vertica_default', **kwargs: Any - ) -> None: - super().__init__(**kwargs) - self.vertica_conn_id = vertica_conn_id - self.sql = sql - - def execute(self, context: Context) -> None: - self.log.info('Executing: %s', self.sql) - hook = VerticaHook(vertica_conn_id=self.vertica_conn_id, log_sql=False) - hook.run(sql=self.sql) + def __init__(self, *, vertica_conn_id: str = 'vertica_default', **kwargs: Any) -> None: + super().__init__(conn_id=vertica_conn_id, **kwargs) + warnings.warn( + """This class is deprecated. + Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", + DeprecationWarning, + stacklevel=2, + ) diff --git a/airflow/providers/vertica/provider.yaml b/airflow/providers/vertica/provider.yaml index d675874a0c8b3..ccadf7b385848 100644 --- a/airflow/providers/vertica/provider.yaml +++ b/airflow/providers/vertica/provider.yaml @@ -37,7 +37,7 @@ versions: dependencies: - apache-airflow>=2.2.0 - - apache-airflow-providers-common-sql>=1.2.0 + - apache-airflow-providers-common-sql>=1.3.0 - vertica-python>=0.5.1 integrations: diff --git a/docs/apache-airflow-providers-common-sql/operators.rst b/docs/apache-airflow-providers-common-sql/operators.rst index 402197fb5204f..e10759117e0b1 100644 --- a/docs/apache-airflow-providers-common-sql/operators.rst +++ b/docs/apache-airflow-providers-common-sql/operators.rst @@ -21,6 +21,29 @@ SQL Operators These operators perform various queries against a SQL database, including column- and table-level data quality checks. +.. _howto/operator:SQLExecuteQueryOperator: + +Execute SQL query +~~~~~~~~~~~~~~~~~ + +Use the :class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator` to run SQL query against +different databases. Parameters of the operators are: + +- ``sql`` - single string, list of strings or string pointing to a template file to be executed; +- ``autocommit`` (optional) if True, each command is automatically committed (default: False); +- ``parameters`` (optional) the parameters to render the SQL query with. +- ``handler`` (optional) the function that will be applied to the cursor. If it's ``None`` results won't returned (default: fetch_all_handler). +- ``split_statements`` (optional) if split single SQL string into statements and run separately (default: False). +- ``return_last`` (optional) depends ``split_statements`` and if it's ``True`` this parameter is used to return the result of only last statement or all split statements (default: True). + +The example below shows how to instantiate the SQLExecuteQueryOperator task. + +.. exampleinclude:: /../../tests/system/providers/common/sql/example_sql_execute_query.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sql_execute_query] + :end-before: [END howto_operator_sql_execute_query] + .. _howto/operator:SQLColumnCheckOperator: Check SQL Table Columns diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 851a3de73c2aa..b9c1e04b0268b 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -17,7 +17,7 @@ }, "amazon": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "boto3>=1.15.0", "jsonpath_ng>=1.5.3", @@ -60,7 +60,7 @@ }, "apache.drill": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "sqlalchemy-drill>=1.1.0" ], @@ -206,7 +206,7 @@ "databricks": { "deps": [ "aiohttp>=3.6.3, <4", - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "databricks-sql-connector>=2.0.0, <3.0.0", "requests>=2.27,<3" @@ -270,7 +270,7 @@ }, "exasol": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "pandas>=0.17.1", "pyexasol>=0.5.1" @@ -409,7 +409,7 @@ }, "jdbc": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "jaydebeapi>=1.1.1" ], @@ -460,7 +460,7 @@ }, "microsoft.mssql": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "pymssql>=2.1.5; platform_machine != \"aarch64\"" ], @@ -491,7 +491,7 @@ }, "mysql": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "mysql-connector-python>=8.0.11; platform_machine != \"aarch64\"", "mysqlclient>=1.3.6; platform_machine != \"aarch64\"" @@ -536,7 +536,7 @@ }, "oracle": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "oracledb>=1.0.0" ], @@ -568,7 +568,7 @@ }, "postgres": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "psycopg2>=2.8.0" ], @@ -663,7 +663,7 @@ }, "snowflake": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "snowflake-connector-python>=2.4.1", "snowflake-sqlalchemy>=1.1.0" @@ -675,7 +675,7 @@ }, "sqlite": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0" + "apache-airflow-providers-common-sql>=1.3.0" ], "cross-providers-deps": [ "common.sql" @@ -711,7 +711,7 @@ }, "trino": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "pandas>=0.17.1", "trino>=0.301.0" @@ -723,7 +723,7 @@ }, "vertica": { "deps": [ - "apache-airflow-providers-common-sql>=1.2.0", + "apache-airflow-providers-common-sql>=1.3.0", "apache-airflow>=2.2.0", "vertica-python>=0.5.1" ], diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 839827a29d610..9fc9aa70810a7 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -23,11 +23,12 @@ from parameterized import parameterized from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler class TestRedshiftSQLOperator(unittest.TestCase): @parameterized.expand([(True, ('a', 'b')), (False, ('c', 'd'))]) - @mock.patch("airflow.providers.amazon.aws.operators.redshift_sql.RedshiftSQLOperator.get_hook") + @mock.patch("airflow.providers.amazon.aws.operators.redshift_sql.RedshiftSQLOperator.get_db_hook") def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook): hook = MagicMock() mock_run = hook.run @@ -38,7 +39,10 @@ def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook ) operator.execute(None) mock_run.assert_called_once_with( - sql, + sql=sql, autocommit=test_autocommit, parameters=test_parameters, + handler=fetch_all_handler, + return_last=True, + split_statements=False, ) diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 373b1b99aed2b..910f295d18d4f 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -28,10 +28,12 @@ from airflow.exceptions import AirflowException from airflow.models import Connection, DagRun, TaskInstance as TI, XCom from airflow.operators.empty import EmptyOperator +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( BranchSQLOperator, SQLCheckOperator, SQLColumnCheckOperator, + SQLExecuteQueryOperator, SQLIntervalCheckOperator, SQLTableCheckOperator, SQLThresholdCheckOperator, @@ -56,6 +58,44 @@ def _get_mock_db_hook(): return MockHook() +class TestSQLExecuteQueryOperator(unittest.TestCase): + def _construct_operator(self, sql, **kwargs): + dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) + return SQLExecuteQueryOperator( + task_id="test_task", + conn_id='default_conn', + sql=sql, + **kwargs, + dag=dag, + ) + + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") + def test_do_xcom_push(self, mock_get_db_hook): + operator = self._construct_operator("SELECT 1;", do_xcom_push=True) + operator.execute(context=MagicMock()) + + mock_get_db_hook.return_value.run.assert_called_once_with( + sql='SELECT 1;', + autocommit=False, + handler=fetch_all_handler, + parameters=None, + return_last=True, + split_statements=False, + ) + + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") + def test_dont_xcom_push(self, mock_get_db_hook): + operator = self._construct_operator("SELECT 1;", do_xcom_push=False) + operator.execute(context=MagicMock()) + + mock_get_db_hook.return_value.run.assert_called_once_with( + sql='SELECT 1;', + autocommit=False, + parameters=None, + split_statements=False, + ) + + class TestColumnCheckOperator: valid_column_mapping = { diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 04a5b930d81ff..452542490e310 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -53,7 +53,7 @@ def test_exec_success(self, db_mock_class): results = op.execute(None) - assert results == mock_results + assert results[0][1] == mock_results db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, http_path=None, @@ -64,7 +64,14 @@ def test_exec_success(self, db_mock_class): schema=None, caller='DatabricksSqlOperator', ) - db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) + db_mock.run.assert_called_once_with( + sql=sql, + parameters=None, + handler=fetch_all_handler, + autocommit=False, + return_last=True, + split_statements=False, + ) @mock.patch('airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook') def test_exec_write_file(self, db_mock_class): @@ -96,7 +103,14 @@ def test_exec_write_file(self, db_mock_class): schema=None, caller='DatabricksSqlOperator', ) - db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler) + db_mock.run.assert_called_once_with( + sql=sql, + parameters=None, + handler=fetch_all_handler, + autocommit=False, + return_last=True, + split_statements=False, + ) class TestDatabricksSqlCopyIntoOperator(unittest.TestCase): diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index 5f93e33a75bbc..2150e55b3e0b2 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -20,24 +20,43 @@ import unittest from unittest import mock +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator class TestExasol(unittest.TestCase): - @mock.patch('airflow.providers.exasol.hooks.exasol.ExasolHook.run') - def test_overwrite_autocommit(self, mock_run): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_overwrite_autocommit(self, mock_get_db_hook): operator = ExasolOperator(task_id='TEST', sql='SELECT 1', autocommit=True) operator.execute({}) - mock_run.assert_called_once_with('SELECT 1', autocommit=True, parameters=None) + mock_get_db_hook.return_value.run.assert_called_once_with( + sql='SELECT 1', + autocommit=True, + parameters=None, + handler=fetch_all_handler, + return_last=True, + split_statements=False, + ) - @mock.patch('airflow.providers.exasol.hooks.exasol.ExasolHook.run') - def test_pass_parameters(self, mock_run): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_pass_parameters(self, mock_get_db_hook): operator = ExasolOperator(task_id='TEST', sql='SELECT {value!s}', parameters={'value': 1}) operator.execute({}) - mock_run.assert_called_once_with('SELECT {value!s}', autocommit=False, parameters={'value': 1}) + mock_get_db_hook.return_value.run.assert_called_once_with( + sql='SELECT {value!s}', + autocommit=False, + parameters={'value': 1}, + handler=fetch_all_handler, + return_last=True, + split_statements=False, + ) - @mock.patch('airflow.providers.exasol.operators.exasol.ExasolHook') - def test_overwrite_schema(self, mock_hook): - operator = ExasolOperator(task_id='TEST', sql='SELECT 1', schema='dummy') - operator.execute({}) - mock_hook.assert_called_once_with(exasol_conn_id='exasol_default', schema='dummy') + @mock.patch('airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__') + def test_overwrite_schema(self, mock_base_op): + ExasolOperator(task_id='TEST', sql='SELECT 1', schema='dummy') + mock_base_op.assert_called_once_with( + conn_id='exasol_default', + hook_params={'schema': 'dummy'}, + default_args={}, + task_id='TEST', + ) diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index 0ba9d545fdaed..cdf094eec565c 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -28,27 +28,28 @@ class TestJdbcOperator(unittest.TestCase): def setUp(self): self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None) - @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook') - def test_execute_do_push(self, mock_jdbc_hook): + @patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_execute_do_push(self, mock_get_db_hook): jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True) jdbc_operator.execute(context={}) - mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id) - mock_jdbc_hook.return_value.run.assert_called_once_with( - jdbc_operator.sql, - jdbc_operator.autocommit, - parameters=jdbc_operator.parameters, + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=jdbc_operator.sql, + autocommit=jdbc_operator.autocommit, handler=fetch_all_handler, + parameters=jdbc_operator.parameters, + return_last=True, + split_statements=False, ) - @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook') - def test_execute_dont_push(self, mock_jdbc_hook): + @patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_execute_dont_push(self, mock_get_db_hook): jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False) jdbc_operator.execute(context={}) - mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id) - mock_jdbc_hook.return_value.run.assert_called_once_with( - jdbc_operator.sql, - jdbc_operator.autocommit, + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=jdbc_operator.sql, + autocommit=jdbc_operator.autocommit, parameters=jdbc_operator.parameters, + split_statements=False, ) diff --git a/tests/providers/microsoft/mssql/operators/test_mssql.py b/tests/providers/microsoft/mssql/operators/test_mssql.py index 131d5aaedb9c5..444af0cf822c8 100644 --- a/tests/providers/microsoft/mssql/operators/test_mssql.py +++ b/tests/providers/microsoft/mssql/operators/test_mssql.py @@ -21,12 +21,13 @@ from unittest.mock import MagicMock, Mock from airflow import AirflowException +from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator class TestMsSqlOperator: - @mock.patch('airflow.hooks.base.BaseHook.get_connection') - def test_get_hook_from_conn(self, get_connection): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_get_hook_from_conn(self, mock_get_db_hook): """ :class:`~.MsSqlOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook` if one is returned. @@ -37,18 +38,20 @@ def test_get_hook_from_conn(self, get_connection): call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`. """ mock_hook = MagicMock() - get_connection.return_value.get_hook.return_value = mock_hook + mock_get_db_hook.return_value = mock_hook op = MsSqlOperator(task_id='test', sql='') - assert op.get_hook() == mock_hook + assert op.get_db_hook() == mock_hook - @mock.patch('airflow.hooks.base.BaseHook.get_connection') - def test_get_hook_default(self, get_connection): + @mock.patch( + 'airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook', autospec=MsSqlHook + ) + def test_get_hook_default(self, mock_get_db_hook): """ If :meth:`airflow.models.Connection.get_hook` does not return a hook (e.g. because of an invalid conn type), then :class:`~.MsSqlHook` should be used. """ - get_connection.return_value.get_hook.side_effect = Mock(side_effect=AirflowException()) + mock_get_db_hook.return_value.side_effect = Mock(side_effect=AirflowException()) op = MsSqlOperator(task_id='test', sql='') - assert op.get_hook().__class__.__name__ == 'MsSqlHook' + assert op.get_db_hook().__class__.__name__ == 'MsSqlHook' diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index 1fd4a82880611..21221cd01f7cb 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -19,13 +19,14 @@ import unittest from unittest import mock +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.oracle.hooks.oracle import OracleHook from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator class TestOracleOperator(unittest.TestCase): - @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run) - def test_execute(self, mock_run): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_execute(self, mock_get_db_hook): sql = 'SELECT * FROM test_table' oracle_conn_id = 'oracle_default' parameters = {'parameter': 'value'} @@ -41,11 +42,13 @@ def test_execute(self, mock_run): task_id=task_id, ) operator.execute(context=context) - mock_run.assert_called_once_with( - mock.ANY, - sql, + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=sql, autocommit=autocommit, parameters=parameters, + handler=fetch_all_handler, + return_last=True, + split_statements=False, ) diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index d4de0feed1f7f..d027451b86ead 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -35,8 +35,6 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] TEST_DAG_ID = 'unit_test_dag' -LONG_MOCK_PATH = "airflow.providers.snowflake.operators.snowflake." -LONG_MOCK_PATH += 'SnowflakeOperator.get_db_hook' class TestSnowflakeOperator(unittest.TestCase): @@ -46,7 +44,7 @@ def setUp(self): dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - @mock.patch(LONG_MOCK_PATH) + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') def test_snowflake_operator(self, mock_get_db_hook): sql = """ CREATE TABLE IF NOT EXISTS test_airflow ( @@ -67,7 +65,7 @@ def test_snowflake_operator(self, mock_get_db_hook): ], ) class TestSnowflakeCheckOperators: - @mock.patch("airflow.providers.snowflake.operators.snowflake.get_db_hook") + @mock.patch('airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook') def test_get_db_hook( self, mock_get_db_hook, diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index ba32e876b65bc..6b10c6de1d563 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -27,8 +27,8 @@ class TestTrinoOperator(unittest.TestCase): - @mock.patch('airflow.providers.trino.operators.trino.TrinoHook') - def test_execute(self, mock_trino_hook): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_execute(self, mock_get_db_hook): """Asserts that the run method is called when a TrinoOperator task is executed""" op = TrinoOperator( @@ -39,6 +39,11 @@ def test_execute(self, mock_trino_hook): ) op.execute(None) - mock_trino_hook.assert_called_once_with(trino_conn_id=TRINO_CONN_ID) - mock_run = mock_trino_hook.return_value.run - mock_run.assert_called_once() + mock_get_db_hook.return_value.run.assert_called_once_with( + sql="SELECT 1;", + autocommit=False, + handler=list, + parameters=None, + return_last=True, + split_statements=False, + ) diff --git a/tests/providers/vertica/operators/test_vertica.py b/tests/providers/vertica/operators/test_vertica.py index 5b2a3073627f3..0dcb58bf9a25e 100644 --- a/tests/providers/vertica/operators/test_vertica.py +++ b/tests/providers/vertica/operators/test_vertica.py @@ -20,13 +20,21 @@ import unittest from unittest import mock +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.vertica.operators.vertica import VerticaOperator class TestVerticaOperator(unittest.TestCase): - @mock.patch('airflow.providers.vertica.operators.vertica.VerticaHook') - def test_execute(self, mock_hook): + @mock.patch('airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook') + def test_execute(self, mock_get_db_hook): sql = "select a, b, c" op = VerticaOperator(task_id='test_task_id', sql=sql) op.execute(None) - mock_hook.return_value.run.assert_called_once_with(sql=sql) + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=sql, + autocommit=False, + handler=fetch_all_handler, + parameters=None, + return_last=True, + split_statements=False, + ) diff --git a/tests/system/providers/common/sql/example_sql_execute_query.py b/tests/system/providers/common/sql/example_sql_execute_query.py new file mode 100644 index 0000000000000..694ca3f2ef99c --- /dev/null +++ b/tests/system/providers/common/sql/example_sql_execute_query.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow import DAG +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.utils.dates import datetime + +AIRFLOW_DB_METADATA_TABLE = "ab_role" +connection_args = { + "conn_id": "airflow_db", + "conn_type": "Postgres", + "host": "postgres", + "schema": "postgres", + "login": "postgres", + "password": "postgres", + "port": 5432, +} + +with DAG( + "example_sql_execute_query", + description="Example DAG for SQLExecuteQueryOperator.", + default_args=connection_args, + start_date=datetime(2021, 1, 1), + schedule=None, + catchup=False, +) as dag: + """ + ### Example SQL execute query DAG + + Runs the SQLExecuteQueryOperator against the Airflow metadata DB. + """ + + # [START howto_operator_sql_execute_query] + execute_query = SQLExecuteQueryOperator( + task_id="execute_query", + sql=f"SELECT 1; SELECT * FROM {AIRFLOW_DB_METADATA_TABLE} LIMIT 1;", + split_statements=True, + return_last=False, + ) + # [END howto_operator_sql_execute_query] + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)