Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/ttsim/combine_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,20 @@ def agg_func(column, p_id_to_aggregate_by, p_id_to_store_by):
else None
)

source_function = functions[qual_name_source]
if qual_name_source in functions:
start_date = functions[qual_name_source].start_date
end_date = functions[qual_name_source].end_date
else:
start_date = DEFAULT_START_DATE
end_date = DEFAULT_END_DATE

return DerivedAggregationFunction(
leaf_name=dt.tree_path_from_qual_name(aggregation_target)[-1],
function=wrapped_func,
source=qual_name_source,
aggregation_method=aggregation_method,
start_date=source_function.start_date,
end_date=source_function.end_date,
start_date=start_date,
end_date=end_date,
)


Expand Down
204 changes: 204 additions & 0 deletions tests/ttsim/_policy_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from __future__ import annotations

import datetime
from pathlib import Path
from typing import TYPE_CHECKING

import dags.tree as dt
import pandas as pd
import yaml

from ttsim import merge_trees

TEST_DIR = Path(__file__).parent / "test_data"

if TYPE_CHECKING:
from ttsim import NestedDataDict, NestedInputStructureDict


class PolicyTest:
"""A class for a single policy test."""

def __init__(
self,
info: NestedDataDict,
input_tree: NestedDataDict,
expected_output_tree: NestedDataDict,
path: Path,
date: datetime.date,
) -> None:
self.info = info
self.input_tree = input_tree
self.expected_output_tree = expected_output_tree
self.path = path
self.date = date

@property
def target_structure(self) -> NestedInputStructureDict:
flat_target_structure = dict.fromkeys(
dt.flatten_to_tree_paths(self.expected_output_tree)
)
return dt.unflatten_from_tree_paths(flat_target_structure)

@property
def test_name(self) -> str:
return self.path.relative_to(TEST_DIR / "test_data").as_posix()


def execute_test(test: PolicyTest):
from pandas.testing import assert_frame_equal

from _gettsim_tests._helpers import cached_set_up_policy_environment
from ttsim import compute_taxes_and_transfers

environment = cached_set_up_policy_environment(date=test.date)

result = compute_taxes_and_transfers(
data_tree=test.input_tree,
environment=environment,
targets_tree=test.target_structure,
)

flat_result = dt.flatten_to_qual_names(result)
flat_expected_output_tree = dt.flatten_to_qual_names(test.expected_output_tree)

if flat_expected_output_tree:
result_dataframe = pd.DataFrame(flat_result)
expected_dataframe = pd.DataFrame(flat_expected_output_tree)
assert_frame_equal(
result_dataframe,
expected_dataframe,
atol=test.info["precision"],
check_dtype=False,
)


def get_policy_test_ids_and_cases() -> dict[str, PolicyTest]:
all_policy_tests = load_policy_test_data("")
return {policy_test.test_name: policy_test for policy_test in all_policy_tests}


def load_policy_test_data(policy_name: str) -> list[PolicyTest]:
out = []

for path_to_yaml in (TEST_DIR / "test_data" / policy_name).glob("**/*.yaml"):
if _is_skipped(path_to_yaml):
continue

with path_to_yaml.open("r", encoding="utf-8") as file:
raw_test_data: NestedDataDict = yaml.safe_load(file)

out.extend(
_get_policy_tests_from_raw_test_data(
raw_test_data=raw_test_data,
path_to_yaml=path_to_yaml,
)
)

return out


def get_test_data_as_tree(test_data: NestedDataDict) -> NestedDataDict:
provided_inputs = test_data["inputs"].get("provided", {})
assumed_inputs = test_data["inputs"].get("assumed", {})

unflattened_dict = {}
unflattened_dict["inputs"] = {}
unflattened_dict["outputs"] = {}

if provided_inputs:
unflattened_dict["inputs"]["provided"] = dt.unflatten_from_qual_names(
provided_inputs
)
else:
unflattened_dict["inputs"]["provided"] = {}
if assumed_inputs:
unflattened_dict["inputs"]["assumed"] = dt.unflatten_from_qual_names(
assumed_inputs
)
else:
unflattened_dict["inputs"]["assumed"] = {}

unflattened_dict["outputs"] = dt.unflatten_from_qual_names(test_data["outputs"])

return unflattened_dict["inputs"], unflattened_dict["outputs"]


def _is_skipped(test_file: Path) -> bool:
return "skip" in test_file.stem or "skip" in test_file.parent.name


def _get_policy_tests_from_raw_test_data(
raw_test_data: NestedDataDict, path_to_yaml: Path
) -> list[PolicyTest]:
"""Get a list of PolicyTest objects from raw test data.

Args:
raw_test_data: The raw test data.
path_to_yaml: The path to the YAML file.

Returns:
A list of PolicyTest objects.
"""
test_info: NestedDataDict = raw_test_data.get("info", {})
inputs: NestedDataDict = raw_test_data.get("inputs", {})
input_tree: NestedDataDict = dt.unflatten_from_tree_paths(
{
k: pd.Series(v)
for k, v in dt.flatten_to_tree_paths(
merge_trees(inputs.get("provided", {}), inputs.get("assumed", {}))
).items()
}
)

expected_output_tree: NestedDataDict = dt.unflatten_from_tree_paths(
{
k: pd.Series(v)
for k, v in dt.flatten_to_tree_paths(
raw_test_data.get("outputs", {})
).items()
}
)

date: datetime.date = _parse_date(path_to_yaml.parent.name)

out = []
if expected_output_tree == {}:
out.append(
PolicyTest(
info=test_info,
input_tree=input_tree,
expected_output_tree={},
path=path_to_yaml,
date=date,
)
)
else:
for target_name, output_data in dt.flatten_to_tree_paths(
expected_output_tree
).items():
one_expected_output: NestedDataDict = dt.unflatten_from_tree_paths(
{target_name: output_data}
)
out.append(
PolicyTest(
info=test_info,
input_tree=input_tree,
expected_output_tree=one_expected_output,
path=path_to_yaml,
date=date,
)
)

return out


def _parse_date(date: str) -> datetime.date:
parts = date.split("-")

if len(parts) == 1:
return datetime.date(int(parts[0]), 1, 1)
if len(parts) == 2:
return datetime.date(int(parts[0]), int(parts[1]), 1)
if len(parts) == 3:
return datetime.date(int(parts[0]), int(parts[1]), int(parts[2]))
136 changes: 136 additions & 0 deletions tests/ttsim/test_combine_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from ttsim.compute_taxes_and_transfers import compute_taxes_and_transfers
from ttsim.function_types import (
DEFAULT_END_DATE,
DEFAULT_START_DATE,
DerivedAggregationFunction,
group_by_function,
policy_function,
Expand Down Expand Up @@ -454,6 +456,140 @@ def test_create_aggregation_with_derived_soure_column():
assert "bar_bg" in inspect.signature(result["foo_hh"]).parameters


@pytest.mark.parametrize(
(
"aggregation_target",
"aggregation_spec",
"group_by_id",
"functions",
"top_level_namespace",
"expected_start_date",
"expected_end_date",
),
[
(
"foo_hh",
AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM),
"hh_id",
{"foo": policy_function(leaf_name="foo")(lambda x: x)},
["foo", "foo_hh", "hh_id"],
DEFAULT_START_DATE,
DEFAULT_END_DATE,
),
(
"foo_hh",
AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM),
"hh_id",
{},
["foo", "foo_hh", "hh_id"],
DEFAULT_START_DATE,
DEFAULT_END_DATE,
),
(
"foo_hh",
AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM),
"hh_id",
{
"foo": policy_function(
leaf_name="foo", start_date="2025-01-01", end_date="2025-12-31"
)(lambda x: x)
},
["foo", "foo_hh", "hh_id"],
datetime.date.fromisoformat("2025-01-01"),
datetime.date.fromisoformat("2025-12-31"),
),
],
)
def test_aggregate_by_group_function_start_and_end_date(
aggregation_target,
aggregation_spec,
group_by_id,
functions,
top_level_namespace,
expected_start_date,
expected_end_date,
):
result = _create_one_aggregate_by_group_func(
aggregation_target=aggregation_target,
aggregation_spec=aggregation_spec,
group_by_id=group_by_id,
functions=functions,
top_level_namespace=top_level_namespace,
)
assert result.start_date == expected_start_date
assert result.end_date == expected_end_date


@pytest.mark.parametrize(
(
"aggregation_target",
"aggregation_spec",
"functions",
"top_level_namespace",
"expected_start_date",
"expected_end_date",
),
[
(
"bar",
AggregateByPIDSpec(
source="foo",
aggr=AggregationType.SUM,
p_id_to_aggregate_by="foreign_id_col",
),
{"foo": policy_function(leaf_name="foo")(lambda x: x)},
["foo", "bar", "foreign_id_col"],
DEFAULT_START_DATE,
DEFAULT_END_DATE,
),
(
"bar",
AggregateByPIDSpec(
source="foo",
aggr=AggregationType.SUM,
p_id_to_aggregate_by="foreign_id_col",
),
{},
["foo", "bar", "foreign_id_col"],
DEFAULT_START_DATE,
DEFAULT_END_DATE,
),
(
"bar",
AggregateByPIDSpec(
source="foo",
aggr=AggregationType.SUM,
p_id_to_aggregate_by="foreign_id_col",
),
{
"foo": policy_function(
leaf_name="foo", start_date="2025-01-01", end_date="2025-12-31"
)(lambda x: x)
},
["foo", "bar", "foreign_id_col"],
datetime.date.fromisoformat("2025-01-01"),
datetime.date.fromisoformat("2025-12-31"),
),
],
)
def test_aggregate_by_p_id_function_start_and_end_date(
aggregation_target,
aggregation_spec,
functions,
top_level_namespace,
expected_start_date,
expected_end_date,
):
result = _create_one_aggregate_by_p_id_func(
aggregation_target=aggregation_target,
aggregation_spec=aggregation_spec,
functions=functions,
top_level_namespace=top_level_namespace,
)
assert result.start_date == expected_start_date
assert result.end_date == expected_end_date


@pytest.mark.parametrize(
(
"aggregation_target",
Expand Down
17 changes: 17 additions & 0 deletions tests/ttsim/test_mettsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from _policy_test_utils import (
PolicyTest,
execute_test,
get_policy_test_ids_and_cases,
)

policy_test_ids_and_cases = get_policy_test_ids_and_cases()


@pytest.mark.parametrize(
"test",
policy_test_ids_and_cases.values(),
ids=policy_test_ids_and_cases.keys(),
)
def test_mettsim(test: PolicyTest):
execute_test(test)
Loading