diff --git a/python-sdk/src/astro/databases/snowflake.py b/python-sdk/src/astro/databases/snowflake.py index 20ce9b3cce..3569fb5c85 100644 --- a/python-sdk/src/astro/databases/snowflake.py +++ b/python-sdk/src/astro/databases/snowflake.py @@ -615,13 +615,13 @@ def load_file_to_table_natively( except AttributeError: try: rows = self.hook.run(sql_statement) - except (AttributeError, ValueError) as exe: + except ValueError as exe: raise DatabaseCustomError from exe except ValueError as exe: raise DatabaseCustomError from exe - + finally: + self.drop_stage(stage) self.evaluate_results(rows) - self.drop_stage(stage) @staticmethod def evaluate_results(rows): diff --git a/python-sdk/tests/databases/test_snowflake.py b/python-sdk/tests/databases/test_snowflake.py index 26e3011c0b..3700711f37 100644 --- a/python-sdk/tests/databases/test_snowflake.py +++ b/python-sdk/tests/databases/test_snowflake.py @@ -2,7 +2,7 @@ import os import pathlib from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch import pandas as pd import pytest @@ -258,7 +258,10 @@ def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable mock_stage, mock_hook, database_table_fixture ): """Test loading on files to snowflake natively for fallback raise exception.""" - mock_hook.run.side_effect = ValueError + mock_hook.run.side_effect = [ + ValueError, # 1st run call copies the data + None, # 2nd run call drops the stage + ] mock_stage.return_value = SnowflakeStage( name="mock_stage", url="gcs://bucket/prefix", @@ -271,6 +274,11 @@ def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable source_file=File(filepath), target_table=target_table, ) + mock_hook.run.assert_has_calls( + [ + call(f"DROP STAGE IF EXISTS {mock_stage.return_value.qualified_name};", autocommit=True), + ] + ) @pytest.mark.integration diff --git a/python-sdk/tests/sql/operators/test_load_file.py b/python-sdk/tests/sql/operators/test_load_file.py index 7fb2d7f070..a3c41a70e9 100644 --- a/python-sdk/tests/sql/operators/test_load_file.py +++ b/python-sdk/tests/sql/operators/test_load_file.py @@ -1244,7 +1244,11 @@ def test_load_file_snowflake_error_out_provider_3_1_0(sample_dag, database_table "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.run" ) as run: schema_exists.return_value = True - run.side_effect = AttributeError() + run.side_effect = [ + AttributeError, # 1st run call copies the data with handler + ValueError, # 2nd run call copies the data + None, # 3rd run call drops the stage + ] with pytest.raises(DatabaseCustomError): with sample_dag: load_file(