Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 12 additions & 33 deletions airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
import warnings
from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.www import utils as wwwutils

if TYPE_CHECKING:
from airflow.utils.context import Context


class RedshiftSQLOperator(BaseOperator):
class RedshiftSQLOperator(SQLExecuteQueryOperator):
"""
Executes SQL Statements against an Amazon Redshift cluster

Expand Down Expand Up @@ -54,29 +51,11 @@ class RedshiftSQLOperator(BaseOperator):
"sql": "postgresql" if "postgresql" in wwwutils.get_attr_renderer() else "sql"
}

def __init__(
self,
*,
sql: str | Iterable[str],
redshift_conn_id: str = 'redshift_default',
parameters: Iterable | Mapping | None = None,
autocommit: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.redshift_conn_id = redshift_conn_id
self.sql = sql
self.autocommit = autocommit
self.parameters = parameters

def get_hook(self) -> RedshiftSQLHook:
"""Create and return RedshiftSQLHook.
:return RedshiftSQLHook: A RedshiftSQLHook instance.
"""
return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)

def execute(self, context: Context) -> None:
"""Execute a statement against Amazon Redshift"""
self.log.info("Executing statement: %s", self.sql)
hook = self.get_hook()
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
def __init__(self, *, redshift_conn_id: str = 'redshift_default', **kwargs) -> None:
super().__init__(conn_id=redshift_conn_id, **kwargs)
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""",
DeprecationWarning,
stacklevel=2,
)
2 changes: 1 addition & 1 deletion airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ versions:

dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql>=1.2.0
- apache-airflow-providers-common-sql>=1.3.0
- boto3>=1.15.0
# watchtower 3 has been released end Jan and introduced breaking change across the board that might
# change logging behaviour:
Expand Down
37 changes: 12 additions & 25 deletions airflow/providers/apache/drill/operators/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
import warnings
from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.apache.drill.hooks.drill import DrillHook
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class DrillOperator(BaseOperator):
class DrillOperator(SQLExecuteQueryOperator):
"""
Executes the provided SQL in the identified Drill environment.

Expand All @@ -47,21 +44,11 @@ class DrillOperator(BaseOperator):
template_ext: Sequence[str] = ('.sql',)
ui_color = '#ededed'

def __init__(
self,
*,
sql: str,
drill_conn_id: str = 'drill_default',
parameters: Iterable | Mapping | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.drill_conn_id = drill_conn_id
self.parameters = parameters
self.hook: DrillHook | None = None

def execute(self, context: Context):
self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
self.hook.run(self.sql, parameters=self.parameters, split_statements=True)
def __init__(self, *, drill_conn_id: str = 'drill_default', **kwargs) -> None:
super().__init__(conn_id=drill_conn_id, **kwargs)
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""",
DeprecationWarning,
stacklevel=2,
)
2 changes: 1 addition & 1 deletion airflow/providers/apache/drill/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ versions:

dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql>=1.2.0
- apache-airflow-providers-common-sql>=1.3.0
- sqlalchemy-drill>=1.1.0

integrations:
Expand Down
81 changes: 78 additions & 3 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
# under the License.
from __future__ import annotations

import ast
import re
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, SupportsAbs

from packaging.version import Version

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, SkipMixin
from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook
from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook, fetch_all_handler
from airflow.version import version

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,7 +95,9 @@ class BaseSQLOperator(BaseOperator):

The provided method is .get_db_hook(). The default behavior will try to
retrieve the DB hook based on connection type.
You can custom the behavior by overriding the .get_db_hook() method.
You can customize the behavior by overriding the .get_db_hook() method.

:param conn_id: reference to a specific database
"""

def __init__(
Expand Down Expand Up @@ -162,6 +165,78 @@ def get_db_hook(self) -> DbApiHook:
return self._hook


class SQLExecuteQueryOperator(BaseSQLOperator):
"""
Executes SQL code in a specific database
:param sql: the SQL code or string pointing to a template file to be executed (templated).
File must have a '.sql' extensions.
:param autocommit: (optional) if True, each command is automatically committed (default: False).
:param parameters: (optional) the parameters to render the SQL query with.
:param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
:param split_statements: (optional) if split single SQL string into statements (default: False).
:param return_last: (optional) if return the result of only last statement (default: True).

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SQLExecuteQueryOperator`
"""

template_fields: Sequence[str] = ('sql', 'parameters')
template_ext: Sequence[str] = ('.sql', '.json')
template_fields_renderers = {"sql": "sql", "parameters": "json"}
ui_color = '#cdaaed'

def __init__(
self,
*,
sql: str | list[str],
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], Any] = fetch_all_handler,
split_statements: bool = False,
return_last: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter works only with do_xcom_push=True.
I have two thoughts on this:

  1. If return_last=True should we under the hood set do_xcom_push=True ?
  2. Assuming we go with (1) then probably the default of return_last should be False?

My goal here is to prevent the undesired state of do_xcom_push=False with return_last=True.
If user set return_last=True we know for sure xcom needs to be pushed.
if user set return_last=False we don't know what is the user wish for xcom so we need to also look for do_xcom_push value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my previous PR which had broke compatibility I decided to avoid it here.

Talking about the return_last it works only when do_xcom_push=True (by default True) and split_statements=True (by default False).
Probably it should be renamed to return_last_if_split_statements :)
Why it's True by default?
I think when we added split_statements we wanted to keep the same returning type independent what were passed there.
#23971 (comment)

If the user has in the legacy code, let's say, PostgresOperator. By default
it has do_xcom_push = True, split_statements = False and return_last = True.

do_xcom_push: bool = True,

return_last: bool = True,

So if the's a string as an SQL parameter. then result of this query will be returned. If the list of strings is passed as an SQL parameter, then wil be returned list of results of each statement. Of course in case if handler were passed. I think here we discussed it here #23971 (comment)

**kwargs,
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.autocommit = autocommit
self.parameters = parameters
self.handler = handler
self.split_statements = split_statements
self.return_last = return_last

def execute(self, context):
self.log.info('Executing: %s', self.sql)
hook = self.get_db_hook()
if self.do_xcom_push:
output = hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler,
split_statements=self.split_statements,
return_last=self.return_last,
)
else:
output = hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
split_statements=self.split_statements,
)

if hasattr(self, '_process_output'):
for out in output:
self._process_output(*out)

return output

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
if isinstance(self.parameters, str):
self.parameters = ast.literal_eval(self.parameters)


class SQLColumnCheckOperator(BaseSQLOperator):
"""
Performs one or more of the templated checks in the column_checks dictionary.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ description: |
`Common SQL Provider <https://en.wikipedia.org/wiki/SQL>`__

versions:
- 1.3.0
- 1.2.0
- 1.1.0
- 1.0.0
Expand Down
63 changes: 22 additions & 41 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@

import csv
import json
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, Tuple, cast
from typing import TYPE_CHECKING, Any, Sequence

from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class DatabricksSqlOperator(BaseOperator):
class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Executes SQL code in a Databricks SQL endpoint or a Databricks cluster

Expand Down Expand Up @@ -80,11 +80,9 @@ class DatabricksSqlOperator(BaseOperator):
def __init__(
self,
*,
sql: str | Iterable[str],
databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
http_path: str | None = None,
sql_endpoint_name: str | None = None,
parameters: Iterable | Mapping | None = None,
session_configuration=None,
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
Expand All @@ -96,37 +94,31 @@ def __init__(
client_parameters: dict[str, Any] | None = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSqlOperator``."""
super().__init__(**kwargs)
super().__init__(conn_id=databricks_conn_id, **kwargs)
self.databricks_conn_id = databricks_conn_id
self.sql = sql
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
self._output_path = output_path
self._output_format = output_format
self._csv_params = csv_params
self.parameters = parameters
self.do_xcom_push = do_xcom_push
self.session_config = session_configuration
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self.client_parameters = client_parameters or {}

def _get_hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(
self.databricks_conn_id,
http_path=self._http_path,
session_configuration=self.session_config,
sql_endpoint_name=self._sql_endpoint_name,
http_headers=self.http_headers,
catalog=self.catalog,
schema=self.schema,
caller="DatabricksSqlOperator",
**self.client_parameters,
)
client_parameters = {} if client_parameters is None else client_parameters
hook_params = kwargs.pop('hook_params', {})

def _format_output(self, schema, results):
self.hook_params = {
'http_path': http_path,
'session_configuration': session_configuration,
'sql_endpoint_name': sql_endpoint_name,
'http_headers': http_headers,
'catalog': catalog,
'schema': schema,
'caller': "DatabricksSqlOperator",
**client_parameters,
**hook_params,
}

def get_db_hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other operators, defining only the hook_params seem enough. e.g. kwargs['hook_params'] = {'schema': schema, **hook_params}. Why in this case you need to override get_db_hook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, Connection will return airflow.providers.databricks.hooks.databricks.DatabricksHook.
And I didn't found the solution how to fix it


def _process_output(self, schema, results):
if not self._output_path:
return
if not self._output_format:
Expand Down Expand Up @@ -157,17 +149,6 @@ def _format_output(self, schema, results):
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")

def execute(self, context: Context):
self.log.info('Executing: %s', self.sql)
hook = self._get_hook()
response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler)
schema, results = cast(List[Tuple[Any, Any]], response)[0]
# self.log.info('Schema: %s', schema)
# self.log.info('Results: %s', results)
self._format_output(schema, results)
if self.do_xcom_push:
return results


COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"]

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ versions:

dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql>=1.2.0
- apache-airflow-providers-common-sql>=1.3.0
- requests>=2.27,<3
- databricks-sql-connector>=2.0.0, <3.0.0
- aiohttp>=3.6.3, <4
Expand Down
Loading