Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param output_format: format of output data if ``output_path` is specified.
Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
:param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
:param do_xcom_push: If True, then the result of SQL executed will be pushed to an XCom.
"""

template_fields: Sequence[str] = (
Expand All @@ -87,7 +86,6 @@ def __init__(
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
do_xcom_push: bool = False,
output_path: str | None = None,
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
Expand All @@ -99,24 +97,28 @@ def __init__(
self._output_path = output_path
self._output_format = output_format
self._csv_params = csv_params
self.http_path = http_path
self.sql_endpoint_name = sql_endpoint_name
self.session_configuration = session_configuration
self.client_parameters = {} if client_parameters is None else client_parameters
self.hook_params = kwargs.pop("hook_params", {})
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema

client_parameters = {} if client_parameters is None else client_parameters
hook_params = kwargs.pop("hook_params", {})

self.hook_params = {
"http_path": http_path,
"session_configuration": session_configuration,
"sql_endpoint_name": sql_endpoint_name,
"http_headers": http_headers,
"catalog": catalog,
"schema": schema,
def get_db_hook(self) -> DatabricksSqlHook:
hook_params = {
"http_path": self.http_path,
"session_configuration": self.session_configuration,
"sql_endpoint_name": self.sql_endpoint_name,
"http_headers": self.http_headers,
"catalog": self.catalog,
"schema": self.schema,
"caller": "DatabricksSqlOperator",
**client_parameters,
**hook_params,
**self.client_parameters,
**self.hook_params,
}

def get_db_hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)

def _process_output(self, schema, results):
if not self._output_path:
Expand Down