diff --git a/cosmos/constants.py b/cosmos/constants.py index 57e570fa8c..148f9cc146 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -4,8 +4,11 @@ from typing import Callable, Dict import aenum +import airflow from packaging.version import Version +AIRFLOW_VERSION = Version(airflow.__version__) + BIGQUERY_PROFILE_TYPE = "bigquery" DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml") DBT_PROJECT_FILENAME = "dbt_project.yml" diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index 9e6fe10eac..ed9407420e 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -3,17 +3,14 @@ from dataclasses import dataclass, field from typing import Any, Dict, List -import airflow from packaging.version import Version +from cosmos.constants import AIRFLOW_VERSION from cosmos.log import get_logger logger = get_logger(__name__) -AIRFLOW_VERSION = Version(airflow.__version__) - - @dataclass class CosmosEntity: """ diff --git a/cosmos/listeners/dag_run_listener.py b/cosmos/listeners/dag_run_listener.py index eb69d00c0f..4b73481845 100644 --- a/cosmos/listeners/dag_run_listener.py +++ b/cosmos/listeners/dag_run_listener.py @@ -3,20 +3,17 @@ import hashlib from typing import TYPE_CHECKING -from airflow import __version__ as airflow_version from airflow.listeners import hookimpl if TYPE_CHECKING: from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from packaging import version - from cosmos import telemetry -from cosmos.constants import _AIRFLOW3_MAJOR_VERSION +from cosmos.constants import _AIRFLOW3_MAJOR_VERSION, AIRFLOW_VERSION from cosmos.log import get_logger -AIRFLOW_VERSION_MAJOR = version.parse(airflow_version).major +AIRFLOW_VERSION_MAJOR = AIRFLOW_VERSION.major logger = get_logger(__name__) diff --git a/cosmos/operators/_asynchronous/bigquery.py b/cosmos/operators/_asynchronous/bigquery.py index 56c15b7fcb..f2dd407f32 100644 --- a/cosmos/operators/_asynchronous/bigquery.py +++ b/cosmos/operators/_asynchronous/bigquery.py @@ -6,8 +6,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Sequence -import airflow - from cosmos.operators.base import _sanitize_xcom_key try: @@ -23,12 +21,12 @@ from cosmos import settings from cosmos.config import ProfileConfig +from cosmos.constants import AIRFLOW_VERSION from cosmos.dataset import get_dataset_alias_name from cosmos.exceptions import CosmosValueError from cosmos.operators.local import AbstractDbtLocalBase from cosmos.settings import remote_target_path, remote_target_path_conn_id -AIRFLOW_VERSION = Version(airflow.__version__) DEFAULT_PRODUCER_ASYNC_TASK_ID = "dbt_setup_async" diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 496f982eec..2d0aa97744 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -15,11 +15,11 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence from urllib.parse import urlparse -import airflow import jinja2 from airflow import DAG from airflow.exceptions import AirflowException, AirflowSkipException from airflow.models.taskinstance import TaskInstance +from packaging.version import Version if TYPE_CHECKING: # pragma: no cover try: @@ -27,9 +27,7 @@ except ImportError: from airflow.utils.context import Context # type: ignore[attr-defined] -from airflow.version import version as airflow_version from attrs import define -from packaging.version import Version from cosmos import cache, settings @@ -46,6 +44,7 @@ ) from cosmos.constants import ( _AIRFLOW3_MAJOR_VERSION, + AIRFLOW_VERSION, DBT_DEPENDENCIES_FILE_NAMES, FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode, @@ -117,8 +116,6 @@ _sanitize_xcom_key, ) -AIRFLOW_VERSION = Version(airflow.__version__) - logger = get_logger(__name__) @@ -323,7 +320,7 @@ def _configure_remote_target_path() -> tuple[Path | ObjectStoragePath, str] | tu if not settings.AIRFLOW_IO_AVAILABLE: raise CosmosValueError( f"You're trying to specify remote target path {target_path_str}, but the required " - f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to " + f"Object Storage feature is unavailable in Airflow version {AIRFLOW_VERSION}. Please upgrade to " "Airflow 2.8 or later." ) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 79617279d4..8e9a6a6337 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -9,9 +9,6 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Callable, List, Union -import airflow -from packaging.version import Version - from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom if TYPE_CHECKING: # pragma: no cover @@ -33,8 +30,10 @@ except ImportError: # pragma: no cover from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] +from packaging.version import Version + from cosmos.config import ProfileConfig -from cosmos.constants import PRODUCER_WATCHER_TASK_ID, InvocationMode +from cosmos.constants import AIRFLOW_VERSION, PRODUCER_WATCHER_TASK_ID, InvocationMode from cosmos.operators.base import ( DbtRunMixin, DbtSeedMixin, @@ -46,9 +45,6 @@ DbtSourceLocalOperator, ) -AIRFLOW_VERSION = Version(airflow.__version__) - - try: from dbt_common.events.base_types import EventMsg except ImportError: # pragma: no cover diff --git a/cosmos/plugin/airflow3.py b/cosmos/plugin/airflow3.py index 91b25d117e..8a0d46f5f1 100644 --- a/cosmos/plugin/airflow3.py +++ b/cosmos/plugin/airflow3.py @@ -22,6 +22,7 @@ from cosmos.plugin.snippets import IFRAME_SCRIPT # Airflow version gating: External views feature for the plugins used here (CosmosAF3Plugin) exist only in >= 3.1 +# Note: We compute AIRFLOW_VERSION locally here (not from constants) so that tests can patch airflow.__version__ and reload this module AIRFLOW_VERSION = Version(airflow.__version__) diff --git a/tests/conftest.py b/tests/conftest.py index 502ccba1d9..ec067649a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,11 @@ import json from unittest.mock import patch -import airflow import pytest from airflow.models.connection import Connection from packaging.version import Version -AIRFLOW_VERSION = Version(airflow.__version__) - +from cosmos.constants import AIRFLOW_VERSION if AIRFLOW_VERSION >= Version("3.1"): # Change introduced in Airflow 3.1.0 diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index a6d522a456..695c8026c0 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -12,15 +12,14 @@ from subprocess import PIPE, Popen from unittest.mock import MagicMock, patch -import airflow import pytest from airflow.models import Variable -from packaging.version import Version from cosmos import settings from cosmos.config import CosmosConfigException, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig from cosmos.constants import ( _AIRFLOW3_MAJOR_VERSION, + AIRFLOW_VERSION, DBT_LOG_FILENAME, DBT_TARGET_DIR_NAME, DbtResourceType, @@ -51,8 +50,6 @@ SAMPLE_DBT_LS_OUTPUT = Path(__file__).parent.parent / "sample/sample_dbt_ls.txt" SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none")) -AIRFLOW_VERSION = Version(airflow.__version__) - if AIRFLOW_VERSION.major >= _AIRFLOW3_MAJOR_VERSION: object_storage_path = "airflow.sdk.ObjectStoragePath" else: diff --git a/tests/listeners/test_dag_run_listener.py b/tests/listeners/test_dag_run_listener.py index cc94257c91..5ad94430b5 100644 --- a/tests/listeners/test_dag_run_listener.py +++ b/tests/listeners/test_dag_run_listener.py @@ -5,21 +5,20 @@ from unittest.mock import patch import pytest -from airflow import __version__ as airflow_version from airflow.models import DAG, DagRun from airflow.utils.state import State -from packaging import version +from packaging.version import Version from cosmos import DbtRunLocalOperator, ProfileConfig, ProjectConfig from cosmos.airflow.dag import DbtDag from cosmos.airflow.task_group import DbtTaskGroup +from cosmos.constants import AIRFLOW_VERSION from cosmos.listeners.dag_run_listener import on_dag_run_failed, on_dag_run_success, total_cosmos_tasks from cosmos.profiles import PostgresUserPasswordProfileMapping DBT_ROOT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt" DBT_PROJECT_NAME = "jaffle_shop" -AIRFLOW_VERSION = version.parse(airflow_version) AIRFLOW_VERSION_MAJOR = AIRFLOW_VERSION.major profile_config = ProfileConfig( @@ -83,7 +82,7 @@ def test_not_cosmos_dag(): def create_dag_run(dag: DAG, run_id: str, run_after: datetime) -> DagRun: - if AIRFLOW_VERSION < version.Version("3.0"): + if AIRFLOW_VERSION < Version("3.0"): # Airflow 2 and 3.0 dag_run = dag.create_dagrun( state=State.NONE, @@ -133,7 +132,7 @@ def create_dag_run(dag: DAG, run_id: str, run_after: datetime) -> DagRun: @pytest.mark.skipif( - AIRFLOW_VERSION >= version.Version("3.1.0"), + AIRFLOW_VERSION >= Version("3.1.0"), reason="TODO: Fix create_dag_run to work with AF 3.1 and remove this skip.", ) @pytest.mark.integration @@ -161,7 +160,7 @@ def test_on_dag_run_success(mock_emit_usage_metrics_if_enabled, caplog): @pytest.mark.skipif( - AIRFLOW_VERSION >= version.Version("3.1.0"), reason="TODO: Fix create_dag_run to work with and remove this skip." + AIRFLOW_VERSION >= Version("3.1.0"), reason="TODO: Fix create_dag_run to work with and remove this skip." ) @pytest.mark.integration @patch("cosmos.listeners.dag_run_listener.telemetry.emit_usage_metrics_if_enabled") diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index 71b28cfaad..6810daf20a 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -3,11 +3,11 @@ from pathlib import Path from unittest.mock import MagicMock, Mock, mock_open, patch -import airflow import pytest from packaging.version import Version from cosmos.config import ProfileConfig +from cosmos.constants import AIRFLOW_VERSION from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators._asynchronous import SetupAsyncOperator, TeardownAsyncOperator from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class @@ -15,8 +15,6 @@ from cosmos.operators._asynchronous.databricks import DbtRunAirflowAsyncDatabricksOperator from cosmos.operators.local import DbtRunLocalOperator -AIRFLOW_VERSION = Version(airflow.__version__) - @pytest.mark.parametrize( "profile_type, dbt_class, expected_operator_class", diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 2c25a1aaf2..560c75999c 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -6,21 +6,18 @@ from pathlib import Path from unittest.mock import MagicMock, patch -import airflow import pytest from airflow.models import DAG from airflow.models.connection import Connection from packaging.version import Version from cosmos.config import ProfileConfig -from cosmos.constants import _AIRFLOW3_MAJOR_VERSION, InvocationMode +from cosmos.constants import _AIRFLOW3_MAJOR_VERSION, AIRFLOW_VERSION, InvocationMode from cosmos.exceptions import CosmosValueError from cosmos.operators.virtualenv import DbtCloneVirtualenvOperator, DbtVirtualenvBaseOperator from cosmos.profiles import PostgresUserPasswordProfileMapping from tests.utils import test_dag as run_test_dag -AIRFLOW_VERSION = Version(airflow.__version__) - DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" DAGS_FOLDER = Path(__file__).parent.parent.parent / "dev/dags/" diff --git a/tests/test_async_example_dag.py b/tests/test_async_example_dag.py index 5d70af6a64..175e3b692b 100644 --- a/tests/test_async_example_dag.py +++ b/tests/test_async_example_dag.py @@ -16,12 +16,10 @@ from functools import lru_cache as cache -import airflow import pytest from airflow.models.dagbag import DagBag from airflow.utils.db import create_default_connections from airflow.utils.session import provide_session -from packaging.version import Version EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "dev/dags" ALL_FILES_TO_IGNORE = [ @@ -29,7 +27,6 @@ ] AIRFLOW_IGNORE_FILE = EXAMPLE_DAGS_DIR / ".airflowignore" -AIRFLOW_VERSION = Version(airflow.__version__) @provide_session diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index 9685c29ebc..04e9c3226f 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -11,7 +11,6 @@ from functools import lru_cache as cache -import airflow import pytest from airflow.models.dagbag import DagBag from airflow.utils.db import create_default_connections @@ -19,14 +18,13 @@ from dbt.version import get_installed_version as get_dbt_version from packaging.version import Version -from cosmos.constants import PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS +from cosmos.constants import AIRFLOW_VERSION, PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS from . import utils as test_utils EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "dev/dags" AIRFLOW_IGNORE_FILE = EXAMPLE_DAGS_DIR / ".airflowignore" DBT_VERSION = Version(get_dbt_version().to_version_string()[1:]) -AIRFLOW_VERSION = Version(airflow.__version__) KUBERNETES_DAGS = ["jaffle_shop_kubernetes"] MIN_VER_DAG_FILE: dict[str, list[str]] = { diff --git a/tests/test_log.py b/tests/test_log.py index c6823c1f6f..8cd102de55 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -1,13 +1,9 @@ -import airflow import pytest -from packaging.version import Version import cosmos.log from cosmos.log import CosmosRichLogger, get_logger from cosmos.provider_info import get_provider_info -AIRFLOW_VERSION = Version(airflow.__version__) - def test_get_logger(monkeypatch): monkeypatch.setattr(cosmos.log, "rich_logging", False) diff --git a/tests/utils.py b/tests/utils.py index df5960c79a..02bf6bc29c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,6 @@ from typing import Any import sqlalchemy -from airflow import __version__ as airflow_version from airflow.configuration import secrets_backend_list from airflow.exceptions import AirflowSkipException from airflow.models.dag import DAG @@ -22,7 +21,7 @@ from packaging.version import Version from sqlalchemy.orm.session import Session -AIRFLOW_VERSION = version.parse(airflow_version) +from cosmos.constants import AIRFLOW_VERSION log = logging.getLogger(__name__)