Skip to content

Commit fa84709

Browse files
Make top-level namespace include potential derived functions (#865)
### What problem do you want to solve? The top-level namespace did not account for derived functions that are created later. We now add place-holders when creating the top-level namespace for those derived functions. --------- Co-authored-by: Hans-Martin von Gaudecker <[email protected]>
1 parent 3468be9 commit fa84709

File tree

7 files changed

+416
-80
lines changed

7 files changed

+416
-80
lines changed

src/ttsim/compute_taxes_and_transfers.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from _gettsim.config import (
1414
DEFAULT_TARGETS,
1515
FOREIGN_KEYS,
16+
SUPPORTED_GROUPINGS,
1617
)
1718
from ttsim.combine_functions import (
1819
combine_policy_functions_and_derived_functions,
@@ -26,14 +27,17 @@
2627
)
2728
from ttsim.policy_environment import PolicyEnvironment
2829
from ttsim.shared import (
30+
all_variations_of_base_name,
2931
assert_valid_ttsim_pytree,
3032
format_errors_and_warnings,
3133
format_list_linewise,
3234
get_name_of_group_by_id,
3335
get_names_of_arguments_without_defaults,
36+
get_re_pattern_for_all_time_units_and_groupings,
3437
merge_trees,
3538
partition_by_reference_dict,
3639
)
40+
from ttsim.time_conversion import TIME_UNITS
3741
from ttsim.typing import (
3842
check_series_has_expected_type,
3943
convert_series_to_internal_type,
@@ -45,6 +49,7 @@
4549
NestedTargetDict,
4650
QualNameDataDict,
4751
QualNameTargetList,
52+
QualNameTTSIMFunctionDict,
4853
QualNameTTSIMObjectDict,
4954
)
5055

@@ -88,8 +93,10 @@ def compute_taxes_and_transfers(
8893
_fail_if_environment_not_valid(environment)
8994

9095
# Transform functions tree to qualified names dict with qualified arguments
91-
top_level_namespace = set(environment.raw_objects_tree.keys()) | set(
92-
environment.aggregation_specs_tree.keys()
96+
top_level_namespace = _get_top_level_namespace(
97+
environment=environment,
98+
supported_time_conversions=tuple(TIME_UNITS.keys()),
99+
supported_groupings=tuple(SUPPORTED_GROUPINGS.keys()),
93100
)
94101
functions = dt.functions_without_tree_logic(
95102
functions=environment.functions_tree, top_level_namespace=top_level_namespace
@@ -131,7 +138,7 @@ def compute_taxes_and_transfers(
131138

132139
# Remove unnecessary elements from user-provided data.
133140
input_data = _create_input_data_for_concatenated_function(
134-
data=data_with_correct_types,
141+
data=data,
135142
functions=functions_with_partialled_parameters,
136143
targets=targets,
137144
)
@@ -166,6 +173,48 @@ def compute_taxes_and_transfers(
166173
return result_tree
167174

168175

176+
def _get_top_level_namespace(
177+
environment: PolicyEnvironment,
178+
supported_time_conversions: tuple[str, ...],
179+
supported_groupings: tuple[str, ...],
180+
) -> set[str]:
181+
"""Get the top level namespace.
182+
183+
Parameters
184+
----------
185+
environment:
186+
The policy environment.
187+
188+
Returns
189+
-------
190+
top_level_namespace:
191+
The top level namespace.
192+
"""
193+
direct_top_level_names = set(environment.raw_objects_tree.keys()) | set(
194+
environment.aggregation_specs_tree.keys()
195+
)
196+
re_pattern = get_re_pattern_for_all_time_units_and_groupings(
197+
supported_groupings=supported_groupings,
198+
supported_time_units=supported_time_conversions,
199+
)
200+
201+
all_top_level_names = set()
202+
for name in direct_top_level_names:
203+
match = re_pattern.fullmatch(name)
204+
base_name = match.group("base_name")
205+
create_conversions_for_time_units = bool(match.group("time_unit"))
206+
207+
all_top_level_names_for_name = all_variations_of_base_name(
208+
base_name=base_name,
209+
supported_time_conversions=supported_time_conversions,
210+
supported_groupings=supported_groupings,
211+
create_conversions_for_time_units=create_conversions_for_time_units,
212+
)
213+
all_top_level_names.update(all_top_level_names_for_name)
214+
215+
return all_top_level_names
216+
217+
169218
def _convert_data_to_correct_types(
170219
data: QualNameDataDict, functions_overridden: QualNameTTSIMObjectDict
171220
) -> QualNameDataDict:

src/ttsim/function_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,6 @@ class DerivedTimeConversionFunction(TTSIMFunction):
379379
def __post_init__(self):
380380
if self.source is None:
381381
raise ValueError("The source must be specified.")
382-
if self.conversion_target is None:
383-
raise ValueError("The conversion target must be specified.")
384382

385383

386384
def _convert_and_validate_dates(

src/ttsim/shared.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
import itertools
45
import re
56
import textwrap
67
from typing import TYPE_CHECKING, Any, TypeVar
@@ -36,6 +37,132 @@ def validate_date_range(start: datetime.date, end: datetime.date):
3637
raise ValueError(f"The start date {start} must be before the end date {end}.")
3738

3839

40+
def get_re_pattern_for_all_time_units_and_groupings(
41+
supported_groupings: tuple[str, ...], supported_time_units: tuple[str, ...]
42+
) -> re.Pattern:
43+
"""Get a regex pattern for time units and groupings.
44+
45+
The pattern matches strings in any of these formats:
46+
- <base_name> (may contain underscores)
47+
- <base_name>_<time_unit>
48+
- <base_name>_<aggregation>
49+
- <base_name>_<time_unit>_<aggregation>
50+
51+
Parameters
52+
----------
53+
supported_groupings
54+
The supported groupings.
55+
supported_time_units
56+
The supported time units.
57+
58+
Returns
59+
-------
60+
pattern
61+
The regex pattern.
62+
"""
63+
units = "".join(supported_time_units)
64+
groupings = "|".join(supported_groupings)
65+
return re.compile(
66+
f"(?P<base_name>.*?)"
67+
f"(?:_(?P<time_unit>[{units}]))?"
68+
f"(?:_(?P<aggregation>{groupings}))?"
69+
f"$"
70+
)
71+
72+
73+
def get_re_pattern_for_specific_time_units_and_groupings(
74+
base_name: str,
75+
supported_time_units: tuple[str, ...],
76+
supported_groupings: tuple[str, ...],
77+
) -> re.Pattern:
78+
"""Get a regex for a specific base name with optional time unit and aggregation.
79+
80+
The pattern matches strings in any of these formats:
81+
- <specific_base_name>
82+
- <specific_base_name>_<time_unit>
83+
- <specific_base_name>_<aggregation>
84+
- <specific_base_name>_<time_unit>_<aggregation>
85+
86+
Parameters
87+
----------
88+
base_name
89+
The specific base name to match.
90+
supported_time_units
91+
The supported time units.
92+
supported_groupings
93+
The supported groupings.
94+
95+
Returns
96+
-------
97+
pattern
98+
The regex pattern.
99+
"""
100+
units = "".join(supported_time_units)
101+
groupings = "|".join(supported_groupings)
102+
return re.compile(
103+
f"(?P<base_name>{re.escape(base_name)})"
104+
f"(?:_(?P<time_unit>[{units}]))?"
105+
f"(?:_(?P<aggregation>{groupings}))?"
106+
f"$"
107+
)
108+
109+
110+
def all_variations_of_base_name(
111+
base_name: str,
112+
supported_time_conversions: list[str],
113+
supported_groupings: list[str],
114+
create_conversions_for_time_units: bool,
115+
) -> set[str]:
116+
"""Get possible derived function names given a base function name.
117+
118+
Examples
119+
--------
120+
>>> all_variations_of_base_name(
121+
base_name="income",
122+
supported_time_conversions=["y", "m"],
123+
supported_groupings=["hh"],
124+
create_conversions_for_time_units=True,
125+
)
126+
{'income_m', 'income_y', 'income_hh_y', 'income_hh_m'}
127+
128+
>>> all_variations_of_base_name(
129+
base_name="claims_benefits",
130+
supported_time_conversions=["y", "m"],
131+
supported_groupings=["hh"],
132+
create_conversions_for_time_units=False,
133+
)
134+
{'claims_benefits_hh'}
135+
136+
Parameters
137+
----------
138+
base_name
139+
The base function name.
140+
supported_time_conversions
141+
The supported time conversions.
142+
supported_groupings
143+
The supported groupings.
144+
create_conversions_for_time_units
145+
Whether to create conversions for time units.
146+
147+
Returns
148+
-------
149+
The names of all potential targets based on the base name.
150+
"""
151+
result = set()
152+
if create_conversions_for_time_units:
153+
for time_unit in supported_time_conversions:
154+
result.add(f"{base_name}_{time_unit}")
155+
for time_unit, aggregation in itertools.product(
156+
supported_time_conversions, supported_groupings
157+
):
158+
result.add(f"{base_name}_{time_unit}_{aggregation}")
159+
else:
160+
result.add(base_name)
161+
for aggregation in supported_groupings:
162+
result.add(f"{base_name}_{aggregation}")
163+
return result
164+
165+
39166
class KeyErrorMessage(str):
40167
"""Subclass str to allow for line breaks in KeyError messages."""
41168

0 commit comments

Comments
 (0)