Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
environment: py311, py312
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected].10
- uses: prefix-dev/[email protected].13
with:
pixi-version: v0.49.0
cache: true
Expand Down Expand Up @@ -62,6 +62,6 @@ jobs:
shell: bash -el {0}
- name: Upload coverage reports
if: runner.os == 'Linux' && matrix.environment == 'py313'
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ exclude = []

[tool.ruff.lint.per-file-ignores]
"conftest.py" = ["ANN"]
# Re-export things for tab-completion explicitly.
"src/ttsim/__init__.py" = ["PLW0127"]
# Long lines.
"src/ttsim/interface_dag_elements/specialized_environment.py" = ["E501"]
"src/ttsim/interface_dag_elements/fail_if.py" = ["E501"]
Expand Down
40 changes: 36 additions & 4 deletions src/ttsim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

from ttsim._version import __version__, __version_tuple__, version, version_tuple
try:
# Import the version from _version.py which is dynamically created by
# setuptools-scm upon installing the project with pip.
# Do not put it under version control!
from ttsim._version import __version__, __version_tuple__, version, version_tuple
except ImportError:
__version__ = "unknown"
__version_tuple__ = ("unknown", "unknown", "unknown")
version = "unknown"
version_tuple = ("unknown", "unknown", "unknown")

from ttsim import tt
from ttsim.copy_environment import copy_environment
from ttsim.interface_dag import main
from ttsim.interface_dag_elements import MainTarget
from ttsim.interface_dag_elements.shared import merge_trees
from ttsim.interface_dag_elements.shared import merge_trees, upsert_tree
from ttsim.main import main
from ttsim.main_args import (
InputData,
Labels,
Expand All @@ -14,6 +24,24 @@
SpecializedEnvironment,
TTTargets,
)
from ttsim.main_target import MainTarget
from ttsim.plot_dag import plot_interface_dag, plot_tt_dag

copy_environment = copy_environment
merge_trees = merge_trees
upsert_tree = upsert_tree
main = main
plot_tt_dag = plot_tt_dag
plot_interface_dag = plot_interface_dag
MainTarget = MainTarget
InputData = InputData
Labels = Labels
OrigPolicyObjects = OrigPolicyObjects
RawResults = RawResults
Results = Results
SpecializedEnvironment = SpecializedEnvironment
TTTargets = TTTargets
tt = tt

__all__ = [
"InputData",
Expand All @@ -29,6 +57,10 @@
"copy_environment",
"main",
"merge_trees",
"plot_interface_dag",
"plot_tt_dag",
"tt",
"upsert_tree",
"version",
"version_tuple",
]
2 changes: 1 addition & 1 deletion src/ttsim/copy_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import optree

if TYPE_CHECKING:
from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
PolicyEnvironment,
SpecEnvWithoutTreeLogicAndWithDerivedFunctions,
SpecEnvWithPartialledParamsAndScalars,
Expand Down
172 changes: 0 additions & 172 deletions src/ttsim/interface_dag_elements/__init__.py
Original file line number Diff line number Diff line change
@@ -1,172 +0,0 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any


@dataclass(frozen=True)
class MainTargetABC:
@classmethod
def to_dict(cls) -> dict[str, Any]:
return {
k: v.to_dict() if isinstance(v, type(MainTargetABC)) else v
for k, v in cls.__dict__.items()
if not k.startswith("_")
}

def __post_init__(self) -> None:
raise NotImplementedError("Do not instantiate this class directly.")


@dataclass(frozen=True)
class WarnIf(MainTargetABC):
functions_and_data_columns_overlap: str = (
"warn_if__functions_and_data_columns_overlap"
)
evaluation_date_set_in_multiple_places: str = (
"warn_if__evaluation_date_set_in_multiple_places"
)


@dataclass(frozen=True)
class FailIf(MainTargetABC):
active_periods_overlap: str = "fail_if__active_periods_overlap"
any_paths_are_invalid: str = "fail_if__any_paths_are_invalid"
backend_has_changed: str = "fail_if__backend_has_changed"
environment_is_invalid: str = "fail_if__environment_is_invalid"
foreign_keys_are_invalid_in_data: str = "fail_if__foreign_keys_are_invalid_in_data"
group_ids_are_outside_top_level_namespace: str = (
"fail_if__group_ids_are_outside_top_level_namespace"
)
group_variables_are_not_constant_within_groups: str = (
"fail_if__group_variables_are_not_constant_within_groups"
)
input_data_is_invalid: str = "fail_if__input_data_is_invalid"
input_data_tree_is_invalid: str = "fail_if__input_data_tree_is_invalid"
input_df_has_bool_or_numeric_column_names: str = (
"fail_if__input_df_has_bool_or_numeric_column_names"
)
input_df_mapper_columns_missing_in_df: str = (
"fail_if__input_df_mapper_columns_missing_in_df"
)
input_df_mapper_has_incorrect_format: str = (
"fail_if__input_df_mapper_has_incorrect_format"
)
non_convertible_objects_in_results_tree: str = (
"fail_if__non_convertible_objects_in_results_tree"
)
param_function_depends_on_column_objects: str = (
"fail_if__param_function_depends_on_column_objects"
)
paths_are_missing_in_targets_tree_mapper: str = (
"fail_if__paths_are_missing_in_targets_tree_mapper"
)
tt_root_nodes_are_missing: str = "fail_if__tt_root_nodes_are_missing"
targets_are_not_in_specialized_environment_or_data: str = (
"fail_if__targets_are_not_in_specialized_environment_or_data"
)
targets_tree_is_invalid: str = "fail_if__targets_tree_is_invalid"


@dataclass(frozen=True)
class Results(MainTargetABC):
df_with_mapper: str = "results__df_with_mapper"
df_with_nested_columns: str = "results__df_with_nested_columns"
tree: str = "results__tree"


@dataclass(frozen=True)
class RawResults(MainTargetABC):
columns: str = "raw_results__columns"
combined: str = "raw_results__combined"
from_input_data: str = "raw_results__from_input_data"
params: str = "raw_results__params"


@dataclass(frozen=True)
class SpecializedEnvironment(MainTargetABC):
without_tree_logic_and_with_derived_functions: str = (
"specialized_environment__without_tree_logic_and_with_derived_functions"
)
with_processed_params_and_scalars: str = (
"specialized_environment__with_processed_params_and_scalars"
)
with_partialled_params_and_scalars: str = (
"specialized_environment__with_partialled_params_and_scalars"
)
tax_transfer_dag: str = "specialized_environment__tax_transfer_dag"
tax_transfer_function: str = "specialized_environment__tax_transfer_function"


@dataclass(frozen=True)
class Targets(MainTargetABC):
qname: str = "tt_targets__qname"
tree: str = "tt_targets__tree"


@dataclass(frozen=True)
class Labels(MainTargetABC):
column_targets: str = "labels__column_targets"
grouping_levels: str = "labels__grouping_levels"
input_data_targets: str = "labels__input_data_targets"
param_targets: str = "labels__param_targets"
processed_data_columns: str = "labels__processed_data_columns"
input_columns: str = "labels__input_columns"
root_nodes: str = "labels__root_nodes"
top_level_namespace: str = "labels__top_level_namespace"


@dataclass(frozen=True)
class DfAndMapper(MainTargetABC):
df: str = "input_data__df_and_mapper__df"
mapper: str = "input_data__df_and_mapper__mapper"


@dataclass(frozen=True)
class InputData(MainTargetABC):
df_and_mapper: type[DfAndMapper] = field(default=DfAndMapper)
df_with_nested_columns: str = "input_data__df_with_nested_columns"
flat: str = "input_data__flat"
tree: str = "input_data__tree"


@dataclass(frozen=True)
class OrigPolicyObjects(MainTargetABC):
column_objects_and_param_functions: str = (
"orig_policy_objects__column_objects_and_param_functions"
)
param_specs: str = "orig_policy_objects__param_specs"
# Do not include root here, will be pre-defined in user-facing implementations.


@dataclass(frozen=True)
class Templates(MainTargetABC):
input_data_dtypes: str = "templates__input_data_dtypes"


@dataclass(frozen=True)
class MainTarget(MainTargetABC):
results: type[Results] = field(default=Results)
templates: type[Templates] = field(default=Templates)
policy_environment: str = "policy_environment"
specialized_environment: type[SpecializedEnvironment] = field(
default=SpecializedEnvironment
)
orig_policy_objects: type[OrigPolicyObjects] = field(default=OrigPolicyObjects)
processed_data: str = "processed_data"
raw_results: type[RawResults] = field(default=RawResults)
labels: type[Labels] = field(default=Labels)
input_data: type[InputData] = field(default=InputData)
tt_targets: type[Targets] = field(default=Targets)
num_segments: str = "num_segments"
backend: str = "backend"
evaluation_date_str: str = "evaluation_date_str"
evaluation_date: str = "evaluation_date"
policy_date_str: str = "policy_date_str"
policy_date: str = "policy_date"
xnp: str = "xnp"
dnp: str = "dnp"
rounding: str = "rounding"
tt_function_set_annotations: str = "tt_function_set_annotations"
warn_if: type[WarnIf] = field(default=WarnIf)
fail_if: type[FailIf] = field(default=FailIf)
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
get_re_pattern_for_specific_time_units_and_groupings,
group_pattern,
)
from ttsim.tt_dag_elements.aggregation import grouped_sum
from ttsim.tt_dag_elements.column_objects_param_function import (
from ttsim.tt.aggregation import grouped_sum
from ttsim.tt.column_objects_param_function import (
DEFAULT_END_DATE,
DEFAULT_START_DATE,
AggByGroupFunction,
Expand All @@ -22,13 +22,13 @@
ParamFunction,
TimeConversionFunction,
)
from ttsim.tt_dag_elements.param_objects import ScalarParam
from ttsim.tt.param_objects import ScalarParam

if TYPE_CHECKING:
import re
from collections.abc import Callable

from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
BoolColumn,
FloatColumn,
IntColumn,
Expand Down Expand Up @@ -584,7 +584,7 @@ def func(x: BoolColumn) -> FloatColumn: ...
def func(x: FloatColumn | IntColumn | BoolColumn) -> FloatColumn:
return converter(x)

return func # type: ignore[has-type]
return func


def create_agg_by_group_functions(
Expand Down
2 changes: 1 addition & 1 deletion src/ttsim/interface_dag_elements/data_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if TYPE_CHECKING:
from types import ModuleType

from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
FlatData,
NestedData,
NestedInputsMapper,
Expand Down
2 changes: 1 addition & 1 deletion src/ttsim/interface_dag_elements/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
import datetime

from ttsim.interface_dag_elements.typing import DashedISOString
from ttsim.typing import DashedISOString


@interface_input(in_top_level_namespace=True)
Expand Down
6 changes: 3 additions & 3 deletions src/ttsim/interface_dag_elements/fail_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@

from ttsim.interface_dag_elements.interface_node_objects import fail_function
from ttsim.interface_dag_elements.shared import get_name_of_group_by_id
from ttsim.tt_dag_elements.column_objects_param_function import (
from ttsim.tt.column_objects_param_function import (
DEFAULT_END_DATE,
ColumnFunction,
ColumnObject,
FKType,
ParamFunction,
PolicyInput,
)
from ttsim.tt_dag_elements.param_objects import (
from ttsim.tt.param_objects import (
PLACEHOLDER_FIELD,
PLACEHOLDER_VALUE,
ParamObject,
Expand All @@ -41,7 +41,7 @@
from collections.abc import Callable

from ttsim.interface_dag_elements.input_data import FlatData
from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
FlatColumnObjectsParamFunctions,
FlatOrigParamSpecs,
NestedData,
Expand Down
2 changes: 1 addition & 1 deletion src/ttsim/interface_dag_elements/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pandas as pd

from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
FlatData,
NestedData,
NestedInputsMapper,
Expand Down
2 changes: 1 addition & 1 deletion src/ttsim/interface_dag_elements/interface_node_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from ttsim.interface_dag_elements.typing import UnorderedQNames
from ttsim.typing import UnorderedQNames


FunArgTypes = ParamSpec("FunArgTypes")
Expand Down
4 changes: 2 additions & 2 deletions src/ttsim/interface_dag_elements/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
get_re_pattern_for_all_time_units_and_groupings,
group_pattern,
)
from ttsim.tt_dag_elements.column_objects_param_function import PolicyInput
from ttsim.tt.column_objects_param_function import PolicyInput

if TYPE_CHECKING:
from ttsim.interface_dag_elements.typing import (
from ttsim.typing import (
OrderedQNames,
PolicyEnvironment,
QNameData,
Expand Down
Loading
Loading