diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index f042435943889..7f50a16071fd4 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -175,8 +175,10 @@ def run( else: raise ValueError("List of SQL statements is empty") - results = [] + results: list[Any] = [] for sql_statement in sql: + if return_last: + results = [] # it makes no sense to pile up previous results # when using AAD tokens, it could expire if previous query run longer than token lifetime with closing(self.get_conn()) as conn: self.set_autocommit(conn, autocommit) @@ -193,7 +195,8 @@ def run( if handler is None: return None - elif self.scalar_return_last: + + if self.scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 379b0fd2c9bbb..3fd7fd082aff8 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -131,10 +131,11 @@ def _process_output( self.log.warning("Description of the cursor is missing. Will not process the output") return description, results field_names = [field[0] for field in description] + # We always need only the last results as we have only one description if scalar_results: - list_results: list[Any] = [results] - else: list_results = results + else: + list_results = results[-1] if self._output_format.lower() == "csv": with open(self._output_path, "w", newline="") as file: if self._csv_params: @@ -159,7 +160,7 @@ def _process_output( file.write("\n") else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - return description, results + return results COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 9a989dfae368a..4f33932a3bb2b 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -48,7 +48,7 @@ def test_exec_success(self, db_mock_class): op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, do_xcom_push=True) db_mock = db_mock_class.return_value mock_description = [("id",), ("value",)] - mock_results = [Row(id=1, value="value1")] + mock_results = [[Row(id=1, value="value1")]] db_mock.run.return_value = mock_results db_mock.last_description = mock_description db_mock.scalar_return_last = False @@ -85,7 +85,7 @@ def test_exec_write_file(self, db_mock_class): op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, output_path=tempfile_path) db_mock = db_mock_class.return_value mock_description = [("id",), ("value",)] - mock_results = [Row(id=1, value="value1")] + mock_results = [[Row(id=1, value="value1")]] db_mock.run.return_value = mock_results db_mock.last_description = mock_description db_mock.scalar_return_last = False