diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" index 2dddce5af..6c2dbf4a7 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" @@ -200,7 +200,7 @@ def vorsorgeaufwendungen_keine_kappung_krankenversicherung_y_sn( @param_function(start_date="2005-01-01", end_date="2022-12-31") def rate_abzugsfähige_altersvorsorgeaufwendungen( parameter_einführungsfaktor_altersvorsorgeaufwendungen: PiecewisePolynomialParamValue, - evaluation_year: int, + policy_year: int, xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the @@ -214,7 +214,7 @@ def rate_abzugsfähige_altersvorsorgeaufwendungen( """ return piecewise_polynomial( - x=evaluation_year, + x=policy_year, parameters=parameter_einführungsfaktor_altersvorsorgeaufwendungen, xnp=xnp, ) diff --git a/src/_gettsim/lohnsteuer/einkommen.py b/src/_gettsim/lohnsteuer/einkommen.py index 73c81ebc4..57a270269 100644 --- a/src/_gettsim/lohnsteuer/einkommen.py +++ b/src/_gettsim/lohnsteuer/einkommen.py @@ -142,7 +142,7 @@ def vorsorge_krankenversicherungsbeiträge_option_b_ab_2019( @param_function(start_date="2005-01-01", end_date="2022-12-31") def einführungsfaktor_rentenversicherungsaufwendungen( parameter_einführungsfaktor_rentenversicherungsaufwendungen: PiecewisePolynomialParamValue, - evaluation_year: int, + policy_year: int, xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the @@ -156,7 +156,7 @@ def einführungsfaktor_rentenversicherungsaufwendungen( """ return piecewise_polynomial( - x=evaluation_year, + x=policy_year, parameters=parameter_einführungsfaktor_rentenversicherungsaufwendungen, xnp=xnp, ) diff --git a/src/ttsim/interface_dag_elements/__init__.py b/src/ttsim/interface_dag_elements/__init__.py index 2c30e9a25..2c11c7e12 100644 --- a/src/ttsim/interface_dag_elements/__init__.py +++ b/src/ttsim/interface_dag_elements/__init__.py @@ -54,6 +54,9 @@ class FailIf(MainTargetABC): non_convertible_objects_in_results_tree: str = ( "fail_if__non_convertible_objects_in_results_tree" ) + param_function_depends_on_column_objects: str = ( + "fail_if__param_function_depends_on_column_objects" + ) paths_are_missing_in_targets_tree_mapper: str = ( "fail_if__paths_are_missing_in_targets_tree_mapper" ) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 6c31d3a73..a645d707b 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -12,6 +12,7 @@ import numpy import optree import pandas as pd +from dags import get_free_arguments from ttsim.interface_dag_elements.interface_node_objects import fail_function from ttsim.interface_dag_elements.shared import get_name_of_group_by_id @@ -849,3 +850,49 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A ) return out + + +@fail_function() +def param_function_depends_on_column_objects( + specialized_environment__without_tree_logic_and_with_derived_functions: SpecEnvWithoutTreeLogicAndWithDerivedFunctions, +) -> None: + """Fail if any ParamFunction depends on ColumnObject arguments. + + Parameters + ---------- + specialized_environment__without_tree_logic_and_with_derived_functions + The specialized environment containing all functions and objects. + + Raises + ------ + ValueError + If any ParamFunction has ColumnObject arguments. + """ + param_functions = { + name: obj + for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items() + if isinstance(obj, ParamFunction) + } + + column_objects = { + name: obj + for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items() + if isinstance(obj, ColumnObject) + } + + violations = "" + for param_func_name, param_func in param_functions.items(): + func_args = set(get_free_arguments(param_func.function)) + + for arg in func_args: + if arg in column_objects: + violations += f" `{param_func_name}` depends on `{arg}`\n" + + if violations: + msg = ( + "ParamFunctions must not depend on ColumnObjects. The following " + f"violations were found:\n\n{violations}\n" + "ParamFunctions may only depend on parameters and scalars, not on " + "ColumnObjects." + ) + raise ValueError(msg) diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index c662dd0a0..acad2c4b4 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -48,7 +48,7 @@ def cached_policy_environment( policy_date=policy_date, orig_policy_objects={"root": root}, backend=backend, - include_fail_nodes=False, + include_fail_nodes=True, include_warn_nodes=False, ) diff --git a/tests/ttsim/interface_dag_elements/test_failures.py b/tests/ttsim/interface_dag_elements/test_failures.py index 1481664dc..2508ba7ea 100644 --- a/tests/ttsim/interface_dag_elements/test_failures.py +++ b/tests/ttsim/interface_dag_elements/test_failures.py @@ -26,6 +26,7 @@ input_df_has_bool_or_numeric_column_names, input_df_mapper_columns_missing_in_df, input_df_mapper_has_incorrect_format, + param_function_depends_on_column_objects, paths_are_missing_in_targets_tree_mapper, targets_are_not_in_specialized_environment_or_data, ) @@ -1413,3 +1414,120 @@ def test_raise_some_error_without_input_data( backend=backend, orig_policy_objects={"root": METTSIM_ROOT}, ) + + +@param_function() +def valid_param_function(x: int) -> int: + """A valid param function that only depends on parameters.""" + return x * 2 + + +@param_function() +def invalid_param_function(some_policy_function: int) -> int: + """An invalid param function that depends on a column object.""" + return some_policy_function * 2 + + +@policy_function() +def some_policy_function(x: int) -> int: + """A policy function for testing.""" + return x + 1 + + +@policy_input() +def some_policy_input() -> int: + """A policy input for testing.""" + + +@pytest.mark.parametrize( + "specialized_environment", + [ + # Valid environment with only param functions and no dependencies + { + "valid_param": valid_param_function, + }, + # Valid environment with param functions and column objects but no dependencies + { + "valid_param": valid_param_function, + "some_policy_function": some_policy_function, + }, + # Valid environment with mixed types but no violations + { + "valid_param": valid_param_function, + "some_policy_function": some_policy_function, + "policy_input": some_policy_input, + "some_scalar": 42, + "some_dict_param": _SOME_DICT_PARAM, + }, + ], +) +def test_param_function_depends_on_column_objects_passes(specialized_environment): + """Test that valid environments pass the validation.""" + param_function_depends_on_column_objects(specialized_environment) + + +@pytest.mark.parametrize( + ("specialized_environment", "expected_error_match"), + [ + ( + { + "invalid_param": invalid_param_function, + "some_policy_function": some_policy_function, + }, + "`invalid_param` depends on `some_policy_function`", + ), + ( + { + "invalid_param": invalid_param_function, + "some_policy_function": some_policy_input, + }, + "`invalid_param` depends on `some_policy_function`", + ), + ( + { + "valid_param": valid_param_function, + "invalid_param": invalid_param_function, + "some_policy_function": some_policy_function, + }, + "`invalid_param` depends on `some_policy_function`", + ), + ], +) +def test_param_function_depends_on_column_objects_raises( + specialized_environment, expected_error_match +): + """Test that invalid environments raise the expected error.""" + with pytest.raises(ValueError, match=expected_error_match): + param_function_depends_on_column_objects(specialized_environment) + + +def test_param_function_depends_on_column_objects_via_main( + backend: Literal["jax", "numpy"], + xnp: ModuleType, +): + """Test that the param_function_depends_on_column_objects check works via main.""" + + with pytest.raises( + ValueError, + match="`invalid_param` depends on `some_policy_function`", + ): + main( + policy_date_str="2025-01-01", + main_target=MainTarget.results.df_with_mapper, + tt_targets={ + "tree": { + "invalid_param": None, + }, + }, + input_data={ + "tree": { + "p_id": xnp.array([1, 2, 3]), + "x": xnp.array([1, 2, 3]), + }, + }, + backend=backend, + policy_environment={ + "invalid_param": invalid_param_function, + "some_policy_function": some_policy_function, + }, + )