Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import importlib
from copy import deepcopy

from cosmos._utils.importer import load_method_from_module

try: # Airflow 3
from airflow.sdk.bases.operator import BaseOperator
except ImportError: # Airflow 2
Expand Down Expand Up @@ -33,8 +34,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
# first, import the operator class from the
# fully qualified name defined in the task
module_name, class_name = task.operator_class.rsplit(".", 1)
module = importlib.import_module(module_name)
Operator = getattr(module, class_name)
Operator = load_method_from_module(module_name, class_name)

task_kwargs = task.arguments
if task.owner != "":
Expand Down
6 changes: 2 additions & 4 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import importlib
import logging
from typing import Any

from cosmos._utils.importer import load_method_from_module
from cosmos.airflow.graph import _snake_case_to_camelcase
from cosmos.config import ProfileConfig
from cosmos.constants import ExecutionMode
Expand All @@ -27,9 +27,7 @@ def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any:
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator"
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
operator_class = getattr(module, class_name)
return operator_class
return load_method_from_module(module_path, class_name)
except (ModuleNotFoundError, AttributeError) as e:
raise ImportError(f"Error in loading class: {class_path}. Unable to find the specified operator class.") from e

Expand Down
10 changes: 7 additions & 3 deletions tests/operators/_asynchronous/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ def profile_config_mock():

def test_create_async_operator_class_valid():
"""Test _create_async_operator_class returns the correct async operator class if available."""
with patch("cosmos.operators._asynchronous.base.importlib.import_module") as mock_import:
with patch("cosmos.operators._asynchronous.base.load_method_from_module") as mock_import:
mock_class = MagicMock()
mock_import.return_value = MagicMock()
setattr(mock_import.return_value, "DbtRunAirflowAsyncBigqueryOperator", mock_class)

mock_import.return_value = mock_class

result = _create_async_operator_class("bigquery", "DbtRun")

mock_import.assert_called_once_with(
"cosmos.operators._asynchronous.bigquery", "DbtRunAirflowAsyncBigqueryOperator"
)
assert result == mock_class


Expand Down