diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index fae0dde00d..646e1ceb26 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,8 +1,9 @@ from __future__ import annotations -import importlib from copy import deepcopy +from cosmos._utils.importer import load_method_from_module + try: # Airflow 3 from airflow.sdk.bases.operator import BaseOperator except ImportError: # Airflow 2 @@ -33,8 +34,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) # first, import the operator class from the # fully qualified name defined in the task module_name, class_name = task.operator_class.rsplit(".", 1) - module = importlib.import_module(module_name) - Operator = getattr(module, class_name) + Operator = load_method_from_module(module_name, class_name) task_kwargs = task.arguments if task.owner != "": diff --git a/cosmos/operators/_asynchronous/base.py b/cosmos/operators/_asynchronous/base.py index 235878ef87..277da4387e 100644 --- a/cosmos/operators/_asynchronous/base.py +++ b/cosmos/operators/_asynchronous/base.py @@ -1,9 +1,9 @@ from __future__ import annotations -import importlib import logging from typing import Any +from cosmos._utils.importer import load_method_from_module from cosmos.airflow.graph import _snake_case_to_camelcase from cosmos.config import ProfileConfig from cosmos.constants import ExecutionMode @@ -27,9 +27,7 @@ def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator" try: module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - operator_class = getattr(module, class_name) - return operator_class + return load_method_from_module(module_path, class_name) except (ModuleNotFoundError, AttributeError) as e: raise ImportError(f"Error in loading class: {class_path}. Unable to find the specified operator class.") from e diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index 6810daf20a..cfd94245c1 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -47,12 +47,16 @@ def profile_config_mock(): def test_create_async_operator_class_valid(): """Test _create_async_operator_class returns the correct async operator class if available.""" - with patch("cosmos.operators._asynchronous.base.importlib.import_module") as mock_import: + with patch("cosmos.operators._asynchronous.base.load_method_from_module") as mock_import: mock_class = MagicMock() - mock_import.return_value = MagicMock() - setattr(mock_import.return_value, "DbtRunAirflowAsyncBigqueryOperator", mock_class) + + mock_import.return_value = mock_class result = _create_async_operator_class("bigquery", "DbtRun") + + mock_import.assert_called_once_with( + "cosmos.operators._asynchronous.bigquery", "DbtRunAirflowAsyncBigqueryOperator" + ) assert result == mock_class