Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions src/ttsim/compute_taxes_and_transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from _gettsim.config import (
DEFAULT_TARGETS,
FOREIGN_KEYS,
SUPPORTED_GROUPINGS,
)
from ttsim.combine_functions import (
combine_policy_functions_and_derived_functions,
Expand All @@ -26,14 +27,17 @@
)
from ttsim.policy_environment import PolicyEnvironment
from ttsim.shared import (
all_variations_of_base_name,
assert_valid_ttsim_pytree,
format_errors_and_warnings,
format_list_linewise,
get_name_of_group_by_id,
get_names_of_arguments_without_defaults,
get_re_pattern_for_all_time_units_and_groupings,
merge_trees,
partition_by_reference_dict,
)
from ttsim.time_conversion import TIME_UNITS
from ttsim.typing import (
check_series_has_expected_type,
convert_series_to_internal_type,
Expand All @@ -45,6 +49,7 @@
NestedTargetDict,
QualNameDataDict,
QualNameTargetList,
QualNameTTSIMFunctionDict,
QualNameTTSIMObjectDict,
)

Expand Down Expand Up @@ -88,8 +93,10 @@ def compute_taxes_and_transfers(
_fail_if_environment_not_valid(environment)

# Transform functions tree to qualified names dict with qualified arguments
top_level_namespace = set(environment.raw_objects_tree.keys()) | set(
environment.aggregation_specs_tree.keys()
top_level_namespace = _get_top_level_namespace(
environment=environment,
supported_time_conversions=tuple(TIME_UNITS.keys()),
supported_groupings=tuple(SUPPORTED_GROUPINGS.keys()),
)
functions = dt.functions_without_tree_logic(
functions=environment.functions_tree, top_level_namespace=top_level_namespace
Expand Down Expand Up @@ -131,7 +138,7 @@ def compute_taxes_and_transfers(

# Remove unnecessary elements from user-provided data.
input_data = _create_input_data_for_concatenated_function(
data=data_with_correct_types,
data=data,
functions=functions_with_partialled_parameters,
targets=targets,
)
Expand Down Expand Up @@ -166,6 +173,48 @@ def compute_taxes_and_transfers(
return result_tree


def _get_top_level_namespace(
environment: PolicyEnvironment,
supported_time_conversions: tuple[str, ...],
supported_groupings: tuple[str, ...],
) -> set[str]:
"""Get the top level namespace.

Parameters
----------
environment:
The policy environment.

Returns
-------
top_level_namespace:
The top level namespace.
"""
direct_top_level_names = set(environment.raw_objects_tree.keys()) | set(
environment.aggregation_specs_tree.keys()
)
re_pattern = get_re_pattern_for_all_time_units_and_groupings(
supported_groupings=supported_groupings,
supported_time_units=supported_time_conversions,
)

all_top_level_names = set()
for name in direct_top_level_names:
match = re_pattern.fullmatch(name)
base_name = match.group("base_name")
create_conversions_for_time_units = bool(match.group("time_unit"))

all_top_level_names_for_name = all_variations_of_base_name(
base_name=base_name,
supported_time_conversions=supported_time_conversions,
supported_groupings=supported_groupings,
create_conversions_for_time_units=create_conversions_for_time_units,
)
all_top_level_names.update(all_top_level_names_for_name)

return all_top_level_names


def _convert_data_to_correct_types(
data: QualNameDataDict, functions_overridden: QualNameTTSIMObjectDict
) -> QualNameDataDict:
Expand Down
2 changes: 0 additions & 2 deletions src/ttsim/function_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,6 @@ class DerivedTimeConversionFunction(TTSIMFunction):
def __post_init__(self):
if self.source is None:
raise ValueError("The source must be specified.")
if self.conversion_target is None:
raise ValueError("The conversion target must be specified.")


def _convert_and_validate_dates(
Expand Down
127 changes: 127 additions & 0 deletions src/ttsim/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import itertools
import re
import textwrap
from typing import TYPE_CHECKING, Any, TypeVar
Expand Down Expand Up @@ -36,6 +37,132 @@ def validate_date_range(start: datetime.date, end: datetime.date):
raise ValueError(f"The start date {start} must be before the end date {end}.")


def get_re_pattern_for_all_time_units_and_groupings(
supported_groupings: tuple[str, ...], supported_time_units: tuple[str, ...]
) -> re.Pattern:
"""Get a regex pattern for time units and groupings.

The pattern matches strings in any of these formats:
- <base_name> (may contain underscores)
- <base_name>_<time_unit>
- <base_name>_<aggregation>
- <base_name>_<time_unit>_<aggregation>

Parameters
----------
supported_groupings
The supported groupings.
supported_time_units
The supported time units.

Returns
-------
pattern
The regex pattern.
"""
units = "".join(supported_time_units)
groupings = "|".join(supported_groupings)
return re.compile(
f"(?P<base_name>.*?)"
f"(?:_(?P<time_unit>[{units}]))?"
f"(?:_(?P<aggregation>{groupings}))?"
f"$"
)


def get_re_pattern_for_specific_time_units_and_groupings(
base_name: str,
supported_time_units: tuple[str, ...],
supported_groupings: tuple[str, ...],
) -> re.Pattern:
"""Get a regex for a specific base name with optional time unit and aggregation.

The pattern matches strings in any of these formats:
- <specific_base_name>
- <specific_base_name>_<time_unit>
- <specific_base_name>_<aggregation>
- <specific_base_name>_<time_unit>_<aggregation>

Parameters
----------
base_name
The specific base name to match.
supported_time_units
The supported time units.
supported_groupings
The supported groupings.

Returns
-------
pattern
The regex pattern.
"""
units = "".join(supported_time_units)
groupings = "|".join(supported_groupings)
return re.compile(
f"(?P<base_name>{re.escape(base_name)})"
f"(?:_(?P<time_unit>[{units}]))?"
f"(?:_(?P<aggregation>{groupings}))?"
f"$"
)


def all_variations_of_base_name(
base_name: str,
supported_time_conversions: list[str],
supported_groupings: list[str],
create_conversions_for_time_units: bool,
) -> set[str]:
"""Get possible derived function names given a base function name.

Examples
--------
>>> all_variations_of_base_name(
base_name="income",
supported_time_conversions=["y", "m"],
supported_groupings=["hh"],
create_conversions_for_time_units=True,
)
{'income_m', 'income_y', 'income_hh_y', 'income_hh_m'}

>>> all_variations_of_base_name(
base_name="claims_benefits",
supported_time_conversions=["y", "m"],
supported_groupings=["hh"],
create_conversions_for_time_units=False,
)
{'claims_benefits_hh'}

Parameters
----------
base_name
The base function name.
supported_time_conversions
The supported time conversions.
supported_groupings
The supported groupings.
create_conversions_for_time_units
Whether to create conversions for time units.

Returns
-------
The names of all potential targets based on the base name.
"""
result = set()
if create_conversions_for_time_units:
for time_unit in supported_time_conversions:
result.add(f"{base_name}_{time_unit}")
for time_unit, aggregation in itertools.product(
supported_time_conversions, supported_groupings
):
result.add(f"{base_name}_{time_unit}_{aggregation}")
else:
result.add(base_name)
for aggregation in supported_groupings:
result.add(f"{base_name}_{aggregation}")
return result


class KeyErrorMessage(str):
"""Subclass str to allow for line breaks in KeyError messages."""

Expand Down
Loading
Loading