Skip to content

Commit 4bb98dc

Browse files
committed
Fix behavior: Make time unit required in time conversion, not in regex.
1 parent fe1cd94 commit 4bb98dc

File tree

5 files changed

+125
-20
lines changed

5 files changed

+125
-20
lines changed

src/ttsim/compute_taxes_and_transfers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,17 @@ def _get_top_level_namespace(
196196
if match := re_pattern.fullmatch(element):
197197
function_base_name = match.group("base_name")
198198
create_conversions_for_time_units = bool(match.group("time_unit"))
199-
200-
potential_derived_functions = potential_target_names_from_base_name(
201-
base_name=function_base_name,
202-
supported_time_conversions=supported_time_conversions,
203-
supported_groupings=supported_groupings,
204-
create_conversions_for_time_units=create_conversions_for_time_units,
205-
)
206-
potential_function_names.update(potential_derived_functions)
199+
else:
200+
function_base_name = element
201+
create_conversions_for_time_units = False
202+
203+
potential_derived_functions = potential_target_names_from_base_name(
204+
base_name=function_base_name,
205+
supported_time_conversions=supported_time_conversions,
206+
supported_groupings=supported_groupings,
207+
create_conversions_for_time_units=create_conversions_for_time_units,
208+
)
209+
potential_function_names.update(potential_derived_functions)
207210

208211
return potential_function_names
209212

src/ttsim/shared.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ def get_re_pattern_for_time_units_and_groupings(
4242
) -> re.Pattern:
4343
"""Get a regex pattern for time units and groupings.
4444
45-
The pattern is of the form:
46-
<base_name>_<time_unit>_<aggregation>
47-
where <base_name> is optional, <time_unit> is one of the supported time units, and
48-
<aggregation> is one of the supported groupings.
45+
The pattern matches strings in any of these formats:
46+
- <base_name> (can contain underscores)
47+
- <base_name>_<time_unit>
48+
- <base_name>_<aggregation>
49+
- <base_name>_<time_unit>_<aggregation>
4950
5051
Parameters
5152
----------
@@ -60,9 +61,12 @@ def get_re_pattern_for_time_units_and_groupings(
6061
The regex pattern.
6162
"""
6263
units = "".join(supported_time_units)
63-
groupings = "|".join([f"_{grouping}" for grouping in supported_groupings])
64+
groupings = "|".join(supported_groupings)
6465
return re.compile(
65-
f"(?P<base_name>.*_)(?P<time_unit>[{units}])(?P<aggregation>{groupings})?"
66+
f"(?P<base_name>.*?)"
67+
f"(?:_(?P<time_unit>[{units}]))?"
68+
f"(?:_(?P<aggregation>{groupings}))?"
69+
f"$"
6670
)
6771

6872

src/ttsim/time_conversion.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,17 +453,22 @@ def _create_time_conversion_functions(
453453
supported_groupings=SUPPORTED_GROUPINGS,
454454
supported_time_units=all_time_units,
455455
)
456+
456457
match = time_unit_pattern.fullmatch(name)
457-
dependencies = set(inspect.signature(func).parameters) if func else set()
458+
base_name = match.group("base_name") or ""
459+
time_unit = match.group("time_unit") or ""
460+
aggregation = match.group("aggregation") or ""
458461

459-
if match:
460-
base_name = match.group("base_name")
461-
time_unit = match.group("time_unit")
462-
aggregation = match.group("aggregation") or ""
462+
dependencies = set(inspect.signature(func).parameters) if func else set()
463463

464+
if match and time_unit:
464465
missing_time_units = [unit for unit in all_time_units if unit != time_unit]
465466
for missing_time_unit in missing_time_units:
466-
new_name = f"{base_name}{missing_time_unit}{aggregation}"
467+
new_name = (
468+
f"{base_name}_{missing_time_unit}{aggregation}"
469+
if aggregation
470+
else f"{base_name}_{missing_time_unit}"
471+
)
467472

468473
# Without this check, we could create cycles in the DAG: Consider a
469474
# hard-coded function `var_y` that takes `var_m` as an input, assuming it

tests/ttsim/test_compute_taxes_and_transfers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_fail_if_foreign_keys_are_invalid,
1717
_fail_if_group_variables_not_constant_within_groups,
1818
_fail_if_pid_is_non_unique,
19+
_get_top_level_namespace,
1920
_partial_parameters_to_functions,
2021
compute_taxes_and_transfers,
2122
)
@@ -699,3 +700,58 @@ def test_fail_if_cannot_be_converted_to_correct_type(
699700
def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr):
700701
with pytest.raises(TypeError, match=re.escape(err_substr)):
701702
assert_valid_ttsim_pytree(tree, leaf_checker, "tree")
703+
704+
705+
@pytest.mark.parametrize(
706+
(
707+
"environment",
708+
"supported_time_conversions",
709+
"supported_groupings",
710+
"expected",
711+
),
712+
[
713+
(
714+
PolicyEnvironment(
715+
raw_objects_tree={
716+
"foo_m": policy_function(leaf_name="foo_m")(lambda x: x)
717+
},
718+
aggregation_specs_tree={},
719+
),
720+
["m", "y"],
721+
["hh"],
722+
{"foo_m", "foo_y", "foo_m_hh", "foo_y_hh"},
723+
),
724+
(
725+
PolicyEnvironment(
726+
raw_objects_tree={"foo": policy_function(leaf_name="foo")(lambda x: x)},
727+
aggregation_specs_tree={},
728+
),
729+
["m", "y"],
730+
["hh"],
731+
{"foo", "foo_hh"},
732+
),
733+
(
734+
PolicyEnvironment(
735+
raw_objects_tree={},
736+
aggregation_specs_tree={
737+
"foo_hh": AggregateByGroupSpec(
738+
source="foo",
739+
aggr=AggregationType.SUM,
740+
),
741+
},
742+
),
743+
["m", "y"],
744+
["hh"],
745+
{"foo", "foo_hh"},
746+
),
747+
],
748+
)
749+
def test_get_top_level_namespace(
750+
environment, supported_time_conversions, supported_groupings, expected
751+
):
752+
result = _get_top_level_namespace(
753+
environment=environment,
754+
supported_time_conversions=supported_time_conversions,
755+
supported_groupings=supported_groupings,
756+
)
757+
assert result == expected

tests/ttsim/test_shared.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ttsim.shared import (
66
create_tree_from_path_and_value,
77
get_name_of_group_by_id,
8+
get_re_pattern_for_time_units_and_groupings,
89
insert_path_and_value,
910
merge_trees,
1011
partition_tree_by_reference_tree,
@@ -341,3 +342,39 @@ def test_potential_target_names_from_base_name(
341342
)
342343
== expected
343344
)
345+
346+
347+
@pytest.mark.parametrize(
348+
(
349+
"func_name",
350+
"supported_time_units",
351+
"supported_groupings",
352+
"expected_base_name",
353+
"expected_time_unit",
354+
"expected_aggregation",
355+
),
356+
[
357+
("foo", ["m", "y"], ["hh"], "foo", None, None),
358+
("foo_m_hh", ["m", "y"], ["hh"], "foo", "m", "hh"),
359+
("foo_y_hh", ["m", "y"], ["hh"], "foo", "y", "hh"),
360+
("foo_m", ["m", "y"], ["hh"], "foo", "m", None),
361+
("foo_y", ["m", "y"], ["hh"], "foo", "y", None),
362+
("foo_hh", ["m", "y"], ["hh"], "foo", None, "hh"),
363+
],
364+
)
365+
def test_get_re_pattern_for_time_units_and_groupings(
366+
func_name,
367+
supported_time_units,
368+
supported_groupings,
369+
expected_base_name,
370+
expected_time_unit,
371+
expected_aggregation,
372+
):
373+
result = get_re_pattern_for_time_units_and_groupings(
374+
supported_time_units=supported_time_units,
375+
supported_groupings=supported_groupings,
376+
)
377+
match = result.fullmatch(func_name)
378+
assert match.group("base_name") == expected_base_name
379+
assert match.group("time_unit") == expected_time_unit
380+
assert match.group("aggregation") == expected_aggregation

0 commit comments

Comments
 (0)