diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index e2b9101789653..d808377cb1ae2 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -63,7 +63,6 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator): :param output_format: format of output data if ``output_path` is specified. Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``. :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data. - :param do_xcom_push: If True, then the result of SQL executed will be pushed to an XCom. """ template_fields: Sequence[str] = ( @@ -87,7 +86,6 @@ def __init__( http_headers: list[tuple[str, str]] | None = None, catalog: str | None = None, schema: str | None = None, - do_xcom_push: bool = False, output_path: str | None = None, output_format: str = "csv", csv_params: dict[str, Any] | None = None, @@ -99,24 +97,28 @@ def __init__( self._output_path = output_path self._output_format = output_format self._csv_params = csv_params + self.http_path = http_path + self.sql_endpoint_name = sql_endpoint_name + self.session_configuration = session_configuration + self.client_parameters = {} if client_parameters is None else client_parameters + self.hook_params = kwargs.pop("hook_params", {}) + self.http_headers = http_headers + self.catalog = catalog + self.schema = schema - client_parameters = {} if client_parameters is None else client_parameters - hook_params = kwargs.pop("hook_params", {}) - - self.hook_params = { - "http_path": http_path, - "session_configuration": session_configuration, - "sql_endpoint_name": sql_endpoint_name, - "http_headers": http_headers, - "catalog": catalog, - "schema": schema, + def get_db_hook(self) -> DatabricksSqlHook: + hook_params = { + "http_path": self.http_path, + "session_configuration": self.session_configuration, + "sql_endpoint_name": self.sql_endpoint_name, + "http_headers": self.http_headers, + "catalog": self.catalog, + "schema": self.schema, "caller": "DatabricksSqlOperator", - **client_parameters, - **hook_params, + **self.client_parameters, + **self.hook_params, } - - def get_db_hook(self) -> DatabricksSqlHook: - return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params) + return DatabricksSqlHook(self.databricks_conn_id, **hook_params) def _process_output(self, schema, results): if not self._output_path: