Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,23 +197,34 @@ def _get_task_id_and_args(
args: dict[str, Any],
use_task_group: bool,
normalize_task_id: Callable[..., Any] | None,
normalize_task_display_name: Callable[..., Any] | None,
resource_suffix: str,
include_resource_type: bool = False,
) -> tuple[str, dict[str, Any]]:
"""
Generate task ID and update args with display name if needed.
"""
args_update = args
task_display_name = f"{node.name}_{resource_suffix}"
task_name = f"{node.name}_{resource_suffix}"

if include_resource_type:
task_display_name = f"{node.name}_{node.resource_type.value}_{resource_suffix}"
task_name = f"{node.name}_{node.resource_type.value}_{resource_suffix}"

if use_task_group:
task_id = resource_suffix
elif normalize_task_id:

elif normalize_task_id and not normalize_task_display_name:
task_id = normalize_task_id(node)
args_update["task_display_name"] = task_display_name
args_update["task_display_name"] = task_name
elif not normalize_task_id and normalize_task_display_name:
task_id = task_name
args_update["task_display_name"] = normalize_task_display_name(node)
elif normalize_task_id and normalize_task_display_name:
task_id = normalize_task_id(node)
args_update["task_display_name"] = normalize_task_display_name(node)
else:
task_id = task_display_name
task_id = task_name

return task_id, args_update


Expand Down Expand Up @@ -250,6 +261,7 @@ def create_task_metadata(
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
normalize_task_id: Callable[..., Any] | None = None,
normalize_task_display_name: Callable[..., Any] | None = None,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER,
on_warning_callback: Callable[..., Any] | None = None,
Expand Down Expand Up @@ -288,10 +300,18 @@ def create_task_metadata(
args["on_warning_callback"] = on_warning_callback
exclude_detached_tests_if_needed(node, args, detached_from_parent)
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, "build", include_resource_type=True
node,
args,
use_task_group,
normalize_task_id,
normalize_task_display_name,
"build",
include_resource_type=True,
)
elif node.resource_type == DbtResourceType.MODEL:
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "run")
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, normalize_task_display_name, "run"
)
elif node.resource_type == DbtResourceType.SOURCE:
args["on_warning_callback"] = on_warning_callback

Expand All @@ -303,7 +323,9 @@ def create_task_metadata(
return None
args["select"] = f"source:{node.resource_name}"
args.pop("models")
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "source")
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source"
)
if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL:
# render sources without freshness as empty operators
# empty operator does not accept custom parameters (e.g., profile_args). recreate the args.
Expand All @@ -314,7 +336,7 @@ def create_task_metadata(
return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator", arguments=args)
else:
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, node.resource_type.value
node, args, use_task_group, normalize_task_id, normalize_task_display_name, node.resource_type.value
)

_override_profile_if_needed(args, node.profile_config_to_override)
Expand Down Expand Up @@ -365,6 +387,7 @@ def generate_task_or_group(
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
normalize_task_id: Callable[..., Any] | None = None,
normalize_task_display_name: Callable[..., Any] | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
enable_owner_inheritance: bool | None = None,
**kwargs: Any,
Expand All @@ -386,6 +409,7 @@ def generate_task_or_group(
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
normalize_task_id=normalize_task_id,
normalize_task_display_name=normalize_task_display_name,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
Expand Down Expand Up @@ -595,6 +619,7 @@ def build_airflow_graph(
test_behavior = render_config.test_behavior
source_rendering_behavior = render_config.source_rendering_behavior
normalize_task_id = render_config.normalize_task_id
normalize_task_display_name = render_config.normalize_task_display_name
enable_owner_inheritance = render_config.enable_owner_inheritance
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator
Expand Down Expand Up @@ -624,6 +649,7 @@ def build_airflow_graph(
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
normalize_task_id=normalize_task_id,
normalize_task_display_name=normalize_task_display_name,
node=node,
detached_from_parent=detached_from_parent,
enable_owner_inheritance=enable_owner_inheritance,
Expand Down
2 changes: 2 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class RenderConfig:
:param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6).
:param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache.
:param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name.
:param normalize_task_display_name: A callable that takes a dbt node as input and returns the task display name. This allows users to assign a custom task display name separate from the node ID.
:param should_detach_multiple_parents_tests: A boolean that allows users to decide whether to run tests with multiple parent dependencies in separate tasks.
:param enable_owner_inheritance: A boolean that allows users to enable the owner inheritance from dbt models to airflow tasks. Defaults to True.
"""
Expand All @@ -87,6 +88,7 @@ class RenderConfig:
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE
airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list)
normalize_task_id: Callable[..., Any] | None = None
normalize_task_display_name: Callable[..., Any] | None = None
should_detach_multiple_parents_tests: bool = False
enable_owner_inheritance: bool | None = True

Expand Down
3 changes: 2 additions & 1 deletion docs/configuration/render-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ The ``RenderConfig`` class takes the following arguments:
- ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``
- ``airflow_vars_to_purge_cache``: (new in v1.5) Specify Airflow variables that will affect the ``LoadMode.DBT_LS`` cache. See `Caching <./caching.html>`_ for more information.
- ``source_rendering_behavior``: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information.
- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task's display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task's display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information.
- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task’s display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task’s display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information.
- ``normalize_task_display_name``: This function allows users to set a custom user-defined function to alter the display name independently of the model name. This way, the task_id can be preserved while the model display name is modified.
- ``should_detach_multiple_parents_tests``: A boolean to control if tests that depend on multiple parents should be run as standalone tasks. See `Parsing Methods <testing-behavior.html>`_ for more information.
- ``enable_owner_inheritance``: A boolean to control if dbt owners should be imported as part of the airflow DAG owners. Defaults to True.

Expand Down
20 changes: 19 additions & 1 deletion docs/configuration/task-display-name.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,22 @@ You can provide a function to convert the model name to an ASCII-compatible form
)

.. note::
Although the slugify example often works, it may not be suitable for use in actual production. Since slugify performs conversions based on pronunciation, there may be cases where task_id is not unique due to homophones and similar issues.
Although the slugify example often works, it may not be suitable for use in actual production. Since slugify performs conversions based on pronunciation, there may be cases where ``task_id`` is not unique due to homophones and similar issues.

Sometimes, users may want to keep the original ``task_id`` but change only the ``display_name`` instead. This can be done using the ``normalize_task_display_name`` field in ``RenderConfig``.

Example:

You can provide a function to display the dbt models in the Airflow UI without any suffixes (e.g., ``_source``, ``_run``, etc).

.. code-block:: python

def normalize_task_display_name(node):
return f"{node.name}"


from cosmos import DbtTaskGroup, RenderConfig

jaffle_shop = DbtTaskGroup(
render_config=RenderConfig(normalize_task_display_name=normalize_task_display_name)
)
111 changes: 109 additions & 2 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,18 +659,24 @@ def _normalize_task_id(node: DbtNode) -> str:
return f"new_task_id_{node.name}_{node.resource_type.value}"


def _normalize_task_display_name(node: DbtNode) -> str:
"""for test_create_task_metadata_normalize_task_id"""
return f"new_task_display_name_{node.name}_{node.resource_type.value}"


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.9"),
reason="Airflow task did not have display_name until the 2.9 release",
)
@pytest.mark.parametrize(
"node_type,node_id,normalize_task_id,use_task_group,test_behavior,expected_node_id,expected_display_name",
"node_type,node_id,normalize_task_id,normalize_task_display_name,use_task_group,test_behavior,expected_node_id,expected_display_name",
[
# normalize_task_id is None (default)
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
None,
False,
None,
"test_node_run",
Expand All @@ -680,6 +686,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SOURCE,
f"{DbtResourceType.SOURCE.value}.my_folder.test_node",
None,
None,
False,
None,
"test_node_source",
Expand All @@ -689,6 +696,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.SEED.value}.my_folder.test_node",
None,
None,
False,
None,
"test_node_seed",
Expand All @@ -698,6 +706,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.SEED.value}.my_folder.test_node",
None,
None,
False,
TestBehavior.BUILD,
"test_node_seed_build",
Expand All @@ -708,6 +717,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
False,
None,
"new_task_id_test_node_model",
Expand All @@ -717,6 +727,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
False,
None,
"new_task_id_test_node_source",
Expand All @@ -726,6 +737,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
False,
None,
"new_task_id_test_node_seed",
Expand All @@ -735,16 +747,100 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
False,
TestBehavior.BUILD,
"new_task_id_test_node_seed",
"test_node_seed_build",
),
# normalize_task_id is passed together with normalize_task_display_name
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
_normalize_task_display_name,
False,
None,
"new_task_id_test_node_model",
"new_task_display_name_test_node_model",
),
(
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
_normalize_task_display_name,
False,
None,
"new_task_id_test_node_source",
"new_task_display_name_test_node_source",
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
_normalize_task_display_name,
False,
None,
"new_task_id_test_node_seed",
"new_task_display_name_test_node_seed",
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
_normalize_task_display_name,
False,
TestBehavior.BUILD,
"new_task_id_test_node_seed",
"new_task_display_name_test_node_seed",
),
# normalize_task_id is not passed but normalize_task_display_name is passed
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
_normalize_task_display_name,
False,
None,
"test_node_run",
"new_task_display_name_test_node_model",
),
(
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
_normalize_task_display_name,
False,
None,
"test_node_source",
"new_task_display_name_test_node_source",
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
_normalize_task_display_name,
False,
None,
"test_node_seed",
"new_task_display_name_test_node_seed",
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
_normalize_task_display_name,
False,
TestBehavior.BUILD,
"test_node_seed_build",
"new_task_display_name_test_node_seed",
),
# normalize_task_id is passed and use_task_group is True
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
True,
None,
"run",
Expand All @@ -754,6 +850,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
True,
None,
"source",
Expand All @@ -763,6 +860,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
True,
None,
"seed",
Expand All @@ -772,6 +870,7 @@ def _normalize_task_id(node: DbtNode) -> str:
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
None,
True,
TestBehavior.BUILD,
"build",
Expand All @@ -780,7 +879,14 @@ def _normalize_task_id(node: DbtNode) -> str:
],
)
def test_create_task_metadata_normalize_task_id(
node_type, node_id, normalize_task_id, use_task_group, test_behavior, expected_node_id, expected_display_name
node_type,
node_id,
normalize_task_id,
normalize_task_display_name,
use_task_group,
test_behavior,
expected_node_id,
expected_display_name,
):
node = DbtNode(
unique_id=node_id,
Expand All @@ -798,6 +904,7 @@ def test_create_task_metadata_normalize_task_id(
dbt_dag_task_group_identifier="",
use_task_group=use_task_group,
normalize_task_id=normalize_task_id,
normalize_task_display_name=normalize_task_display_name,
source_rendering_behavior=SourceRenderingBehavior.ALL,
test_behavior=test_behavior,
)
Expand Down