Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_gettsim/einkommensteuer/abzüge/vorsorge.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def vorsorgeaufwendungen_keine_kappung_krankenversicherung_y_sn(
@param_function(start_date="2005-01-01", end_date="2022-12-31")
def rate_abzugsfähige_altersvorsorgeaufwendungen(
parameter_einführungsfaktor_altersvorsorgeaufwendungen: PiecewisePolynomialParamValue,
evaluation_year: int,
policy_year: int,
xnp: ModuleType,
) -> dict[str, Any]:
"""Calculate introductory factor for pension expense deductions which depends on the
Expand All @@ -214,7 +214,7 @@ def rate_abzugsfähige_altersvorsorgeaufwendungen(

"""
return piecewise_polynomial(
x=evaluation_year,
x=policy_year,
parameters=parameter_einführungsfaktor_altersvorsorgeaufwendungen,
xnp=xnp,
)
Expand Down
4 changes: 2 additions & 2 deletions src/_gettsim/lohnsteuer/einkommen.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def vorsorge_krankenversicherungsbeiträge_option_b_ab_2019(
@param_function(start_date="2005-01-01", end_date="2022-12-31")
def einführungsfaktor_rentenversicherungsaufwendungen(
parameter_einführungsfaktor_rentenversicherungsaufwendungen: PiecewisePolynomialParamValue,
evaluation_year: int,
policy_year: int,
xnp: ModuleType,
) -> dict[str, Any]:
"""Calculate introductory factor for pension expense deductions which depends on the
Expand All @@ -156,7 +156,7 @@ def einführungsfaktor_rentenversicherungsaufwendungen(

"""
return piecewise_polynomial(
x=evaluation_year,
x=policy_year,
parameters=parameter_einführungsfaktor_rentenversicherungsaufwendungen,
xnp=xnp,
)
Expand Down
3 changes: 3 additions & 0 deletions src/ttsim/interface_dag_elements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class FailIf(MainTargetABC):
non_convertible_objects_in_results_tree: str = (
"fail_if__non_convertible_objects_in_results_tree"
)
param_function_depends_on_column_objects: str = (
"fail_if__param_function_depends_on_column_objects"
)
paths_are_missing_in_targets_tree_mapper: str = (
"fail_if__paths_are_missing_in_targets_tree_mapper"
)
Expand Down
47 changes: 47 additions & 0 deletions src/ttsim/interface_dag_elements/fail_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy
import optree
import pandas as pd
from dags import get_free_arguments

from ttsim.interface_dag_elements.interface_node_objects import fail_function
from ttsim.interface_dag_elements.shared import get_name_of_group_by_id
Expand Down Expand Up @@ -849,3 +850,49 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A
)

return out


@fail_function()
def param_function_depends_on_column_objects(
specialized_environment__without_tree_logic_and_with_derived_functions: SpecEnvWithoutTreeLogicAndWithDerivedFunctions,
) -> None:
"""Fail if any ParamFunction depends on ColumnObject arguments.

Parameters
----------
specialized_environment__without_tree_logic_and_with_derived_functions
The specialized environment containing all functions and objects.

Raises
------
ValueError
If any ParamFunction has ColumnObject arguments.
"""
param_functions = {
name: obj
for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items()
if isinstance(obj, ParamFunction)
}

column_objects = {
name: obj
for name, obj in specialized_environment__without_tree_logic_and_with_derived_functions.items()
if isinstance(obj, ColumnObject)
}

violations = ""
for param_func_name, param_func in param_functions.items():
func_args = set(get_free_arguments(param_func.function))

for arg in func_args:
if arg in column_objects:
violations += f" `{param_func_name}` depends on `{arg}`\n"

if violations:
msg = (
"ParamFunctions must not depend on ColumnObjects. The following "
f"violations were found:\n\n{violations}\n"
"ParamFunctions may only depend on parameters and scalars, not on "
"ColumnObjects."
)
raise ValueError(msg)
2 changes: 1 addition & 1 deletion src/ttsim/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def cached_policy_environment(
policy_date=policy_date,
orig_policy_objects={"root": root},
backend=backend,
include_fail_nodes=False,
include_fail_nodes=True,
include_warn_nodes=False,
)

Expand Down
118 changes: 118 additions & 0 deletions tests/ttsim/interface_dag_elements/test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
input_df_has_bool_or_numeric_column_names,
input_df_mapper_columns_missing_in_df,
input_df_mapper_has_incorrect_format,
param_function_depends_on_column_objects,
paths_are_missing_in_targets_tree_mapper,
targets_are_not_in_specialized_environment_or_data,
)
Expand Down Expand Up @@ -1413,3 +1414,120 @@ def test_raise_some_error_without_input_data(
backend=backend,
orig_policy_objects={"root": METTSIM_ROOT},
)


@param_function()
def valid_param_function(x: int) -> int:
"""A valid param function that only depends on parameters."""
return x * 2


@param_function()
def invalid_param_function(some_policy_function: int) -> int:
"""An invalid param function that depends on a column object."""
return some_policy_function * 2


@policy_function()
def some_policy_function(x: int) -> int:
"""A policy function for testing."""
return x + 1


@policy_input()
def some_policy_input() -> int:
"""A policy input for testing."""


@pytest.mark.parametrize(
"specialized_environment",
[
# Valid environment with only param functions and no dependencies
{
"valid_param": valid_param_function,
},
# Valid environment with param functions and column objects but no dependencies
{
"valid_param": valid_param_function,
"some_policy_function": some_policy_function,
},
# Valid environment with mixed types but no violations
{
"valid_param": valid_param_function,
"some_policy_function": some_policy_function,
"policy_input": some_policy_input,
"some_scalar": 42,
"some_dict_param": _SOME_DICT_PARAM,
},
],
)
def test_param_function_depends_on_column_objects_passes(specialized_environment):
"""Test that valid environments pass the validation."""
param_function_depends_on_column_objects(specialized_environment)


@pytest.mark.parametrize(
("specialized_environment", "expected_error_match"),
[
(
{
"invalid_param": invalid_param_function,
"some_policy_function": some_policy_function,
},
"`invalid_param` depends on `some_policy_function`",
),
(
{
"invalid_param": invalid_param_function,
"some_policy_function": some_policy_input,
},
"`invalid_param` depends on `some_policy_function`",
),
(
{
"valid_param": valid_param_function,
"invalid_param": invalid_param_function,
"some_policy_function": some_policy_function,
},
"`invalid_param` depends on `some_policy_function`",
),
],
)
def test_param_function_depends_on_column_objects_raises(
specialized_environment, expected_error_match
):
"""Test that invalid environments raise the expected error."""
with pytest.raises(ValueError, match=expected_error_match):
param_function_depends_on_column_objects(specialized_environment)


def test_param_function_depends_on_column_objects_via_main(
backend: Literal["jax", "numpy"],
xnp: ModuleType,
):
"""Test that the param_function_depends_on_column_objects check works via main."""

with pytest.raises(
ValueError,
match="`invalid_param` depends on `some_policy_function`",
):
main(
policy_date_str="2025-01-01",
main_target=MainTarget.results.df_with_mapper,
tt_targets={
"tree": {
"invalid_param": None,
},
},
input_data={
"tree": {
"p_id": xnp.array([1, 2, 3]),
"x": xnp.array([1, 2, 3]),
},
},
backend=backend,
policy_environment={
"invalid_param": invalid_param_function,
"some_policy_function": some_policy_function,
},
)
Loading