diff --git a/CHANGES.md b/CHANGES.md index 4a87dedbf..981ac9d9b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,9 @@ releases are available on [Anaconda.org](https://anaconda.org/conda-forge/gettsi ## Unpublished +- {gh}`1050` Allow input template creation when path to root node traverses an + automatically created function ({ghuser}`MImmesberger`) + - {gh}`1048` Raise an error if invalid input paths are provided to main ({ghuser}`hmgaudecker`) diff --git a/src/ttsim/interface_dag_elements/templates.py b/src/ttsim/interface_dag_elements/templates.py index 3e5408df1..0c8c98cc6 100644 --- a/src/ttsim/interface_dag_elements/templates.py +++ b/src/ttsim/interface_dag_elements/templates.py @@ -4,7 +4,11 @@ import dags.tree as dt +from ttsim.interface_dag_elements.automatically_added_functions import TIME_UNIT_LABELS from ttsim.interface_dag_elements.interface_node_objects import interface_function +from ttsim.interface_dag_elements.shared import ( + get_re_pattern_for_all_time_units_and_groupings, +) from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput from ttsim.tt_dag_elements.vectorization import scalar_type_to_array_type @@ -23,8 +27,30 @@ def input_data_dtypes( specialized_environment__with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars, # noqa: E501 policy_environment: PolicyEnvironment, tt_targets__qname: OrderedQNames, + labels__grouping_levels: OrderedQNames, labels__top_level_namespace: UnorderedQNames, ) -> NestedInputStructureDict: + """ + A template of the required input data and their expected types. + + Parameters + ---------- + specialized_environment__with_partialled_params_and_scalars + The specialized environment with partialled parameters and scalars. + policy_environment + The policy environment containing functions and parameters. + tt_targets__qname + Ordered qualified names of the targets. + labels__grouping_levels + Ordered qualified names of grouping levels. + labels__top_level_namespace + Unordered qualified names of the top-level namespace. + + Returns + ------- + NestedInputStructureDict + A nested dictionary mapping input paths to their data types. + """ base_dtype_tree = dt.create_tree_with_input_types( functions=dt.unflatten_from_qnames( specialized_environment__with_partialled_params_and_scalars, @@ -33,18 +59,56 @@ def input_data_dtypes( top_level_inputs=labels__top_level_namespace, ) - # Replace dtypes of PolicyInputs that have the generic type 'FloatColumn | IntColumn - # | BoolColumn' with the actual dtype found in the policy environment. - flat_policy_env = dt.flatten_to_tree_paths(policy_environment) - flat_dtype_tree = dt.flatten_to_tree_paths(base_dtype_tree) - out = {} - for p, derived_dtype_in_base in flat_dtype_tree.items(): - policy_env_element = flat_policy_env[p] - if p[0] in {"evaluation_year", "evaluation_month", "evaluation_day"}: + qname_policy_env = dt.flatten_to_qnames(policy_environment) + qname_dtype_tree = dt.flatten_to_qnames(base_dtype_tree) + policy_inputs = { + k: v for k, v in qname_policy_env.items() if isinstance(v, PolicyInput) + } + + cleaned_qname_dtype_tree: dict[str, str] = {} + + pattern_all = get_re_pattern_for_all_time_units_and_groupings( + time_units=list(TIME_UNIT_LABELS), + grouping_levels=labels__grouping_levels, + ) + + for qn, derived_dtype_in_base in qname_dtype_tree.items(): + if qn in {"evaluation_year", "evaluation_month", "evaluation_day"}: continue - if isinstance(policy_env_element, PolicyInput) and "|" in derived_dtype_in_base: - out[p] = scalar_type_to_array_type(policy_env_element.data_type) + + match = pattern_all.fullmatch(qn) + base_name = match.group("base_name") + if ( + base_name not in qname_dtype_tree + and base_name not in cleaned_qname_dtype_tree + and base_name in policy_inputs + ): + # If some input data is provided, we create aggregation functions + # automatically only if the source node is part of the input data. Hence, if + # the user provides incomplete input data (i.e. some policy inputs are + # missing) and those policy inputs are sources of automatic aggregation + # functions, dt.create_tree_with_input_types will return the name of the + # aggregation function as root node. The policy input is not in the output. + # We take care of this here. + cleaned_qname_dtype_tree[base_name] = scalar_type_to_array_type( + policy_inputs[base_name].data_type + ) + + # Also add the ID of the grouped variable if grouping exists + grouping = match.group("grouping") + if grouping: + grouping_id = f"{grouping}_id" + if grouping_id not in cleaned_qname_dtype_tree: + cleaned_qname_dtype_tree[grouping_id] = "IntColumn" + + elif qn in policy_inputs: + # Replace dtypes of PolicyInputs that have the generic type 'FloatColumn | + # IntColumn | BoolColumn' with the actual dtype found in the policy + # environment. + cleaned_qname_dtype_tree[qn] = scalar_type_to_array_type( + policy_inputs[qn].data_type + ) else: - out[p] = derived_dtype_in_base + cleaned_qname_dtype_tree[qn] = derived_dtype_in_base - return dt.unflatten_from_tree_paths(out) + return dt.unflatten_from_qnames(cleaned_qname_dtype_tree) diff --git a/tests/ttsim/interface_dag_elements/test_templates.py b/tests/ttsim/interface_dag_elements/test_templates.py index c7b6c3acb..bb61ab488 100644 --- a/tests/ttsim/interface_dag_elements/test_templates.py +++ b/tests/ttsim/interface_dag_elements/test_templates.py @@ -52,6 +52,11 @@ ) +@policy_input() +def kin_id() -> int: + pass + + @policy_input() def inp1() -> int: pass @@ -63,8 +68,8 @@ def inp2() -> float: @policy_function() -def x(inp1: int, p1: int, p2: dict[str, int]) -> int: - return inp1 + p1 + p2["a"] + p2["b"] +def x(inp1_kin: int, p1: int, p2: dict[str, int]) -> int: + return inp1_kin + p1 + p2["a"] + p2["b"] @policy_function() @@ -81,6 +86,7 @@ def test_template_all_outputs_no_inputs(backend): actual = main( main_target="templates__input_data_dtypes", policy_environment={ + "kin_id": kin_id, "inp1": inp1, "p1": p1, "a": {"inp2": inp2, "x": x, "y": y, "p2": p2}, @@ -92,7 +98,11 @@ def test_template_all_outputs_no_inputs(backend): evaluation_date_str="2025-01-01", backend=backend, ) - assert actual == {"a": {"inp2": "FloatColumn"}, "inp1": "IntColumn"} + assert actual == { + "a": {"inp2": "FloatColumn"}, + "inp1": "IntColumn", + "kin_id": "IntColumn", + } def test_template_all_outputs_with_inputs(backend, xnp): @@ -108,6 +118,7 @@ def test_template_all_outputs_with_inputs(backend, xnp): } }, policy_environment={ + "kin_id": kin_id, "inp1": inp1, "p1": p1, "a": {"inp2": inp2, "x": x, "y": y, "p2": p2}, @@ -119,7 +130,11 @@ def test_template_all_outputs_with_inputs(backend, xnp): evaluation_date_str="2025-01-01", backend=backend, ) - assert actual == {"a": {"inp2": "FloatColumn"}, "inp1": "IntColumn"} + assert actual == { + "a": {"inp2": "FloatColumn"}, + "inp1": "IntColumn", + "kin_id": "IntColumn", + } def test_template_output_y_no_inputs(backend): @@ -127,6 +142,7 @@ def test_template_output_y_no_inputs(backend): main_target="templates__input_data_dtypes", tt_targets={"tree": {"a": {"y": None}}}, policy_environment={ + "kin_id": kin_id, "inp1": inp1, "p1": p1, "a": {"inp2": inp2, "x": x, "y": y, "p2": p2}, @@ -155,6 +171,34 @@ def test_template_output_x_with_inputs(backend, xnp): }, tt_targets={"tree": {"a": {"x": None}}}, policy_environment={ + "kin_id": kin_id, + "inp1": inp1, + "p1": p1, + "a": {"inp2": inp2, "x": x, "y": y, "p2": p2}, + "b": { + "z": z, + }, + }, + rounding=True, + evaluation_date_str="2025-01-01", + backend=backend, + ) + assert actual == {"inp1": "IntColumn", "kin_id": "IntColumn"} + + +def test_template_all_outputs_no_input_for_root_of_derived_function(backend, xnp): + actual = main( + main_target="templates__input_data_dtypes", + input_data={ + "tree": { + "p_id": xnp.array([4, 5, 6]), + "a": { + "inp2": xnp.array([1, 2, 3]), + }, + } + }, + policy_environment={ + "kin_id": kin_id, "inp1": inp1, "p1": p1, "a": {"inp2": inp2, "x": x, "y": y, "p2": p2}, @@ -166,4 +210,8 @@ def test_template_output_x_with_inputs(backend, xnp): evaluation_date_str="2025-01-01", backend=backend, ) - assert actual == {"inp1": "IntColumn"} + assert actual == { + "a": {"inp2": "FloatColumn"}, + "inp1": "IntColumn", + "kin_id": "IntColumn", + }