Skip to content
21 changes: 19 additions & 2 deletions cosmos/listeners/dag_run_listener.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import annotations

import hashlib

from airflow import __version__ as airflow_version
from airflow.listeners import hookimpl
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from packaging import version

from cosmos import telemetry
from cosmos.constants import _AIRFLOW3_MAJOR_VERSION
from cosmos.log import get_logger

AIRFLOW_VERSION_MAJOR = version.parse(airflow_version).major

logger = get_logger(__name__)


Expand Down Expand Up @@ -50,8 +57,13 @@ def on_dag_run_success(dag_run: DagRun, msg: str) -> None:
logger.debug("The DAG does not use Cosmos")
return

if AIRFLOW_VERSION_MAJOR < _AIRFLOW3_MAJOR_VERSION:
dag_hash = dag_run.dag_hash
else:
dag_hash = hashlib.md5(dag_run.dag_id.encode("utf-8")).hexdigest()

additional_telemetry_metrics = {
"dag_hash": dag_run.dag_hash,
"dag_hash": dag_hash,
"status": EventStatus.SUCCESS,
"task_count": len(serialized_dag.task_ids),
"cosmos_task_count": total_cosmos_tasks(serialized_dag),
Expand All @@ -73,8 +85,13 @@ def on_dag_run_failed(dag_run: DagRun, msg: str) -> None:
logger.debug("The DAG does not use Cosmos")
return

if AIRFLOW_VERSION_MAJOR < _AIRFLOW3_MAJOR_VERSION:
dag_hash = dag_run.dag_hash
else:
dag_hash = hashlib.md5(dag_run.dag_id.encode("utf-8")).hexdigest()

additional_telemetry_metrics = {
"dag_hash": dag_run.dag_hash,
"dag_hash": dag_hash,
"status": EventStatus.FAILED,
"task_count": len(serialized_dag.task_ids),
"cosmos_task_count": total_cosmos_tasks(serialized_dag),
Expand Down
54 changes: 41 additions & 13 deletions tests/listeners/test_dag_run_listener.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import uuid
from datetime import datetime
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import patch

Expand All @@ -13,12 +13,15 @@
from cosmos import DbtRunLocalOperator, ProfileConfig, ProjectConfig
from cosmos.airflow.dag import DbtDag
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.constants import _AIRFLOW3_MAJOR_VERSION
from cosmos.listeners.dag_run_listener import on_dag_run_failed, on_dag_run_success, total_cosmos_tasks
from cosmos.profiles import PostgresUserPasswordProfileMapping

DBT_ROOT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt"
DBT_PROJECT_NAME = "jaffle_shop"

AIRFLOW_VERSION_MAJOR = version.parse(airflow_version).major

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
Expand Down Expand Up @@ -79,8 +82,6 @@ def test_not_cosmos_dag():
assert total_cosmos_tasks(dag) == 0


# TODO: Make test compatible with Airflow 3.0. Issue:https://github.com/astronomer/astronomer-cosmos/issues/1703
@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0")
@pytest.mark.integration
@patch("cosmos.listeners.dag_run_listener.telemetry.emit_usage_metrics_if_enabled")
def test_on_dag_run_success(mock_emit_usage_metrics_if_enabled, caplog):
Expand All @@ -95,19 +96,32 @@ def test_on_dag_run_success(mock_emit_usage_metrics_if_enabled, caplog):
dag_id="basic_cosmos_dag",
)
run_id = str(uuid.uuid1())
dag_run = dag.create_dagrun(
state=State.NONE,
run_id=run_id,
)

run_after = datetime.now(timezone.utc) - timedelta(seconds=1)
if AIRFLOW_VERSION_MAJOR < _AIRFLOW3_MAJOR_VERSION:
# Airflow 2
dag_run = dag.create_dagrun(
state=State.NONE,
run_id=run_id,
)
else:
# Airflow 3
from airflow.utils.types import DagRunTriggeredByType, DagRunType

dag_run = dag.create_dagrun(
state=State.NONE,
run_id=run_id,
run_after=run_after,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)

on_dag_run_success(dag_run, msg="test success")
assert "Running on_dag_run_success" in caplog.text
assert "Completed on_dag_run_success" in caplog.text
assert mock_emit_usage_metrics_if_enabled.call_count == 1


# TODO: Make test compatible with Airflow 3.0. Issue:https://github.com/astronomer/astronomer-cosmos/issues/1703
@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test need to be updated for Airflow 3.0")
@pytest.mark.integration
@patch("cosmos.listeners.dag_run_listener.telemetry.emit_usage_metrics_if_enabled")
def test_on_dag_run_failed(mock_emit_usage_metrics_if_enabled, caplog):
Expand All @@ -122,10 +136,24 @@ def test_on_dag_run_failed(mock_emit_usage_metrics_if_enabled, caplog):
dag_id="basic_cosmos_dag",
)
run_id = str(uuid.uuid1())
dag_run = dag.create_dagrun(
state=State.FAILED,
run_id=run_id,
)
run_after = datetime.now(timezone.utc) - timedelta(seconds=1)
if AIRFLOW_VERSION_MAJOR < _AIRFLOW3_MAJOR_VERSION:
# Airflow 2
dag_run = dag.create_dagrun(
state=State.NONE,
run_id=run_id,
)
else:
# Airflow 3
from airflow.utils.types import DagRunTriggeredByType, DagRunType

dag_run = dag.create_dagrun(
state=State.NONE,
run_id=run_id,
run_after=run_after,
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)

on_dag_run_failed(dag_run, msg="test failed")
assert "Running on_dag_run_failed" in caplog.text
Expand Down