diff --git a/pixi.lock b/pixi.lock index 4fdf83c1a3..bdb100b6e0 100644 --- a/pixi.lock +++ b/pixi.lock @@ -3825,10 +3825,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda - pypi: git+https://github.com/opensourceeconomics/dags#d4e4b0b268be13472444cd9e291202513d0b1bcb - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f4/58/cc0721a1030fcbab0984beea0bf3c4610ec103f738423cdfa9c4ceb40598/jax-0.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/2d/cb/11bb92324afb6ba678f388e10b78d6b02196bc8887eb5aa0d85ce398edf9/jaxlib-0.5.0-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl + - pypi: https://files.pythonhosted.org/packages/38/bc/c4260e4a6c6bf684d0313308de1c860467275221d5e7daf69b3fcddfdd0b/ml_dtypes-0.5.1-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/93/d56fb9ba5569dc29d8263c72e46d21a2fd38741339ebf03f54cf7561828c/pdbp-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6f/e6e5aff77ea2a48dd96808bb51d7450875af154ee7cbe72188afb0b37929/scipy-1.15.2-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/31/c1/d73ff5900c6b462879039ac92f89424ad1eb544b1f6bd77f12f9c3013e20/types_networkx-3.4.2.20241227-py3-none-any.whl - pypi: . @@ -5339,6 +5344,26 @@ packages: purls: [] size: 21903 timestamp: 1694400856979 +- pypi: . + name: gettsim + version: 0.7.1.dev82+g6c6c0155.d20250222 + sha256: 468d7ee73fd31c448cd102912b902342be58fb2b7298aacca19f8f4eca97a50d + requires_dist: + - astor + - ipywidgets + - networkx + - numpy + - numpy-groupies + - openpyxl + - optree + - pandas + - plotly + - pygments + - pygraphviz + - pytest + - pyyaml + requires_python: '>=3.11' + editable: true - pypi: . name: gettsim version: 0.7.1.dev305+gb67bbac0.d20250215 @@ -6006,6 +6031,35 @@ packages: - pkg:pypi/isoduration?source=hash-mapping size: 17189 timestamp: 1638811664194 +- pypi: https://files.pythonhosted.org/packages/f4/58/cc0721a1030fcbab0984beea0bf3c4610ec103f738423cdfa9c4ceb40598/jax-0.5.0-py3-none-any.whl + name: jax + version: 0.5.0 + sha256: b3907aa87ae2c340b39cdbf80c07a74550369cafcaf7398fb60ba58d167345ab + requires_dist: + - jaxlib<=0.5.0,>=0.5.0 + - ml-dtypes>=0.4.0 + - numpy>=1.25 + - numpy>=1.26.0 ; python_full_version >= '3.12' + - opt-einsum + - scipy>=1.11.1 + - jaxlib==0.5.0 ; extra == 'minimum-jaxlib' + - jaxlib==0.4.38 ; extra == 'ci' + - jaxlib<=0.5.0,>=0.5.0 ; extra == 'tpu' + - libtpu-nightly==0.1.dev20241010+nightly.cleanup ; extra == 'tpu' + - libtpu==0.0.8 ; extra == 'tpu' + - requests ; extra == 'tpu' + - jaxlib==0.5.0 ; extra == 'cuda' + - jax-cuda12-plugin[with-cuda]<=0.5.0,>=0.5.0 ; extra == 'cuda' + - jaxlib==0.5.0 ; extra == 'cuda12' + - jax-cuda12-plugin[with-cuda]<=0.5.0,>=0.5.0 ; extra == 'cuda12' + - jaxlib==0.5.0 ; extra == 'cuda12-pip' + - jax-cuda12-plugin[with-cuda]<=0.5.0,>=0.5.0 ; extra == 'cuda12-pip' + - jaxlib==0.5.0 ; extra == 'cuda12-local' + - jax-cuda12-plugin==0.5.0 ; extra == 'cuda12-local' + - jaxlib==0.5.0 ; extra == 'rocm' + - jax-rocm60-plugin<=0.5.0,>=0.5.0 ; extra == 'rocm' + - kubernetes ; extra == 'k8s' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/jax-0.4.34-pyhd8ed1ab_0.conda sha256: da3880afc35042b92e0ad214beee372e41162da374dfaa8b1164db1dcee671b2 md5: f0d3c57355acf3f06d93672e57c0c1e8 @@ -6025,6 +6079,15 @@ packages: - pkg:pypi/jax?source=hash-mapping size: 1421248 timestamp: 1729336456855 +- pypi: https://files.pythonhosted.org/packages/2d/cb/11bb92324afb6ba678f388e10b78d6b02196bc8887eb5aa0d85ce398edf9/jaxlib-0.5.0-cp312-cp312-win_amd64.whl + name: jaxlib + version: 0.5.0 + sha256: 5baedbeeb60fa493c7528783254f04c6e986a2826266b198ed37e9336af2ef8c + requires_dist: + - scipy>=1.11.1 + - numpy>=1.25 + - ml-dtypes>=0.2.0 + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/jaxlib-0.4.34-cpu_py312haec0345_0.conda sha256: 8e2dce1d39ccb4e6883444f71b7155168c0612d9d086b58616578448aeb33afe md5: 08f587f0f6505671c7715163e18b6f3a @@ -8268,6 +8331,22 @@ packages: purls: [] size: 103019089 timestamp: 1727378392081 +- pypi: https://files.pythonhosted.org/packages/38/bc/c4260e4a6c6bf684d0313308de1c860467275221d5e7daf69b3fcddfdd0b/ml_dtypes-0.5.1-cp312-cp312-win_amd64.whl + name: ml-dtypes + version: 0.5.1 + sha256: 9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 + requires_dist: + - numpy>=1.21 + - numpy>=1.21.2 ; python_full_version >= '3.10' + - numpy>=1.23.3 ; python_full_version >= '3.11' + - numpy>=1.26.0 ; python_full_version >= '3.12' + - numpy>=2.1.0 ; python_full_version >= '3.13' + - absl-py ; extra == 'dev' + - pytest ; extra == 'dev' + - pytest-xdist ; extra == 'dev' + - pylint>=2.6.0 ; extra == 'dev' + - pyink ; extra == 'dev' + requires_python: '>=3.9' - conda: https://conda.anaconda.org/conda-forge/linux-64/ml_dtypes-0.5.0-py312hf9745cd_0.conda sha256: 559c14640ce8e3f2270da6130ba50ae624f3db56176fad29a5436b2dec3fc3b2 md5: 8ca779f3f30b00181aeee820fe8b22d5 @@ -8831,6 +8910,11 @@ packages: purls: [] size: 8491156 timestamp: 1731379715927 +- pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl + name: opt-einsum + version: 3.4.0 + sha256: 69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd + requires_python: '>=3.8' - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.4.0-hd8ed1ab_0.conda sha256: 583cb8748a9821e301a404806da0de62e8ba01607feecf12c0ef06d8bc77077e md5: 73d0b1d98a9030bdefe712648af583a0 @@ -9311,6 +9395,15 @@ packages: - tabcompleter>=1.4.0 - colorama>=0.4.6 ; platform_system == 'Windows' requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/29/93/d56fb9ba5569dc29d8263c72e46d21a2fd38741339ebf03f54cf7561828c/pdbp-1.6.1-py3-none-any.whl + name: pdbp + version: 1.6.1 + sha256: f10bad2ee044c0e5c168cb0825abfdbdc01c50013e9755df5261b060bdd35c22 + requires_dist: + - pygments>=2.18.0 + - tabcompleter>=1.4.0 + - colorama>=0.4.6 ; sys_platform == 'win32' + requires_python: '>=3.8' - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_0.conda sha256: 90a09d134a4a43911b716d4d6eb9d169238aff2349056f7323d9db613812667e md5: 629f3203c99b32e0988910c93e77f3b6 @@ -10858,6 +10951,49 @@ packages: - pkg:pypi/rpds-py?source=hash-mapping size: 210974 timestamp: 1730923229667 +- pypi: https://files.pythonhosted.org/packages/f5/6f/e6e5aff77ea2a48dd96808bb51d7450875af154ee7cbe72188afb0b37929/scipy-1.15.2-cp312-cp312-win_amd64.whl + name: scipy + version: 1.15.2 + sha256: e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded + requires_dist: + - numpy>=1.23.5,<2.5 + - pytest ; extra == 'test' + - pytest-cov ; extra == 'test' + - pytest-timeout ; extra == 'test' + - pytest-xdist ; extra == 'test' + - asv ; extra == 'test' + - mpmath ; extra == 'test' + - gmpy2 ; extra == 'test' + - threadpoolctl ; extra == 'test' + - scikit-umfpack ; extra == 'test' + - pooch ; extra == 'test' + - hypothesis>=6.30 ; extra == 'test' + - array-api-strict>=2.0,<2.1.1 ; extra == 'test' + - cython ; extra == 'test' + - meson ; extra == 'test' + - ninja ; sys_platform != 'emscripten' and extra == 'test' + - sphinx>=5.0.0,<8.0.0 ; extra == 'doc' + - intersphinx-registry ; extra == 'doc' + - pydata-sphinx-theme>=0.15.2 ; extra == 'doc' + - sphinx-copybutton ; extra == 'doc' + - sphinx-design>=0.4.0 ; extra == 'doc' + - matplotlib>=3.5 ; extra == 'doc' + - numpydoc ; extra == 'doc' + - jupytext ; extra == 'doc' + - myst-nb ; extra == 'doc' + - pooch ; extra == 'doc' + - jupyterlite-sphinx>=0.16.5 ; extra == 'doc' + - jupyterlite-pyodide-kernel ; extra == 'doc' + - mypy==1.10.0 ; extra == 'dev' + - typing-extensions ; extra == 'dev' + - types-psutil ; extra == 'dev' + - pycodestyle ; extra == 'dev' + - ruff>=0.0.292 ; extra == 'dev' + - cython-lint>=0.12.2 ; extra == 'dev' + - rich-click ; extra == 'dev' + - doit>=0.36.0 ; extra == 'dev' + - pydevtool ; extra == 'dev' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.14.1-py312h62794b6_1.conda sha256: d069a64edade554261672d8febf4756aeb56a6cb44bd91844eaa944e5d9f4eb9 md5: b43233a9e2f62fb94affe5607ea79473 @@ -11139,6 +11275,13 @@ packages: requires_dist: - pyreadline3 ; platform_system == 'Windows' requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + name: tabcompleter + version: 1.4.0 + sha256: d744aa735b49c0a6cc2fb8fcd40077fec47425e4388301010b14e6ce3311368b + requires_dist: + - pyreadline3 ; sys_platform == 'win32' + requires_python: '>=3.8' - conda: https://conda.anaconda.org/conda-forge/noarch/tabulate-0.9.0-pyhd8ed1ab_1.tar.bz2 sha256: f6e4a0dd24ba060a4af69ca79d32361a6678e61d78c73eb5e357909b025b4620 md5: 4759805cce2d914c38472f70bf4d8bcb diff --git a/src/_gettsim/combine_functions_in_tree.py b/src/_gettsim/combine_functions_in_tree.py index c3df12e330..ab912c457d 100644 --- a/src/_gettsim/combine_functions_in_tree.py +++ b/src/_gettsim/combine_functions_in_tree.py @@ -34,10 +34,10 @@ format_errors_and_warnings, format_list_linewise, get_names_of_arguments_without_defaults, + insert_path_and_value, partition_tree_by_reference_tree, remove_group_suffix, rename_arguments_and_add_annotations, - upsert_path_and_value, upsert_tree, ) from _gettsim.time_conversion import create_time_conversion_functions @@ -87,8 +87,8 @@ def combine_policy_functions_and_derived_functions( aggregation_type="p_id", ) current_functions_tree = upsert_tree( - base=environment.functions_tree, - to_upsert=aggregate_by_p_id_functions, + base=aggregate_by_p_id_functions, + to_upsert=environment.functions_tree, ) # Create functions for different time units @@ -97,8 +97,8 @@ def combine_policy_functions_and_derived_functions( data_tree=data_tree, ) current_functions_tree = upsert_tree( - base=current_functions_tree, - to_upsert=time_conversion_functions, + base=time_conversion_functions, + to_upsert=current_functions_tree, ) # Create aggregation functions @@ -109,15 +109,15 @@ def combine_policy_functions_and_derived_functions( aggregations_tree_provided_by_env=environment.aggregation_specs_tree, ) current_functions_tree = upsert_tree( - base=current_functions_tree, - to_upsert=aggregate_by_group_functions, + base=aggregate_by_group_functions, + to_upsert=current_functions_tree, ) # Create groupings groupings = create_groupings() current_functions_tree = upsert_tree( - base=current_functions_tree, - to_upsert=groupings, + base=groupings, + to_upsert=current_functions_tree, ) _fail_if_targets_not_in_functions_tree(current_functions_tree, targets_tree) @@ -200,10 +200,10 @@ def _create_aggregation_functions( annotations=annotations, ) - out_tree = upsert_path_and_value( + out_tree = insert_path_and_value( base=out_tree, - path_to_upsert=tree_path, - value_to_upsert=derived_func, + path_to_insert=tree_path, + value_to_insert=derived_func, ) return out_tree @@ -269,10 +269,10 @@ def _create_derived_aggregations_tree( ) and tree_path not in optree.tree_paths(aggregation_source_tree) if aggregation_specs_needed: - derived_aggregations_tree = upsert_path_and_value( + derived_aggregations_tree = insert_path_and_value( base=derived_aggregations_tree, - path_to_upsert=tree_path, - value_to_upsert=AggregateByGroupSpec( + path_to_insert=tree_path, + value_to_insert=AggregateByGroupSpec( aggr="sum", source_col=remove_group_suffix(leaf_name), ), @@ -310,9 +310,9 @@ def _get_potential_aggregation_function_names_from_function_arguments( name=name, namespace=tree_path[:-1], ) - current_tree = upsert_path_and_value( + current_tree = insert_path_and_value( base=current_tree, - path_to_upsert=path_of_function_argument, + path_to_insert=path_of_function_argument, ) return current_tree diff --git a/src/_gettsim/functions/loader.py b/src/_gettsim/functions/loader.py index 06b842a0eb..c8a93f3053 100644 --- a/src/_gettsim/functions/loader.py +++ b/src/_gettsim/functions/loader.py @@ -13,7 +13,11 @@ ) from _gettsim.functions.policy_function import PolicyFunction from _gettsim.gettsim_typing import NestedAggregationSpecDict, NestedFunctionDict -from _gettsim.shared import upsert_path_and_value +from _gettsim.shared import ( + create_tree_from_path_and_value, + insert_path_and_value, + merge_trees, +) def load_functions_tree_for_date(date: datetime.date) -> NestedFunctionDict: @@ -22,7 +26,9 @@ def load_functions_tree_for_date(date: datetime.date) -> NestedFunctionDict: This function takes the list of root paths and searches for all modules containing PolicyFunctions. Then it loads all PolicyFunctions that are active at the given date - and parses them into the functions tree. + and constructs the functions tree. + + Namespaces are at the directory level. Parameters ---------- @@ -40,22 +46,19 @@ def load_functions_tree_for_date(date: datetime.date) -> NestedFunctionDict: functions_tree = {} for path in paths_to_policy_functions: - active_functions_dict = get_active_policy_functions_from_module( + new_functions_tree = get_active_policy_functions_tree_from_module( path=path, date=date, package_root=RESOURCE_DIR ) - tree_path = _convert_path_to_tree_path(path=path, package_root=RESOURCE_DIR) - - functions_tree = upsert_path_and_value( - base=functions_tree, - path_to_upsert=tree_path, - value_to_upsert=active_functions_dict, + functions_tree = merge_trees( + left=functions_tree, + right=new_functions_tree, ) return functions_tree -def get_active_policy_functions_from_module( +def get_active_policy_functions_tree_from_module( path: Path, package_root: Path, date: datetime.date, @@ -73,7 +76,7 @@ def get_active_policy_functions_from_module( Returns ------- - A dictionary of active PolicyFunctions with their leaf names as keys. + A nested dictionary of active PolicyFunctions with their leaf names as keys. """ module = _load_module(path, package_root) module_name = _convert_path_to_importable_module_name(path, package_root) @@ -88,7 +91,13 @@ def get_active_policy_functions_from_module( policy_functions, module_name ) - return {func.leaf_name: func for func in policy_functions if func.is_active(date)} + active_policy_functions = { + func.leaf_name: func for func in policy_functions if func.is_active(date) + } + return create_tree_from_path_and_value( + path=_convert_path_to_tree_path(path=path, package_root=RESOURCE_DIR), + value=active_policy_functions, + ) def _fail_if_multiple_policy_functions_are_active_at_the_same_time( @@ -213,41 +222,36 @@ def _convert_path_to_importable_module_name(path: Path, package_root: Path) -> s def _convert_path_to_tree_path(path: Path, package_root: Path) -> tuple[str, ...]: """ - Convert a system path to a tree path. + Convert the path from the package root to a tree path. - The system path is the path to the python module on the user's system. The tree path - are the branches of the tree. + Removes the package root and module name from the path. Parameters ---------- path: - The path to the python module on the user's system. + The path to a Python module. package_root: - The root of the package that contains the functions. + GETTSIM's package root. Returns ------- - The tree path. + The tree path, to be used as a key in the functions tree. Examples -------- - >>> _convert_path_to_tree_path(RESOURCE_DIR / "taxes" / "dir" - / "functions.py") - ("dir", "functions") + >>> _convert_path_to_tree_path( + ... path=RESOURCE_DIR / "taxes" / "dir" / "functions.py", + ... package_root=RESOURCE_DIR, + ... ) + ("dir") """ - # TODO(@MImmesberger): Simplify after changing directory structure + parts = path.relative_to(package_root.parent / "_gettsim").parts + + # TODO(@MImmesberger): Remove the subsequent line after changing directory structure # https://github.com/iza-institute-of-labor-economics/gettsim/pull/805 - # tree_path = path.relative_to(package_root.parent).parts[1:] # noqa: ERA001 - tree_path = tuple( - path.relative_to(package_root.parent) - .with_suffix("") - .as_posix() - .removeprefix("_gettsim/") - .removeprefix("taxes/") - .removeprefix("transfers/") - .split("/") - ) - return _simplify_tree_path_when_module_name_equals_dir_name(tree_path) + parts = parts[1:] if parts[0] in {"taxes", "transfers"} else parts + + return parts[:-1] def load_aggregation_specs_tree() -> NestedAggregationSpecDict: @@ -276,10 +280,10 @@ def load_aggregation_specs_tree() -> NestedAggregationSpecDict: tree_path = _convert_path_to_tree_path(path=path, package_root=RESOURCE_DIR) - aggregation_specs_tree = upsert_path_and_value( + aggregation_specs_tree = insert_path_and_value( base=aggregation_specs_tree, - path_to_upsert=tree_path, - value_to_upsert=aggregation_specs, + path_to_insert=tree_path, + value_to_insert=aggregation_specs, ) return aggregation_specs_tree @@ -328,23 +332,3 @@ def _load_aggregation_specs_from_module( ) return out - - -def _simplify_tree_path_when_module_name_equals_dir_name( - tree_path: tuple[str, ...], -) -> tuple[str, ...]: - """ - Shorten path when a module lives a directory named the same way. - - This is done to avoid namespaces like arbeitslosengeld__arbeitslosengeld if the - file structure looks like: - arbeitslosengeld - | |- arbeitslosengeld.py - | |- ... - """ - if len(tree_path) >= 2: - out = tree_path[:-1] if tree_path[-1] == tree_path[-2] else tree_path - else: - out = tree_path - - return out diff --git a/src/_gettsim/interface.py b/src/_gettsim/interface.py index 659368759a..c1f5aea7a2 100644 --- a/src/_gettsim/interface.py +++ b/src/_gettsim/interface.py @@ -38,10 +38,10 @@ format_errors_and_warnings, format_list_linewise, get_names_of_arguments_without_defaults, + merge_trees, partition_tree_by_reference_tree, qualified_name_reducer, qualified_name_splitter, - upsert_tree, ) @@ -146,9 +146,9 @@ def compute_taxes_and_transfers( results = tax_transfer_function(input_data_tree) if debug: - results = upsert_tree( - base=results, - to_upsert=data_tree_with_correct_types, + results = merge_trees( + left=results, + right=data_tree_with_correct_types, ) return results diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index f74737368c..3f3533e93b 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -5,6 +5,7 @@ import flatten_dict import numpy +import optree from dags.signature import rename_arguments from flatten_dict.reducers import make_reducer from flatten_dict.splitters import make_splitter @@ -65,6 +66,28 @@ def create_tree_from_path_and_value(path: tuple[str], value: Any = None) -> dict return nested_dict +def merge_trees(left: dict, right: dict) -> dict: + """ + Merge two pytrees, raising an error if a path is present in both trees. + + Parameters + ---------- + left + The first tree to be merged. + right + The second tree to be merged. + + Returns + ------- + The merged pytree. + """ + + if set(optree.tree_paths(left)) & set(optree.tree_paths(right)): + raise ValueError("Conflicting paths in trees to merge.") + + return upsert_tree(base=left, to_upsert=right) + + def upsert_tree(base: dict, to_upsert: dict) -> dict: """ Upsert a tree into another tree for trees defined by dictionaries only. @@ -116,6 +139,20 @@ def upsert_path_and_value( return upsert_tree(base=base, to_upsert=to_upsert) +def insert_path_and_value( + base: dict[str, Any], path_to_insert: tuple[str], value_to_insert: Any = None +) -> dict[str, Any]: + """Insert a path and value into a tree. + + The path is a list of strings that represent the keys in the nested dictionary. The + path must not exist in base. + """ + to_insert = create_tree_from_path_and_value( + path=path_to_insert, value=value_to_insert + ) + return merge_trees(left=base, right=to_insert) + + def partition_tree_by_reference_tree( tree_to_partition: NestedFunctionDict | NestedDataDict, reference_tree: NestedFunctionDict | NestedDataDict, diff --git a/src/_gettsim/time_conversion.py b/src/_gettsim/time_conversion.py index 7dc606ad8d..668ba4c6ea 100644 --- a/src/_gettsim/time_conversion.py +++ b/src/_gettsim/time_conversion.py @@ -12,6 +12,7 @@ from _gettsim.functions.policy_function import PolicyFunction from _gettsim.gettsim_typing import NestedDataDict, NestedFunctionDict from _gettsim.shared import ( + insert_path_and_value, rename_arguments_and_add_annotations, upsert_path_and_value, ) @@ -284,10 +285,10 @@ def create_time_conversion_functions( if new_path in optree.tree_paths(converted_functions) + data_tree_paths: continue else: - converted_functions = upsert_path_and_value( + converted_functions = insert_path_and_value( base=converted_functions, - path_to_upsert=new_path, - value_to_upsert=der_func, + path_to_insert=new_path, + value_to_insert=der_func, ) # Create time-conversions for data columns @@ -302,6 +303,8 @@ def create_time_conversion_functions( if new_path in optree.tree_paths(converted_functions) + data_tree_paths: continue else: + # Upsert because derived functions based on data should overwrite + # derived functions based on other functions. converted_functions = upsert_path_and_value( base=converted_functions, path_to_upsert=new_path, diff --git a/src/_gettsim_tests/test_loader.py b/src/_gettsim_tests/test_loader.py index 0fbd3db9e7..e10267cdfe 100644 --- a/src/_gettsim_tests/test_loader.py +++ b/src/_gettsim_tests/test_loader.py @@ -7,9 +7,9 @@ from _gettsim.config import PATHS_TO_INTERNAL_FUNCTIONS, RESOURCE_DIR from _gettsim.functions.loader import ( + _convert_path_to_tree_path, _find_python_files_recursively, _load_module, - _simplify_tree_path_when_module_name_equals_dir_name, ) from _gettsim.functions.policy_function import ( _vectorize_func, @@ -18,6 +18,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from pathlib import Path def test_load_path(): @@ -59,18 +60,20 @@ def test_vectorize_func(vectorized_function: Callable) -> None: @pytest.mark.parametrize( ( "path", + "package_root", "expected_tree_path", ), [ - (("foo", "bar", "bar"), ("foo", "bar")), - (("foo", "bar", "baz"), ("foo", "bar", "baz")), - (("foo", "bar", "bar", "bar"), ("foo", "bar", "bar")), - (("foo", "bar", "bar", "baz"), ("foo", "bar", "bar", "baz")), + (RESOURCE_DIR / "foo" / "bar.py", RESOURCE_DIR, ("foo",)), + (RESOURCE_DIR / "foo" / "spam" / "bar.py", RESOURCE_DIR, ("foo", "spam")), + (RESOURCE_DIR / "taxes" / "foo" / "bar.py", RESOURCE_DIR, ("foo",)), + (RESOURCE_DIR / "transfers" / "foo" / "bar.py", RESOURCE_DIR, ("foo",)), ], ) -def test_remove_recurring_branch_names( - path: str, expected_tree_path: tuple[str, ...] +def test_convert_path_to_tree_path( + path: Path, package_root: Path, expected_tree_path: tuple[str, ...] ) -> None: assert ( - _simplify_tree_path_when_module_name_equals_dir_name(path) == expected_tree_path + _convert_path_to_tree_path(path=path, package_root=package_root) + == expected_tree_path ) diff --git a/src/_gettsim_tests/test_policy_environment.py b/src/_gettsim_tests/test_policy_environment.py index 74484a8317..27bf9d248e 100644 --- a/src/_gettsim_tests/test_policy_environment.py +++ b/src/_gettsim_tests/test_policy_environment.py @@ -105,6 +105,7 @@ def test_access_different_date_jahresanfang(): assert params["foo_jahresanfang"] == 2020 +@pytest.mark.xfail(reason="Needs renamings PR.") @pytest.mark.parametrize( "tree, last_day, function_name_last_day, function_name_next_day", [ diff --git a/src/_gettsim_tests/test_shared.py b/src/_gettsim_tests/test_shared.py index 170ca023df..e449bb2a4d 100644 --- a/src/_gettsim_tests/test_shared.py +++ b/src/_gettsim_tests/test_shared.py @@ -4,6 +4,8 @@ from _gettsim.shared import ( create_tree_from_path_and_value, + insert_path_and_value, + merge_trees, partition_tree_by_reference_tree, upsert_path_and_value, upsert_tree, @@ -31,6 +33,33 @@ def test_upsert_path_and_value(base, path_to_upsert, value_to_upsert, expected): assert result == expected +@pytest.mark.parametrize( + "base, path_to_insert, value_to_insert, expected", + [ + ({}, ("a",), 1, {"a": 1}), + ({"a": 1}, ("b",), 2, {"a": 1, "b": 2}), + ], +) +def test_insert_path_and_value(base, path_to_insert, value_to_insert, expected): + result = insert_path_and_value( + base=base, path_to_insert=path_to_insert, value_to_insert=value_to_insert + ) + assert result == expected + + +@pytest.mark.parametrize( + "base, path_to_insert, value_to_insert", + [ + ({"a": 1}, ("a",), 2), + ], +) +def test_insert_path_and_value_invalid(base, path_to_insert, value_to_insert): + with pytest.raises(ValueError, match="Conflicting paths in trees to merge."): + insert_path_and_value( + base=base, path_to_insert=path_to_insert, value_to_insert=value_to_insert + ) + + @pytest.mark.parametrize( "paths, expected", [ @@ -43,6 +72,30 @@ def test_create_tree_from_path_and_value(paths, expected): assert create_tree_from_path_and_value(paths) == expected +@pytest.mark.parametrize( + "left, right, expected", + [ + ({}, {"a": 1}, {"a": 1}), + ({"a": 1}, {"b": 2}, {"a": 1, "b": 2}), + ({"a": {"b": 1}}, {"a": {"c": 2}}, {"a": {"b": 1, "c": 2}}), + ({"a": {"b": 1}}, {"a": 3}, {"a": 3}), + ({"a": 3}, {"a": {"b": 1}}, {"a": {"b": 1}}), + ({"a": SampleDataClass(a=1)}, {}, {"a": SampleDataClass(a=1)}), + ], +) +def test_merge_trees_valid(left, right, expected): + assert merge_trees(left=left, right=right) == expected + + +@pytest.mark.parametrize( + "left, right", + [({"a": 1}, {"a": 2}), ({"a": 1}, {"a": 1}), ({"a": {"b": 1}}, {"a": {"b": 5}})], +) +def test_merge_trees_invalid(left, right): + with pytest.raises(ValueError, match="Conflicting paths in trees to merge."): + merge_trees(left=left, right=right) + + @pytest.mark.parametrize( "base_dict, update_dict, expected", [