From cb5008c85acf9afb06ebc5c0371fe0c6a5961055 Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Fri, 11 Apr 2025 09:51:58 +0200 Subject: [PATCH 1/3] Add start_date, end_date checks for aggregate by p_id and aggregate by group. --- tests/ttsim/test_combine_functions.py | 136 ++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/tests/ttsim/test_combine_functions.py b/tests/ttsim/test_combine_functions.py index e2a49ea33..2557b9e6d 100644 --- a/tests/ttsim/test_combine_functions.py +++ b/tests/ttsim/test_combine_functions.py @@ -16,6 +16,8 @@ ) from ttsim.compute_taxes_and_transfers import compute_taxes_and_transfers from ttsim.function_types import ( + DEFAULT_END_DATE, + DEFAULT_START_DATE, DerivedAggregationFunction, group_by_function, policy_function, @@ -454,6 +456,140 @@ def test_create_aggregation_with_derived_soure_column(): assert "bar_bg" in inspect.signature(result["foo_hh"]).parameters +@pytest.mark.parametrize( + ( + "aggregation_target", + "aggregation_spec", + "group_by_id", + "functions", + "top_level_namespace", + "expected_start_date", + "expected_end_date", + ), + [ + ( + "foo_hh", + AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM), + "hh_id", + {"foo": policy_function(leaf_name="foo")(lambda x: x)}, + ["foo", "foo_hh", "hh_id"], + DEFAULT_START_DATE, + DEFAULT_END_DATE, + ), + ( + "foo_hh", + AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM), + "hh_id", + {}, + ["foo", "foo_hh", "hh_id"], + DEFAULT_START_DATE, + DEFAULT_END_DATE, + ), + ( + "foo_hh", + AggregateByGroupSpec(source="foo", aggr=AggregationType.SUM), + "hh_id", + { + "foo": policy_function( + leaf_name="foo", start_date="2025-01-01", end_date="2025-12-31" + )(lambda x: x) + }, + ["foo", "foo_hh", "hh_id"], + datetime.date.fromisoformat("2025-01-01"), + datetime.date.fromisoformat("2025-12-31"), + ), + ], +) +def test_aggregate_by_group_function_start_and_end_date( + aggregation_target, + aggregation_spec, + group_by_id, + functions, + top_level_namespace, + expected_start_date, + expected_end_date, +): + result = _create_one_aggregate_by_group_func( + aggregation_target=aggregation_target, + aggregation_spec=aggregation_spec, + group_by_id=group_by_id, + functions=functions, + top_level_namespace=top_level_namespace, + ) + assert result.start_date == expected_start_date + assert result.end_date == expected_end_date + + +@pytest.mark.parametrize( + ( + "aggregation_target", + "aggregation_spec", + "functions", + "top_level_namespace", + "expected_start_date", + "expected_end_date", + ), + [ + ( + "bar", + AggregateByPIDSpec( + source="foo", + aggr=AggregationType.SUM, + p_id_to_aggregate_by="foreign_id_col", + ), + {"foo": policy_function(leaf_name="foo")(lambda x: x)}, + ["foo", "bar", "foreign_id_col"], + DEFAULT_START_DATE, + DEFAULT_END_DATE, + ), + ( + "bar", + AggregateByPIDSpec( + source="foo", + aggr=AggregationType.SUM, + p_id_to_aggregate_by="foreign_id_col", + ), + {}, + ["foo", "bar", "foreign_id_col"], + DEFAULT_START_DATE, + DEFAULT_END_DATE, + ), + ( + "bar", + AggregateByPIDSpec( + source="foo", + aggr=AggregationType.SUM, + p_id_to_aggregate_by="foreign_id_col", + ), + { + "foo": policy_function( + leaf_name="foo", start_date="2025-01-01", end_date="2025-12-31" + )(lambda x: x) + }, + ["foo", "bar", "foreign_id_col"], + datetime.date.fromisoformat("2025-01-01"), + datetime.date.fromisoformat("2025-12-31"), + ), + ], +) +def test_aggregate_by_p_id_function_start_and_end_date( + aggregation_target, + aggregation_spec, + functions, + top_level_namespace, + expected_start_date, + expected_end_date, +): + result = _create_one_aggregate_by_p_id_func( + aggregation_target=aggregation_target, + aggregation_spec=aggregation_spec, + functions=functions, + top_level_namespace=top_level_namespace, + ) + assert result.start_date == expected_start_date + assert result.end_date == expected_end_date + + @pytest.mark.parametrize( ( "aggregation_target", From 908896059eb35b32718f075cef5677d6ed1d58fb Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Fri, 11 Apr 2025 09:53:13 +0200 Subject: [PATCH 2/3] Use default start and end date for data cols in aggregate by pid. --- src/ttsim/combine_functions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ttsim/combine_functions.py b/src/ttsim/combine_functions.py index cb194767d..a8a659955 100644 --- a/src/ttsim/combine_functions.py +++ b/src/ttsim/combine_functions.py @@ -499,15 +499,20 @@ def agg_func(column, p_id_to_aggregate_by, p_id_to_store_by): else None ) - source_function = functions[qual_name_source] + if qual_name_source in functions: + start_date = functions[qual_name_source].start_date + end_date = functions[qual_name_source].end_date + else: + start_date = DEFAULT_START_DATE + end_date = DEFAULT_END_DATE return DerivedAggregationFunction( leaf_name=dt.tree_path_from_qual_name(aggregation_target)[-1], function=wrapped_func, source=qual_name_source, aggregation_method=aggregation_method, - start_date=source_function.start_date, - end_date=source_function.end_date, + start_date=start_date, + end_date=end_date, ) From 70d4e9a6c7239cbe09de5d84a1f7d63e1fe8b14b Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Fri, 11 Apr 2025 10:15:40 +0200 Subject: [PATCH 3/3] Add infrastructure for METTSIM tests but don't add test data yet. --- tests/ttsim/_policy_test_utils.py | 204 ++++++++++++++++++++++++++++++ tests/ttsim/test_mettsim.py | 17 +++ 2 files changed, 221 insertions(+) create mode 100644 tests/ttsim/_policy_test_utils.py create mode 100644 tests/ttsim/test_mettsim.py diff --git a/tests/ttsim/_policy_test_utils.py b/tests/ttsim/_policy_test_utils.py new file mode 100644 index 000000000..0cd5749d5 --- /dev/null +++ b/tests/ttsim/_policy_test_utils.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import dags.tree as dt +import pandas as pd +import yaml + +from ttsim import merge_trees + +TEST_DIR = Path(__file__).parent / "test_data" + +if TYPE_CHECKING: + from ttsim import NestedDataDict, NestedInputStructureDict + + +class PolicyTest: + """A class for a single policy test.""" + + def __init__( + self, + info: NestedDataDict, + input_tree: NestedDataDict, + expected_output_tree: NestedDataDict, + path: Path, + date: datetime.date, + ) -> None: + self.info = info + self.input_tree = input_tree + self.expected_output_tree = expected_output_tree + self.path = path + self.date = date + + @property + def target_structure(self) -> NestedInputStructureDict: + flat_target_structure = dict.fromkeys( + dt.flatten_to_tree_paths(self.expected_output_tree) + ) + return dt.unflatten_from_tree_paths(flat_target_structure) + + @property + def test_name(self) -> str: + return self.path.relative_to(TEST_DIR / "test_data").as_posix() + + +def execute_test(test: PolicyTest): + from pandas.testing import assert_frame_equal + + from _gettsim_tests._helpers import cached_set_up_policy_environment + from ttsim import compute_taxes_and_transfers + + environment = cached_set_up_policy_environment(date=test.date) + + result = compute_taxes_and_transfers( + data_tree=test.input_tree, + environment=environment, + targets_tree=test.target_structure, + ) + + flat_result = dt.flatten_to_qual_names(result) + flat_expected_output_tree = dt.flatten_to_qual_names(test.expected_output_tree) + + if flat_expected_output_tree: + result_dataframe = pd.DataFrame(flat_result) + expected_dataframe = pd.DataFrame(flat_expected_output_tree) + assert_frame_equal( + result_dataframe, + expected_dataframe, + atol=test.info["precision"], + check_dtype=False, + ) + + +def get_policy_test_ids_and_cases() -> dict[str, PolicyTest]: + all_policy_tests = load_policy_test_data("") + return {policy_test.test_name: policy_test for policy_test in all_policy_tests} + + +def load_policy_test_data(policy_name: str) -> list[PolicyTest]: + out = [] + + for path_to_yaml in (TEST_DIR / "test_data" / policy_name).glob("**/*.yaml"): + if _is_skipped(path_to_yaml): + continue + + with path_to_yaml.open("r", encoding="utf-8") as file: + raw_test_data: NestedDataDict = yaml.safe_load(file) + + out.extend( + _get_policy_tests_from_raw_test_data( + raw_test_data=raw_test_data, + path_to_yaml=path_to_yaml, + ) + ) + + return out + + +def get_test_data_as_tree(test_data: NestedDataDict) -> NestedDataDict: + provided_inputs = test_data["inputs"].get("provided", {}) + assumed_inputs = test_data["inputs"].get("assumed", {}) + + unflattened_dict = {} + unflattened_dict["inputs"] = {} + unflattened_dict["outputs"] = {} + + if provided_inputs: + unflattened_dict["inputs"]["provided"] = dt.unflatten_from_qual_names( + provided_inputs + ) + else: + unflattened_dict["inputs"]["provided"] = {} + if assumed_inputs: + unflattened_dict["inputs"]["assumed"] = dt.unflatten_from_qual_names( + assumed_inputs + ) + else: + unflattened_dict["inputs"]["assumed"] = {} + + unflattened_dict["outputs"] = dt.unflatten_from_qual_names(test_data["outputs"]) + + return unflattened_dict["inputs"], unflattened_dict["outputs"] + + +def _is_skipped(test_file: Path) -> bool: + return "skip" in test_file.stem or "skip" in test_file.parent.name + + +def _get_policy_tests_from_raw_test_data( + raw_test_data: NestedDataDict, path_to_yaml: Path +) -> list[PolicyTest]: + """Get a list of PolicyTest objects from raw test data. + + Args: + raw_test_data: The raw test data. + path_to_yaml: The path to the YAML file. + + Returns: + A list of PolicyTest objects. + """ + test_info: NestedDataDict = raw_test_data.get("info", {}) + inputs: NestedDataDict = raw_test_data.get("inputs", {}) + input_tree: NestedDataDict = dt.unflatten_from_tree_paths( + { + k: pd.Series(v) + for k, v in dt.flatten_to_tree_paths( + merge_trees(inputs.get("provided", {}), inputs.get("assumed", {})) + ).items() + } + ) + + expected_output_tree: NestedDataDict = dt.unflatten_from_tree_paths( + { + k: pd.Series(v) + for k, v in dt.flatten_to_tree_paths( + raw_test_data.get("outputs", {}) + ).items() + } + ) + + date: datetime.date = _parse_date(path_to_yaml.parent.name) + + out = [] + if expected_output_tree == {}: + out.append( + PolicyTest( + info=test_info, + input_tree=input_tree, + expected_output_tree={}, + path=path_to_yaml, + date=date, + ) + ) + else: + for target_name, output_data in dt.flatten_to_tree_paths( + expected_output_tree + ).items(): + one_expected_output: NestedDataDict = dt.unflatten_from_tree_paths( + {target_name: output_data} + ) + out.append( + PolicyTest( + info=test_info, + input_tree=input_tree, + expected_output_tree=one_expected_output, + path=path_to_yaml, + date=date, + ) + ) + + return out + + +def _parse_date(date: str) -> datetime.date: + parts = date.split("-") + + if len(parts) == 1: + return datetime.date(int(parts[0]), 1, 1) + if len(parts) == 2: + return datetime.date(int(parts[0]), int(parts[1]), 1) + if len(parts) == 3: + return datetime.date(int(parts[0]), int(parts[1]), int(parts[2])) diff --git a/tests/ttsim/test_mettsim.py b/tests/ttsim/test_mettsim.py new file mode 100644 index 000000000..e08ea18d6 --- /dev/null +++ b/tests/ttsim/test_mettsim.py @@ -0,0 +1,17 @@ +import pytest +from _policy_test_utils import ( + PolicyTest, + execute_test, + get_policy_test_ids_and_cases, +) + +policy_test_ids_and_cases = get_policy_test_ids_and_cases() + + +@pytest.mark.parametrize( + "test", + policy_test_ids_and_cases.values(), + ids=policy_test_ids_and_cases.keys(), +) +def test_mettsim(test: PolicyTest): + execute_test(test)