-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Add SQLExecuteQueryOperator #25717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SQLExecuteQueryOperator #25717
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -17,16 +17,17 @@ | |||||
| # under the License. | ||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import ast | ||||||
| import re | ||||||
| from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs | ||||||
| from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Sequence, SupportsAbs | ||||||
|
|
||||||
| from packaging.version import Version | ||||||
|
|
||||||
| from airflow.compat.functools import cached_property | ||||||
| from airflow.exceptions import AirflowException | ||||||
| from airflow.hooks.base import BaseHook | ||||||
| from airflow.models import BaseOperator, SkipMixin | ||||||
| from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook | ||||||
| from airflow.providers.common.sql.hooks.sql import DbApiHook, _backported_get_hook, fetch_all_handler | ||||||
| from airflow.version import version | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
|
|
@@ -94,7 +95,9 @@ class BaseSQLOperator(BaseOperator): | |||||
|
|
||||||
| The provided method is .get_db_hook(). The default behavior will try to | ||||||
| retrieve the DB hook based on connection type. | ||||||
| You can custom the behavior by overriding the .get_db_hook() method. | ||||||
| You can customize the behavior by overriding the .get_db_hook() method. | ||||||
|
|
||||||
| :param conn_id: reference to a specific database | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
|
|
@@ -162,6 +165,78 @@ def get_db_hook(self) -> DbApiHook: | |||||
| return self._hook | ||||||
|
|
||||||
|
|
||||||
| class SQLExecuteQueryOperator(BaseSQLOperator): | ||||||
kazanzhy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| """ | ||||||
| Executes SQL code in a specific database | ||||||
| :param sql: the SQL code or string pointing to a template file to be executed (templated). | ||||||
| File must have a '.sql' extensions. | ||||||
| :param autocommit: (optional) if True, each command is automatically committed (default: False). | ||||||
| :param parameters: (optional) the parameters to render the SQL query with. | ||||||
| :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). | ||||||
| :param split_statements: (optional) if split single SQL string into statements (default: False). | ||||||
| :param return_last: (optional) if return the result of only last statement (default: True). | ||||||
|
|
||||||
| .. seealso:: | ||||||
| For more information on how to use this operator, take a look at the guide: | ||||||
| :ref:`howto/operator:SQLExecuteQueryOperator` | ||||||
| """ | ||||||
|
|
||||||
| template_fields: Sequence[str] = ('sql', 'parameters') | ||||||
| template_ext: Sequence[str] = ('.sql', '.json') | ||||||
| template_fields_renderers = {"sql": "sql", "parameters": "json"} | ||||||
| ui_color = '#cdaaed' | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| *, | ||||||
| sql: str | list[str], | ||||||
| autocommit: bool = False, | ||||||
| parameters: Mapping | Iterable | None = None, | ||||||
| handler: Callable[[Any], Any] = fetch_all_handler, | ||||||
| split_statements: bool = False, | ||||||
| return_last: bool = True, | ||||||
|
||||||
| do_xcom_push: bool = True, |
| return_last: bool = True, |
So if the's a string as an SQL parameter. then result of this query will be returned. If the list of strings is passed as an SQL parameter, then wil be returned list of results of each statement. Of course in case if handler were passed. I think here we discussed it here #23971 (comment)
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,20 +20,20 @@ | |
|
|
||
| import csv | ||
| import json | ||
| from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, Tuple, cast | ||
| from typing import TYPE_CHECKING, Any, Sequence | ||
|
|
||
| from databricks.sql.utils import ParamEscaper | ||
|
|
||
| from airflow.exceptions import AirflowException | ||
| from airflow.models import BaseOperator | ||
| from airflow.providers.common.sql.hooks.sql import fetch_all_handler | ||
| from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator | ||
| from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook | ||
|
|
||
| if TYPE_CHECKING: | ||
| from airflow.utils.context import Context | ||
|
|
||
|
|
||
| class DatabricksSqlOperator(BaseOperator): | ||
| class DatabricksSqlOperator(SQLExecuteQueryOperator): | ||
| """ | ||
| Executes SQL code in a Databricks SQL endpoint or a Databricks cluster | ||
|
|
||
|
|
@@ -80,11 +80,9 @@ class DatabricksSqlOperator(BaseOperator): | |
| def __init__( | ||
| self, | ||
| *, | ||
| sql: str | Iterable[str], | ||
| databricks_conn_id: str = DatabricksSqlHook.default_conn_name, | ||
| http_path: str | None = None, | ||
| sql_endpoint_name: str | None = None, | ||
| parameters: Iterable | Mapping | None = None, | ||
| session_configuration=None, | ||
| http_headers: list[tuple[str, str]] | None = None, | ||
| catalog: str | None = None, | ||
|
|
@@ -96,37 +94,31 @@ def __init__( | |
| client_parameters: dict[str, Any] | None = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| """Creates a new ``DatabricksSqlOperator``.""" | ||
| super().__init__(**kwargs) | ||
| super().__init__(conn_id=databricks_conn_id, **kwargs) | ||
| self.databricks_conn_id = databricks_conn_id | ||
| self.sql = sql | ||
| self._http_path = http_path | ||
| self._sql_endpoint_name = sql_endpoint_name | ||
| self._output_path = output_path | ||
| self._output_format = output_format | ||
| self._csv_params = csv_params | ||
| self.parameters = parameters | ||
| self.do_xcom_push = do_xcom_push | ||
| self.session_config = session_configuration | ||
| self.http_headers = http_headers | ||
| self.catalog = catalog | ||
| self.schema = schema | ||
| self.client_parameters = client_parameters or {} | ||
|
|
||
| def _get_hook(self) -> DatabricksSqlHook: | ||
| return DatabricksSqlHook( | ||
| self.databricks_conn_id, | ||
| http_path=self._http_path, | ||
| session_configuration=self.session_config, | ||
| sql_endpoint_name=self._sql_endpoint_name, | ||
| http_headers=self.http_headers, | ||
| catalog=self.catalog, | ||
| schema=self.schema, | ||
| caller="DatabricksSqlOperator", | ||
| **self.client_parameters, | ||
| ) | ||
| client_parameters = {} if client_parameters is None else client_parameters | ||
| hook_params = kwargs.pop('hook_params', {}) | ||
|
|
||
| def _format_output(self, schema, results): | ||
| self.hook_params = { | ||
| 'http_path': http_path, | ||
| 'session_configuration': session_configuration, | ||
| 'sql_endpoint_name': sql_endpoint_name, | ||
| 'http_headers': http_headers, | ||
| 'catalog': catalog, | ||
| 'schema': schema, | ||
| 'caller': "DatabricksSqlOperator", | ||
| **client_parameters, | ||
| **hook_params, | ||
| } | ||
|
|
||
| def get_db_hook(self) -> DatabricksSqlHook: | ||
| return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params) | ||
|
||
|
|
||
| def _process_output(self, schema, results): | ||
| if not self._output_path: | ||
| return | ||
| if not self._output_format: | ||
|
|
@@ -157,17 +149,6 @@ def _format_output(self, schema, results): | |
| else: | ||
| raise AirflowException(f"Unsupported output format: '{self._output_format}'") | ||
|
|
||
| def execute(self, context: Context): | ||
| self.log.info('Executing: %s', self.sql) | ||
| hook = self._get_hook() | ||
| response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler) | ||
| schema, results = cast(List[Tuple[Any, Any]], response)[0] | ||
| # self.log.info('Schema: %s', schema) | ||
| # self.log.info('Results: %s', results) | ||
| self._format_output(schema, results) | ||
| if self.do_xcom_push: | ||
| return results | ||
|
|
||
|
|
||
| COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.