Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions src/ttsim/interface_dag_elements/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import dags.tree as dt

from ttsim.interface_dag_elements.automatically_added_functions import TIME_UNIT_LABELS
from ttsim.interface_dag_elements.interface_node_objects import interface_function
from ttsim.interface_dag_elements.shared import (
get_re_pattern_for_all_time_units_and_groupings,
)
from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput
from ttsim.tt_dag_elements.vectorization import scalar_type_to_array_type

Expand All @@ -23,8 +27,30 @@ def input_data_dtypes(
specialized_environment__with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars, # noqa: E501
policy_environment: PolicyEnvironment,
tt_targets__qname: OrderedQNames,
labels__grouping_levels: OrderedQNames,
labels__top_level_namespace: UnorderedQNames,
) -> NestedInputStructureDict:
"""
A template of the required input data and their expected types.

Parameters
----------
specialized_environment__with_partialled_params_and_scalars
The specialized environment with partialled parameters and scalars.
policy_environment
The policy environment containing functions and parameters.
tt_targets__qname
Ordered qualified names of the targets.
labels__grouping_levels
Ordered qualified names of grouping levels.
labels__top_level_namespace
Unordered qualified names of the top-level namespace.

Returns
-------
NestedInputStructureDict
A nested dictionary mapping input paths to their data types.
"""
base_dtype_tree = dt.create_tree_with_input_types(
functions=dt.unflatten_from_qnames(
specialized_environment__with_partialled_params_and_scalars,
Expand All @@ -33,18 +59,56 @@ def input_data_dtypes(
top_level_inputs=labels__top_level_namespace,
)

# Replace dtypes of PolicyInputs that have the generic type 'FloatColumn | IntColumn
# | BoolColumn' with the actual dtype found in the policy environment.
flat_policy_env = dt.flatten_to_tree_paths(policy_environment)
flat_dtype_tree = dt.flatten_to_tree_paths(base_dtype_tree)
out = {}
for p, derived_dtype_in_base in flat_dtype_tree.items():
policy_env_element = flat_policy_env[p]
if p[0] in {"evaluation_year", "evaluation_month", "evaluation_day"}:
qname_policy_env = dt.flatten_to_qnames(policy_environment)
qname_dtype_tree = dt.flatten_to_qnames(base_dtype_tree)
policy_inputs = {
k: v for k, v in qname_policy_env.items() if isinstance(v, PolicyInput)
}

cleaned_qname_dtype_tree: dict[str, str] = {}

pattern_all = get_re_pattern_for_all_time_units_and_groupings(
time_units=list(TIME_UNIT_LABELS),
grouping_levels=labels__grouping_levels,
)

for qn, derived_dtype_in_base in qname_dtype_tree.items():
if qn in {"evaluation_year", "evaluation_month", "evaluation_day"}:
continue
if isinstance(policy_env_element, PolicyInput) and "|" in derived_dtype_in_base:
out[p] = scalar_type_to_array_type(policy_env_element.data_type)

match = pattern_all.fullmatch(qn)
base_name = match.group("base_name")
if (
base_name not in qname_dtype_tree
and base_name not in cleaned_qname_dtype_tree
and base_name in policy_inputs
):
# If some input data is provided, we create aggregation functions
# automatically only if the source node is part of the input data. Hence, if
# the user provides incomplete input data (i.e. some policy inputs are
# missing) and those policy inputs are sources of automatic aggregation
# functions, dt.create_tree_with_input_types will return the name of the
# aggregation function as root node. The policy input is not in the output.
# We take care of this here.
cleaned_qname_dtype_tree[base_name] = scalar_type_to_array_type(
policy_inputs[base_name].data_type
)

# Also add the ID of the grouped variable if grouping exists
grouping = match.group("grouping")
if grouping:
grouping_id = f"{grouping}_id"
if grouping_id not in cleaned_qname_dtype_tree:
cleaned_qname_dtype_tree[grouping_id] = "IntColumn"

elif qn in policy_inputs:
# Replace dtypes of PolicyInputs that have the generic type 'FloatColumn |
# IntColumn | BoolColumn' with the actual dtype found in the policy
# environment.
cleaned_qname_dtype_tree[qn] = scalar_type_to_array_type(
policy_inputs[qn].data_type
)
else:
out[p] = derived_dtype_in_base
cleaned_qname_dtype_tree[qn] = derived_dtype_in_base

return dt.unflatten_from_tree_paths(out)
return dt.unflatten_from_qnames(cleaned_qname_dtype_tree)
58 changes: 53 additions & 5 deletions tests/ttsim/interface_dag_elements/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
)


@policy_input()
def kin_id() -> int:
pass


@policy_input()
def inp1() -> int:
pass
Expand All @@ -63,8 +68,8 @@ def inp2() -> float:


@policy_function()
def x(inp1: int, p1: int, p2: dict[str, int]) -> int:
return inp1 + p1 + p2["a"] + p2["b"]
def x(inp1_kin: int, p1: int, p2: dict[str, int]) -> int:
return inp1_kin + p1 + p2["a"] + p2["b"]


@policy_function()
Expand All @@ -81,6 +86,7 @@ def test_template_all_outputs_no_inputs(backend):
actual = main(
main_target="templates__input_data_dtypes",
policy_environment={
"kin_id": kin_id,
"inp1": inp1,
"p1": p1,
"a": {"inp2": inp2, "x": x, "y": y, "p2": p2},
Expand All @@ -92,7 +98,11 @@ def test_template_all_outputs_no_inputs(backend):
evaluation_date_str="2025-01-01",
backend=backend,
)
assert actual == {"a": {"inp2": "FloatColumn"}, "inp1": "IntColumn"}
assert actual == {
"a": {"inp2": "FloatColumn"},
"inp1": "IntColumn",
"kin_id": "IntColumn",
}


def test_template_all_outputs_with_inputs(backend, xnp):
Expand All @@ -108,6 +118,7 @@ def test_template_all_outputs_with_inputs(backend, xnp):
}
},
policy_environment={
"kin_id": kin_id,
"inp1": inp1,
"p1": p1,
"a": {"inp2": inp2, "x": x, "y": y, "p2": p2},
Expand All @@ -119,14 +130,19 @@ def test_template_all_outputs_with_inputs(backend, xnp):
evaluation_date_str="2025-01-01",
backend=backend,
)
assert actual == {"a": {"inp2": "FloatColumn"}, "inp1": "IntColumn"}
assert actual == {
"a": {"inp2": "FloatColumn"},
"inp1": "IntColumn",
"kin_id": "IntColumn",
}


def test_template_output_y_no_inputs(backend):
actual = main(
main_target="templates__input_data_dtypes",
tt_targets={"tree": {"a": {"y": None}}},
policy_environment={
"kin_id": kin_id,
"inp1": inp1,
"p1": p1,
"a": {"inp2": inp2, "x": x, "y": y, "p2": p2},
Expand Down Expand Up @@ -155,6 +171,34 @@ def test_template_output_x_with_inputs(backend, xnp):
},
tt_targets={"tree": {"a": {"x": None}}},
policy_environment={
"kin_id": kin_id,
"inp1": inp1,
"p1": p1,
"a": {"inp2": inp2, "x": x, "y": y, "p2": p2},
"b": {
"z": z,
},
},
rounding=True,
evaluation_date_str="2025-01-01",
backend=backend,
)
assert actual == {"inp1": "IntColumn", "kin_id": "IntColumn"}


def test_template_all_outputs_no_input_for_root_of_derived_function(backend, xnp):
actual = main(
main_target="templates__input_data_dtypes",
input_data={
"tree": {
"p_id": xnp.array([4, 5, 6]),
"a": {
"inp2": xnp.array([1, 2, 3]),
},
}
},
policy_environment={
"kin_id": kin_id,
"inp1": inp1,
"p1": p1,
"a": {"inp2": inp2, "x": x, "y": y, "p2": p2},
Expand All @@ -166,4 +210,8 @@ def test_template_output_x_with_inputs(backend, xnp):
evaluation_date_str="2025-01-01",
backend=backend,
)
assert actual == {"inp1": "IntColumn"}
assert actual == {
"a": {"inp2": "FloatColumn"},
"inp1": "IntColumn",
"kin_id": "IntColumn",
}
Loading