Skip to content

Commit 6a437a1

Browse files
committed
Make time conversion a function of data if available.
1 parent f39c983 commit 6a437a1

File tree

5 files changed

+114
-72
lines changed

5 files changed

+114
-72
lines changed

src/ttsim/function_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,6 @@ class DerivedTimeConversionFunction(TTSIMFunction):
379379
def __post_init__(self):
380380
if self.source is None:
381381
raise ValueError("The source must be specified.")
382-
if self.conversion_target is None:
383-
raise ValueError("The conversion target must be specified.")
384382

385383

386384
def _convert_and_validate_dates(

src/ttsim/shared.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,41 @@ def get_re_pattern_for_time_units_and_groupings(
7070
)
7171

7272

73+
def get_re_pattern_for_some_base_name(
74+
base_name: str, supported_time_units: list[str], supported_groupings: list[str]
75+
) -> re.Pattern:
76+
"""Get a regex for a specific base name with optional time unit and aggregation.
77+
78+
The pattern matches strings in any of these formats:
79+
- <specific_base_name>
80+
- <specific_base_name>_<time_unit>
81+
- <specific_base_name>_<aggregation>
82+
- <specific_base_name>_<time_unit>_<aggregation>
83+
84+
Parameters
85+
----------
86+
base_name
87+
The specific base name to match.
88+
supported_time_units
89+
The supported time units.
90+
supported_groupings
91+
The supported groupings.
92+
93+
Returns
94+
-------
95+
pattern
96+
The regex pattern.
97+
"""
98+
units = "".join(supported_time_units)
99+
groupings = "|".join(supported_groupings)
100+
return re.compile(
101+
f"(?P<base_name>{re.escape(base_name)})"
102+
f"(?:_(?P<time_unit>[{units}]))?"
103+
f"(?:_(?P<aggregation>{groupings}))?"
104+
f"$"
105+
)
106+
107+
73108
def all_variations_of_base_name(
74109
base_name: str,
75110
supported_time_conversions: list[str],

src/ttsim/time_conversion.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
from dags import rename_arguments
88

99
from _gettsim.config import SUPPORTED_GROUPINGS
10-
from ttsim.function_types import DerivedTimeConversionFunction, PolicyFunction
11-
from ttsim.shared import get_re_pattern_for_time_units_and_groupings
10+
from ttsim.function_types import DerivedTimeConversionFunction, TTSIMObject
11+
from ttsim.shared import (
12+
get_re_pattern_for_some_base_name,
13+
get_re_pattern_for_time_units_and_groupings,
14+
)
1215

1316
if TYPE_CHECKING:
17+
import re
1418
from collections.abc import Callable
1519

1620
from ttsim.typing import QualNameDataDict, QualNameTTSIMObjectDict
@@ -374,7 +378,7 @@ def d_to_w(value: float) -> float:
374378

375379

376380
def create_time_conversion_functions(
377-
functions: QualNameTTSIMObjectDict,
381+
ttsim_objects: QualNameTTSIMObjectDict,
378382
data: QualNameDataDict,
379383
) -> QualNameTTSIMObjectDict:
380384
"""
@@ -414,58 +418,64 @@ def create_time_conversion_functions(
414418
The functions dict with the new time conversion functions.
415419
"""
416420

417-
converted_functions = {}
421+
converted_ttsim_objects = {}
418422

419-
# Create time-conversions for existing functions
420-
for name, function in functions.items():
421-
all_time_conversions_for_this_function = _create_time_conversion_functions(
422-
name=name, func=function
423+
for source_name, ttsim_object in ttsim_objects.items():
424+
all_time_units = list(_TIME_UNITS)
425+
time_unit_pattern = get_re_pattern_for_time_units_and_groupings(
426+
supported_groupings=SUPPORTED_GROUPINGS,
427+
supported_time_units=all_time_units,
423428
)
424-
for der_name, der_func in all_time_conversions_for_this_function.items():
425-
# Skip if the function already exists or the data column exists
426-
if der_name in converted_functions or der_name in data:
427-
continue
428-
else:
429-
converted_functions[der_name] = der_func
429+
match = time_unit_pattern.fullmatch(source_name)
430+
base_name = match.group("base_name")
431+
432+
# If base_name is in data, make all time conversion depend on it instead of the
433+
# function
434+
for data_name in data:
435+
match = get_re_pattern_for_some_base_name(
436+
base_name=base_name,
437+
supported_time_units=all_time_units,
438+
supported_groupings=SUPPORTED_GROUPINGS,
439+
).fullmatch(data_name)
440+
if match:
441+
source_name = data_name # noqa: PLW2901
442+
break
430443

431-
# Create time-conversions for data columns
432-
for name in data:
433-
all_time_conversions_for_this_data_column = _create_time_conversion_functions(
434-
name=name
444+
all_time_conversions_for_this_function = _create_time_conversion_functions(
445+
source_name=source_name,
446+
ttsim_object=ttsim_object,
447+
time_unit_pattern=time_unit_pattern,
448+
all_time_units=all_time_units,
435449
)
436-
for der_name, der_func in all_time_conversions_for_this_data_column.items():
437-
# Skip if the function already exists or the data column exists
438-
if der_name in converted_functions or der_name in data:
450+
for der_name, der_func in all_time_conversions_for_this_function.items():
451+
if der_name in converted_ttsim_objects or der_name in data:
439452
continue
440453
else:
441-
converted_functions[der_name] = der_func
454+
converted_ttsim_objects[der_name] = der_func
442455

443-
return converted_functions
456+
return converted_ttsim_objects
444457

445458

446459
def _create_time_conversion_functions(
447-
name: str, func: PolicyFunction | None = None
460+
source_name: str,
461+
ttsim_object: TTSIMObject,
462+
time_unit_pattern: re.Pattern,
463+
all_time_units: list[str],
448464
) -> dict[str, DerivedTimeConversionFunction]:
449465
result: dict[str, DerivedTimeConversionFunction] = {}
450-
all_time_units = list(_TIME_UNITS)
451-
452-
time_unit_pattern = get_re_pattern_for_time_units_and_groupings(
453-
supported_groupings=SUPPORTED_GROUPINGS,
454-
supported_time_units=all_time_units,
455-
)
456-
457-
match = time_unit_pattern.fullmatch(name)
466+
match = time_unit_pattern.fullmatch(source_name)
458467
base_name = match.group("base_name")
459468
time_unit = match.group("time_unit") or ""
460469
aggregation = match.group("aggregation") or ""
461-
462-
dependencies = set(inspect.signature(func).parameters) if func else set()
470+
dependencies = (
471+
set(inspect.signature(ttsim_object).parameters) if ttsim_object else set()
472+
)
463473

464474
if match and time_unit:
465475
missing_time_units = [unit for unit in all_time_units if unit != time_unit]
466476
for missing_time_unit in missing_time_units:
467477
new_name = (
468-
f"{base_name}_{missing_time_unit}{aggregation}"
478+
f"{base_name}_{missing_time_unit}_{aggregation}"
469479
if aggregation
470480
else f"{base_name}_{missing_time_unit}"
471481
)
@@ -485,21 +495,23 @@ def _create_time_conversion_functions(
485495
result[new_name] = DerivedTimeConversionFunction(
486496
leaf_name=dt.tree_path_from_qual_name(new_name)[-1],
487497
function=_create_function_for_time_unit(
488-
name,
489-
_time_conversion_functions[f"{time_unit}_to_{missing_time_unit}"],
498+
source=source_name,
499+
converter=_time_conversion_functions[
500+
f"{time_unit}_to_{missing_time_unit}"
501+
],
490502
),
491-
source=name,
492-
start_date=func.start_date,
493-
end_date=func.end_date,
503+
source=source_name,
504+
start_date=ttsim_object.start_date,
505+
end_date=ttsim_object.end_date,
494506
)
495507

496508
return result
497509

498510

499511
def _create_function_for_time_unit(
500-
function_name: str, converter: Callable[[float], float]
512+
source: str, converter: Callable[[float], float]
501513
) -> Callable[[float], float]:
502-
@rename_arguments(mapper={"x": function_name})
514+
@rename_arguments(mapper={"x": source})
503515
def func(x: float) -> float:
504516
return converter(x)
505517

tests/ttsim/test_shared.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
all_variations_of_base_name,
77
create_tree_from_path_and_value,
88
get_name_of_group_by_id,
9+
get_re_pattern_for_some_base_name,
910
get_re_pattern_for_time_units_and_groupings,
1011
insert_path_and_value,
1112
merge_trees,
@@ -379,3 +380,27 @@ def test_get_re_pattern_for_time_units_and_groupings(
379380
assert match.group("base_name") == expected_base_name
380381
assert match.group("time_unit") == expected_time_unit
381382
assert match.group("aggregation") == expected_aggregation
383+
384+
385+
@pytest.mark.parametrize(
386+
(
387+
"base_name",
388+
"supported_time_units",
389+
"supported_groupings",
390+
"expected_match",
391+
),
392+
[
393+
("foo", ["m", "y"], ["hh"], "foo_m_hh"),
394+
("foo", ["m", "y"], ["hh", "x"], "foo_m"),
395+
("foo", ["m", "y"], ["hh", "x"], "foo_hh"),
396+
],
397+
)
398+
def test_get_re_pattern_for_some_base_name(
399+
base_name, supported_time_units, supported_groupings, expected_match
400+
):
401+
re_pattern = get_re_pattern_for_some_base_name(
402+
base_name=base_name,
403+
supported_time_units=supported_time_units,
404+
supported_groupings=supported_groupings,
405+
)
406+
assert re_pattern.fullmatch(expected_match)

tests/ttsim/test_time_conversion.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -280,34 +280,6 @@ def test_should_create_functions_for_other_time_units_for_functions(
280280
for expected_name in expected:
281281
assert expected_name in time_conversion_functions
282282

283-
@pytest.mark.parametrize(
284-
("name", "expected"),
285-
[
286-
("test_y", ["test_m", "test_q", "test_w", "test_d"]),
287-
("test_y_hh", ["test_m_hh", "test_q_hh", "test_w_hh", "test_d_hh"]),
288-
("test_y_sn", ["test_m_sn", "test_q_sn", "test_w_sn", "test_d_sn"]),
289-
("test_q", ["test_y", "test_m", "test_w", "test_d"]),
290-
("test_q_hh", ["test_y_hh", "test_m_hh", "test_w_hh", "test_d_hh"]),
291-
("test_q_sn", ["test_y_sn", "test_m_sn", "test_w_sn", "test_d_sn"]),
292-
("test_m", ["test_y", "test_q", "test_w", "test_d"]),
293-
("test_m_hh", ["test_y_hh", "test_q_hh", "test_w_hh", "test_d_hh"]),
294-
("test_m_sn", ["test_y_sn", "test_q_sn", "test_w_sn", "test_d_sn"]),
295-
("test_w", ["test_y", "test_q", "test_m", "test_d"]),
296-
("test_w_hh", ["test_y_hh", "test_q_hh", "test_m_hh", "test_d_hh"]),
297-
("test_w_sn", ["test_y_sn", "test_q_sn", "test_m_sn", "test_d_sn"]),
298-
("test_d", ["test_y", "test_q", "test_m", "test_w"]),
299-
("test_d_hh", ["test_y_hh", "test_q_hh", "test_m_hh", "test_w_hh"]),
300-
("test_d_sn", ["test_y_sn", "test_q_sn", "test_m_sn", "test_w_sn"]),
301-
],
302-
)
303-
def test_should_create_functions_for_other_time_units_for_data_cols(
304-
self, name: str, expected: list[str]
305-
) -> None:
306-
time_conversion_functions = create_time_conversion_functions({}, {name: None})
307-
308-
for expected_name in expected:
309-
assert expected_name in time_conversion_functions
310-
311283
def test_should_not_create_functions_automatically_that_exist_already(self) -> None:
312284
time_conversion_functions = create_time_conversion_functions(
313285
{"test1_d": policy_function(leaf_name="test1_d")(lambda: 1)},

0 commit comments

Comments
 (0)