diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index e409eecfa..22e855186 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -34,7 +34,7 @@ jobs: environment: py311, py312 steps: - uses: actions/checkout@v4 - - uses: prefix-dev/setup-pixi@v0.8.10 + - uses: prefix-dev/setup-pixi@v0.8.13 with: pixi-version: v0.49.0 cache: true @@ -62,6 +62,6 @@ jobs: shell: bash -el {0} - name: Upload coverage reports if: runner.os == 'Linux' && matrix.environment == 'py313' - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/pixi.lock b/pixi.lock index ae80671d3..413b4bde3 100644 --- a/pixi.lock +++ b/pixi.lock @@ -16467,8 +16467,8 @@ packages: timestamp: 1733367480074 - pypi: ./ name: ttsim - version: 0.7.1.dev188+g5d736342.d20250724 - sha256: cefa7a79ac4a812de443db006820c39c6a5375f5f14df80edff5f39c58e7b87a + version: 0.7.1.dev202+g87c1768a.d20250724 + sha256: 1897b27ec932ff5bea5e4d458fadb66c76918bebd5a5473d3420a21e138f0807 requires_dist: - dags>=0.4.1 - ipywidgets diff --git a/pyproject.toml b/pyproject.toml index b577341db..c8b0fe2cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -231,6 +231,8 @@ exclude = [] [tool.ruff.lint.per-file-ignores] "conftest.py" = ["ANN"] +# Re-export things for tab-completion explicitly. +"src/ttsim/__init__.py" = ["PLW0127"] # Long lines. "src/ttsim/interface_dag_elements/specialized_environment.py" = ["E501"] "src/ttsim/interface_dag_elements/fail_if.py" = ["E501"] diff --git a/src/ttsim/__init__.py b/src/ttsim/__init__.py index 92e096995..4b9b69827 100644 --- a/src/ttsim/__init__.py +++ b/src/ttsim/__init__.py @@ -1,10 +1,20 @@ from __future__ import annotations -from ttsim._version import __version__, __version_tuple__, version, version_tuple +try: + # Import the version from _version.py which is dynamically created by + # setuptools-scm upon installing the project with pip. + # Do not put it under version control! + from ttsim._version import __version__, __version_tuple__, version, version_tuple +except ImportError: + __version__ = "unknown" + __version_tuple__ = ("unknown", "unknown", "unknown") + version = "unknown" + version_tuple = ("unknown", "unknown", "unknown") + +from ttsim import tt from ttsim.copy_environment import copy_environment -from ttsim.interface_dag import main -from ttsim.interface_dag_elements import MainTarget -from ttsim.interface_dag_elements.shared import merge_trees +from ttsim.interface_dag_elements.shared import merge_trees, upsert_tree +from ttsim.main import main from ttsim.main_args import ( InputData, Labels, @@ -14,6 +24,24 @@ SpecializedEnvironment, TTTargets, ) +from ttsim.main_target import MainTarget +from ttsim.plot_dag import plot_interface_dag, plot_tt_dag + +copy_environment = copy_environment +merge_trees = merge_trees +upsert_tree = upsert_tree +main = main +plot_tt_dag = plot_tt_dag +plot_interface_dag = plot_interface_dag +MainTarget = MainTarget +InputData = InputData +Labels = Labels +OrigPolicyObjects = OrigPolicyObjects +RawResults = RawResults +Results = Results +SpecializedEnvironment = SpecializedEnvironment +TTTargets = TTTargets +tt = tt __all__ = [ "InputData", @@ -29,6 +57,10 @@ "copy_environment", "main", "merge_trees", + "plot_interface_dag", + "plot_tt_dag", + "tt", + "upsert_tree", "version", "version_tuple", ] diff --git a/src/ttsim/copy_environment.py b/src/ttsim/copy_environment.py index c01740ea8..4117ad8bc 100644 --- a/src/ttsim/copy_environment.py +++ b/src/ttsim/copy_environment.py @@ -8,7 +8,7 @@ import optree if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( PolicyEnvironment, SpecEnvWithoutTreeLogicAndWithDerivedFunctions, SpecEnvWithPartialledParamsAndScalars, diff --git a/src/ttsim/interface_dag_elements/__init__.py b/src/ttsim/interface_dag_elements/__init__.py index 398ff7617..e69de29bb 100644 --- a/src/ttsim/interface_dag_elements/__init__.py +++ b/src/ttsim/interface_dag_elements/__init__.py @@ -1,172 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - - -@dataclass(frozen=True) -class MainTargetABC: - @classmethod - def to_dict(cls) -> dict[str, Any]: - return { - k: v.to_dict() if isinstance(v, type(MainTargetABC)) else v - for k, v in cls.__dict__.items() - if not k.startswith("_") - } - - def __post_init__(self) -> None: - raise NotImplementedError("Do not instantiate this class directly.") - - -@dataclass(frozen=True) -class WarnIf(MainTargetABC): - functions_and_data_columns_overlap: str = ( - "warn_if__functions_and_data_columns_overlap" - ) - evaluation_date_set_in_multiple_places: str = ( - "warn_if__evaluation_date_set_in_multiple_places" - ) - - -@dataclass(frozen=True) -class FailIf(MainTargetABC): - active_periods_overlap: str = "fail_if__active_periods_overlap" - any_paths_are_invalid: str = "fail_if__any_paths_are_invalid" - backend_has_changed: str = "fail_if__backend_has_changed" - environment_is_invalid: str = "fail_if__environment_is_invalid" - foreign_keys_are_invalid_in_data: str = "fail_if__foreign_keys_are_invalid_in_data" - group_ids_are_outside_top_level_namespace: str = ( - "fail_if__group_ids_are_outside_top_level_namespace" - ) - group_variables_are_not_constant_within_groups: str = ( - "fail_if__group_variables_are_not_constant_within_groups" - ) - input_data_is_invalid: str = "fail_if__input_data_is_invalid" - input_data_tree_is_invalid: str = "fail_if__input_data_tree_is_invalid" - input_df_has_bool_or_numeric_column_names: str = ( - "fail_if__input_df_has_bool_or_numeric_column_names" - ) - input_df_mapper_columns_missing_in_df: str = ( - "fail_if__input_df_mapper_columns_missing_in_df" - ) - input_df_mapper_has_incorrect_format: str = ( - "fail_if__input_df_mapper_has_incorrect_format" - ) - non_convertible_objects_in_results_tree: str = ( - "fail_if__non_convertible_objects_in_results_tree" - ) - param_function_depends_on_column_objects: str = ( - "fail_if__param_function_depends_on_column_objects" - ) - paths_are_missing_in_targets_tree_mapper: str = ( - "fail_if__paths_are_missing_in_targets_tree_mapper" - ) - tt_root_nodes_are_missing: str = "fail_if__tt_root_nodes_are_missing" - targets_are_not_in_specialized_environment_or_data: str = ( - "fail_if__targets_are_not_in_specialized_environment_or_data" - ) - targets_tree_is_invalid: str = "fail_if__targets_tree_is_invalid" - - -@dataclass(frozen=True) -class Results(MainTargetABC): - df_with_mapper: str = "results__df_with_mapper" - df_with_nested_columns: str = "results__df_with_nested_columns" - tree: str = "results__tree" - - -@dataclass(frozen=True) -class RawResults(MainTargetABC): - columns: str = "raw_results__columns" - combined: str = "raw_results__combined" - from_input_data: str = "raw_results__from_input_data" - params: str = "raw_results__params" - - -@dataclass(frozen=True) -class SpecializedEnvironment(MainTargetABC): - without_tree_logic_and_with_derived_functions: str = ( - "specialized_environment__without_tree_logic_and_with_derived_functions" - ) - with_processed_params_and_scalars: str = ( - "specialized_environment__with_processed_params_and_scalars" - ) - with_partialled_params_and_scalars: str = ( - "specialized_environment__with_partialled_params_and_scalars" - ) - tax_transfer_dag: str = "specialized_environment__tax_transfer_dag" - tax_transfer_function: str = "specialized_environment__tax_transfer_function" - - -@dataclass(frozen=True) -class Targets(MainTargetABC): - qname: str = "tt_targets__qname" - tree: str = "tt_targets__tree" - - -@dataclass(frozen=True) -class Labels(MainTargetABC): - column_targets: str = "labels__column_targets" - grouping_levels: str = "labels__grouping_levels" - input_data_targets: str = "labels__input_data_targets" - param_targets: str = "labels__param_targets" - processed_data_columns: str = "labels__processed_data_columns" - input_columns: str = "labels__input_columns" - root_nodes: str = "labels__root_nodes" - top_level_namespace: str = "labels__top_level_namespace" - - -@dataclass(frozen=True) -class DfAndMapper(MainTargetABC): - df: str = "input_data__df_and_mapper__df" - mapper: str = "input_data__df_and_mapper__mapper" - - -@dataclass(frozen=True) -class InputData(MainTargetABC): - df_and_mapper: type[DfAndMapper] = field(default=DfAndMapper) - df_with_nested_columns: str = "input_data__df_with_nested_columns" - flat: str = "input_data__flat" - tree: str = "input_data__tree" - - -@dataclass(frozen=True) -class OrigPolicyObjects(MainTargetABC): - column_objects_and_param_functions: str = ( - "orig_policy_objects__column_objects_and_param_functions" - ) - param_specs: str = "orig_policy_objects__param_specs" - # Do not include root here, will be pre-defined in user-facing implementations. - - -@dataclass(frozen=True) -class Templates(MainTargetABC): - input_data_dtypes: str = "templates__input_data_dtypes" - - -@dataclass(frozen=True) -class MainTarget(MainTargetABC): - results: type[Results] = field(default=Results) - templates: type[Templates] = field(default=Templates) - policy_environment: str = "policy_environment" - specialized_environment: type[SpecializedEnvironment] = field( - default=SpecializedEnvironment - ) - orig_policy_objects: type[OrigPolicyObjects] = field(default=OrigPolicyObjects) - processed_data: str = "processed_data" - raw_results: type[RawResults] = field(default=RawResults) - labels: type[Labels] = field(default=Labels) - input_data: type[InputData] = field(default=InputData) - tt_targets: type[Targets] = field(default=Targets) - num_segments: str = "num_segments" - backend: str = "backend" - evaluation_date_str: str = "evaluation_date_str" - evaluation_date: str = "evaluation_date" - policy_date_str: str = "policy_date_str" - policy_date: str = "policy_date" - xnp: str = "xnp" - dnp: str = "dnp" - rounding: str = "rounding" - tt_function_set_annotations: str = "tt_function_set_annotations" - warn_if: type[WarnIf] = field(default=WarnIf) - fail_if: type[FailIf] = field(default=FailIf) diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index 4b60bc6d6..8b9f12912 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -12,8 +12,8 @@ get_re_pattern_for_specific_time_units_and_groupings, group_pattern, ) -from ttsim.tt_dag_elements.aggregation import grouped_sum -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.aggregation import grouped_sum +from ttsim.tt.column_objects_param_function import ( DEFAULT_END_DATE, DEFAULT_START_DATE, AggByGroupFunction, @@ -22,13 +22,13 @@ ParamFunction, TimeConversionFunction, ) -from ttsim.tt_dag_elements.param_objects import ScalarParam +from ttsim.tt.param_objects import ScalarParam if TYPE_CHECKING: import re from collections.abc import Callable - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( BoolColumn, FloatColumn, IntColumn, @@ -584,7 +584,7 @@ def func(x: BoolColumn) -> FloatColumn: ... def func(x: FloatColumn | IntColumn | BoolColumn) -> FloatColumn: return converter(x) - return func # type: ignore[has-type] + return func def create_agg_by_group_functions( diff --git a/src/ttsim/interface_dag_elements/data_converters.py b/src/ttsim/interface_dag_elements/data_converters.py index a16a58726..3cdcbdef3 100644 --- a/src/ttsim/interface_dag_elements/data_converters.py +++ b/src/ttsim/interface_dag_elements/data_converters.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatData, NestedData, NestedInputsMapper, diff --git a/src/ttsim/interface_dag_elements/dates.py b/src/ttsim/interface_dag_elements/dates.py index 2e83d81b1..99166e2a9 100644 --- a/src/ttsim/interface_dag_elements/dates.py +++ b/src/ttsim/interface_dag_elements/dates.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: import datetime - from ttsim.interface_dag_elements.typing import DashedISOString + from ttsim.typing import DashedISOString @interface_input(in_top_level_namespace=True) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index a74143260..7b3040940 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -23,7 +23,7 @@ from ttsim.interface_dag_elements.interface_node_objects import fail_function from ttsim.interface_dag_elements.shared import get_name_of_group_by_id -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( DEFAULT_END_DATE, ColumnFunction, ColumnObject, @@ -31,7 +31,7 @@ ParamFunction, PolicyInput, ) -from ttsim.tt_dag_elements.param_objects import ( +from ttsim.tt.param_objects import ( PLACEHOLDER_FIELD, PLACEHOLDER_VALUE, ParamObject, @@ -41,7 +41,7 @@ from collections.abc import Callable from ttsim.interface_dag_elements.input_data import FlatData - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, NestedData, diff --git a/src/ttsim/interface_dag_elements/input_data.py b/src/ttsim/interface_dag_elements/input_data.py index adbcd44ef..2fdb6b680 100644 --- a/src/ttsim/interface_dag_elements/input_data.py +++ b/src/ttsim/interface_dag_elements/input_data.py @@ -18,7 +18,7 @@ import pandas as pd - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatData, NestedData, NestedInputsMapper, diff --git a/src/ttsim/interface_dag_elements/interface_node_objects.py b/src/ttsim/interface_dag_elements/interface_node_objects.py index af64c5862..00a3dd749 100644 --- a/src/ttsim/interface_dag_elements/interface_node_objects.py +++ b/src/ttsim/interface_dag_elements/interface_node_objects.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable - from ttsim.interface_dag_elements.typing import UnorderedQNames + from ttsim.typing import UnorderedQNames FunArgTypes = ParamSpec("FunArgTypes") diff --git a/src/ttsim/interface_dag_elements/labels.py b/src/ttsim/interface_dag_elements/labels.py index 3d0480b4c..023551c2d 100644 --- a/src/ttsim/interface_dag_elements/labels.py +++ b/src/ttsim/interface_dag_elements/labels.py @@ -14,10 +14,10 @@ get_re_pattern_for_all_time_units_and_groupings, group_pattern, ) -from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput +from ttsim.tt.column_objects_param_function import PolicyInput if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( OrderedQNames, PolicyEnvironment, QNameData, diff --git a/src/ttsim/interface_dag_elements/num_segments.py b/src/ttsim/interface_dag_elements/num_segments.py index f96f0dade..52b57e1f5 100644 --- a/src/ttsim/interface_dag_elements/num_segments.py +++ b/src/ttsim/interface_dag_elements/num_segments.py @@ -5,7 +5,7 @@ from ttsim.interface_dag_elements.interface_node_objects import interface_function if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import QNameData + from ttsim.typing import QNameData @interface_function(in_top_level_namespace=True) diff --git a/src/ttsim/interface_dag_elements/orig_policy_objects.py b/src/ttsim/interface_dag_elements/orig_policy_objects.py index 9e225e761..b17204802 100644 --- a/src/ttsim/interface_dag_elements/orig_policy_objects.py +++ b/src/ttsim/interface_dag_elements/orig_policy_objects.py @@ -11,7 +11,7 @@ interface_function, interface_input, ) -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( ColumnObject, ParamFunction, ) @@ -20,7 +20,7 @@ from pathlib import Path from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, OrigParamSpec, diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index 114faa77a..c91fe4e5b 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -12,7 +12,7 @@ merge_trees, upsert_tree, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( ConsecutiveIntLookupTableParam, DictParam, ParamObject, @@ -24,15 +24,15 @@ get_month_based_phase_inout_of_age_thresholds_param_value, get_year_based_phase_inout_of_age_thresholds_param_value, ) -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( DEFAULT_END_DATE, ) -from ttsim.tt_dag_elements.piecewise_polynomial import get_piecewise_parameters +from ttsim.tt.piecewise_polynomial import get_piecewise_parameters if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, NestedColumnObjectsParamFunctions, diff --git a/src/ttsim/interface_dag_elements/processed_data.py b/src/ttsim/interface_dag_elements/processed_data.py index 9ec206063..5bdbb7185 100644 --- a/src/ttsim/interface_dag_elements/processed_data.py +++ b/src/ttsim/interface_dag_elements/processed_data.py @@ -5,12 +5,12 @@ import dags.tree as dt from ttsim.interface_dag_elements.interface_node_objects import interface_function -from ttsim.tt_dag_elements.column_objects_param_function import reorder_ids +from ttsim.tt.column_objects_param_function import reorder_ids if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import FlatData, QNameData + from ttsim.typing import FlatData, QNameData @interface_function(in_top_level_namespace=True) diff --git a/src/ttsim/interface_dag_elements/raw_results.py b/src/ttsim/interface_dag_elements/raw_results.py index a841b84ab..882c1384d 100644 --- a/src/ttsim/interface_dag_elements/raw_results.py +++ b/src/ttsim/interface_dag_elements/raw_results.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( OrderedQNames, QNameData, SpecEnvWithProcessedParamsAndScalars, diff --git a/src/ttsim/interface_dag_elements/results.py b/src/ttsim/interface_dag_elements/results.py index b4ebb7ee0..c98cc162a 100644 --- a/src/ttsim/interface_dag_elements/results.py +++ b/src/ttsim/interface_dag_elements/results.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: import pandas as pd - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatData, NestedData, NestedStrings, diff --git a/src/ttsim/interface_dag_elements/shared.py b/src/ttsim/interface_dag_elements/shared.py index 1f319ae13..39b59ebff 100644 --- a/src/ttsim/interface_dag_elements/shared.py +++ b/src/ttsim/interface_dag_elements/shared.py @@ -8,7 +8,7 @@ import optree if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( DashedISOString, NestedColumnObjectsParamFunctions, NestedData, diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index dc60787ca..8953bf825 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -15,13 +15,13 @@ interface_input, ) from ttsim.interface_dag_elements.shared import merge_trees -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( ColumnFunction, ColumnObject, ParamFunction, PolicyInput, ) -from ttsim.tt_dag_elements.param_objects import ParamObject, RawParam +from ttsim.tt.param_objects import ParamObject, RawParam if TYPE_CHECKING: import datetime @@ -30,7 +30,7 @@ import networkx as nx - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( OrderedQNames, PolicyEnvironment, QNameData, diff --git a/src/ttsim/interface_dag_elements/templates.py b/src/ttsim/interface_dag_elements/templates.py index 1afad2303..afc2a15d6 100644 --- a/src/ttsim/interface_dag_elements/templates.py +++ b/src/ttsim/interface_dag_elements/templates.py @@ -9,11 +9,11 @@ from ttsim.interface_dag_elements.shared import ( get_re_pattern_for_all_time_units_and_groupings, ) -from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput -from ttsim.tt_dag_elements.vectorization import scalar_type_to_array_type +from ttsim.tt.column_objects_param_function import PolicyInput +from ttsim.tt.vectorization import scalar_type_to_array_type if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( NestedInputStructureDict, OrderedQNames, PolicyEnvironment, diff --git a/src/ttsim/interface_dag_elements/tt_targets.py b/src/ttsim/interface_dag_elements/tt_targets.py index 667ed883f..9621d491b 100644 --- a/src/ttsim/interface_dag_elements/tt_targets.py +++ b/src/ttsim/interface_dag_elements/tt_targets.py @@ -7,10 +7,10 @@ from ttsim.interface_dag_elements.interface_node_objects import ( interface_function, ) -from ttsim.tt_dag_elements.column_objects_param_function import ColumnFunction +from ttsim.tt.column_objects_param_function import ColumnFunction if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( NestedStrings, NestedTargetDict, OrderedQNames, diff --git a/src/ttsim/interface_dag_elements/warn_if.py b/src/ttsim/interface_dag_elements/warn_if.py index 1c07cfca7..28f94b267 100644 --- a/src/ttsim/interface_dag_elements/warn_if.py +++ b/src/ttsim/interface_dag_elements/warn_if.py @@ -10,12 +10,12 @@ format_list_linewise, ) from ttsim.interface_dag_elements.interface_node_objects import warn_function -from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput +from ttsim.tt.column_objects_param_function import PolicyInput if TYPE_CHECKING: import datetime - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( PolicyEnvironment, QNameData, UnorderedQNames, diff --git a/src/ttsim/interface_dag.py b/src/ttsim/main.py similarity index 99% rename from src/ttsim/interface_dag.py rename to src/ttsim/main.py index 3653526d9..a118b5b5a 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/main.py @@ -11,7 +11,6 @@ import networkx as nx import optree -from ttsim.interface_dag_elements import MainTarget, MainTargetABC from ttsim.interface_dag_elements.fail_if import ( format_errors_and_warnings, format_list_linewise, @@ -25,19 +24,12 @@ ) from ttsim.interface_dag_elements.orig_policy_objects import load_module from ttsim.main_args import MainArg +from ttsim.main_target import MainTarget, MainTargetABC if TYPE_CHECKING: import datetime from collections.abc import Iterable - from ttsim.interface_dag_elements.typing import ( - DashedISOString, - FlatInterfaceObjects, - NestedTargetDict, - PolicyEnvironment, - QNameData, - UnorderedQNames, - ) from ttsim.main_args import ( InputData, Labels, @@ -47,6 +39,14 @@ SpecializedEnvironment, TTTargets, ) + from ttsim.typing import ( + DashedISOString, + FlatInterfaceObjects, + NestedTargetDict, + PolicyEnvironment, + QNameData, + UnorderedQNames, + ) def main( diff --git a/src/ttsim/main_args.py b/src/ttsim/main_args.py index f07063d76..ab273fae7 100644 --- a/src/ttsim/main_args.py +++ b/src/ttsim/main_args.py @@ -10,7 +10,7 @@ import networkx as nx import pandas as pd - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatData, FlatOrigParamSpecs, diff --git a/src/ttsim/main_target.py b/src/ttsim/main_target.py new file mode 100644 index 000000000..398ff7617 --- /dev/null +++ b/src/ttsim/main_target.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class MainTargetABC: + @classmethod + def to_dict(cls) -> dict[str, Any]: + return { + k: v.to_dict() if isinstance(v, type(MainTargetABC)) else v + for k, v in cls.__dict__.items() + if not k.startswith("_") + } + + def __post_init__(self) -> None: + raise NotImplementedError("Do not instantiate this class directly.") + + +@dataclass(frozen=True) +class WarnIf(MainTargetABC): + functions_and_data_columns_overlap: str = ( + "warn_if__functions_and_data_columns_overlap" + ) + evaluation_date_set_in_multiple_places: str = ( + "warn_if__evaluation_date_set_in_multiple_places" + ) + + +@dataclass(frozen=True) +class FailIf(MainTargetABC): + active_periods_overlap: str = "fail_if__active_periods_overlap" + any_paths_are_invalid: str = "fail_if__any_paths_are_invalid" + backend_has_changed: str = "fail_if__backend_has_changed" + environment_is_invalid: str = "fail_if__environment_is_invalid" + foreign_keys_are_invalid_in_data: str = "fail_if__foreign_keys_are_invalid_in_data" + group_ids_are_outside_top_level_namespace: str = ( + "fail_if__group_ids_are_outside_top_level_namespace" + ) + group_variables_are_not_constant_within_groups: str = ( + "fail_if__group_variables_are_not_constant_within_groups" + ) + input_data_is_invalid: str = "fail_if__input_data_is_invalid" + input_data_tree_is_invalid: str = "fail_if__input_data_tree_is_invalid" + input_df_has_bool_or_numeric_column_names: str = ( + "fail_if__input_df_has_bool_or_numeric_column_names" + ) + input_df_mapper_columns_missing_in_df: str = ( + "fail_if__input_df_mapper_columns_missing_in_df" + ) + input_df_mapper_has_incorrect_format: str = ( + "fail_if__input_df_mapper_has_incorrect_format" + ) + non_convertible_objects_in_results_tree: str = ( + "fail_if__non_convertible_objects_in_results_tree" + ) + param_function_depends_on_column_objects: str = ( + "fail_if__param_function_depends_on_column_objects" + ) + paths_are_missing_in_targets_tree_mapper: str = ( + "fail_if__paths_are_missing_in_targets_tree_mapper" + ) + tt_root_nodes_are_missing: str = "fail_if__tt_root_nodes_are_missing" + targets_are_not_in_specialized_environment_or_data: str = ( + "fail_if__targets_are_not_in_specialized_environment_or_data" + ) + targets_tree_is_invalid: str = "fail_if__targets_tree_is_invalid" + + +@dataclass(frozen=True) +class Results(MainTargetABC): + df_with_mapper: str = "results__df_with_mapper" + df_with_nested_columns: str = "results__df_with_nested_columns" + tree: str = "results__tree" + + +@dataclass(frozen=True) +class RawResults(MainTargetABC): + columns: str = "raw_results__columns" + combined: str = "raw_results__combined" + from_input_data: str = "raw_results__from_input_data" + params: str = "raw_results__params" + + +@dataclass(frozen=True) +class SpecializedEnvironment(MainTargetABC): + without_tree_logic_and_with_derived_functions: str = ( + "specialized_environment__without_tree_logic_and_with_derived_functions" + ) + with_processed_params_and_scalars: str = ( + "specialized_environment__with_processed_params_and_scalars" + ) + with_partialled_params_and_scalars: str = ( + "specialized_environment__with_partialled_params_and_scalars" + ) + tax_transfer_dag: str = "specialized_environment__tax_transfer_dag" + tax_transfer_function: str = "specialized_environment__tax_transfer_function" + + +@dataclass(frozen=True) +class Targets(MainTargetABC): + qname: str = "tt_targets__qname" + tree: str = "tt_targets__tree" + + +@dataclass(frozen=True) +class Labels(MainTargetABC): + column_targets: str = "labels__column_targets" + grouping_levels: str = "labels__grouping_levels" + input_data_targets: str = "labels__input_data_targets" + param_targets: str = "labels__param_targets" + processed_data_columns: str = "labels__processed_data_columns" + input_columns: str = "labels__input_columns" + root_nodes: str = "labels__root_nodes" + top_level_namespace: str = "labels__top_level_namespace" + + +@dataclass(frozen=True) +class DfAndMapper(MainTargetABC): + df: str = "input_data__df_and_mapper__df" + mapper: str = "input_data__df_and_mapper__mapper" + + +@dataclass(frozen=True) +class InputData(MainTargetABC): + df_and_mapper: type[DfAndMapper] = field(default=DfAndMapper) + df_with_nested_columns: str = "input_data__df_with_nested_columns" + flat: str = "input_data__flat" + tree: str = "input_data__tree" + + +@dataclass(frozen=True) +class OrigPolicyObjects(MainTargetABC): + column_objects_and_param_functions: str = ( + "orig_policy_objects__column_objects_and_param_functions" + ) + param_specs: str = "orig_policy_objects__param_specs" + # Do not include root here, will be pre-defined in user-facing implementations. + + +@dataclass(frozen=True) +class Templates(MainTargetABC): + input_data_dtypes: str = "templates__input_data_dtypes" + + +@dataclass(frozen=True) +class MainTarget(MainTargetABC): + results: type[Results] = field(default=Results) + templates: type[Templates] = field(default=Templates) + policy_environment: str = "policy_environment" + specialized_environment: type[SpecializedEnvironment] = field( + default=SpecializedEnvironment + ) + orig_policy_objects: type[OrigPolicyObjects] = field(default=OrigPolicyObjects) + processed_data: str = "processed_data" + raw_results: type[RawResults] = field(default=RawResults) + labels: type[Labels] = field(default=Labels) + input_data: type[InputData] = field(default=InputData) + tt_targets: type[Targets] = field(default=Targets) + num_segments: str = "num_segments" + backend: str = "backend" + evaluation_date_str: str = "evaluation_date_str" + evaluation_date: str = "evaluation_date" + policy_date_str: str = "policy_date_str" + policy_date: str = "policy_date" + xnp: str = "xnp" + dnp: str = "dnp" + rounding: str = "rounding" + tt_function_set_annotations: str = "tt_function_set_annotations" + warn_if: type[WarnIf] = field(default=WarnIf) + fail_if: type[FailIf] = field(default=FailIf) diff --git a/src/ttsim/plot_dag.py b/src/ttsim/plot_dag.py index af4f2a0e8..e755385d5 100644 --- a/src/ttsim/plot_dag.py +++ b/src/ttsim/plot_dag.py @@ -14,7 +14,6 @@ import plotly.graph_objects as go from ttsim import main -from ttsim.interface_dag import load_flat_interface_functions_and_inputs from ttsim.interface_dag_elements.interface_node_objects import ( FailFunction, InputDependentInterfaceFunction, @@ -23,7 +22,8 @@ WarnFunction, interface_function, ) -from ttsim.tt_dag_elements import ( +from ttsim.main import load_flat_interface_functions_and_inputs +from ttsim.tt import ( ColumnFunction, ParamFunction, ParamObject, @@ -38,7 +38,7 @@ from pathlib import Path from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( PolicyEnvironment, SpecEnvWithoutTreeLogicAndWithDerivedFunctions, ) @@ -150,6 +150,7 @@ def plot_interface_dag( include_fail_and_warn_nodes: bool = True, show_node_description: bool = False, output_path: Path | None = None, + remove_orig_policy_objects__root: bool = True, ) -> go.Figure: """Plot the full interface DAG.""" interface_functions_and_inputs = load_flat_interface_functions_and_inputs() @@ -194,6 +195,8 @@ def plot_interface_dag( description=description or "No description available.", namespace=namespace, ) + if remove_orig_policy_objects__root: + dag.remove_nodes_from(["orig_policy_objects__root"]) fig = _plot_dag( dag=dag, diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index a71971220..788a2b006 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -17,7 +17,7 @@ from ttsim.interface_dag_elements.fail_if import format_list_linewise from ttsim.interface_dag_elements.shared import to_datetime from ttsim.plot_dag import dummy_callable -from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput +from ttsim.tt.column_objects_param_function import PolicyInput # Set display options to show all columns without truncation pd.set_option("display.max_columns", None) @@ -28,7 +28,7 @@ from pathlib import Path from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, NestedData, diff --git a/src/ttsim/tt_dag_elements/__init__.py b/src/ttsim/tt/__init__.py similarity index 83% rename from src/ttsim/tt_dag_elements/__init__.py rename to src/ttsim/tt/__init__.py index e8a1a3b53..07daf9636 100644 --- a/src/ttsim/tt_dag_elements/__init__.py +++ b/src/ttsim/tt/__init__.py @@ -1,5 +1,5 @@ -from ttsim.tt_dag_elements.aggregation import AggType -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.aggregation import AggType +from ttsim.tt.column_objects_param_function import ( AggByGroupFunction, AggByPIDFunction, ColumnFunction, @@ -17,7 +17,7 @@ policy_function, policy_input, ) -from ttsim.tt_dag_elements.param_objects import ( +from ttsim.tt.param_objects import ( ConsecutiveIntLookupTableParam, ConsecutiveIntLookupTableParamValue, DictParam, @@ -30,12 +30,13 @@ get_month_based_phase_inout_of_age_thresholds_param_value, get_year_based_phase_inout_of_age_thresholds_param_value, ) -from ttsim.tt_dag_elements.piecewise_polynomial import ( +from ttsim.tt.piecewise_polynomial import ( get_piecewise_parameters, + get_piecewise_thresholds, piecewise_polynomial, ) -from ttsim.tt_dag_elements.rounding import RoundingSpec -from ttsim.tt_dag_elements.shared import join +from ttsim.tt.rounding import RoundingSpec +from ttsim.tt.shared import join __all__ = [ "AggByGroupFunction", @@ -63,6 +64,7 @@ "get_consecutive_int_lookup_table_param_value", "get_month_based_phase_inout_of_age_thresholds_param_value", "get_piecewise_parameters", + "get_piecewise_thresholds", "get_year_based_phase_inout_of_age_thresholds_param_value", "group_creation_function", "join", diff --git a/src/ttsim/tt_dag_elements/aggregation.py b/src/ttsim/tt/aggregation.py similarity index 98% rename from src/ttsim/tt_dag_elements/aggregation.py rename to src/ttsim/tt/aggregation.py index 20fb25f84..86f7798fb 100644 --- a/src/ttsim/tt_dag_elements/aggregation.py +++ b/src/ttsim/tt/aggregation.py @@ -3,10 +3,10 @@ from enum import StrEnum from typing import TYPE_CHECKING, Literal, overload -from ttsim.tt_dag_elements import aggregation_jax, aggregation_numpy +from ttsim.tt import aggregation_jax, aggregation_numpy if TYPE_CHECKING: - from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn + from ttsim.tt import BoolColumn, FloatColumn, IntColumn class AggType(StrEnum): diff --git a/src/ttsim/tt_dag_elements/aggregation_jax.py b/src/ttsim/tt/aggregation_jax.py similarity index 98% rename from src/ttsim/tt_dag_elements/aggregation_jax.py rename to src/ttsim/tt/aggregation_jax.py index ff0566fea..696434650 100644 --- a/src/ttsim/tt_dag_elements/aggregation_jax.py +++ b/src/ttsim/tt/aggregation_jax.py @@ -9,7 +9,7 @@ pass if TYPE_CHECKING: - from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn + from ttsim.tt import BoolColumn, FloatColumn, IntColumn def grouped_count(group_id: IntColumn, num_segments: int) -> jnp.ndarray: diff --git a/src/ttsim/tt_dag_elements/aggregation_numpy.py b/src/ttsim/tt/aggregation_numpy.py similarity index 99% rename from src/ttsim/tt_dag_elements/aggregation_numpy.py rename to src/ttsim/tt/aggregation_numpy.py index 31afc2ba1..c1cdce1a2 100644 --- a/src/ttsim/tt_dag_elements/aggregation_numpy.py +++ b/src/ttsim/tt/aggregation_numpy.py @@ -6,7 +6,7 @@ import numpy_groupies as npg if TYPE_CHECKING: - from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn + from ttsim.tt import BoolColumn, FloatColumn, IntColumn def grouped_count(group_id: IntColumn) -> IntColumn: diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt/column_objects_param_function.py similarity index 99% rename from src/ttsim/tt_dag_elements/column_objects_param_function.py rename to src/ttsim/tt/column_objects_param_function.py index c5321aa8c..95991d28f 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt/column_objects_param_function.py @@ -11,7 +11,7 @@ from dags import rename_arguments from ttsim.interface_dag_elements.shared import to_datetime -from ttsim.tt_dag_elements.aggregation import ( +from ttsim.tt.aggregation import ( AggType, all_by_p_id, any_by_p_id, @@ -28,14 +28,14 @@ min_by_p_id, sum_by_p_id, ) -from ttsim.tt_dag_elements.rounding import RoundingSpec -from ttsim.tt_dag_elements.vectorization import vectorize_function +from ttsim.tt.rounding import RoundingSpec +from ttsim.tt.vectorization import vectorize_function if TYPE_CHECKING: from collections.abc import Callable from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( DashedISOString, IntColumn, UnorderedQNames, diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt/param_objects.py similarity index 99% rename from src/ttsim/tt_dag_elements/param_objects.py rename to src/ttsim/tt/param_objects.py index 00f0e306e..b0ef9bbd8 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt/param_objects.py @@ -12,7 +12,7 @@ from jaxtyping import Array, Bool, Float, Int - from ttsim.tt_dag_elements.typing import NestedLookupDict + from ttsim.typing import NestedLookupDict @dataclass(frozen=True) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt/piecewise_polynomial.py similarity index 98% rename from src/ttsim/tt_dag_elements/piecewise_polynomial.py rename to src/ttsim/tt/piecewise_polynomial.py index 879d594c7..36e6948f4 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt/piecewise_polynomial.py @@ -5,7 +5,7 @@ import numpy -from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue +from ttsim.tt.param_objects import PiecewisePolynomialParamValue if TYPE_CHECKING: from types import ModuleType @@ -122,7 +122,7 @@ def get_piecewise_parameters( ) # Extract lower thresholds. - lower_thresholds, upper_thresholds, thresholds = check_and_get_thresholds( + lower_thresholds, upper_thresholds, thresholds = get_piecewise_thresholds( leaf_name=leaf_name, parameter_dict=parameter_dict, xnp=xnp, @@ -151,7 +151,7 @@ def get_piecewise_parameters( ) -def check_and_get_thresholds( # noqa: C901 +def get_piecewise_thresholds( # noqa: C901 leaf_name: str, parameter_dict: dict[int, dict[str, float]], xnp: ModuleType, diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt/rounding.py similarity index 97% rename from src/ttsim/tt_dag_elements/rounding.py rename to src/ttsim/tt/rounding.py index 1a9e6cd80..ecb7ce1c0 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt/rounding.py @@ -8,7 +8,7 @@ from collections.abc import Callable from types import ModuleType - from ttsim.interface_dag_elements.typing import FloatColumn + from ttsim.typing import FloatColumn ROUNDING_DIRECTION = Literal["up", "down", "nearest"] diff --git a/src/ttsim/tt_dag_elements/shared.py b/src/ttsim/tt/shared.py similarity index 96% rename from src/ttsim/tt_dag_elements/shared.py rename to src/ttsim/tt/shared.py index 833760041..b2d18255f 100644 --- a/src/ttsim/tt_dag_elements/shared.py +++ b/src/ttsim/tt/shared.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import BoolColumn, FloatColumn, IntColumn + from ttsim.typing import BoolColumn, FloatColumn, IntColumn @overload diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt/vectorization.py similarity index 100% rename from src/ttsim/tt_dag_elements/vectorization.py rename to src/ttsim/tt/vectorization.py diff --git a/src/ttsim/tt_dag_elements/typing.py b/src/ttsim/tt_dag_elements/typing.py deleted file mode 100644 index 66486f5ad..000000000 --- a/src/ttsim/tt_dag_elements/typing.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Literal, NewType, TypeAlias - -if TYPE_CHECKING: - import datetime - - NestedLookupDict: TypeAlias = dict[int, float | int | bool | "NestedLookupDict"] - OrigParamSpec = ( - # Header - dict[str, str | None | dict[Literal["de", "en"], str | None]] - | - # Parameters at one point in time - dict[ - datetime.date, - dict[Literal["note", "reference"] | str | int, Any], # noqa: PYI051 - ] - ) - DashedISOString = NewType("DashedISOString", str) - """A string representing a date in the format 'YYYY-MM-DD'.""" diff --git a/src/ttsim/interface_dag_elements/typing.py b/src/ttsim/typing.py similarity index 93% rename from src/ttsim/interface_dag_elements/typing.py rename to src/ttsim/typing.py index 1b07fdd7e..107b0d51f 100644 --- a/src/ttsim/interface_dag_elements/typing.py +++ b/src/ttsim/typing.py @@ -2,14 +2,13 @@ from typing import TYPE_CHECKING, Any, Literal, NewType, TypeAlias, TypeVar -from jaxtyping import Array, Bool, Float, Int - -BoolColumn: TypeAlias = Bool[Array, " n_obs"] -IntColumn: TypeAlias = Int[Array, " n_obs"] -FloatColumn: TypeAlias = Float[Array, " n_obs"] +if TYPE_CHECKING: + from jaxtyping import Array, Bool, Float, Int + BoolColumn: TypeAlias = Bool[Array, " n_obs"] + IntColumn: TypeAlias = Int[Array, " n_obs"] + FloatColumn: TypeAlias = Float[Array, " n_obs"] -if TYPE_CHECKING: # Make these available for import from other modules. import datetime from collections.abc import Iterable, Mapping @@ -45,7 +44,7 @@ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Possible leaves of the various trees. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - from ttsim.tt_dag_elements import ( + from ttsim.tt import ( ColumnFunction, ColumnObject, ParamFunction, @@ -112,3 +111,5 @@ """Map qualified names to column objects and anything that comes out of processing the params.""" # noqa: E501 SpecEnvWithPartialledParamsAndScalars = Mapping[str, ColumnFunction] """Map qualified names to column functions that depend on columns only.""" + + NestedLookupDict: TypeAlias = dict[int, float | int | bool | "NestedLookupDict"] diff --git a/tests/interface_dag_elements/test_automatically_added_functions.py b/tests/interface_dag_elements/test_automatically_added_functions.py index 3952f3149..bac513254 100644 --- a/tests/interface_dag_elements/test_automatically_added_functions.py +++ b/tests/interface_dag_elements/test_automatically_added_functions.py @@ -29,7 +29,7 @@ y_to_q, y_to_w, ) -from ttsim.tt_dag_elements import policy_function +from ttsim.tt import policy_function def return_one() -> int: diff --git a/tests/interface_dag_elements/test_data_converters.py b/tests/interface_dag_elements/test_data_converters.py index 88f66a408..fee93fe97 100644 --- a/tests/interface_dag_elements/test_data_converters.py +++ b/tests/interface_dag_elements/test_data_converters.py @@ -15,7 +15,7 @@ df_with_nested_columns_to_flat_data, nested_data_to_df_with_mapped_columns, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( ScalarParam, param_function, policy_function, diff --git a/tests/interface_dag_elements/test_failures.py b/tests/interface_dag_elements/test_failures.py index 3f4a6a56b..c52a019c0 100644 --- a/tests/interface_dag_elements/test_failures.py +++ b/tests/interface_dag_elements/test_failures.py @@ -35,7 +35,7 @@ paths_are_missing_in_targets_tree_mapper, targets_are_not_in_specialized_environment_or_data, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( ConsecutiveIntLookupTableParam, ConsecutiveIntLookupTableParamValue, DictParam, @@ -52,7 +52,7 @@ from jaxtyping import Array, Float - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, OrigParamSpec, diff --git a/tests/interface_dag_elements/test_labels.py b/tests/interface_dag_elements/test_labels.py index 4dd629929..6e43b094b 100644 --- a/tests/interface_dag_elements/test_labels.py +++ b/tests/interface_dag_elements/test_labels.py @@ -3,7 +3,7 @@ import pytest from ttsim.interface_dag_elements.labels import grouping_levels, top_level_namespace -from ttsim.tt_dag_elements import policy_function, policy_input +from ttsim.tt import policy_function, policy_input def identity(x: int) -> int: diff --git a/tests/interface_dag_elements/test_orig_policy_objects.py b/tests/interface_dag_elements/test_orig_policy_objects.py index fcb4475fc..3d36f44b9 100644 --- a/tests/interface_dag_elements/test_orig_policy_objects.py +++ b/tests/interface_dag_elements/test_orig_policy_objects.py @@ -8,7 +8,7 @@ _find_files_recursively, load_module, ) -from ttsim.tt_dag_elements.param_objects import ( +from ttsim.tt.param_objects import ( ConsecutiveIntLookupTableParam, DictParam, PiecewisePolynomialParam, diff --git a/tests/interface_dag_elements/test_policy_environment.py b/tests/interface_dag_elements/test_policy_environment.py index 71241355c..da9b8dc71 100644 --- a/tests/interface_dag_elements/test_policy_environment.py +++ b/tests/interface_dag_elements/test_policy_environment.py @@ -19,12 +19,12 @@ _active_param_objects, _get_param_value, ) -from ttsim.tt_dag_elements import ScalarParam, policy_function +from ttsim.tt import ScalarParam, policy_function if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( NestedColumnObjectsParamFunctions, ) diff --git a/tests/interface_dag_elements/test_specialized_environment.py b/tests/interface_dag_elements/test_specialized_environment.py index 97740fe76..11e74115a 100644 --- a/tests/interface_dag_elements/test_specialized_environment.py +++ b/tests/interface_dag_elements/test_specialized_environment.py @@ -15,7 +15,7 @@ with_partialled_params_and_scalars, with_processed_params_and_scalars, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( AggType, DictParam, PiecewisePolynomialParam, @@ -30,7 +30,7 @@ ) if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FloatColumn, IntColumn, PolicyEnvironment, diff --git a/tests/interface_dag_elements/test_templates.py b/tests/interface_dag_elements/test_templates.py index 6e1ae43cf..0658ada19 100644 --- a/tests/interface_dag_elements/test_templates.py +++ b/tests/interface_dag_elements/test_templates.py @@ -8,11 +8,11 @@ from ttsim.testing_utils import ( load_policy_test_data, ) -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( policy_function, policy_input, ) -from ttsim.tt_dag_elements.param_objects import DictParam, ScalarParam +from ttsim.tt.param_objects import DictParam, ScalarParam METTSIM_ROOT = Path(__file__).parent.parent / "mettsim" diff --git a/tests/interface_dag_elements/test_warnings.py b/tests/interface_dag_elements/test_warnings.py index 0b930cad6..5481a6d45 100644 --- a/tests/interface_dag_elements/test_warnings.py +++ b/tests/interface_dag_elements/test_warnings.py @@ -7,9 +7,9 @@ import pytest from ttsim import main -from ttsim.interface_dag_elements import MainTarget -from ttsim.tt_dag_elements.column_objects_param_function import policy_function -from ttsim.tt_dag_elements.param_objects import ScalarParam +from ttsim.main_target import MainTarget +from ttsim.tt.column_objects_param_function import policy_function +from ttsim.tt.param_objects import ScalarParam @policy_function() diff --git a/tests/mettsim/demographics.py b/tests/mettsim/demographics.py index 3955bb9d5..1705c02a1 100644 --- a/tests/mettsim/demographics.py +++ b/tests/mettsim/demographics.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import AggType, agg_by_group_function +from ttsim.tt import AggType, agg_by_group_function @agg_by_group_function(agg_type=AggType.COUNT) diff --git a/tests/mettsim/group_by_ids.py b/tests/mettsim/group_by_ids.py index 0fd158f49..e9b3d55bd 100644 --- a/tests/mettsim/group_by_ids.py +++ b/tests/mettsim/group_by_ids.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING -from ttsim.tt_dag_elements import group_creation_function +from ttsim.tt import group_creation_function if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import IntColumn + from ttsim.tt import IntColumn @group_creation_function() diff --git a/tests/mettsim/housing_benefits/amount.py b/tests/mettsim/housing_benefits/amount.py index 81a74c29e..43368b9a4 100644 --- a/tests/mettsim/housing_benefits/amount.py +++ b/tests/mettsim/housing_benefits/amount.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import policy_function +from ttsim.tt import policy_function @policy_function(vectorization_strategy="vectorize") diff --git a/tests/mettsim/housing_benefits/eligibility/eligibility.py b/tests/mettsim/housing_benefits/eligibility/eligibility.py index 3c7919e5f..8abbd9dc7 100644 --- a/tests/mettsim/housing_benefits/eligibility/eligibility.py +++ b/tests/mettsim/housing_benefits/eligibility/eligibility.py @@ -11,7 +11,7 @@ from __future__ import annotations -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( AggType, agg_by_group_function, policy_function, diff --git a/tests/mettsim/housing_benefits/income/income.py b/tests/mettsim/housing_benefits/income/income.py index 6edc2dad1..e73397a47 100644 --- a/tests/mettsim/housing_benefits/income/income.py +++ b/tests/mettsim/housing_benefits/income/income.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import RoundingSpec, policy_function +from ttsim.tt import RoundingSpec, policy_function @policy_function( diff --git a/tests/mettsim/inputs.py b/tests/mettsim/inputs.py index dd32cd939..058affe31 100644 --- a/tests/mettsim/inputs.py +++ b/tests/mettsim/inputs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import AggType, FKType, agg_by_group_function, policy_input +from ttsim.tt import AggType, FKType, agg_by_group_function, policy_input @policy_input() diff --git a/tests/mettsim/orc_hunting_bounty/orc_hunting_bounty.py b/tests/mettsim/orc_hunting_bounty/orc_hunting_bounty.py index 7039d2bd0..642b1381a 100644 --- a/tests/mettsim/orc_hunting_bounty/orc_hunting_bounty.py +++ b/tests/mettsim/orc_hunting_bounty/orc_hunting_bounty.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from ttsim.tt_dag_elements import param_function, policy_function, policy_input +from ttsim.tt import param_function, policy_function, policy_input if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import RawParam + from ttsim.typing import RawParam @dataclass(frozen=True) diff --git a/tests/mettsim/payroll_tax/amount.py b/tests/mettsim/payroll_tax/amount.py index f8b5fd11e..dd368f73d 100644 --- a/tests/mettsim/payroll_tax/amount.py +++ b/tests/mettsim/payroll_tax/amount.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from types import ModuleType -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( PiecewisePolynomialParamValue, piecewise_polynomial, policy_function, diff --git a/tests/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py b/tests/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py index 0d5e55997..14de99b22 100644 --- a/tests/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py +++ b/tests/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( AggType, agg_by_p_id_function, join, @@ -12,7 +12,7 @@ if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import BoolColumn, IntColumn + from ttsim.typing import BoolColumn, IntColumn @agg_by_p_id_function(agg_type=AggType.SUM) diff --git a/tests/mettsim/payroll_tax/child_tax_credit/inputs.py b/tests/mettsim/payroll_tax/child_tax_credit/inputs.py index f93d891b4..f166f7d5f 100644 --- a/tests/mettsim/payroll_tax/child_tax_credit/inputs.py +++ b/tests/mettsim/payroll_tax/child_tax_credit/inputs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import FKType, policy_input +from ttsim.tt import FKType, policy_input @policy_input(foreign_key_type=FKType.MAY_POINT_TO_SELF) diff --git a/tests/mettsim/payroll_tax/income/amount.py b/tests/mettsim/payroll_tax/income/amount.py index 652a41c33..c51dec082 100644 --- a/tests/mettsim/payroll_tax/income/amount.py +++ b/tests/mettsim/payroll_tax/income/amount.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import policy_function +from ttsim.tt import policy_function @policy_function(vectorization_strategy="vectorize") diff --git a/tests/mettsim/payroll_tax/income/deductions.py b/tests/mettsim/payroll_tax/income/deductions.py index 601c01c7b..6a4a0e388 100644 --- a/tests/mettsim/payroll_tax/income/deductions.py +++ b/tests/mettsim/payroll_tax/income/deductions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ttsim.tt_dag_elements import policy_function +from ttsim.tt import policy_function @policy_function(vectorization_strategy="vectorize") diff --git a/tests/mettsim/payroll_tax/income/inputs.py b/tests/mettsim/payroll_tax/income/inputs.py index 7ea02d6b2..403f1fe48 100644 --- a/tests/mettsim/payroll_tax/income/inputs.py +++ b/tests/mettsim/payroll_tax/income/inputs.py @@ -2,7 +2,7 @@ from __future__ import annotations -from ttsim.tt_dag_elements import policy_input +from ttsim.tt import policy_input @policy_input() diff --git a/tests/mettsim/property_tax/amount.py b/tests/mettsim/property_tax/amount.py index 44d5eaf03..734e0b518 100644 --- a/tests/mettsim/property_tax/amount.py +++ b/tests/mettsim/property_tax/amount.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from types import ModuleType -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( PiecewisePolynomialParamValue, piecewise_polynomial, policy_function, diff --git a/tests/mettsim_tests/test_mettsim.py b/tests/mettsim_tests/test_mettsim.py index 4c993fd04..99f261b2d 100644 --- a/tests/mettsim_tests/test_mettsim.py +++ b/tests/mettsim_tests/test_mettsim.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: import datetime - from ttsim.interface_dag_elements.typing import ( + from ttsim.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, ) diff --git a/tests/test_copy_environment.py b/tests/test_copy_environment.py index 48520c51f..540dc4451 100644 --- a/tests/test_copy_environment.py +++ b/tests/test_copy_environment.py @@ -8,11 +8,11 @@ import pytest from ttsim import copy_environment, main -from ttsim.interface_dag_elements import MainTarget -from ttsim.tt_dag_elements.param_objects import RawParam, ScalarParam +from ttsim.main_target import MainTarget +from ttsim.tt.param_objects import RawParam, ScalarParam if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import PolicyEnvironment + from ttsim.typing import PolicyEnvironment METTSIM_ROOT = Path(__file__).parent / "mettsim" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index e426835e2..67d6ab69d 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -8,7 +8,7 @@ import pytest from ttsim import InputData, MainTarget, TTTargets, main -from ttsim.tt_dag_elements.column_objects_param_function import policy_function +from ttsim.tt.column_objects_param_function import policy_function if TYPE_CHECKING: from types import ModuleType diff --git a/tests/test_interface_dag.py b/tests/test_interface_dag.py index 3967cb88e..933c6cdea 100644 --- a/tests/test_interface_dag.py +++ b/tests/test_interface_dag.py @@ -8,7 +8,13 @@ import pytest from ttsim import InputData, OrigPolicyObjects, TTTargets -from ttsim.interface_dag import ( +from ttsim.interface_dag_elements.fail_if import format_list_linewise +from ttsim.interface_dag_elements.interface_node_objects import ( + fail_function, + input_dependent_interface_function, + interface_function, +) +from ttsim.main import ( _fail_if_input_structure_is_invalid, _fail_if_requested_nodes_cannot_be_found, _fail_if_root_nodes_of_interface_dag_are_missing, @@ -18,15 +24,9 @@ _resolve_dynamic_interface_objects_to_static_nodes, load_flat_interface_functions_and_inputs, ) -from ttsim.interface_dag_elements import MainTarget -from ttsim.interface_dag_elements.fail_if import format_list_linewise -from ttsim.interface_dag_elements.interface_node_objects import ( - fail_function, - input_dependent_interface_function, - interface_function, -) +from ttsim.main_target import MainTarget from ttsim.plot_dag import dummy_callable -from ttsim.tt_dag_elements.column_objects_param_function import policy_function +from ttsim.tt.column_objects_param_function import policy_function @interface_function(leaf_name="interface_function_a") diff --git a/tests/test_plot_dag.py b/tests/test_plot_dag.py index 11991f721..07ba47317 100644 --- a/tests/test_plot_dag.py +++ b/tests/test_plot_dag.py @@ -4,13 +4,13 @@ import pytest -from ttsim.interface_dag import main +from ttsim.main import main from ttsim.plot_dag import ( _get_tt_dag_with_node_metadata, _QNameNodeSelector, plot_interface_dag, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( ScalarParam, param_function, policy_function, diff --git a/tests/tt_dag_elements/test_aggregation_functions.py b/tests/tt_dag_elements/test_aggregation_functions.py index 870185dcb..0bb1bcf21 100644 --- a/tests/tt_dag_elements/test_aggregation_functions.py +++ b/tests/tt_dag_elements/test_aggregation_functions.py @@ -13,7 +13,7 @@ my_datetime = lambda x: x # noqa: E731 -from ttsim.tt_dag_elements.aggregation import ( +from ttsim.tt.aggregation import ( grouped_all, grouped_any, grouped_count, diff --git a/tests/tt_dag_elements/test_piecewise_polynomial.py b/tests/tt_dag_elements/test_piecewise_polynomial.py index e2610c56e..ad1921ed7 100644 --- a/tests/tt_dag_elements/test_piecewise_polynomial.py +++ b/tests/tt_dag_elements/test_piecewise_polynomial.py @@ -13,7 +13,7 @@ from types import ModuleType -from ttsim.tt_dag_elements.piecewise_polynomial import ( +from ttsim.tt.piecewise_polynomial import ( PiecewisePolynomialParamValue, get_piecewise_parameters, piecewise_polynomial, diff --git a/tests/tt_dag_elements/test_rounding.py b/tests/tt_dag_elements/test_rounding.py index 42feb586c..f27369741 100644 --- a/tests/tt_dag_elements/test_rounding.py +++ b/tests/tt_dag_elements/test_rounding.py @@ -8,7 +8,7 @@ from pandas._testing import assert_series_equal from ttsim import main -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( RoundingSpec, policy_function, policy_input, diff --git a/tests/tt_dag_elements/test_shared.py b/tests/tt_dag_elements/test_shared.py index 245750bbd..ce0a72f78 100644 --- a/tests/tt_dag_elements/test_shared.py +++ b/tests/tt_dag_elements/test_shared.py @@ -5,12 +5,12 @@ import numpy import pytest -from ttsim.tt_dag_elements import join +from ttsim.tt import join if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import IntColumn + from ttsim.typing import IntColumn @pytest.mark.parametrize( diff --git a/tests/tt_dag_elements/test_ttsim_objects.py b/tests/tt_dag_elements/test_ttsim_objects.py index f1b2f2406..7bb3b6c09 100644 --- a/tests/tt_dag_elements/test_ttsim_objects.py +++ b/tests/tt_dag_elements/test_ttsim_objects.py @@ -4,7 +4,7 @@ import pytest -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( AggType, PolicyFunction, PolicyInput, @@ -13,7 +13,7 @@ policy_function, policy_input, ) -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( ParamFunction, param_function, ) diff --git a/tests/tt_dag_elements/test_vectorization.py b/tests/tt_dag_elements/test_vectorization.py index aaca76c5a..78ffe6c45 100644 --- a/tests/tt_dag_elements/test_vectorization.py +++ b/tests/tt_dag_elements/test_vectorization.py @@ -18,16 +18,16 @@ from ttsim.interface_dag_elements.policy_environment import ( _active_column_objects_and_param_functions, ) -from ttsim.tt_dag_elements import ( +from ttsim.tt import ( GroupCreationFunction, PolicyInput, policy_function, ) -from ttsim.tt_dag_elements.column_objects_param_function import ( +from ttsim.tt.column_objects_param_function import ( AggByGroupFunction, AggByPIDFunction, ) -from ttsim.tt_dag_elements.vectorization import ( +from ttsim.tt.vectorization import ( TranslateToVectorizableError, _is_lambda_function, _make_vectorizable, @@ -38,7 +38,7 @@ if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import IntColumn + from ttsim.typing import IntColumn # ====================================================================================== # String comparison