|
13 | 13 | from _gettsim.config import ( |
14 | 14 | DEFAULT_TARGETS, |
15 | 15 | FOREIGN_KEYS, |
| 16 | + SUPPORTED_GROUPINGS, |
16 | 17 | ) |
17 | 18 | from ttsim.combine_functions import ( |
18 | 19 | combine_policy_functions_and_derived_functions, |
|
26 | 27 | ) |
27 | 28 | from ttsim.policy_environment import PolicyEnvironment |
28 | 29 | from ttsim.shared import ( |
| 30 | + all_variations_of_base_name, |
29 | 31 | assert_valid_ttsim_pytree, |
30 | 32 | format_errors_and_warnings, |
31 | 33 | format_list_linewise, |
|
34 | 36 | get_re_pattern_for_time_units_and_groupings, |
35 | 37 | merge_trees, |
36 | 38 | partition_by_reference_dict, |
37 | | - potential_target_names_from_base_name, |
38 | 39 | ) |
| 40 | +from ttsim.time_conversion import _TIME_UNITS |
39 | 41 | from ttsim.typing import ( |
40 | 42 | check_series_has_expected_type, |
41 | 43 | convert_series_to_internal_type, |
@@ -90,7 +92,11 @@ def compute_taxes_and_transfers( |
90 | 92 | _fail_if_environment_not_valid(environment) |
91 | 93 |
|
92 | 94 | # Transform functions tree to qualified names dict with qualified arguments |
93 | | - top_level_namespace = _get_top_level_namespace(environment) |
| 95 | + top_level_namespace = _get_top_level_namespace( |
| 96 | + environment=environment, |
| 97 | + supported_time_conversions=list(_TIME_UNITS.keys()), |
| 98 | + supported_groupings=list(SUPPORTED_GROUPINGS.keys()), |
| 99 | + ) |
94 | 100 | functions = dt.functions_without_tree_logic( |
95 | 101 | functions=environment.functions_tree, top_level_namespace=top_level_namespace |
96 | 102 | ) |
@@ -183,32 +189,29 @@ def _get_top_level_namespace( |
183 | 189 | top_level_namespace: |
184 | 190 | The top level namespace. |
185 | 191 | """ |
186 | | - names_from_environment = set(environment.raw_objects_tree.keys()) | set( |
| 192 | + direct_top_level_names = set(environment.raw_objects_tree.keys()) | set( |
187 | 193 | environment.aggregation_specs_tree.keys() |
188 | 194 | ) |
189 | | - potential_function_names = set() |
| 195 | + all_top_level_names = set() |
190 | 196 | re_pattern = get_re_pattern_for_time_units_and_groupings( |
191 | 197 | supported_groupings=supported_groupings, |
192 | 198 | supported_time_units=supported_time_conversions, |
193 | 199 | ) |
194 | 200 |
|
195 | | - for element in names_from_environment: |
196 | | - if match := re_pattern.fullmatch(element): |
197 | | - function_base_name = match.group("base_name") |
198 | | - create_conversions_for_time_units = bool(match.group("time_unit")) |
199 | | - else: |
200 | | - function_base_name = element |
201 | | - create_conversions_for_time_units = False |
| 201 | + for name in direct_top_level_names: |
| 202 | + match = re_pattern.fullmatch(name) |
| 203 | + function_base_name = match.group("base_name") |
| 204 | + create_conversions_for_time_units = bool(match.group("time_unit")) |
202 | 205 |
|
203 | | - potential_derived_functions = potential_target_names_from_base_name( |
| 206 | + all_top_level_names_for_name = all_variations_of_base_name( |
204 | 207 | base_name=function_base_name, |
205 | 208 | supported_time_conversions=supported_time_conversions, |
206 | 209 | supported_groupings=supported_groupings, |
207 | 210 | create_conversions_for_time_units=create_conversions_for_time_units, |
208 | 211 | ) |
209 | | - potential_function_names.update(potential_derived_functions) |
| 212 | + all_top_level_names.update(all_top_level_names_for_name) |
210 | 213 |
|
211 | | - return potential_function_names |
| 214 | + return all_top_level_names |
212 | 215 |
|
213 | 216 |
|
214 | 217 | def _convert_data_to_correct_types( |
|
0 commit comments