Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/_gettsim_tests/test_full_taxes_and_transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
load_policy_test_data,
)
from ttsim import compute_taxes_and_transfers
from ttsim.function_types import PolicyInput
from ttsim.typing import check_series_has_expected_type
from ttsim.function_types import PolicyInput, check_series_has_expected_type

test_data = load_policy_test_data("full_taxes_and_transfers")

Expand Down
119 changes: 1 addition & 118 deletions src/ttsim/compute_taxes_and_transfers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import inspect
import warnings
from typing import TYPE_CHECKING, Any, get_args
from typing import TYPE_CHECKING, Any

import dags
import dags.tree as dt
Expand All @@ -15,9 +15,7 @@
)
from ttsim.config import numpy_or_jax as np
from ttsim.function_types import (
DerivedAggregationFunction,
GroupByFunction,
PolicyFunction,
PolicyInput,
TTSIMFunction,
)
Expand All @@ -34,10 +32,6 @@
partition_by_reference_dict,
)
from ttsim.time_conversion import TIME_UNITS
from ttsim.typing import (
check_series_has_expected_type,
convert_series_to_internal_type,
)

if TYPE_CHECKING:
from ttsim.typing import (
Expand Down Expand Up @@ -128,10 +122,6 @@ def compute_taxes_and_transfers(
)

_warn_if_functions_overridden_by_data(functions_overridden)
# data_with_correct_types = _convert_data_to_correct_types(
# data=data,
# functions_overridden=functions_overridden,
# )

functions_with_rounding_specs = (
_add_rounding_to_functions(functions=functions_to_be_used)
Expand Down Expand Up @@ -224,113 +214,6 @@ def _get_top_level_namespace(
return all_top_level_names


def _convert_data_to_correct_types(
data: QualNameDataDict, functions_overridden: QualNameTTSIMFunctionDict
) -> QualNameDataDict:
"""Convert all data columns to the type that is expected by GETTSIM.

Parameters
----------
data
Data provided by the user.
functions_overridden
Functions that are overridden by data.

Returns
-------
Data with correct types.

"""
collected_errors = ["The data types of the following columns are invalid:\n"]
collected_conversions = [
"The data types of the following input variables have been converted:"
]
general_warning = (
"Note that the automatic conversion of data types is unsafe and that"
" its correctness cannot be guaranteed."
" The best solution is to convert all columns to the expected data"
" types yourself."
)

data_with_correct_types = {}

for name, series in data.items():
internal_type = None

# Look for column in TYPES_INPUT_VARIABLES
types_qualified_input_variables = dt.flatten_to_qual_names(
TYPES_INPUT_VARIABLES
)
if name in types_qualified_input_variables:
internal_type = types_qualified_input_variables[name]
# Look for column in functions_tree_overridden
elif name in functions_overridden:
func = functions_overridden[name]
func_is_group_by_function = isinstance(
getattr(func, "__wrapped__", func), GroupByFunction
)
func_is_policy_function = isinstance(
getattr(func, "__wrapped__", func), PolicyFunction
) and not isinstance(
getattr(func, "__wrapped__", func), DerivedAggregationFunction
)
skip_vectorization = (
func.skip_vectorization if func_is_policy_function else True
)
return_annotation_is_array = (
func_is_group_by_function or func_is_policy_function
) and skip_vectorization
if return_annotation_is_array:
# Assumes that things are annotated with numpy.ndarray([dtype]), might
# require a change if using proper numpy.typing. Not changing for now
# as we will likely switch to JAX completely.
internal_type = get_args(func.__annotations__["return"])[0]
elif "return" in func.__annotations__:
internal_type = func.__annotations__["return"]
else:
pass
else:
pass

# Make conversion if necessary
if internal_type and not check_series_has_expected_type(
series=series, internal_type=internal_type
):
try:
converted_leaf = convert_series_to_internal_type(
series=series, internal_type=internal_type
)
data_with_correct_types[name] = converted_leaf
collected_conversions.append(
f" - {name} from {series.dtype} to {internal_type.__name__}"
)
except ValueError as e:
collected_errors.append(f"\n - {name}: {e}")
else:
data_with_correct_types[name] = series

# If any error occured raise Error
if len(collected_errors) > 1:
msg = """
Note that conversion from floating point to integers or Booleans inherently
suffers from approximation error. It might well be that your data seemingly
obey the restrictions when scrolling through them, but in fact they do not
(for example, because 1e-15 is displayed as 0.0). \n The best solution is to
convert all columns to the expected data types yourself.
"""
collected_errors = "\n".join(collected_errors)
raise ValueError(format_errors_and_warnings(collected_errors + msg))
# Otherwise raise warning which lists all successful conversions
elif len(collected_conversions) > 1:
collected_conversions = format_list_linewise(collected_conversions)
warnings.warn(
collected_conversions + "\n" + "\n" + general_warning,
stacklevel=2,
)

return data_with_correct_types


def _create_input_data_for_concatenated_function(
data: QualNameDataDict,
functions: QualNameTTSIMFunctionDict,
Expand Down
37 changes: 37 additions & 0 deletions src/ttsim/function_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@
from typing import TYPE_CHECKING, Literal, TypeVar

import numpy
from pandas.api.types import (
is_bool_dtype,
is_datetime64_any_dtype,
is_float_dtype,
is_integer_dtype,
)

from ttsim.rounding import RoundingSpec
from ttsim.shared import to_datetime, validate_date_range

if TYPE_CHECKING:
from collections.abc import Callable

import pandas as pd

from ttsim.config import numpy_or_jax as np
from ttsim.typing import DashedISOString

T = TypeVar("T")
Expand Down Expand Up @@ -423,3 +432,31 @@ def _convert_and_validate_dates(
validate_date_range(start_date, end_date)

return start_date, end_date


def check_series_has_expected_type(series: pd.Series, internal_type: np.dtype) -> bool:
"""Checks whether used series has already expected internal type.

Parameters
----------
series : pandas.Series or pandas.DataFrame or dict of pandas.Series
Data provided by the user.
internal_type : TypeVar
One of the internal gettsim types.

Returns
-------
Bool

"""
if (
(internal_type == float) & (is_float_dtype(series))
or (internal_type == int) & (is_integer_dtype(series))
or (internal_type == bool) & (is_bool_dtype(series))
or (internal_type == numpy.datetime64) & (is_datetime64_any_dtype(series))
):
out = True
else:
out = False

return out
140 changes: 3 additions & 137 deletions src/ttsim/typing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from typing import TYPE_CHECKING, NewType

import numpy
import pandas as pd
from pandas.api.types import (
is_bool_dtype,
is_datetime64_any_dtype,
is_float_dtype,
is_integer_dtype,
is_object_dtype,
)

from ttsim.config import numpy_or_jax as np

if TYPE_CHECKING:
from collections.abc import Mapping

import pandas as pd

# Make these available for import from other modules.
from dags.tree.typing import ( # noqa: F401
GenericCallable,
Expand All @@ -24,6 +14,7 @@
)

from ttsim.aggregation import AggregateByGroupSpec, AggregateByPIDSpec
from ttsim.config import numpy_or_jax as np
from ttsim.function_types import PolicyInput, TTSIMFunction, TTSIMObject

NestedTTSIMObjectDict = Mapping[str, TTSIMObject | "NestedTTSIMObjectDict"]
Expand All @@ -47,128 +38,3 @@

DashedISOString = NewType("DashedISOString", str)
"""A string representing a date in the format 'YYYY-MM-DD'."""


def check_series_has_expected_type(series: pd.Series, internal_type: np.dtype) -> bool:
"""Checks whether used series has already expected internal type.

Parameters
----------
series : pandas.Series or pandas.DataFrame or dict of pandas.Series
Data provided by the user.
internal_type : TypeVar
One of the internal gettsim types.

Returns
-------
Bool

"""
if (
(internal_type == float) & (is_float_dtype(series))
or (internal_type == int) & (is_integer_dtype(series))
or (internal_type == bool) & (is_bool_dtype(series))
or (internal_type == numpy.datetime64) & (is_datetime64_any_dtype(series))
):
out = True
else:
out = False

return out


def convert_series_to_internal_type(
series: pd.Series, internal_type: np.dtype
) -> pd.Series:
"""Check if data type of series fits to the internal type of gettsim and otherwise
convert data type of series to the internal type of gettsim.

Parameters
----------
series : pd.Series
Some data series.
internal_type : TypeVar
One of the internal gettsim types.

Returns
-------
out : adjusted pd.Series

"""
# Copy input series in out
out = series.copy()

basic_error_msg = (
f"Conversion from input type {out.dtype} to {internal_type.__name__} failed."
)
if is_object_dtype(out):
raise ValueError(basic_error_msg + " Object type is not supported as input.")
else:
# Conversion to float
if internal_type == float:
# Conversion from boolean to float fails
if is_bool_dtype(out):
raise ValueError(basic_error_msg + " This conversion is not supported.")
else:
try:
out = out.astype(float)
except ValueError as e:
raise ValueError(basic_error_msg) from e

# Conversion to int
elif internal_type == int:
if is_float_dtype(out):
# checking if decimal places are equal to 0, if not return error
if np.array_equal(out, out.astype(np.int64)):
out = out.astype(np.int64)
else:
raise ValueError(
basic_error_msg + " This conversion is only supported if all"
" decimal places of input data are equal to 0."
)
else:
try:
out = out.astype(np.int64)
except ValueError as e:
raise ValueError(basic_error_msg) from e

# Conversion to boolean
elif internal_type == bool:
# if input data type is integer
if is_integer_dtype(out):
# check if series consists only of 1 or 0
if len([v for v in out.unique() if v not in [1, 0]]) == 0:
out = out.astype(bool)
else:
raise ValueError(
basic_error_msg + " This conversion is only supported if"
" input data exclusively contains the values 1 and 0."
)
# if input data type is float
elif is_float_dtype(out):
# check if series consists only of 1.0 or 0.0
if len([v for v in out.unique() if v not in [1, 0]]) == 0:
out = out.astype(bool)
else:
raise ValueError(
basic_error_msg + " This conversion is only supported if"
" input data exclusively contains the values 1.0 and 0.0."
)

else:
raise ValueError(
basic_error_msg + " Conversion to boolean is only supported for"
" int and float columns."
)

# Conversion to DateTime
elif internal_type == np.datetime64:
if not is_datetime64_any_dtype(out):
try:
out = out.astype(np.datetime64)
except ValueError as e:
raise ValueError(basic_error_msg) from e
else:
raise ValueError(f"The internal type {internal_type} is not yet supported.")

return out
Loading
Loading