77from dags import rename_arguments
88
99from _gettsim .config import SUPPORTED_GROUPINGS
10- from ttsim .function_types import DerivedTimeConversionFunction , PolicyFunction
11- from ttsim .shared import get_re_pattern_for_time_units_and_groupings
10+ from ttsim .function_types import DerivedTimeConversionFunction , TTSIMObject
11+ from ttsim .shared import (
12+ get_re_pattern_for_some_base_name ,
13+ get_re_pattern_for_time_units_and_groupings ,
14+ )
1215
1316if TYPE_CHECKING :
17+ import re
1418 from collections .abc import Callable
1519
1620 from ttsim .typing import QualNameDataDict , QualNameTTSIMObjectDict
@@ -374,7 +378,7 @@ def d_to_w(value: float) -> float:
374378
375379
376380def create_time_conversion_functions (
377- functions : QualNameTTSIMObjectDict ,
381+ ttsim_objects : QualNameTTSIMObjectDict ,
378382 data : QualNameDataDict ,
379383) -> QualNameTTSIMObjectDict :
380384 """
@@ -414,58 +418,64 @@ def create_time_conversion_functions(
414418 The functions dict with the new time conversion functions.
415419 """
416420
417- converted_functions = {}
421+ converted_ttsim_objects = {}
418422
419- # Create time-conversions for existing functions
420- for name , function in functions .items ():
421- all_time_conversions_for_this_function = _create_time_conversion_functions (
422- name = name , func = function
423+ for source_name , ttsim_object in ttsim_objects .items ():
424+ all_time_units = list (_TIME_UNITS )
425+ time_unit_pattern = get_re_pattern_for_time_units_and_groupings (
426+ supported_groupings = SUPPORTED_GROUPINGS ,
427+ supported_time_units = all_time_units ,
423428 )
424- for der_name , der_func in all_time_conversions_for_this_function .items ():
425- # Skip if the function already exists or the data column exists
426- if der_name in converted_functions or der_name in data :
427- continue
428- else :
429- converted_functions [der_name ] = der_func
429+ match = time_unit_pattern .fullmatch (source_name )
430+ base_name = match .group ("base_name" )
431+
432+ # If base_name is in data, make all time conversion depend on it instead of the
433+ # function
434+ for data_name in data :
435+ match = get_re_pattern_for_some_base_name (
436+ base_name = base_name ,
437+ supported_time_units = all_time_units ,
438+ supported_groupings = SUPPORTED_GROUPINGS ,
439+ ).fullmatch (data_name )
440+ if match :
441+ source_name = data_name # noqa: PLW2901
442+ break
430443
431- # Create time-conversions for data columns
432- for name in data :
433- all_time_conversions_for_this_data_column = _create_time_conversion_functions (
434- name = name
444+ all_time_conversions_for_this_function = _create_time_conversion_functions (
445+ source_name = source_name ,
446+ ttsim_object = ttsim_object ,
447+ time_unit_pattern = time_unit_pattern ,
448+ all_time_units = all_time_units ,
435449 )
436- for der_name , der_func in all_time_conversions_for_this_data_column .items ():
437- # Skip if the function already exists or the data column exists
438- if der_name in converted_functions or der_name in data :
450+ for der_name , der_func in all_time_conversions_for_this_function .items ():
451+ if der_name in converted_ttsim_objects or der_name in data :
439452 continue
440453 else :
441- converted_functions [der_name ] = der_func
454+ converted_ttsim_objects [der_name ] = der_func
442455
443- return converted_functions
456+ return converted_ttsim_objects
444457
445458
446459def _create_time_conversion_functions (
447- name : str , func : PolicyFunction | None = None
460+ source_name : str ,
461+ ttsim_object : TTSIMObject ,
462+ time_unit_pattern : re .Pattern ,
463+ all_time_units : list [str ],
448464) -> dict [str , DerivedTimeConversionFunction ]:
449465 result : dict [str , DerivedTimeConversionFunction ] = {}
450- all_time_units = list (_TIME_UNITS )
451-
452- time_unit_pattern = get_re_pattern_for_time_units_and_groupings (
453- supported_groupings = SUPPORTED_GROUPINGS ,
454- supported_time_units = all_time_units ,
455- )
456-
457- match = time_unit_pattern .fullmatch (name )
466+ match = time_unit_pattern .fullmatch (source_name )
458467 base_name = match .group ("base_name" )
459468 time_unit = match .group ("time_unit" ) or ""
460469 aggregation = match .group ("aggregation" ) or ""
461-
462- dependencies = set (inspect .signature (func ).parameters ) if func else set ()
470+ dependencies = (
471+ set (inspect .signature (ttsim_object ).parameters ) if ttsim_object else set ()
472+ )
463473
464474 if match and time_unit :
465475 missing_time_units = [unit for unit in all_time_units if unit != time_unit ]
466476 for missing_time_unit in missing_time_units :
467477 new_name = (
468- f"{ base_name } _{ missing_time_unit } { aggregation } "
478+ f"{ base_name } _{ missing_time_unit } _ { aggregation } "
469479 if aggregation
470480 else f"{ base_name } _{ missing_time_unit } "
471481 )
@@ -485,21 +495,23 @@ def _create_time_conversion_functions(
485495 result [new_name ] = DerivedTimeConversionFunction (
486496 leaf_name = dt .tree_path_from_qual_name (new_name )[- 1 ],
487497 function = _create_function_for_time_unit (
488- name ,
489- _time_conversion_functions [f"{ time_unit } _to_{ missing_time_unit } " ],
498+ source = source_name ,
499+ converter = _time_conversion_functions [
500+ f"{ time_unit } _to_{ missing_time_unit } "
501+ ],
490502 ),
491- source = name ,
492- start_date = func .start_date ,
493- end_date = func .end_date ,
503+ source = source_name ,
504+ start_date = ttsim_object .start_date ,
505+ end_date = ttsim_object .end_date ,
494506 )
495507
496508 return result
497509
498510
499511def _create_function_for_time_unit (
500- function_name : str , converter : Callable [[float ], float ]
512+ source : str , converter : Callable [[float ], float ]
501513) -> Callable [[float ], float ]:
502- @rename_arguments (mapper = {"x" : function_name })
514+ @rename_arguments (mapper = {"x" : source })
503515 def func (x : float ) -> float :
504516 return converter (x )
505517
0 commit comments