diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index a4ad7fd523ce1..2879a194a698c 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -240,6 +240,7 @@ def get_dag_by_file_location(dag_id: str): raise AirflowException( f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." ) + # This method is called only when we explicitly do not have a bundle name dagbag = DagBag(dag_folder=dag_model.fileloc) return dagbag.dags[dag_id] diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f79f36e55d632..5c1d624949e3d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -628,6 +628,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: include_examples=False, safe_mode=False, load_op_links=False, + bundle_name=bundle_info.name, ) if TYPE_CHECKING: assert what.ti.dag_id diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 62554b061a895..34028655574e7 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -182,8 +182,57 @@ def test_parse(test_dags_dir: Path, make_ti_context): assert ti.task assert ti.task.dag - assert isinstance(ti.task, BaseOperator) - assert isinstance(ti.task.dag, DAG) + + +@mock.patch("airflow.dag_processing.dagbag.DagBag") +def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): + """Test that checks that the dagbag is constructed as expected during parsing""" + mock_bag_instance = mock.Mock() + mock_dagbag.return_value = mock_bag_instance + mock_dag = mock.Mock(spec=DAG) + mock_task = mock.Mock(spec=BaseOperator) + + mock_bag_instance.dags = {"super_basic": mock_dag} + mock_dag.task_dict = {"a": mock_task} + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="a", + dag_id="super_basic", + run_id="c", + try_number=1, + dag_version_id=uuid7(), + ), + dag_rel_path="super_basic.py", + bundle_info=BundleInfo(name="my-bundle", version=None), + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + with patch.dict( + os.environ, + { + "AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps( + [ + { + "name": "my-bundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"path": str(test_dags_dir), "refresh_interval": 1}, + } + ] + ), + }, + ): + parse(what, mock.Mock()) + + mock_dagbag.assert_called_once_with( + dag_folder=mock.ANY, + include_examples=False, + safe_mode=False, + load_op_links=False, + bundle_name="my-bundle", + ) @pytest.mark.parametrize(