Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def create_agg_by_group_functions(
potential_agg_by_group_sources = {
qn: o for qn, o in all_functions_and_data.items() if not gp.match(qn)
}
# Exclude objects that have been explicitly provided.u
# Exclude objects that have been explicitly provided.

agg_by_group_function_names = {
t
Expand All @@ -623,6 +623,19 @@ def create_agg_by_group_functions(
match = gp.match(abgfn)
base_name_with_time_unit = match.group("base_name_with_time_unit")
if base_name_with_time_unit in potential_agg_by_group_sources:
# Check if the aggregation target is already a dependency of the source
# function to avoid creating cycles in the DAG. Consider a function `x` that
# takes `x_hh` as an input, assuming it to be provided in the input data. If
# we create a function `x_hh`, which would aggregate `x` by household, we
# create a cycle. If `x_hh` is actually provided as an input, `x_hh` would
# be overwritten, removing the cycle. However, if `x_hh` is not provided as
# an input, an error message would be shown that a cycle between `x` and
# `x_hh` was detected. This hides the actual problem, which is that `x_hh`
# is not provided as an input.
source_function = column_functions.get(base_name_with_time_unit)
if source_function and abgfn in get_free_arguments(source_function):
continue

group_id = f"{match.group('group')}_id"
mapper = {"group_id": group_id, "column": base_name_with_time_unit}
agg_func = rename_arguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_should_apply_converter(self):
assert function(1) == 7


def test_should_not_create_cycle():
def test_time_conversions_should_not_create_cycle():
# Check for:
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/621
def x(test_m: int) -> int:
Expand All @@ -360,6 +360,28 @@ def x(test_m: int) -> int:
assert "test_m" not in time_conversion_functions


def test_grouping_functions_should_not_create_cycle():
@policy_function()
def x(x_hh: int) -> int:
return x_hh

@policy_function()
def some_other_function_requiring_x_hh(x_hh: int) -> int:
return x_hh

grouping_functions = create_agg_by_group_functions(
column_functions={
"x": x,
"some_other_function_requiring_x_hh": some_other_function_requiring_x_hh,
},
input_columns=set(),
tt_targets=("some_other_function_requiring_x_hh",),
grouping_levels=("hh",),
)

assert "x_hh" not in grouping_functions


@pytest.mark.parametrize(
(
"column_functions",
Expand Down
Loading