diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 5f76f02c4a..c5f47d4ab9 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -49,8 +49,13 @@ except ImportError: from airflow.models import BaseOperator # Airflow 2 +try: # Airflow 3 + from airflow.sdk.definitions.asset import Asset +except (ModuleNotFoundError, ImportError): # Airflow 2 + from airflow.datasets import Dataset as Asset # type: ignore + + try: - from airflow.datasets import Dataset from openlineage.common.provider.dbt.local import DbtLocalArtifactProcessor except ModuleNotFoundError: is_openlineage_available = False @@ -59,7 +64,6 @@ is_openlineage_available = True if TYPE_CHECKING: - from airflow.datasets import Dataset # noqa: F811 from dbt.cli.main import dbtRunner, dbtRunnerResult try: # pragma: no cover @@ -489,8 +493,7 @@ def _handle_datasets(self, context: Context) -> None: outlets = self.get_datasets("outputs") self.log.info("Inlets: %s", inlets) self.log.info("Outlets: %s", outlets) - if AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION: - self.register_dataset(inlets, outlets, context) + self.register_dataset(inlets, outlets, context) def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: if self.cache_dir is None: @@ -572,14 +575,12 @@ def run_command( env=env, cwd=tmp_project_dir, ) - if is_openlineage_available and AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION: - # Airflow 3 does not support associating 'openlineage_events_completes' with task_instance. The - # support for this is expected to be worked upon while addressing issue: - # https://github.com/astronomer/astronomer-cosmos/issues/1635 + if is_openlineage_available: self.calculate_openlineage_events_completes(env, tmp_dir_path) - context[ - "task_instance" - ].openlineage_events_completes = self.openlineage_events_completes # type: ignore + if AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION: + # Airflow 3 does not support associating 'openlineage_events_completes' with task_instance, + # in that case we're storing as self.openlineage_events_completes + context["task_instance"].openlineage_events_completes = self.openlineage_events_completes # type: ignore[attr-defined] if self.emit_datasets: self._handle_datasets(context) @@ -629,7 +630,7 @@ def calculate_openlineage_events_completes( except (FileNotFoundError, NotImplementedError, ValueError, KeyError, jinja2.exceptions.UndefinedError): self.log.debug("Unable to parse OpenLineage events", stack_info=True) - def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]: + def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Asset]: """ Use openlineage-integration-common to extract lineage events from the artifacts generated after running the dbt command. Relies on the following files: @@ -640,15 +641,16 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]: Return a list of Dataset URIs (strings). """ uris = [] + for completed in self.openlineage_events_completes: for output in getattr(completed, source): dataset_uri = output.namespace + "/" + urllib.parse.quote(output.name) uris.append(dataset_uri) - self.log.debug("URIs to be converted to Dataset: %s", uris) + self.log.debug("URIs to be converted to Asset: %s", uris) - datasets = [] + assets = [] try: - datasets = [Dataset(uri) for uri in uris] + assets = [Asset(uri) for uri in uris] except ValueError: raise AirflowCompatibilityError( """ @@ -659,9 +661,9 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]: By setting ``emit_datasets=False`` in ``RenderConfig``. For more information, see https://astronomer.github.io/astronomer-cosmos/configuration/render-config.html. """ ) - return datasets + return assets - def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None: + def register_dataset(self, new_inlets: list[Asset], new_outlets: list[Asset], context: Context) -> None: """ Register a list of datasets as outlets of the current task, when possible. @@ -676,9 +678,14 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] with DatasetAlias: https://github.com/apache/airflow/issues/42495 """ - from airflow.utils.session import create_session + if AIRFLOW_VERSION.major >= 3 and not settings.enable_dataset_alias: + logger.error("To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True.") + raise AirflowCompatibilityError( + "To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True." + ) + elif AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: + from airflow.utils.session import create_session - if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: logger.info("Assigning inlets/outlets without DatasetAlias") with create_session() as session: self.outlets.extend(new_outlets) # type: ignore[attr-defined] @@ -690,10 +697,21 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] DAG.bulk_write_to_db([self.dag], session=session) # type: ignore[attr-defined, call-arg, arg-type] session.commit() else: - logger.info("Assigning inlets/outlets with DatasetAlias") dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) # type: ignore[attr-defined] - for outlet in new_outlets: - context["outlet_events"][dataset_alias_name].add(outlet) # type: ignore[index] + + if AIRFLOW_VERSION.major == 2: + logger.info("Assigning inlets/outlets with DatasetAlias in Airflow 2") + + for outlet in new_outlets: + context["outlet_events"][dataset_alias_name].add(outlet) # type: ignore[index] + else: # AIRFLOW_VERSION.major == 3 + logger.info("Assigning outlets with DatasetAlias in Airflow 3") + from airflow.sdk.definitions.asset import AssetAlias + + # This line was necessary in Airflow 3.0.0, but this may become automatic in newer versions + self.outlets.append(AssetAlias(dataset_alias_name)) # type: ignore[attr-defined, call-arg, arg-type] + for outlet in new_outlets: + context["outlet_events"][AssetAlias(dataset_alias_name)].add(outlet) def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage: """ @@ -806,6 +824,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: operator_kwargs["outlets"] = [ DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) ] # type: ignore + if "task_id" in operator_kwargs: operator_kwargs.pop("task_id") BaseOperator.__init__(self, task_id=self.task_id, **operator_kwargs) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 9f258022fe..30ea848d78 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -462,18 +462,22 @@ def test_run_operator_dataset_inlets_and_outlets(caplog): assert test_operator.outlets == [] -# TODO: Add compatibility for Airflow 3. Issue: https://github.com/astronomer/astronomer-cosmos/issues/1704. -@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0") +@pytest.mark.skipif(version.parse(airflow_version).major >= 3, reason="This test is specific for Airflow 2.10 and 2.11") @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.10"), reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.", ) +@pytest.mark.skipif( + version.parse(airflow_version).minor == 10, + reason="This test stopped working and we'll investigate why after Cosmos 1.10 release: https://github.com/astronomer/astronomer-cosmos/issues/1730", +) @pytest.mark.integration def test_run_operator_dataset_inlets_and_outlets_airflow_210(caplog): try: from airflow.models.asset import AssetAliasModel except ModuleNotFoundError: from airflow.models.dataset import DatasetAliasModel as AssetAliasModel + from sqlalchemy.orm.exc import FlushError with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag: seed_operator = DbtSeedLocalOperator( @@ -510,6 +514,91 @@ def test_run_operator_dataset_inlets_and_outlets_airflow_210(caplog): assert run_operator.outlets == [AssetAliasModel(name="test_id_1__run")] assert test_operator.outlets == [AssetAliasModel(name="test_id_1__test")] + with pytest.raises(FlushError): + run_test_dag(dag) + # This is a known limitation of Airflow 2.10.0 and 2.10.1 + # https://github.com/apache/airflow/issues/42495 + + # When this is solved, we can uncomment the following: + # dag_run, session = run_test_dag(dag) + + # Once this issue is solved, we should do some type of check on the actual datasets being emitted, + # so we guarantee Cosmos is backwards compatible via tests using something along the lines or an alternative, + # based on the resolution of the issue logged in Airflow: + # dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "")) + # assert dataset_model == 1 + + +@pytest.mark.skipif( + version.parse(airflow_version) < version.Version("3.0.0"), + reason="From Airflow 3.0 onwards, we started using AssetAlias, which changed the original behaviour", +) +@pytest.mark.integration +def test_run_operator_dataset_inlets_and_outlets_airflow_3_onwards(caplog): + + with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag: + seed_operator = DbtSeedLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="seed", + dag=dag, + emit_datasets=False, + dbt_cmd_flags=["--select", "raw_customers"], + install_deps=True, + append_env=True, + ) + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="run", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="test", + dag=dag, + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + seed_operator >> run_operator >> test_operator + + caplog.clear() + dag.test() + assert "Assigning outlets with DatasetAlias in Airflow 3" in caplog.text + "Outlets: [Asset(name='postgres://0.0.0.0:5432/postgres/public/stg_customers', uri='postgres://0.0.0.0:5432/postgres/public/stg_customers'" in caplog.text + + +@pytest.mark.skipif( + version.parse(airflow_version).major < 3, + reason="Airflow 3.0 only supports assets when setting enable_dataset_alias=True", +) +@pytest.mark.integration +@patch("cosmos.settings.enable_dataset_alias", 0) +def test_run_operator_dataset_with_airflow_3_and_enabled_dataset_alias_false_fails(caplog): + with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag: + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=Path(__file__).parent.parent.parent / "dev/dags/dbt/altered_jaffle_shop", + task_id="run", + dbt_cmd_flags=["--models", "customers"], + install_deps=True, + append_env=True, + ) + run_operator + + caplog.set_level(logging.ERROR) + caplog.clear() + run_test_dag(dag) + + assert "AirflowCompatibilityError" in caplog.text + assert "ERROR" in caplog.text + assert "To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True." in caplog.text + @patch("cosmos.settings.enable_dataset_alias", 0) @pytest.mark.skipif( @@ -567,8 +656,6 @@ def test_run_operator_dataset_emission_is_skipped(caplog): assert run_operator.outlets == [] -# TODO: Make test compatible with Airflow 3.0. Issue:https://github.com/astronomer/astronomer-cosmos/issues/1713 -@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0") @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4") or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,