Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ttsim/interface_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def main(
"""
Main function that processes the inputs and returns the outputs.
"""

input_qnames = _harmonize_inputs(locals())
if main_target is not None:
if main_targets is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/ttsim/interface_dag_elements/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ def flat_from_df_with_nested_columns(
include_if_all_inputs_present=["input_data__tree"],
leaf_name="flat",
)
def flat_from_tree(tree: NestedData) -> FlatData:
def flat_from_tree(tree: NestedData, xnp: ModuleType) -> FlatData: # noqa: ARG001
"""The input DataFrame as a flattened data structure.

Args:
tree:
The input tree.
xnp:
The backend to use, just put here so that fail_if.input_data_tree_is_invalid
runs before this.

Returns
-------
Expand Down
231 changes: 204 additions & 27 deletions tests/ttsim/interface_dag_elements/test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_ParamWithActivePeriod,
active_periods_overlap,
assert_valid_ttsim_pytree,
environment_is_invalid,
foreign_keys_are_invalid_in_data,
group_ids_are_outside_top_level_namespace,
group_variables_are_not_constant_within_groups,
Expand Down Expand Up @@ -649,32 +650,6 @@ def test_fail_if_group_variables_are_not_constant_within_groups():
)


def test_fail_if_p_id_is_missing(xnp):
data = {("fam_id",): xnp.array([1, 2, 3])}

with pytest.raises(
ValueError,
match="The input data must contain the `p_id` column.",
):
input_data_is_invalid(data)


def test_fail_if_p_id_is_missing_via_main(backend):
data = {"fam_id": pd.Series([1, 2, 3], name="fam_id")}
with pytest.raises(
ValueError,
match="The input data must contain the `p_id` column.",
):
main(
main_target="fail_if__input_data_is_invalid",
input_data={"tree": data},
policy_environment={},
tt_targets={"tree": {}},
rounding=False,
backend=backend,
)


@pytest.mark.parametrize(
"df",
[
Expand Down Expand Up @@ -831,7 +806,7 @@ def test_fail_if_p_id_does_not_exist(xnp):
input_data_is_invalid(data)


def test_fail_if_p_id_does_not_exist_via_main(backend):
def test_fail_if_p_id_is_missing_via_main(backend):
data = {"fam_id": pd.Series([1, 2, 3], name="fam_id")}
with pytest.raises(
ValueError,
Expand All @@ -842,6 +817,7 @@ def test_fail_if_p_id_does_not_exist_via_main(backend):
input_data={"tree": data},
policy_environment={},
tt_targets={"tree": {}},
date=datetime.date(2025, 1, 1),
rounding=False,
backend=backend,
)
Expand Down Expand Up @@ -870,6 +846,7 @@ def test_fail_if_p_id_is_not_unique_via_main(minimal_input_data, backend):
input_data={"tree": data},
policy_environment={},
tt_targets={"tree": {}},
date=datetime.date(2025, 1, 1),
rounding=False,
backend=backend,
)
Expand Down Expand Up @@ -920,6 +897,7 @@ def test_fail_if_input_data_has_different_lengths(backend):
input_data={"tree": data},
policy_environment={},
tt_targets={"tree": {}},
date=datetime.date(2025, 1, 1),
rounding=False,
backend=backend,
)
Expand Down Expand Up @@ -1028,6 +1006,7 @@ def test_fail_if_targets_are_not_in_specialized_environment_or_data_via_main(
input_data={"tree": minimal_input_data},
policy_environment={},
tt_targets={"tree": {"unknown_target": None}},
date=datetime.date(2025, 1, 1),
rounding=False,
backend=backend,
)
Expand Down Expand Up @@ -1188,3 +1167,201 @@ def test_fail_if_input_df_mapper_columns_missing_in_df_via_main(
date_str="2025-01-01",
backend=backend,
)


@pytest.mark.parametrize(
(
"tt_targets__tree",
"match",
),
[
(
{
1: None,
"number_of_individuals_kin": None,
},
"Key 1 in tt_targets__tree must be a string but",
),
(
{
"number_of_individuals_kin": 1,
},
r"Leaf at tt_targets__tree\[number_of_individuals_kin\] is invalid",
),
(
["number_of_individuals_kin"],
"tt_targets__tree must be a dict, got",
),
(
"number_of_individuals_kin",
"tt_targets__tree must be a dict, got",
),
],
)
def test_invalid_tt_targets_tree(
tt_targets__tree,
match,
backend: Literal["jax", "numpy"],
xnp: ModuleType,
minimal_data_tree,
):
with pytest.raises(TypeError, match=match):
main(
main_target=MainTarget.results.df_with_nested_columns,
backend=backend,
input_data=InputData.tree(
tree={
**minimal_data_tree,
"kin_id": xnp.array([0, 1, 2]),
}
),
orig_policy_objects={"root": METTSIM_ROOT},
date_str="2025-01-01",
tt_targets={"tree": tt_targets__tree},
)


@pytest.mark.parametrize(
(
"input_data_tree",
"match",
),
[
(
{
"number_of_individuals_kin": [1],
},
r"Leaf at input_data__tree\[number_of_individuals_kin\] is invalid",
),
(
{"number_of_individuals_kin": "1"},
r"Leaf at input_data__tree\[number_of_individuals_kin\] is invalid",
),
],
)
def test_invalid_input_data_tree_via_main(
input_data_tree, match, backend: Literal["jax", "numpy"], xnp: ModuleType
):
input_data_tree_with_p_id = {
**input_data_tree,
"p_id": xnp.array([2]),
}
with pytest.raises(TypeError, match=match):
main(
main_target=MainTarget.results.df_with_nested_columns,
backend=backend,
input_data=InputData.tree(tree=input_data_tree_with_p_id),
orig_policy_objects={"root": METTSIM_ROOT},
date_str="2025-01-01",
)


@pytest.mark.parametrize(
(
"policy_environment",
"match",
),
[
(
{
"invalid_leaf": 42,
},
r"Leaf at policy_environment\[invalid_leaf\] is invalid",
),
(
{
"nested": {
"invalid_leaf": "not_allowed_string",
},
},
r"Leaf at policy_environment\[nested\]\[invalid_leaf\] is invalid",
),
(
{
"nested": {
"another_invalid": [1, 2, 3],
},
},
r"Leaf at policy_environment\[nested\]\[another_invalid\] is invalid",
),
(
{
"nested": {
"yet_another": {"dict": "not_allowed"},
},
},
r"Leaf at policy_environment\[nested\]\[yet_another\]\[dict\] is invalid",
),
(
{
1: "valid_string",
},
"Key 1 in policy_environment must be a string but",
),
(
["not_a_dict"],
"policy_environment must be a dict, got",
),
],
)
def test_fail_if_environment_is_invalid(policy_environment, match):
with pytest.raises(TypeError, match=match):
environment_is_invalid(policy_environment)


@pytest.mark.parametrize(
"policy_environment",
[
# Valid environment with policy functions
{
"valid_func": policy_function(leaf_name="valid_func")(identity),
"another_func": policy_function(leaf_name="another_func")(return_one),
},
# Valid environment with param functions
{
"valid_param": some_param_func_returning_array_of_length_2,
"another_param": some_param_func_returning_list_of_length_2,
},
# Valid environment with param objects
{
"valid_param_obj": _SOME_DICT_PARAM,
"another_param_obj": _SOME_PIECEWISE_POLYNOMIAL_PARAM,
},
# Valid environment with module types
{
"numpy_module": numpy,
"jax_string": "jax",
"numpy_string": "numpy",
},
# Valid environment with nested structure
{
"nested": {
"valid_func": policy_function(leaf_name="nested_func")(identity),
"valid_param": some_param_func_returning_array_of_length_2,
},
"top_level": _SOME_DICT_PARAM,
},
# Valid environment with mixed types
{
"func": policy_function(leaf_name="func")(identity),
"param": some_param_func_returning_array_of_length_2,
"param_obj": _SOME_DICT_PARAM,
"module": numpy,
"backend": "jax",
},
],
)
def test_environment_is_invalid_passes(policy_environment):
"""Test that valid environments pass the validation."""
environment_is_invalid(policy_environment)


def test_invalid_input_data_as_object_via_main(backend: Literal["jax", "numpy"]):
with pytest.raises(TypeError, match="input_data__tree must be a dict, got"):
main(
main_target=MainTarget.results.df_with_nested_columns,
backend=backend,
input_data=InputData.tree(tree=object()),
orig_policy_objects={"root": METTSIM_ROOT},
date_str="2025-01-01",
)
3 changes: 3 additions & 0 deletions tests/ttsim/interface_dag_elements/test_warnings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import warnings

import pandas as pd
Expand Down Expand Up @@ -35,6 +36,7 @@ def test_warn_if_functions_and_data_columns_overlap(backend):
"some_target": another_func,
},
tt_targets={"tree": {"some_target": None}},
date=datetime.date(2025, 1, 1),
rounding=False,
include_fail_nodes=False,
backend=backend,
Expand All @@ -61,6 +63,7 @@ def test_warn_if_functions_and_columns_overlap_no_warning_if_no_overlap(backend)
},
policy_environment={"some_func": some_func},
tt_targets={"tree": {"some_func": None}},
date=datetime.date(2025, 1, 1),
rounding=False,
include_fail_nodes=False,
backend=backend,
Expand Down
Loading