diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index 6f9a29cb9..4b60bc6d6 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -611,7 +611,7 @@ def create_agg_by_group_functions( potential_agg_by_group_sources = { qn: o for qn, o in all_functions_and_data.items() if not gp.match(qn) } - # Exclude objects that have been explicitly provided.u + # Exclude objects that have been explicitly provided. agg_by_group_function_names = { t @@ -623,6 +623,19 @@ def create_agg_by_group_functions( match = gp.match(abgfn) base_name_with_time_unit = match.group("base_name_with_time_unit") if base_name_with_time_unit in potential_agg_by_group_sources: + # Check if the aggregation target is already a dependency of the source + # function to avoid creating cycles in the DAG. Consider a function `x` that + # takes `x_hh` as an input, assuming it to be provided in the input data. If + # we create a function `x_hh`, which would aggregate `x` by household, we + # create a cycle. If `x_hh` is actually provided as an input, `x_hh` would + # be overwritten, removing the cycle. However, if `x_hh` is not provided as + # an input, an error message would be shown that a cycle between `x` and + # `x_hh` was detected. This hides the actual problem, which is that `x_hh` + # is not provided as an input. + source_function = column_functions.get(base_name_with_time_unit) + if source_function and abgfn in get_free_arguments(source_function): + continue + group_id = f"{match.group('group')}_id" mapper = {"group_id": group_id, "column": base_name_with_time_unit} agg_func = rename_arguments( diff --git a/tests/ttsim/interface_dag_elements/test_automatically_added_functions.py b/tests/ttsim/interface_dag_elements/test_automatically_added_functions.py index 327bcfe8b..3952f3149 100644 --- a/tests/ttsim/interface_dag_elements/test_automatically_added_functions.py +++ b/tests/ttsim/interface_dag_elements/test_automatically_added_functions.py @@ -345,7 +345,7 @@ def test_should_apply_converter(self): assert function(1) == 7 -def test_should_not_create_cycle(): +def test_time_conversions_should_not_create_cycle(): # Check for: # https://github.com/iza-institute-of-labor-economics/gettsim/issues/621 def x(test_m: int) -> int: @@ -360,6 +360,28 @@ def x(test_m: int) -> int: assert "test_m" not in time_conversion_functions +def test_grouping_functions_should_not_create_cycle(): + @policy_function() + def x(x_hh: int) -> int: + return x_hh + + @policy_function() + def some_other_function_requiring_x_hh(x_hh: int) -> int: + return x_hh + + grouping_functions = create_agg_by_group_functions( + column_functions={ + "x": x, + "some_other_function_requiring_x_hh": some_other_function_requiring_x_hh, + }, + input_columns=set(), + tt_targets=("some_other_function_requiring_x_hh",), + grouping_levels=("hh",), + ) + + assert "x_hh" not in grouping_functions + + @pytest.mark.parametrize( ( "column_functions",