From 41153f5a064d3cda614842a72aee73023b108eb4 Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Tue, 22 Jul 2025 17:44:45 +0200 Subject: [PATCH 1/6] Fail if param functions depend on policy functions. --- src/ttsim/interface_dag_elements/__init__.py | 3 + src/ttsim/interface_dag_elements/fail_if.py | 67 ++++++++ .../interface_dag_elements/test_failures.py | 144 ++++++++++++++++++ 3 files changed, 214 insertions(+) diff --git a/src/ttsim/interface_dag_elements/__init__.py b/src/ttsim/interface_dag_elements/__init__.py index 2c30e9a25..2c11c7e12 100644 --- a/src/ttsim/interface_dag_elements/__init__.py +++ b/src/ttsim/interface_dag_elements/__init__.py @@ -54,6 +54,9 @@ class FailIf(MainTargetABC): non_convertible_objects_in_results_tree: str = ( "fail_if__non_convertible_objects_in_results_tree" ) + param_function_depends_on_column_objects: str = ( + "fail_if__param_function_depends_on_column_objects" + ) paths_are_missing_in_targets_tree_mapper: str = ( "fail_if__paths_are_missing_in_targets_tree_mapper" ) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 6c31d3a73..58d6df8f7 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -12,6 +12,7 @@ import numpy import optree import pandas as pd +from dags import get_free_arguments from ttsim.interface_dag_elements.interface_node_objects import fail_function from ttsim.interface_dag_elements.shared import get_name_of_group_by_id @@ -849,3 +850,69 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A ) return out + + +@fail_function() +def param_function_depends_on_column_objects( + specialized_environment__without_tree_logic_and_with_derived_functions: SpecEnvWithoutTreeLogicAndWithDerivedFunctions, +) -> None: + """Fail if any ParamFunction depends on ColumnObject arguments. + + Parameters + ---------- + specialized_environment__without_tree_logic_and_with_derived_functions + The specialized environment containing all functions and objects. + + Raises + ------ + ValueError + If any ParamFunction has ColumnObject arguments. + """ + param_functions = { + name: obj + for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items() + if isinstance(obj, ParamFunction) + } + + column_objects = { + name: obj + for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items() + if isinstance(obj, ColumnObject) + } + + violations: list[tuple[str, str]] = [] + for param_func_name, param_func in param_functions.items(): + func = param_func.function if hasattr(param_func, "function") else param_func + func_args = set(get_free_arguments(func)) + + allowed_column_object_args = [ + "evaluation_date", + "evaluation_year", + "evaluation_month", + "evaluation_day", + "policy_date", + "policy_year", + "policy_month", + "policy_day", + ] + + violations.extend( + (param_func_name, arg) + for arg in func_args + if arg in column_objects and arg not in allowed_column_object_args + ) + + if violations: + formatted_violations = format_list_linewise( + [ + f"{param_func_name} depends on {column_obj_name}" + for param_func_name, column_obj_name in violations + ] + ) + msg = format_errors_and_warnings( + "ParamFunctions should not depend on ColumnObjects. The following " + f"violations were found:\n\n{formatted_violations}\n\n" + "ParamFunctions should only depend on parameters and scalars, not on " + "ColumnObjects." + ) + raise ValueError(msg) diff --git a/tests/ttsim/interface_dag_elements/test_failures.py b/tests/ttsim/interface_dag_elements/test_failures.py index 1481664dc..156b13396 100644 --- a/tests/ttsim/interface_dag_elements/test_failures.py +++ b/tests/ttsim/interface_dag_elements/test_failures.py @@ -26,6 +26,7 @@ input_df_has_bool_or_numeric_column_names, input_df_mapper_columns_missing_in_df, input_df_mapper_has_incorrect_format, + param_function_depends_on_column_objects, paths_are_missing_in_targets_tree_mapper, targets_are_not_in_specialized_environment_or_data, ) @@ -35,6 +36,7 @@ DictParam, PiecewisePolynomialParam, PiecewisePolynomialParamValue, + ScalarParam, group_creation_function, param_function, policy_function, @@ -1413,3 +1415,145 @@ def test_raise_some_error_without_input_data( backend=backend, orig_policy_objects={"root": METTSIM_ROOT}, ) + + +@param_function() +def valid_param_function(x: int) -> int: + """A valid param function that only depends on parameters.""" + return x * 2 + + +@param_function() +def valid_param_function_with_evaluation_and_policy_year( + evaluation_year: int, + policy_year: int, +) -> int: + """A valid param function that only depends on evaluation_year and policy_year.""" + return evaluation_year - policy_year + + +@param_function() +def invalid_param_function(policy_func: int) -> int: + """An invalid param function that depends on a column object.""" + return policy_func * 2 + + +@policy_function() +def some_policy_function(x: int) -> int: + """A policy function for testing.""" + return x + 1 + + +@policy_input() +def some_policy_input() -> int: + """A policy input for testing.""" + return 1 + + +@policy_input() +def evaluation_year() -> int: + """A policy input for testing.""" + + +@policy_input() +def policy_year() -> int: + """A policy input for testing.""" + + +@pytest.mark.parametrize( + "specialized_environment", + [ + # Valid environment with only param functions and no dependencies + { + "valid_param": valid_param_function, + }, + # Valid environment with param functions and column objects but no dependencies + { + "valid_param": valid_param_function, + "policy_func": some_policy_function, + }, + # Valid environment with mixed types but no violations + { + "valid_param": valid_param_function, + "policy_func": some_policy_function, + "policy_input": some_policy_input, + "some_scalar": 42, + "some_dict_param": _SOME_DICT_PARAM, + }, + # Valid environment with evaluation_year and policy_year as scalars + { + "valid_param": valid_param_function_with_evaluation_and_policy_year, + "evaluation_year": ScalarParam(value=2025), + "policy_year": ScalarParam(value=2024), + }, + # Valid environment with evaluation_year and policy_year as policy inputs + { + "valid_param": valid_param_function_with_evaluation_and_policy_year, + "evaluation_year": evaluation_year, + "policy_year": policy_year, + }, + ], +) +def test_param_function_depends_on_column_objects_passes(specialized_environment): + """Test that valid environments pass the validation.""" + param_function_depends_on_column_objects(specialized_environment) + + +@pytest.mark.parametrize( + ("specialized_environment", "expected_error_match"), + [ + ( + { + "invalid_param": invalid_param_function, + "policy_func": some_policy_function, + }, + "invalid_param depends on policy_func", + ), + ( + { + "valid_param": valid_param_function, + "invalid_param": invalid_param_function, + "policy_func": some_policy_function, + }, + "invalid_param depends on policy_func", + ), + ], +) +def test_param_function_depends_on_column_objects_raises( + specialized_environment, expected_error_match +): + """Test that invalid environments raise the expected error.""" + with pytest.raises(ValueError, match=expected_error_match): + param_function_depends_on_column_objects(specialized_environment) + + +def test_param_function_depends_on_column_objects_via_main( + backend: Literal["jax", "numpy"], + xnp: ModuleType, +): + """Test that the param_function_depends_on_column_objects check works via main.""" + + with pytest.raises( + ValueError, + match="invalid_param depends on policy_func", + ): + main( + policy_date_str="2025-01-01", + main_target=MainTarget.results.df_with_mapper, + tt_targets={ + "tree": { + "invalid_param": None, + }, + }, + input_data={ + "tree": { + "p_id": xnp.array([1, 2, 3]), + "x": xnp.array([1, 2, 3]), + }, + }, + backend=backend, + policy_environment={ + "invalid_param": invalid_param_function, + "policy_func": some_policy_function, + }, + ) From 5220d70593154ab8dad95bdea7f6e8efd43e7e05 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 23 Jul 2025 06:15:16 +0200 Subject: [PATCH 2/6] Some renamings, add test for 'policy_input'. --- src/ttsim/interface_dag_elements/fail_if.py | 6 ++-- .../interface_dag_elements/test_failures.py | 28 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 58d6df8f7..93b17dfd6 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -905,14 +905,14 @@ def param_function_depends_on_column_objects( if violations: formatted_violations = format_list_linewise( [ - f"{param_func_name} depends on {column_obj_name}" + f"`{param_func_name}` depends on `{column_obj_name}`" for param_func_name, column_obj_name in violations ] ) msg = format_errors_and_warnings( - "ParamFunctions should not depend on ColumnObjects. The following " + "ParamFunctions must not depend on ColumnObjects. The following " f"violations were found:\n\n{formatted_violations}\n\n" - "ParamFunctions should only depend on parameters and scalars, not on " + "ParamFunctions may only depend on parameters and scalars, not on " "ColumnObjects." ) raise ValueError(msg) diff --git a/tests/ttsim/interface_dag_elements/test_failures.py b/tests/ttsim/interface_dag_elements/test_failures.py index 156b13396..6ab895a11 100644 --- a/tests/ttsim/interface_dag_elements/test_failures.py +++ b/tests/ttsim/interface_dag_elements/test_failures.py @@ -1433,9 +1433,9 @@ def valid_param_function_with_evaluation_and_policy_year( @param_function() -def invalid_param_function(policy_func: int) -> int: +def invalid_param_function(some_policy_function: int) -> int: """An invalid param function that depends on a column object.""" - return policy_func * 2 + return some_policy_function * 2 @policy_function() @@ -1447,7 +1447,6 @@ def some_policy_function(x: int) -> int: @policy_input() def some_policy_input() -> int: """A policy input for testing.""" - return 1 @policy_input() @@ -1470,12 +1469,12 @@ def policy_year() -> int: # Valid environment with param functions and column objects but no dependencies { "valid_param": valid_param_function, - "policy_func": some_policy_function, + "some_policy_function": some_policy_function, }, # Valid environment with mixed types but no violations { "valid_param": valid_param_function, - "policy_func": some_policy_function, + "some_policy_function": some_policy_function, "policy_input": some_policy_input, "some_scalar": 42, "some_dict_param": _SOME_DICT_PARAM, @@ -1505,17 +1504,24 @@ def test_param_function_depends_on_column_objects_passes(specialized_environment ( { "invalid_param": invalid_param_function, - "policy_func": some_policy_function, + "some_policy_function": some_policy_function, + }, + "`invalid_param` depends on `some_policy_function`", + ), + ( + { + "invalid_param": invalid_param_function, + "some_policy_function": some_policy_input, }, - "invalid_param depends on policy_func", + "`invalid_param` depends on `some_policy_function`", ), ( { "valid_param": valid_param_function, "invalid_param": invalid_param_function, - "policy_func": some_policy_function, + "some_policy_function": some_policy_function, }, - "invalid_param depends on policy_func", + "`invalid_param` depends on `some_policy_function`", ), ], ) @@ -1535,7 +1541,7 @@ def test_param_function_depends_on_column_objects_via_main( with pytest.raises( ValueError, - match="invalid_param depends on policy_func", + match="`invalid_param` depends on `some_policy_function`", ): main( policy_date_str="2025-01-01", @@ -1554,6 +1560,6 @@ def test_param_function_depends_on_column_objects_via_main( backend=backend, policy_environment={ "invalid_param": invalid_param_function, - "policy_func": some_policy_function, + "some_policy_function": some_policy_function, }, ) From db87be930f578cee17980eb578389d887589d851 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 23 Jul 2025 06:22:58 +0200 Subject: [PATCH 3/6] Simplify, 'ParamFunction's are guaranteed to have a 'function' attribute. --- src/ttsim/interface_dag_elements/fail_if.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 93b17dfd6..58f2806ff 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -882,8 +882,7 @@ def param_function_depends_on_column_objects( violations: list[tuple[str, str]] = [] for param_func_name, param_func in param_functions.items(): - func = param_func.function if hasattr(param_func, "function") else param_func - func_args = set(get_free_arguments(func)) + func_args = set(get_free_arguments(param_func.function)) allowed_column_object_args = [ "evaluation_date", From 47cd8a65557abd79692f9fa442c33a3bd6fc6db5 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 23 Jul 2025 06:45:32 +0200 Subject: [PATCH 4/6] Cannot make exceptions for dates. --- .../abz\303\274ge/vorsorge.py" | 4 +-- src/_gettsim/lohnsteuer/einkommen.py | 4 +-- src/ttsim/interface_dag_elements/fail_if.py | 21 +++--------- src/ttsim/testing_utils.py | 2 +- .../interface_dag_elements/test_failures.py | 32 ------------------- 5 files changed, 9 insertions(+), 54 deletions(-) diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" index 2dddce5af..6c2dbf4a7 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" @@ -200,7 +200,7 @@ def vorsorgeaufwendungen_keine_kappung_krankenversicherung_y_sn( @param_function(start_date="2005-01-01", end_date="2022-12-31") def rate_abzugsfähige_altersvorsorgeaufwendungen( parameter_einführungsfaktor_altersvorsorgeaufwendungen: PiecewisePolynomialParamValue, - evaluation_year: int, + policy_year: int, xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the @@ -214,7 +214,7 @@ def rate_abzugsfähige_altersvorsorgeaufwendungen( """ return piecewise_polynomial( - x=evaluation_year, + x=policy_year, parameters=parameter_einführungsfaktor_altersvorsorgeaufwendungen, xnp=xnp, ) diff --git a/src/_gettsim/lohnsteuer/einkommen.py b/src/_gettsim/lohnsteuer/einkommen.py index 73c81ebc4..57a270269 100644 --- a/src/_gettsim/lohnsteuer/einkommen.py +++ b/src/_gettsim/lohnsteuer/einkommen.py @@ -142,7 +142,7 @@ def vorsorge_krankenversicherungsbeiträge_option_b_ab_2019( @param_function(start_date="2005-01-01", end_date="2022-12-31") def einführungsfaktor_rentenversicherungsaufwendungen( parameter_einführungsfaktor_rentenversicherungsaufwendungen: PiecewisePolynomialParamValue, - evaluation_year: int, + policy_year: int, xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the @@ -156,7 +156,7 @@ def einführungsfaktor_rentenversicherungsaufwendungen( """ return piecewise_polynomial( - x=evaluation_year, + x=policy_year, parameters=parameter_einführungsfaktor_rentenversicherungsaufwendungen, xnp=xnp, ) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 58f2806ff..2cb839980 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -884,33 +884,20 @@ def param_function_depends_on_column_objects( for param_func_name, param_func in param_functions.items(): func_args = set(get_free_arguments(param_func.function)) - allowed_column_object_args = [ - "evaluation_date", - "evaluation_year", - "evaluation_month", - "evaluation_day", - "policy_date", - "policy_year", - "policy_month", - "policy_day", - ] - violations.extend( - (param_func_name, arg) - for arg in func_args - if arg in column_objects and arg not in allowed_column_object_args + (param_func_name, arg) for arg in func_args if arg in column_objects ) if violations: - formatted_violations = format_list_linewise( + formatted_violations = "\n ".join( [ f"`{param_func_name}` depends on `{column_obj_name}`" for param_func_name, column_obj_name in violations ] ) - msg = format_errors_and_warnings( + msg = ( "ParamFunctions must not depend on ColumnObjects. The following " - f"violations were found:\n\n{formatted_violations}\n\n" + f"violations were found:\n\n {formatted_violations}\n\n" "ParamFunctions may only depend on parameters and scalars, not on " "ColumnObjects." ) diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index c662dd0a0..acad2c4b4 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -48,7 +48,7 @@ def cached_policy_environment( policy_date=policy_date, orig_policy_objects={"root": root}, backend=backend, - include_fail_nodes=False, + include_fail_nodes=True, include_warn_nodes=False, ) diff --git a/tests/ttsim/interface_dag_elements/test_failures.py b/tests/ttsim/interface_dag_elements/test_failures.py index 6ab895a11..2508ba7ea 100644 --- a/tests/ttsim/interface_dag_elements/test_failures.py +++ b/tests/ttsim/interface_dag_elements/test_failures.py @@ -36,7 +36,6 @@ DictParam, PiecewisePolynomialParam, PiecewisePolynomialParamValue, - ScalarParam, group_creation_function, param_function, policy_function, @@ -1423,15 +1422,6 @@ def valid_param_function(x: int) -> int: return x * 2 -@param_function() -def valid_param_function_with_evaluation_and_policy_year( - evaluation_year: int, - policy_year: int, -) -> int: - """A valid param function that only depends on evaluation_year and policy_year.""" - return evaluation_year - policy_year - - @param_function() def invalid_param_function(some_policy_function: int) -> int: """An invalid param function that depends on a column object.""" @@ -1449,16 +1439,6 @@ def some_policy_input() -> int: """A policy input for testing.""" -@policy_input() -def evaluation_year() -> int: - """A policy input for testing.""" - - -@policy_input() -def policy_year() -> int: - """A policy input for testing.""" - - @pytest.mark.parametrize( "specialized_environment", [ @@ -1479,18 +1459,6 @@ def policy_year() -> int: "some_scalar": 42, "some_dict_param": _SOME_DICT_PARAM, }, - # Valid environment with evaluation_year and policy_year as scalars - { - "valid_param": valid_param_function_with_evaluation_and_policy_year, - "evaluation_year": ScalarParam(value=2025), - "policy_year": ScalarParam(value=2024), - }, - # Valid environment with evaluation_year and policy_year as policy inputs - { - "valid_param": valid_param_function_with_evaluation_and_policy_year, - "evaluation_year": evaluation_year, - "policy_year": policy_year, - }, ], ) def test_param_function_depends_on_column_objects_passes(specialized_environment): From 22ac49dcbfdc6bc2c6dbaa6ecc17cba7d0570a13 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 23 Jul 2025 06:53:21 +0200 Subject: [PATCH 5/6] Collect violations as strings right away. --- src/ttsim/interface_dag_elements/fail_if.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 2cb839980..c59e615d4 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -880,24 +880,18 @@ def param_function_depends_on_column_objects( if isinstance(obj, ColumnObject) } - violations: list[tuple[str, str]] = [] + violations = "" for param_func_name, param_func in param_functions.items(): func_args = set(get_free_arguments(param_func.function)) - violations.extend( - (param_func_name, arg) for arg in func_args if arg in column_objects - ) + for arg in func_args: + if arg in column_objects: + violations += f" `{param_func_name}` depends on `{arg}`\n" if violations: - formatted_violations = "\n ".join( - [ - f"`{param_func_name}` depends on `{column_obj_name}`" - for param_func_name, column_obj_name in violations - ] - ) msg = ( "ParamFunctions must not depend on ColumnObjects. The following " - f"violations were found:\n\n {formatted_violations}\n\n" + f"violations were found:\n\n {violations}\n" "ParamFunctions may only depend on parameters and scalars, not on " "ColumnObjects." ) From 601482ed0abecb6c487c799aa0159f53b24c29db Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 23 Jul 2025 06:54:26 +0200 Subject: [PATCH 6/6] Formatting. --- src/ttsim/interface_dag_elements/fail_if.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index c59e615d4..a645d707b 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -891,7 +891,7 @@ def param_function_depends_on_column_objects( if violations: msg = ( "ParamFunctions must not depend on ColumnObjects. The following " - f"violations were found:\n\n {violations}\n" + f"violations were found:\n\n{violations}\n" "ParamFunctions may only depend on parameters and scalars, not on " "ColumnObjects." )