Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@
except ImportError:
from airflow.models import BaseOperator # Airflow 2

try: # Airflow 3
from airflow.sdk.definitions.asset import Asset
except (ModuleNotFoundError, ImportError): # Airflow 2
from airflow.datasets import Dataset as Asset # type: ignore


try:
from airflow.datasets import Dataset
from openlineage.common.provider.dbt.local import DbtLocalArtifactProcessor
except ModuleNotFoundError:
is_openlineage_available = False
Expand All @@ -59,7 +64,6 @@
is_openlineage_available = True

if TYPE_CHECKING:
from airflow.datasets import Dataset # noqa: F811
from dbt.cli.main import dbtRunner, dbtRunnerResult

try: # pragma: no cover
Expand Down Expand Up @@ -489,8 +493,7 @@ def _handle_datasets(self, context: Context) -> None:
outlets = self.get_datasets("outputs")
self.log.info("Inlets: %s", inlets)
self.log.info("Outlets: %s", outlets)
if AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION:
self.register_dataset(inlets, outlets, context)
self.register_dataset(inlets, outlets, context)

def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
if self.cache_dir is None:
Expand Down Expand Up @@ -572,14 +575,12 @@ def run_command(
env=env,
cwd=tmp_project_dir,
)
if is_openlineage_available and AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION:
# Airflow 3 does not support associating 'openlineage_events_completes' with task_instance. The
# support for this is expected to be worked upon while addressing issue:
# https://github.com/astronomer/astronomer-cosmos/issues/1635
if is_openlineage_available:
self.calculate_openlineage_events_completes(env, tmp_dir_path)
context[
"task_instance"
].openlineage_events_completes = self.openlineage_events_completes # type: ignore
if AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION:
# Airflow 3 does not support associating 'openlineage_events_completes' with task_instance,
# in that case we're storing as self.openlineage_events_completes
context["task_instance"].openlineage_events_completes = self.openlineage_events_completes # type: ignore[attr-defined]

if self.emit_datasets:
self._handle_datasets(context)
Expand Down Expand Up @@ -629,7 +630,7 @@ def calculate_openlineage_events_completes(
except (FileNotFoundError, NotImplementedError, ValueError, KeyError, jinja2.exceptions.UndefinedError):
self.log.debug("Unable to parse OpenLineage events", stack_info=True)

def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Asset]:
"""
Use openlineage-integration-common to extract lineage events from the artifacts generated after running the dbt
command. Relies on the following files:
Expand All @@ -640,15 +641,16 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
Return a list of Dataset URIs (strings).
"""
uris = []

for completed in self.openlineage_events_completes:
for output in getattr(completed, source):
dataset_uri = output.namespace + "/" + urllib.parse.quote(output.name)
uris.append(dataset_uri)
self.log.debug("URIs to be converted to Dataset: %s", uris)
self.log.debug("URIs to be converted to Asset: %s", uris)

datasets = []
assets = []
try:
datasets = [Dataset(uri) for uri in uris]
assets = [Asset(uri) for uri in uris]
except ValueError:
raise AirflowCompatibilityError(
"""
Expand All @@ -659,9 +661,9 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
By setting ``emit_datasets=False`` in ``RenderConfig``. For more information, see https://astronomer.github.io/astronomer-cosmos/configuration/render-config.html.
"""
)
return datasets
return assets

def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None:
def register_dataset(self, new_inlets: list[Asset], new_outlets: list[Asset], context: Context) -> None:
"""
Register a list of datasets as outlets of the current task, when possible.

Expand All @@ -676,9 +678,14 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]
with DatasetAlias:
https://github.com/apache/airflow/issues/42495
"""
from airflow.utils.session import create_session
if AIRFLOW_VERSION.major >= 3 and not settings.enable_dataset_alias:
logger.error("To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True.")
raise AirflowCompatibilityError(
"To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True."
)
elif AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias:
from airflow.utils.session import create_session

if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias:
logger.info("Assigning inlets/outlets without DatasetAlias")
with create_session() as session:
self.outlets.extend(new_outlets) # type: ignore[attr-defined]
Expand All @@ -690,10 +697,21 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]
DAG.bulk_write_to_db([self.dag], session=session) # type: ignore[attr-defined, call-arg, arg-type]
session.commit()
else:
logger.info("Assigning inlets/outlets with DatasetAlias")
dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) # type: ignore[attr-defined]
for outlet in new_outlets:
context["outlet_events"][dataset_alias_name].add(outlet) # type: ignore[index]

if AIRFLOW_VERSION.major == 2:
logger.info("Assigning inlets/outlets with DatasetAlias in Airflow 2")

for outlet in new_outlets:
context["outlet_events"][dataset_alias_name].add(outlet) # type: ignore[index]
else: # AIRFLOW_VERSION.major == 3
Comment thread
pankajkoti marked this conversation as resolved.
logger.info("Assigning outlets with DatasetAlias in Airflow 3")
from airflow.sdk.definitions.asset import AssetAlias

# This line was necessary in Airflow 3.0.0, but this may become automatic in newer versions
self.outlets.append(AssetAlias(dataset_alias_name)) # type: ignore[attr-defined, call-arg, arg-type]
for outlet in new_outlets:
context["outlet_events"][AssetAlias(dataset_alias_name)].add(outlet)

def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage:
"""
Expand Down Expand Up @@ -806,6 +824,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
operator_kwargs["outlets"] = [
DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id))
] # type: ignore

if "task_id" in operator_kwargs:
operator_kwargs.pop("task_id")
BaseOperator.__init__(self, task_id=self.task_id, **operator_kwargs)
Expand Down
95 changes: 91 additions & 4 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,22 @@ def test_run_operator_dataset_inlets_and_outlets(caplog):
assert test_operator.outlets == []


# TODO: Add compatibility for Airflow 3. Issue: https://github.com/astronomer/astronomer-cosmos/issues/1704.
@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0")
@pytest.mark.skipif(version.parse(airflow_version).major >= 3, reason="This test is specific for Airflow 2.10 and 2.11")
@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.10"),
reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.",
)
@pytest.mark.skipif(
version.parse(airflow_version).minor == 10,
reason="This test stopped working and we'll investigate why after Cosmos 1.10 release: https://github.com/astronomer/astronomer-cosmos/issues/1730",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets_airflow_210(caplog):
try:
from airflow.models.asset import AssetAliasModel
except ModuleNotFoundError:
from airflow.models.dataset import DatasetAliasModel as AssetAliasModel
from sqlalchemy.orm.exc import FlushError

with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedLocalOperator(
Expand Down Expand Up @@ -510,6 +514,91 @@ def test_run_operator_dataset_inlets_and_outlets_airflow_210(caplog):
assert run_operator.outlets == [AssetAliasModel(name="test_id_1__run")]
assert test_operator.outlets == [AssetAliasModel(name="test_id_1__test")]

with pytest.raises(FlushError):
run_test_dag(dag)
# This is a known limitation of Airflow 2.10.0 and 2.10.1
# https://github.com/apache/airflow/issues/42495

# When this is solved, we can uncomment the following:
# dag_run, session = run_test_dag(dag)

# Once this issue is solved, we should do some type of check on the actual datasets being emitted,
# so we guarantee Cosmos is backwards compatible via tests using something along the lines or an alternative,
# based on the resolution of the issue logged in Airflow:
# dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "<something>"))
# assert dataset_model == 1


@pytest.mark.skipif(
version.parse(airflow_version) < version.Version("3.0.0"),
reason="From Airflow 3.0 onwards, we started using AssetAlias, which changed the original behaviour",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets_airflow_3_onwards(caplog):

with DAG("test_id_1", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="seed",
dag=dag,
emit_datasets=False,
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
)
run_operator = DbtRunLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="run",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="test",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
seed_operator >> run_operator >> test_operator

caplog.clear()
dag.test()
assert "Assigning outlets with DatasetAlias in Airflow 3" in caplog.text
"Outlets: [Asset(name='postgres://0.0.0.0:5432/postgres/public/stg_customers', uri='postgres://0.0.0.0:5432/postgres/public/stg_customers'" in caplog.text


@pytest.mark.skipif(
version.parse(airflow_version).major < 3,
reason="Airflow 3.0 only supports assets when setting enable_dataset_alias=True",
)
@pytest.mark.integration
@patch("cosmos.settings.enable_dataset_alias", 0)
def test_run_operator_dataset_with_airflow_3_and_enabled_dataset_alias_false_fails(caplog):
with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag:
run_operator = DbtRunLocalOperator(
profile_config=real_profile_config,
project_dir=Path(__file__).parent.parent.parent / "dev/dags/dbt/altered_jaffle_shop",
task_id="run",
dbt_cmd_flags=["--models", "customers"],
install_deps=True,
append_env=True,
)
run_operator

caplog.set_level(logging.ERROR)
caplog.clear()
run_test_dag(dag)

assert "AirflowCompatibilityError" in caplog.text
assert "ERROR" in caplog.text
assert "To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True." in caplog.text


@patch("cosmos.settings.enable_dataset_alias", 0)
@pytest.mark.skipif(
Expand Down Expand Up @@ -567,8 +656,6 @@ def test_run_operator_dataset_emission_is_skipped(caplog):
assert run_operator.outlets == []


# TODO: Make test compatible with Airflow 3.0. Issue:https://github.com/astronomer/astronomer-cosmos/issues/1713
@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0")
@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4")
or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,
Expand Down
Loading