From 66e1d71618d04dbadcbb74325df9ea230538f549 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Tue, 10 Jun 2025 17:57:20 +0200 Subject: [PATCH 01/25] Got started setting xnp as the backend at runtime. --- pixi.lock | 60 +++-- pyproject.toml | 3 + .../kindergeld\303\274bertrag.py" | 16 +- src/_gettsim/ids.py | 156 +++++++------ src/_gettsim/kindergeld/kindergeld.py | 9 +- src/_gettsim/lohnsteuer/lohnsteuer.py | 24 +- .../rente/altersrente/altersgrenzen.py | 43 ++-- src/_gettsim/wohngeld/einkommen.py | 10 +- src/_gettsim/wohngeld/miete.py | 34 ++- src/_gettsim/wohngeld/wohngeld.py | 13 +- src/_gettsim_tests/test_policy.py | 5 +- src/ttsim/interface_dag.py | 10 +- src/ttsim/interface_dag_elements/backend.py | 48 ++++ .../interface_dag_elements/data_converters.py | 22 +- src/ttsim/interface_dag_elements/fail_if.py | 47 +++- .../interface_dag_elements/input_data.py | 3 + .../interface_dag_elements/processed_data.py | 16 +- src/ttsim/testing_utils.py | 17 +- .../column_objects_param_function.py | 2 +- src/ttsim/tt_dag_elements/param_objects.py | 61 ++--- .../tt_dag_elements/piecewise_polynomial.py | 110 +++++---- src/ttsim/tt_dag_elements/rounding.py | 31 +-- src/ttsim/tt_dag_elements/shared.py | 30 ++- src/ttsim/tt_dag_elements/vectorization.py | 35 ++- tests/ttsim/mettsim/group_by_ids.py | 67 +++--- tests/ttsim/test_convert_nested_data.py | 21 +- tests/ttsim/test_failures.py | 22 +- tests/ttsim/test_mettsim.py | 5 +- tests/ttsim/test_specialized_environment.py | 36 +-- .../test_aggregation_functions.py | 213 +++++++++--------- .../test_piecewise_polynomial.py | 17 +- tests/ttsim/tt_dag_elements/test_rounding.py | 50 ++-- tests/ttsim/tt_dag_elements/test_shared.py | 44 ++-- .../tt_dag_elements/test_ttsim_objects.py | 6 +- 34 files changed, 784 insertions(+), 502 deletions(-) create mode 100644 src/ttsim/interface_dag_elements/backend.py diff --git a/pixi.lock b/pixi.lock index 192775cf4..89b447502 100644 --- a/pixi.lock +++ b/pixi.lock @@ -271,7 +271,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -507,7 +507,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -743,7 +743,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -987,7 +987,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . mypy: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -1263,7 +1263,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: ./ + - pypi: . osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1503,7 +1503,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: ./ + - pypi: . osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1743,7 +1743,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: ./ + - pypi: . win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1991,7 +1991,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: ./ + - pypi: . py311: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -2263,7 +2263,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/aa/3d/52a75740d6c449073d4bb54da382f6368553f285fb5a680b27dd198dd839/optree-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2499,7 +2499,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e8/89/1267444a074b6e4402b5399b73b930a7b86cde054a41cecb9694be726a92/optree-0.15.0-cp311-cp311-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2735,7 +2735,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/98/a5/f8d6c278ce72b2ed8c1ebac968c3c652832bd2d9e65ec81fe6a21082c313/optree-0.15.0-cp311-cp311-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2979,7 +2979,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . py312: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -3251,7 +3251,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3487,7 +3487,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3723,7 +3723,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3967,7 +3967,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . py312-jax: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -4245,12 +4245,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4492,12 +4493,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4739,12 +4741,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4983,6 +4986,7 @@ environments: - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/dc/89/99805cd801919b4535e023bfe2de651f5a3ec4f5846a867cbc08006db455/jax-0.6.1-py3-none-any.whl + - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 - pypi: https://files.pythonhosted.org/packages/1b/12/2bc629d530ee1b333edc81a1d68d262bad2f813ce60fdd46e98d48cc8a20/jaxlib-0.6.1-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/38/bc/c4260e4a6c6bf684d0313308de1c860467275221d5e7daf69b3fcddfdd0b/ml_dtypes-0.5.1-cp312-cp312-win_amd64.whl @@ -4993,7 +4997,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: ./ + - pypi: . packages: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726 @@ -6572,10 +6576,10 @@ packages: purls: [] size: 21903 timestamp: 1694400856979 -- pypi: ./ +- pypi: . name: gettsim - version: 0.7.1.dev428+gc84a692f.d20250610 - sha256: 7d1606ea5cc943af5952dc9eaae3b8a6f2b102958f2955fddca40bbf1bb438d7 + version: 0.7.1.dev433+g45b17d643.d20250610 + sha256: 3c87f5e1b0dbb950bd0bd46d24bb9ebc09061dc62200009645b901e24004348b requires_dist: - ipywidgets - networkx @@ -7443,6 +7447,16 @@ packages: - pkg:pypi/jax?source=hash-mapping size: 1580534 timestamp: 1747653718316 +- pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 + name: jax-datetime + version: 0.1.0 + requires_dist: + - jax + - numpy + - absl-py ; extra == 'tests' + - chex ; extra == 'tests' + - pytest ; extra == 'tests' + requires_python: '>=3.11' - pypi: https://files.pythonhosted.org/packages/1b/12/2bc629d530ee1b333edc81a1d68d262bad2f813ce60fdd46e98d48cc8a20/jaxlib-0.6.1-cp312-cp312-win_amd64.whl name: jaxlib version: 0.6.1 diff --git a/pyproject.toml b/pyproject.toml index 8d740ccfe..73bd642ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,9 @@ python = "3.12.*" jax = ">=0.4.20" jaxlib = ">=0.4.20" +[tool.pixi.feature.jax.pypi-dependencies] +jax-datetime = { git = "https://github.com/google/jax-datetime.git" } + [tool.pixi.feature.jax.target.win-64.pypi-dependencies] jax = { version = ">=0.4.20", extras = ["cpu"] } jaxlib = ">=0.4.20" diff --git "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" index 47797fa18..6d64930c0 100644 --- "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" +++ "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" @@ -7,7 +7,7 @@ from ttsim.tt_dag_elements import AggType, agg_by_p_id_function, join, policy_function if TYPE_CHECKING: - from ttsim.config import numpy_or_jax as np + import numpy @agg_by_p_id_function(start_date="2005-01-01", agg_type=AggType.SUM) @@ -57,9 +57,9 @@ def _mean_kindergeld_per_child_ohne_staffelung_m( @policy_function(start_date="2005-01-01", vectorization_strategy="not_required") def kindergeld_zur_bedarfsdeckung_m( kindergeld_pro_kind_m: float, - kindergeld__p_id_empfänger: np.ndarray, # int - p_id: np.ndarray, # int -) -> np.ndarray: # float + kindergeld__p_id_empfänger: numpy.ndarray, # int + p_id: numpy.ndarray, # int +) -> numpy.ndarray: # float """Kindergeld that is used to cover the SGB II Regelbedarf of the child. Even though the Kindergeld is paid to the parent (see function @@ -119,10 +119,10 @@ def differenz_kindergeld_kindbedarf_m( @policy_function(start_date="2005-01-01", vectorization_strategy="not_required") def in_anderer_bg_als_kindergeldempfänger( - p_id: np.ndarray, # int - kindergeld__p_id_empfänger: np.ndarray, # int - bg_id: np.ndarray, # int -) -> np.ndarray: # bool + p_id: numpy.ndarray, # int + kindergeld__p_id_empfänger: numpy.ndarray, # int + bg_id: numpy.ndarray, # int +) -> numpy.ndarray: # bool """True if the person is in a different Bedarfsgemeinschaft than the Kindergeldempfänger of that person. """ diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index ac420fad4..872c72486 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -2,7 +2,12 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + + import numpy from ttsim.tt_dag_elements import group_creation_function, policy_input @@ -18,81 +23,82 @@ def hh_id() -> int: @group_creation_function() def ehe_id( - p_id: np.ndarray, - familie__p_id_ehepartner: np.ndarray, -) -> np.ndarray: + p_id: numpy.ndarray, familie__p_id_ehepartner: numpy.ndarray, xnp: ModuleType +) -> numpy.ndarray: """Couples that are either married or in a civil union.""" - n = np.max(p_id) + 1 - p_id_ehepartner_or_own_p_id = np.where( + n = numpy.max(p_id) + 1 + p_id_ehepartner_or_own_p_id = numpy.where( familie__p_id_ehepartner < 0, p_id, familie__p_id_ehepartner ) result = ( - np.maximum(p_id, p_id_ehepartner_or_own_p_id) - + np.minimum(p_id, p_id_ehepartner_or_own_p_id) * n + numpy.maximum(p_id, p_id_ehepartner_or_own_p_id) + + numpy.minimum(p_id, p_id_ehepartner_or_own_p_id) * n ) - return _reorder_ids(result) + return _reorder_ids(result, xnp) @group_creation_function() def fg_id( - arbeitslosengeld_2__p_id_einstandspartner: np.ndarray, - p_id: np.ndarray, - hh_id: np.ndarray, - alter: np.ndarray, - familie__p_id_elternteil_1: np.ndarray, - familie__p_id_elternteil_2: np.ndarray, -) -> np.ndarray: + arbeitslosengeld_2__p_id_einstandspartner: numpy.ndarray, + p_id: numpy.ndarray, + hh_id: numpy.ndarray, + alter: numpy.ndarray, + familie__p_id_elternteil_1: numpy.ndarray, + familie__p_id_elternteil_2: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Familiengemeinschaft. Base unit for some transfers. Maximum of two generations, the relevant base unit for Bürgergeld / Arbeitslosengeld 2, before excluding children who have enough income fend for themselves. """ - n = np.max(p_id) + 1 + n = numpy.max(p_id) + 1 # Get the array index for all p_ids of parents p_id_elternteil_1_loc = familie__p_id_elternteil_1 p_id_elternteil_2_loc = familie__p_id_elternteil_2 for i in range(p_id.shape[0]): - p_id_elternteil_1_loc = np.where( + p_id_elternteil_1_loc = numpy.where( familie__p_id_elternteil_1 == p_id[i], i, p_id_elternteil_1_loc ) - p_id_elternteil_2_loc = np.where( + p_id_elternteil_2_loc = numpy.where( familie__p_id_elternteil_2 == p_id[i], i, p_id_elternteil_2_loc ) - children = np.isin(p_id, familie__p_id_elternteil_1) | np.isin( + children = numpy.isin(p_id, familie__p_id_elternteil_1) | numpy.isin( p_id, familie__p_id_elternteil_2 ) # Assign the same fg_id to everybody who has an Einstandspartner, # otherwise create a new one from p_id - fg_id = np.where( + fg_id = numpy.where( arbeitslosengeld_2__p_id_einstandspartner < 0, p_id + p_id * n, - np.maximum(p_id, arbeitslosengeld_2__p_id_einstandspartner) - + np.minimum(p_id, arbeitslosengeld_2__p_id_einstandspartner) * n, + numpy.maximum(p_id, arbeitslosengeld_2__p_id_einstandspartner) + + numpy.minimum(p_id, arbeitslosengeld_2__p_id_einstandspartner) * n, ) fg_id = _assign_parents_fg_id( - fg_id, p_id, p_id_elternteil_1_loc, hh_id, alter, children, n + fg_id, p_id, p_id_elternteil_1_loc, hh_id, alter, children, n, xnp ) fg_id = _assign_parents_fg_id( - fg_id, p_id, p_id_elternteil_2_loc, hh_id, alter, children, n + fg_id, p_id, p_id_elternteil_2_loc, hh_id, alter, children, n, xnp ) - return _reorder_ids(fg_id) + return _reorder_ids(fg_id, xnp) def _assign_parents_fg_id( - fg_id: np.ndarray, - p_id: np.ndarray, - p_id_elternteil_loc: np.ndarray, - hh_id: np.ndarray, - alter: np.ndarray, - children: np.ndarray, - n: np.ndarray, -) -> np.ndarray: + fg_id: numpy.ndarray, + p_id: numpy.ndarray, + p_id_elternteil_loc: numpy.ndarray, + hh_id: numpy.ndarray, + alter: numpy.ndarray, + children: numpy.ndarray, + n: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Get the fg_id of the childs parents. If the child is not married, has no children, is under 25 and in the same household, @@ -104,7 +110,7 @@ def _assign_parents_fg_id( # TODO(@MImmesberger): Remove hard-coded number # https://github.com/iza-institute-of-labor-economics/gettsim/issues/668 - return np.where( + return numpy.where( (p_id_elternteil_loc >= 0) * (fg_id == p_id + p_id * n) * (hh_id == hh_id[p_id_elternteil_loc]) @@ -117,11 +123,12 @@ def _assign_parents_fg_id( @group_creation_function() def bg_id( - fg_id: np.ndarray, - p_id: np.ndarray, - arbeitslosengeld_2__eigenbedarf_gedeckt: np.ndarray, - alter: np.ndarray, -) -> np.ndarray: + fg_id: numpy.ndarray, + p_id: numpy.ndarray, + arbeitslosengeld_2__eigenbedarf_gedeckt: numpy.ndarray, + alter: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Bedarfsgemeinschaft Familiengemeinschaft except for children who have enough income to fend for @@ -133,95 +140,98 @@ def bg_id( # TODO(@MImmesberger): Remove hard-coded number # https://github.com/iza-institute-of-labor-economics/gettsim/issues/668 - offset = np.max(fg_id) + 1 + offset = numpy.max(fg_id) + 1 # Create new id for everyone who is not part of the Bedarfsgemeinschaft - bg_id = np.where( + bg_id = numpy.where( (arbeitslosengeld_2__eigenbedarf_gedeckt) * (alter < 25), offset + p_id, fg_id, ) - return _reorder_ids(bg_id) + return _reorder_ids(bg_id, xnp) @group_creation_function() def eg_id( - arbeitslosengeld_2__p_id_einstandspartner: np.ndarray, - p_id: np.ndarray, -) -> np.ndarray: + arbeitslosengeld_2__p_id_einstandspartner: numpy.ndarray, + p_id: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Einstandsgemeinschaft / Einstandspartner according to SGB II. A couple whose members are deemed to be responsible for each other. """ - n = np.max(p_id) + 1 - p_id_einstandspartner__or_own_p_id = np.where( + n = numpy.max(p_id) + 1 + p_id_einstandspartner__or_own_p_id = numpy.where( arbeitslosengeld_2__p_id_einstandspartner < 0, p_id, arbeitslosengeld_2__p_id_einstandspartner, ) result = ( - np.maximum(p_id, p_id_einstandspartner__or_own_p_id) - + np.minimum(p_id, p_id_einstandspartner__or_own_p_id) * n + numpy.maximum(p_id, p_id_einstandspartner__or_own_p_id) + + numpy.minimum(p_id, p_id_einstandspartner__or_own_p_id) * n ) - return _reorder_ids(result) + return _reorder_ids(result, xnp) @group_creation_function() def wthh_id( - hh_id: np.ndarray, - vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg: np.ndarray, - vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: np.ndarray, -) -> np.ndarray: + hh_id: numpy.ndarray, + vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg: numpy.ndarray, + vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Wohngeldrechtlicher Teilhaushalt. The relevant unit for Wohngeld. Members of a household for whom the Wohngeld priority check compared to Bürgergeld yields the same result ∈ {True, False}. """ - offset = np.max(hh_id) + 1 - wthh_id = np.where( + offset = numpy.max(hh_id) + 1 + wthh_id = numpy.where( vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg | vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg, hh_id + offset, hh_id, ) - return _reorder_ids(wthh_id) + return _reorder_ids(wthh_id, xnp) @group_creation_function() def sn_id( - p_id: np.ndarray, - familie__p_id_ehepartner: np.ndarray, - einkommensteuer__gemeinsam_veranlagt: np.ndarray, -) -> np.ndarray: + p_id: numpy.ndarray, + familie__p_id_ehepartner: numpy.ndarray, + einkommensteuer__gemeinsam_veranlagt: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Steuernummer. Spouses filing taxes jointly or individuals. """ - n = np.max(p_id) + 1 + n = numpy.max(p_id) + 1 - p_id_ehepartner_or_own_p_id = np.where( + p_id_ehepartner_or_own_p_id = numpy.where( (familie__p_id_ehepartner >= 0) * (einkommensteuer__gemeinsam_veranlagt), familie__p_id_ehepartner, p_id, ) result = ( - np.maximum(p_id, p_id_ehepartner_or_own_p_id) - + np.minimum(p_id, p_id_ehepartner_or_own_p_id) * n + numpy.maximum(p_id, p_id_ehepartner_or_own_p_id) + + numpy.minimum(p_id, p_id_ehepartner_or_own_p_id) * n ) - return _reorder_ids(result) + return _reorder_ids(result, xnp) -def _reorder_ids(ids: np.ndarray) -> np.ndarray: +def _reorder_ids(ids: numpy.ndarray, xnp: ModuleType) -> numpy.ndarray: """Make ID's consecutively numbered.""" - sorting = np.argsort(ids) + sorting = xnp.argsort(ids) ids_sorted = ids[sorting] - index_after_sort = np.arange(ids.shape[0])[sorting] + index_after_sort = xnp.arange(ids.shape[0])[sorting] # Look for difference from previous entry in sorted array - diff_to_prev = np.where(np.diff(ids_sorted) >= 1, 1, 0) + diff_to_prev = xnp.where(xnp.diff(ids_sorted) >= 1, 1, 0) # Sum up all differences to get new id - cons_ids = np.concatenate((np.asarray([0]), np.cumsum(diff_to_prev))) - return cons_ids[np.argsort(index_after_sort)] + cons_ids = xnp.concatenate((xnp.asarray([0]), xnp.cumsum(diff_to_prev))) + return cons_ids[xnp.argsort(index_after_sort)] diff --git a/src/_gettsim/kindergeld/kindergeld.py b/src/_gettsim/kindergeld/kindergeld.py index 7075f5d8b..d03056476 100644 --- a/src/_gettsim/kindergeld/kindergeld.py +++ b/src/_gettsim/kindergeld/kindergeld.py @@ -14,7 +14,8 @@ ) if TYPE_CHECKING: - from ttsim.config import numpy_or_jax as np + import numpy + from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue @@ -120,9 +121,9 @@ def kind_bis_10_mit_kindergeld( @policy_function(vectorization_strategy="not_required") def gleiche_fg_wie_empfänger( - p_id: np.ndarray, # int - p_id_empfänger: np.ndarray, # int - fg_id: np.ndarray, # int + p_id: numpy.ndarray, # int + p_id_empfänger: numpy.ndarray, # int + fg_id: numpy.ndarray, # int ) -> np.ndarray: # bool """The child's Kindergeldempfänger is in the same Familiengemeinschaft.""" fg_id_kindergeldempfänger = join( diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index 2a05ac99e..76f5a7d4d 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -2,7 +2,8 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, param_function, @@ -10,6 +11,9 @@ policy_function, ) +if TYPE_CHECKING: + import numpy + def basis_für_klassen_5_6( einkommen_y: float, parameter_einkommensteuertarif: PiecewisePolynomialParamValue @@ -58,7 +62,7 @@ def parameter_max_lohnsteuer_klasse_5_6( einkommensgrenzwerte_steuerklassen_5_6[3], einkommensteuer__parameter_einkommensteuertarif, ) - thresholds = np.asarray( + thresholds = numpy.asarray( [ 0, einkommensgrenzwerte_steuerklassen_5_6[1], @@ -66,7 +70,7 @@ def parameter_max_lohnsteuer_klasse_5_6( einkommensgrenzwerte_steuerklassen_5_6[3], ] ) - intercepts = np.asarray( + intercepts = numpy.asarray( [ 0, lohnsteuer_bis_erste_grenze, @@ -74,9 +78,9 @@ def parameter_max_lohnsteuer_klasse_5_6( lohnsteuer_bis_dritte_grenze, ] ) - rates = np.expand_dims( + rates = numpy.expand_dims( einkommensteuer__parameter_einkommensteuertarif.rates[0][ - np.array([3, 3, 3, 4]) + numpy.array([3, 3, 3, 4]) ], axis=0, ) @@ -126,7 +130,7 @@ def tarif_klassen_5_und_6( min_lohnsteuer = ( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_y ) - return np.minimum(np.maximum(min_lohnsteuer, basis), max_lohnsteuer) + return numpy.minimum(numpy.maximum(min_lohnsteuer, basis), max_lohnsteuer) @policy_function(start_date="2015-01-01") @@ -155,7 +159,7 @@ def basistarif_mit_kinderfreibetrag( kinderfreibetrag_soli_y: float, ) -> float: """Lohnsteuer in the Basistarif deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = np.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) return piecewise_polynomial( @@ -171,7 +175,7 @@ def splittingtarif_mit_kinderfreibetrag( kinderfreibetrag_soli_y: float, ) -> float: """Lohnsteuer in the Splittingtarif deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = np.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) return 2 * piecewise_polynomial( @@ -188,7 +192,7 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( kinderfreibetrag_soli_y: float, ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6 deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = np.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) @@ -204,7 +208,7 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_abzüglich_kinderfreibetrag_soli ) - return np.minimum(np.maximum(min_lohnsteuer, basis), max_lohnsteuer) + return numpy.minimum(numpy.maximum(min_lohnsteuer, basis), max_lohnsteuer) @policy_function(start_date="2015-01-01") diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py index 225b6f626..8b16f0c29 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py @@ -2,9 +2,13 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + from ttsim.tt_dag_elements import policy_function +if TYPE_CHECKING: + from types import ModuleType + @policy_function( end_date="2011-12-31", @@ -18,6 +22,7 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_ohne_besonders_langjährig( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Full retirement age after eligibility checks, assuming eligibility for Regelaltersrente. @@ -29,14 +34,14 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_ohne_besonders_langjährig( out = regelaltersrente__altersgrenze if für_frauen__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, für_frauen__altersgrenze) + out = xnp.minimum(out, für_frauen__altersgrenze) if wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt: - out = np.minimum( + out = xnp.minimum( out, wegen_arbeitslosigkeit__altersgrenze, ) if langjährig__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, langjährig__altersgrenze) + out = xnp.minimum(out, langjährig__altersgrenze) return out @@ -56,6 +61,7 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_besonders_langjährig( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Full retirement age after eligibility checks, assuming eligibility for Regelaltersrente. @@ -72,16 +78,16 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_besonders_langjährig( out = regelaltersrente__altersgrenze if für_frauen__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, für_frauen__altersgrenze) + out = xnp.minimum(out, für_frauen__altersgrenze) if wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt: - out = np.minimum( + out = xnp.minimum( out, wegen_arbeitslosigkeit__altersgrenze, ) if langjährig__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, langjährig__altersgrenze) + out = xnp.minimum(out, langjährig__altersgrenze) if besonders_langjährig__grundsätzlich_anspruchsberechtigt: - out = np.minimum( + out = xnp.minimum( out, besonders_langjährig__altersgrenze, ) @@ -99,6 +105,7 @@ def altersgrenze_mit_besonders_langjährig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Full retirement age after eligibility checks, assuming eligibility for Regelaltersrente. @@ -110,9 +117,9 @@ def altersgrenze_mit_besonders_langjährig_ohne_arbeitslosigkeit_frauen( out = regelaltersrente__altersgrenze if langjährig__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, langjährig__altersgrenze) + out = xnp.minimum(out, langjährig__altersgrenze) if besonders_langjährig__grundsätzlich_anspruchsberechtigt: - out = np.minimum( + out = xnp.minimum( out, besonders_langjährig__altersgrenze, ) @@ -132,6 +139,7 @@ def altersgrenze_vorzeitig_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze_vorzeitig: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Earliest possible retirement age after checking for eligibility. @@ -151,9 +159,9 @@ def altersgrenze_vorzeitig_mit_arbeitslosigkeit_frauen( if langjährig__grundsätzlich_anspruchsberechtigt: out = langjährig_vorzeitig if für_frauen__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, frauen_vorzeitig) + out = xnp.minimum(out, frauen_vorzeitig) if wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt: - out = np.minimum(out, arbeitsl_vorzeitig) + out = xnp.minimum(out, arbeitsl_vorzeitig) return out @@ -163,6 +171,7 @@ def altersgrenze_vorzeitig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze_vorzeitig: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Earliest possible retirement age after checking for eligibility. @@ -187,6 +196,7 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( für_frauen__grundsätzlich_anspruchsberechtigt: bool, langjährig__grundsätzlich_anspruchsberechtigt: bool, wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt: bool, + xnp: ModuleType, ) -> bool: """Eligibility for some form ofearly retirement. @@ -208,6 +218,7 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( ) def vorzeitig_grundsätzlich_anspruchsberechtigt_vorzeitig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, + xnp: ModuleType, ) -> bool: """Eligibility for early retirement. @@ -226,6 +237,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Reference age for deduction calculation in case of early retirement (Zugangsfaktor). @@ -239,7 +251,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( and für_frauen__grundsätzlich_anspruchsberechtigt and wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt ): - out = min( + out = xnp.min( [ für_frauen__altersgrenze, langjährig__altersgrenze, @@ -250,7 +262,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt and für_frauen__grundsätzlich_anspruchsberechtigt ): - out = min( + out = xnp.min( [ für_frauen__altersgrenze, langjährig__altersgrenze, @@ -260,7 +272,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt and wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt ): - out = min( + out = xnp.min( [ langjährig__altersgrenze, wegen_arbeitslosigkeit__altersgrenze, @@ -283,6 +295,7 @@ def referenzalter_abschlag_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, + xnp: ModuleType, ) -> float: """Reference age for deduction calculation in case of early retirement (Zugangsfaktor). diff --git a/src/_gettsim/wohngeld/einkommen.py b/src/_gettsim/wohngeld/einkommen.py index 60dde52d6..c358d82d6 100644 --- a/src/_gettsim/wohngeld/einkommen.py +++ b/src/_gettsim/wohngeld/einkommen.py @@ -2,7 +2,8 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + from ttsim.tt_dag_elements import ( AggType, ConsecutiveInt1dLookupTableParamValue, @@ -14,6 +15,9 @@ policy_function, ) +if TYPE_CHECKING: + import numpy + @agg_by_p_id_function(agg_type=AggType.SUM) def alleinerziehendenbonus( @@ -44,13 +48,13 @@ def einkommen( """ eink_nach_abzug_m_hh = einkommen_vor_freibetrag - einkommensfreibetrag unteres_eink = min_einkommen_lookup_table.values_to_look_up[ - np.minimum( + numpy.minimum( anzahl_personen, min_einkommen_lookup_table.values_to_look_up.shape[0] ) - min_einkommen_lookup_table.base_to_subtract ] - return np.maximum(eink_nach_abzug_m_hh, unteres_eink) + return numpy.maximum(eink_nach_abzug_m_hh, unteres_eink) @policy_function() diff --git a/src/_gettsim/wohngeld/miete.py b/src/_gettsim/wohngeld/miete.py index 94462b31d..457d46edc 100644 --- a/src/_gettsim/wohngeld/miete.py +++ b/src/_gettsim/wohngeld/miete.py @@ -3,8 +3,12 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING -from ttsim.config import numpy_or_jax as np +if TYPE_CHECKING: + from types import ModuleType + + import numpy from ttsim.tt_dag_elements import ( ConsecutiveInt1dLookupTableParamValue, ConsecutiveInt2dLookupTableParamValue, @@ -17,10 +21,10 @@ @dataclass(frozen=True) class LookupTableBaujahr: - baujahre: np.ndarray - lookup_table: np.ndarray - lookup_base_to_subtract_cols: np.ndarray - lookup_base_to_subtract_rows: np.ndarray + baujahre: numpy.ndarray + lookup_table: numpy.ndarray + lookup_base_to_subtract_cols: numpy.ndarray + lookup_base_to_subtract_rows: numpy.ndarray @param_function( @@ -29,6 +33,7 @@ class LookupTableBaujahr: def max_miete_m_lookup_mit_baujahr( raw_max_miete_m_nach_baujahr: dict[int | str, dict[int, dict[int, float]]], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> LookupTableBaujahr: """Maximum rent considered in Wohngeld calculation.""" tmp = raw_max_miete_m_nach_baujahr.copy() @@ -52,12 +57,12 @@ def max_miete_m_lookup_mit_baujahr( subtract_cols.append(lookup_table.base_to_subtract_cols) subtract_rows.append(lookup_table.base_to_subtract_rows) - full_lookup_table = np.stack(values, axis=0) - full_lookup_base_to_subtract_cols = np.asarray(subtract_cols) - full_lookup_base_to_subtract_rows = np.asarray(subtract_rows) + full_lookup_table = xnp.stack(values, axis=0) + full_lookup_base_to_subtract_cols = xnp.asarray(subtract_cols) + full_lookup_base_to_subtract_rows = xnp.asarray(subtract_rows) return LookupTableBaujahr( - baujahre=np.asarray(baujahre), + baujahre=xnp.asarray(baujahre), lookup_table=full_lookup_table, lookup_base_to_subtract_cols=full_lookup_base_to_subtract_cols, lookup_base_to_subtract_rows=full_lookup_base_to_subtract_rows, @@ -68,6 +73,7 @@ def max_miete_m_lookup_mit_baujahr( def max_miete_m_lookup_ohne_baujahr( raw_max_miete_m: dict[int | str, dict[int, float]], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt2dLookupTableParamValue: """Maximum rent considered in Wohngeld calculation.""" expanded = raw_max_miete_m.copy() @@ -87,6 +93,7 @@ def max_miete_m_lookup_ohne_baujahr( def min_miete_lookup( raw_min_miete_m: dict[int, float], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Minimum rent considered in Wohngeld calculation.""" max_n_p_normal = max_anzahl_personen["normale_berechnung"] @@ -107,6 +114,7 @@ def min_miete_lookup( def heizkostenentlastung_m_lookup( raw_heizkostenentlastung_m: dict[int | str, float], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Heizkostenentlastung as a lookup table.""" expanded = raw_heizkostenentlastung_m.copy() @@ -124,6 +132,7 @@ def heizkostenentlastung_m_lookup( def dauerhafte_heizkostenkomponente_m_lookup( raw_dauerhafte_heizkostenkomponente_m: dict[int | str, float], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Dauerhafte Heizkostenenkomponente as a lookup table.""" expanded = raw_dauerhafte_heizkostenkomponente_m.copy() @@ -141,6 +150,7 @@ def dauerhafte_heizkostenkomponente_m_lookup( def klimakomponente_m_lookup( raw_klimakomponente_m: dict[int | str, float], max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Klimakomponente as a lookup table.""" expanded = raw_klimakomponente_m.copy() @@ -205,10 +215,11 @@ def miete_m_hh_mit_baujahr( wohnen__bruttokaltmiete_m_hh: float, min_miete_m_hh: float, max_miete_m_lookup: LookupTableBaujahr, + xnp: ModuleType, ) -> float: """Rent considered in housing benefit calculation on household level until 2008.""" - selected_bin_index = np.searchsorted( + selected_bin_index = xnp.searchsorted( max_miete_m_lookup.baujahre, wohnen__baujahr_immobilie_hh, side="left", @@ -232,6 +243,7 @@ def miete_m_hh_ohne_baujahr_ohne_heizkostenentlastung( wohnen__bruttokaltmiete_m_hh: float, min_miete_m_hh: float, max_miete_m_lookup: ConsecutiveInt2dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" @@ -255,6 +267,7 @@ def miete_m_hh_mit_heizkostenentlastung( min_miete_m_hh: float, max_miete_m_lookup: ConsecutiveInt2dLookupTableParamValue, heizkostenentlastung_m_lookup: ConsecutiveInt1dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" max_miete_m = max_miete_m_lookup.values_to_look_up[ @@ -285,6 +298,7 @@ def miete_m_hh_mit_heizkostenentlastung_dauerhafte_heizkostenkomponente_klimakom heizkostenentlastung_m_lookup: ConsecutiveInt1dLookupTableParamValue, dauerhafte_heizkostenkomponente_m_lookup: ConsecutiveInt1dLookupTableParamValue, klimakomponente_m_lookup: ConsecutiveInt1dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" max_miete_m = max_miete_m_lookup.values_to_look_up[ diff --git a/src/_gettsim/wohngeld/wohngeld.py b/src/_gettsim/wohngeld/wohngeld.py index f248b14bd..60edea0a7 100644 --- a/src/_gettsim/wohngeld/wohngeld.py +++ b/src/_gettsim/wohngeld/wohngeld.py @@ -21,7 +21,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from ttsim.config import numpy_or_jax as np from ttsim.tt_dag_elements import ( AggType, RoundingSpec, @@ -32,6 +31,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from _gettsim.param_types import ConsecutiveInt1dLookupTableParamValue @@ -84,6 +85,7 @@ def basisformel( einkommen_m: float, miete_m: float, params: BasisformelParamValues, + xnp: ModuleType, ) -> float: """Basic formula for housing benefit calculation. @@ -99,12 +101,12 @@ def basisformel( anzahl_personen - params.zusatzbetrag_nach_haushaltsgröße.base_to_subtract ] ) - out = np.maximum( + out = xnp.maximum( 0.0, params.skalierungsfaktor * (miete_m - ((a + (b * miete_m) + (c * einkommen_m)) * einkommen_m)), ) - return np.minimum(miete_m, out + zusatzbetrag_nach_haushaltsgröße) + return xnp.minimum(miete_m, out + zusatzbetrag_nach_haushaltsgröße) @policy_function( @@ -120,6 +122,7 @@ def anspruchshöhe_m_wthh( miete_m_wthh: float, grundsätzlich_anspruchsberechtigt_wthh: bool, basisformel_params: BasisformelParamValues, + xnp: ModuleType, ) -> float: """Housing benefit after wealth and income check. @@ -134,6 +137,7 @@ def anspruchshöhe_m_wthh( einkommen_m=einkommen_m_wthh, miete_m=miete_m_wthh, params=basisformel_params, + xnp=xnp, ) else: out = 0.0 @@ -154,6 +158,7 @@ def anspruchshöhe_m_bg( miete_m_bg: float, grundsätzlich_anspruchsberechtigt_bg: bool, basisformel_params: BasisformelParamValues, + xnp: ModuleType, ) -> float: """Housing benefit after wealth and income check. @@ -166,6 +171,7 @@ def anspruchshöhe_m_bg( einkommen_m=einkommen_m_bg, miete_m=miete_m_bg, params=basisformel_params, + xnp=xnp, ) else: out = 0.0 @@ -179,6 +185,7 @@ def basisformel_params( koeffizienten_berechnungsformel: dict[int, dict[str, float]], max_anzahl_personen: dict[str, int], zusatzbetrag_pro_person_in_großen_haushalten: float, + xnp: ModuleType, ) -> BasisformelParamValues: """Convert the parameters of the Wohngeld basis formula to a format that can be used by Numpy and Jax. diff --git a/src/_gettsim_tests/test_policy.py b/src/_gettsim_tests/test_policy.py index b680370f8..1351461dd 100644 --- a/src/_gettsim_tests/test_policy.py +++ b/src/_gettsim_tests/test_policy.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy import pytest from _gettsim.config import GETTSIM_ROOT @@ -14,7 +15,9 @@ TEST_DIR = Path(__file__).parent -POLICY_TEST_IDS_AND_CASES = load_policy_test_data(test_dir=TEST_DIR, policy_name="") +POLICY_TEST_IDS_AND_CASES = load_policy_test_data( + test_dir=TEST_DIR, policy_name="", xnp=numpy +) @pytest.mark.parametrize( diff --git a/src/ttsim/interface_dag.py b/src/ttsim/interface_dag.py index 8060e4816..4908f551a 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/interface_dag.py @@ -22,10 +22,18 @@ from ttsim.interface_dag_elements.typing import UnorderedQNames -def main(inputs: dict[str, Any], targets: list[str] | None = None) -> dict[str, Any]: +def main( + inputs: dict[str, Any], + targets: list[str] | None = None, + backend: Literal[numpy, jax] = "numpy", +) -> dict[str, Any]: """ Main function that processes the inputs and returns the outputs. """ + + if "backend" not in inputs: + inputs["backend"] = backend + nodes = { p: n for p, n in load_interface_functions_and_inputs().items() diff --git a/src/ttsim/interface_dag_elements/backend.py b/src/ttsim/interface_dag_elements/backend.py new file mode 100644 index 000000000..731fc6f5d --- /dev/null +++ b/src/ttsim/interface_dag_elements/backend.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from types import ModuleType + +import numpy + +from ttsim.interface_dag_elements.interface_node_objects import interface_function + + +@interface_function(in_top_level_namespace=True) +def xnp(backend: str) -> ModuleType: + """ + Return the backend for numerical operations (either NumPy or jax). + """ + + if backend == "numpy": + xnp = numpy + elif backend == "jax": + try: + import jax + except ImportError: + raise ImportError( + "jax is not installed. Please install jax to use the 'jax' backend." + ) + xnp = jax.numpy + else: + raise ValueError(f"Unsupported backend: {backend}. Choose 'numpy' or 'jax'.") + return xnp + + +@interface_function(in_top_level_namespace=True) +def dnp(backend: str) -> ModuleType: + """ + Return the backend for datetime objects (either NumPy or jax-datetime) + """ + global dnp + + if backend == "numpy": + dnp = numpy + elif backend == "jax": + try: + import jax_datetime + except ImportError: + raise ImportError( + "jax-datetime is not installed. Please install jax-datetime to use the 'jax' backend." + ) + dnp = jax_datetime + return dnp diff --git a/src/ttsim/interface_dag_elements/data_converters.py b/src/ttsim/interface_dag_elements/data_converters.py index 5909da4c1..0b2627a6a 100644 --- a/src/ttsim/interface_dag_elements/data_converters.py +++ b/src/ttsim/interface_dag_elements/data_converters.py @@ -1,5 +1,6 @@ from __future__ import annotations +from types import ModuleType from typing import TYPE_CHECKING import dags.tree as dt @@ -66,6 +67,7 @@ def nested_data_to_df_with_mapped_columns( def dataframe_to_nested_data( mapper: NestedInputsMapper, df: pd.DataFrame, + xnp: ModuleType, ) -> NestedData: """Transform a pandas DataFrame to a nested dictionary expected by TTSIM. ` @@ -107,26 +109,28 @@ def dataframe_to_nested_data( >>> result { "n1": { - "n2": pd.Series([1, 2, 3]), - "n3": pd.Series([4, 5, 6]), + "n2": np.array([1, 2, 3]), + "n3": np.array([4, 5, 6]), }, - "n4": pd.Series([3, 3, 3]), + "n4": np.array([3, 3, 3]), } """ qualified_inputs_tree_to_df_columns = dt.flatten_to_qual_names(mapper) - name_to_input_series = {} + name_to_input_array = {} for ( qualified_input_name, input_value, ) in qualified_inputs_tree_to_df_columns.items(): if input_value in df.columns: - name_to_input_series[qualified_input_name] = df[input_value] + name_to_input_array[qualified_input_name] = xnp.asarray(df[input_value]) else: - name_to_input_series[qualified_input_name] = pd.Series( - [input_value] * len(df), - index=df.index, + name_to_input_array[qualified_input_name] = xnp.asarray( + pd.Series( + [input_value] * len(df), + index=df.index, + ) ) - return dt.unflatten_from_qual_names(name_to_input_series) + return dt.unflatten_from_qual_names(name_to_input_array) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index a5f9eaf1f..778ed1b92 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -11,7 +11,6 @@ import optree import pandas as pd -from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.interface_node_objects import interface_function from ttsim.interface_dag_elements.shared import get_name_of_group_by_id from ttsim.tt_dag_elements.column_objects_param_function import ( @@ -25,6 +24,10 @@ from ttsim.tt_dag_elements.param_objects import ParamObject if TYPE_CHECKING: + from types import ModuleType + + import numpy + from ttsim.interface_dag_elements.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, @@ -260,7 +263,7 @@ def input_data_tree_is_invalid(input_data__tree: NestedData) -> None: """ assert_valid_ttsim_pytree( tree=input_data__tree, - leaf_checker=lambda leaf: isinstance(leaf, int | pd.Series | np.ndarray), + leaf_checker=lambda leaf: isinstance(leaf, int | pd.Series | numpy.ndarray), tree_name="input_data__tree", ) p_id = input_data__tree.get("p_id", None) @@ -427,13 +430,13 @@ def non_convertible_objects_in_results_tree( ) -> None: """Fail if results should be converted to a DataFrame but contain non-convertible objects.""" - _numeric_types = (int, float, bool, np.integer, np.floating, np.bool_) + _numeric_types = (int, float, bool, numpy.integer, numpy.floating, numpy.bool_) faulty_paths = [] # TODO: HM doesn't think this will work as is, we'll need to check the length of # the data. Someone might request a policy parameter that is a 3-element array. for path, data in dt.flatten_to_tree_paths(results__tree).items(): - if isinstance(data, (pd.Series, np.ndarray, list)): + if isinstance(data, (pd.Series, numpy.ndarray, list)): if all(isinstance(item, _numeric_types) for item in data): continue else: @@ -741,3 +744,39 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A ) return out + + +def fail_if__dtype_not_int( + data: numpy.ndarray, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of integer type.""" + if not xnp.issubdtype(data.dtype, xnp.integer): + raise TypeError( + f"Data in {agg_func} must be of integer type, but is {data.dtype}." + ) + + +def fail_if__dtype_not_numeric_or_datetime( + data: numpy.ndarray, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of numeric or datetime type.""" + if not xnp.issubdtype(data.dtype, (xnp.number, xnp.datetime64)): + raise TypeError( + f"Data in {agg_func} must be of numeric or datetime type, but is {data.dtype}." + ) + + +def fail_if__dtype_not_numeric_or_boolean( + data: numpy.ndarray, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of numeric or boolean type.""" + if not xnp.issubdtype(data.dtype, (xnp.number, xnp.bool_)): + raise TypeError( + f"Data in {agg_func} must be of numeric or boolean type, but is {data.dtype}." + ) diff --git a/src/ttsim/interface_dag_elements/input_data.py b/src/ttsim/interface_dag_elements/input_data.py index e6ff9368c..1a61e3180 100644 --- a/src/ttsim/interface_dag_elements/input_data.py +++ b/src/ttsim/interface_dag_elements/input_data.py @@ -1,5 +1,6 @@ from __future__ import annotations +from types import ModuleType from typing import TYPE_CHECKING from ttsim.interface_dag_elements.data_converters import dataframe_to_nested_data @@ -33,6 +34,7 @@ def df_with_nested_columns() -> pd.DataFrame: def tree( df_and_mapper__df: pd.DataFrame, df_and_mapper__mapper: NestedInputsMapper, + xnp: ModuleType, ) -> NestedData: """The input DataFrame as a nested data structure. @@ -48,4 +50,5 @@ def tree( return dataframe_to_nested_data( df=df_and_mapper__df, mapper=df_and_mapper__mapper, + xnp=xnp, ) diff --git a/src/ttsim/interface_dag_elements/processed_data.py b/src/ttsim/interface_dag_elements/processed_data.py index 0a1a61683..bd1f5757a 100644 --- a/src/ttsim/interface_dag_elements/processed_data.py +++ b/src/ttsim/interface_dag_elements/processed_data.py @@ -4,16 +4,18 @@ import dags.tree as dt -from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.interface_node_objects import interface_function if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import NestedData, QNameData @interface_function(in_top_level_namespace=True) def processed_data( input_data__tree: NestedData, + xnp: ModuleType, ) -> QNameData: """Process the data for use in the taxes and transfers function. @@ -27,5 +29,15 @@ def processed_data( A DataFrame. """ return { - k: np.asarray(v) for k, v in dt.flatten_to_qual_names(input_data__tree).items() + k: xnp.asarray(v) for k, v in dt.flatten_to_qual_names(input_data__tree).items() + } + + +def process_input_data( + input_data__tree: dict, + xnp: ModuleType, +) -> dict: + """Process input data.""" + return { + k: xnp.asarray(v) for k, v in dt.flatten_to_qual_names(input_data__tree).items() } diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index df9278f93..11f7946fc 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -9,7 +9,6 @@ import yaml from ttsim import main, merge_trees -from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.data_converters import ( nested_data_to_df_with_nested_columns, ) @@ -22,6 +21,7 @@ if TYPE_CHECKING: import datetime from pathlib import Path + from types import ModuleType from ttsim.interface_dag_elements.typing import ( NestedData, @@ -54,13 +54,15 @@ def __init__( path: Path, date: datetime.date, test_dir: Path, + xnp: ModuleType, ) -> None: self.info = info - self.input_tree = optree.tree_map(np.array, input_tree) + self.input_tree = optree.tree_map(xnp.array, input_tree) self.expected_output_tree = expected_output_tree self.path = path self.date = date self.test_dir = test_dir + self.xnp = xnp @property def target_structure(self) -> NestedInputStructureDict: @@ -133,7 +135,9 @@ def execute_test(test: PolicyTest, root: Path, jit: bool = False) -> None: # no ) from e -def load_policy_test_data(test_dir: Path, policy_name: str) -> dict[str, PolicyTest]: +def load_policy_test_data( + test_dir: Path, policy_name: str, xnp: ModuleType +) -> dict[str, PolicyTest]: """Load all tests found by recursively searching test_dir / "test_data" / policy_name @@ -155,6 +159,7 @@ def load_policy_test_data(test_dir: Path, policy_name: str) -> dict[str, PolicyT test_dir=test_dir, raw_test_data=raw_test_data, path_to_yaml=path_to_yaml, + xnp=xnp, ) out[this_test.name] = this_test @@ -169,6 +174,7 @@ def _get_policy_test_from_raw_test_data( test_dir: Path, path_to_yaml: Path, raw_test_data: NestedData, + xnp: ModuleType, ) -> PolicyTest: """Get a list of PolicyTest objects from raw test data. @@ -182,7 +188,7 @@ def _get_policy_test_from_raw_test_data( test_info: NestedData = raw_test_data.get("info", {}) input_tree: NestedData = dt.unflatten_from_tree_paths( { - k: np.array(v) + k: xnp.array(v) for k, v in dt.flatten_to_tree_paths( merge_trees( left=raw_test_data["inputs"].get("provided", {}), @@ -194,7 +200,7 @@ def _get_policy_test_from_raw_test_data( expected_output_tree: NestedData = dt.unflatten_from_tree_paths( { - k: np.array(v) + k: xnp.array(v) for k, v in dt.flatten_to_tree_paths( raw_test_data.get("outputs", {}) ).items() @@ -210,4 +216,5 @@ def _get_policy_test_from_raw_test_data( path=path_to_yaml, date=date, test_dir=test_dir, + xnp=xnp, ) diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index 679a63f59..164f1297b 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -40,9 +40,9 @@ from ttsim.tt_dag_elements.vectorization import vectorize_function if TYPE_CHECKING: + import numpy import pandas as pd - from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.typing import ( DashedISOString, GenericCallable, diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index 3c98cebfc..cc3056877 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -6,7 +6,10 @@ import numpy -from ttsim.config import numpy_or_jax as np +if TYPE_CHECKING: + from types import ModuleType + + import numpy from ttsim.tt_dag_elements.column_objects_param_function import param_function if TYPE_CHECKING: @@ -143,9 +146,9 @@ def __post_init__(self) -> None: class PiecewisePolynomialParamValue: """The parameters expected by piecewise_polynomial""" - thresholds: np.ndarray - intercepts: np.ndarray - rates: np.ndarray + thresholds: numpy.ndarray + intercepts: numpy.ndarray + rates: numpy.ndarray @dataclass(frozen=True) @@ -153,7 +156,7 @@ class ConsecutiveInt1dLookupTableParamValue: """The parameters expected by lookup_table""" base_to_subtract: int - values_to_look_up: np.ndarray + values_to_look_up: numpy.ndarray @dataclass(frozen=True) @@ -162,44 +165,46 @@ class ConsecutiveInt2dLookupTableParamValue: base_to_subtract_rows: int base_to_subtract_cols: int - values_to_look_up: np.ndarray + values_to_look_up: numpy.ndarray def get_consecutive_int_1d_lookup_table_param_value( raw: dict[int, float | int | bool], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Get the parameters for a 1-dimensional lookup table.""" - lookup_keys = numpy.asarray(sorted(raw)) - assert (lookup_keys - min(lookup_keys) == np.arange(len(lookup_keys))).all(), ( + lookup_keys = xnp.asarray(sorted(raw)) + assert (lookup_keys - xnp.min(lookup_keys) == xnp.arange(len(lookup_keys))).all(), ( "Dictionary keys must be consecutive integers." ) return ConsecutiveInt1dLookupTableParamValue( - base_to_subtract=min(lookup_keys), - values_to_look_up=np.asarray([raw[k] for k in lookup_keys]), + base_to_subtract=xnp.min(lookup_keys), + values_to_look_up=xnp.asarray([raw[k] for k in lookup_keys]), ) def get_consecutive_int_2d_lookup_table_param_value( raw: dict[int, dict[int, float | int | bool]], + xnp: ModuleType, ) -> ConsecutiveInt2dLookupTableParamValue: """Get the parameters for a 2-dimensional lookup table.""" - lookup_keys_rows = numpy.asarray(sorted(raw.keys())) - lookup_keys_cols = numpy.asarray(sorted(raw[lookup_keys_rows[0]].keys())) + lookup_keys_rows = xnp.asarray(sorted(raw.keys())) + lookup_keys_cols = xnp.asarray(sorted(raw[lookup_keys_rows[0]].keys())) for col_value in raw.values(): - lookup_keys_this_col = numpy.asarray(sorted(col_value.keys())) + lookup_keys_this_col = xnp.asarray(sorted(col_value.keys())) assert (lookup_keys_cols == lookup_keys_this_col).all(), ( "Column keys must be the same in each column, got:" f"{lookup_keys_cols} and {lookup_keys_this_col}" ) for lookup_keys in lookup_keys_rows, lookup_keys_cols: - assert (lookup_keys - min(lookup_keys) == np.arange(len(lookup_keys))).all(), ( - f"Dictionary keys must be consecutive integers, got: {lookup_keys}" - ) + assert ( + lookup_keys - xnp.min(lookup_keys) == xnp.arange(len(lookup_keys)) + ).all(), f"Dictionary keys must be consecutive integers, got: {lookup_keys}" return ConsecutiveInt2dLookupTableParamValue( - base_to_subtract_rows=min(lookup_keys_rows), - base_to_subtract_cols=min(lookup_keys_cols), - values_to_look_up=np.array( + base_to_subtract_rows=xnp.min(lookup_keys_rows), + base_to_subtract_cols=xnp.min(lookup_keys_cols), + values_to_look_up=xnp.array( [ raw[row][col] for row, col in itertools.product(lookup_keys_rows, lookup_keys_cols) @@ -214,6 +219,7 @@ def _year_fraction(r: dict[Literal["years", "months"], int]) -> float: def get_month_based_phase_inout_of_age_thresholds_param_value( raw: dict[str | int, Any], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Get the parameters for month-based phase-in/phase-out of age thresholds. @@ -240,13 +246,13 @@ def _fill_phase_inout( first_m_since_ad_to_consider = _m_since_ad(y=raw.pop("first_year_to_consider"), m=1) last_m_since_ad_to_consider = _m_since_ad(y=raw.pop("last_year_to_consider"), m=12) assert all(isinstance(k, int) for k in raw) - first_year_phase_inout: int = min(raw.keys()) # type: ignore[assignment] - first_month_phase_inout: int = min(raw[first_year_phase_inout].keys()) + first_year_phase_inout: int = xnp.min(raw.keys()) # type: ignore[assignment] + first_month_phase_inout: int = xnp.min(raw[first_year_phase_inout].keys()) first_m_since_ad_phase_inout = _m_since_ad( y=first_year_phase_inout, m=first_month_phase_inout ) - last_year_phase_inout: int = max(raw.keys()) # type: ignore[assignment] - last_month_phase_inout: int = max(raw[last_year_phase_inout].keys()) + last_year_phase_inout: int = xnp.max(raw.keys()) # type: ignore[assignment] + last_month_phase_inout: int = xnp.max(raw[last_year_phase_inout].keys()) last_m_since_ad_phase_inout = _m_since_ad( y=last_year_phase_inout, m=last_month_phase_inout ) @@ -268,12 +274,13 @@ def _fill_phase_inout( ) } return get_consecutive_int_1d_lookup_table_param_value( - {**before_phase_inout, **during_phase_inout, **after_phase_inout} + {**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp ) def get_year_based_phase_inout_of_age_thresholds_param_value( raw: dict[str | int, Any], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Get the parameters for year-based phase-in/phase-out of age thresholds. @@ -283,8 +290,8 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( first_year_to_consider = raw.pop("first_year_to_consider") last_year_to_consider = raw.pop("last_year_to_consider") assert all(isinstance(k, int) for k in raw) - first_year_phase_inout: int = min(raw.keys()) # type: ignore[assignment] - last_year_phase_inout: int = max(raw.keys()) # type: ignore[assignment] + first_year_phase_inout: int = xnp.min(raw.keys()) # type: ignore[assignment] + last_year_phase_inout: int = xnp.max(raw.keys()) # type: ignore[assignment] assert first_year_to_consider <= first_year_phase_inout assert last_year_to_consider >= last_year_phase_inout before_phase_inout: dict[int, float] = { @@ -300,5 +307,5 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( for b_y in range(last_year_phase_inout + 1, last_year_to_consider + 1) } return get_consecutive_int_1d_lookup_table_param_value( - {**before_phase_inout, **during_phase_inout, **after_phase_inout} + {**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp ) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 0735ac0d2..5420923e9 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -1,11 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, get_args +from typing import TYPE_CHECKING, Literal, get_args import numpy -from ttsim.config import numpy_or_jax as np +if TYPE_CHECKING: + from types import ModuleType + + import numpy from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue FUNC_TYPES = Literal[ @@ -47,41 +50,44 @@ class RatesOptions: def piecewise_polynomial( - x: np.ndarray, + x: numpy.ndarray, parameters: PiecewisePolynomialParamValue, - rates_multiplier: np.ndarray = 1.0, -) -> np.ndarray: + rates_multiplier: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Calculate value of the piecewise function at `x`. If the first interval begins at -inf the polynomial of that interval can only have slope of 0. Requesting a value outside of the provided thresholds will lead to undefined behaviour. Parameters ---------- - x : np.ndarray + x : numpy.ndarray Array with values at which the piecewise polynomial is to be calculated. thresholds : np.array A one-dimensional array containing the thresholds for all intervals. - coefficients : np.ndarray + coefficients : numpy.ndarray A two-dimensional array where columns are interval sections and rows correspond to the coefficient of the nth polynomial. - intercepts : np.ndarray + intercepts : numpy.ndarray The intercepts at the lower threshold of each interval. - rates_multiplier : np.ndarray + rates_multiplier : numpy.ndarray Multiplier to create individual or scaled rates. + xnp : ModuleType + The numpy module to use for calculations. Returns ------- - out : np.ndarray + out : numpy.ndarray The value of `x` under the piecewise function. """ order = parameters.rates.shape[0] # Get interval of requested value - selected_bin = np.searchsorted(parameters.thresholds, x, side="right") - 1 + selected_bin = xnp.searchsorted(parameters.thresholds, x, side="right") - 1 coefficients = parameters.rates[:, selected_bin].T # Calculate distance from X to lower threshold - increment_to_calc = np.where( - parameters.thresholds[selected_bin] == -np.inf, + increment_to_calc = xnp.where( + parameters.thresholds[selected_bin] == -xnp.inf, 0, x - parameters.thresholds[selected_bin], ) @@ -89,11 +95,11 @@ def piecewise_polynomial( out = ( parameters.intercepts[selected_bin] + ( - ((increment_to_calc.reshape(-1, 1)) ** np.arange(1, order + 1, 1)) + ((increment_to_calc.reshape(-1, 1)) ** xnp.arange(1, order + 1, 1)) * (coefficients) ).sum(axis=1) ) * rates_multiplier - return np.squeeze(out) + return xnp.squeeze(out) def get_piecewise_parameters( @@ -151,7 +157,8 @@ def get_piecewise_parameters( def check_and_get_thresholds( leaf_name: str, parameter_dict: dict[int, dict[str, float]], -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + xnp: ModuleType, +) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """Check and transfer raw threshold data. Transfer and check raw threshold data, which needs to be specified in a @@ -162,6 +169,8 @@ def check_and_get_thresholds( parameter_dict leaf_name keys + xnp : ModuleType + The numpy module to use for calculations. Returns ------- @@ -186,7 +195,7 @@ def check_and_get_thresholds( upper_thresholds[keys[-1]] = parameter_dict[keys[-1]]["upper_threshold"] # Check if the function is defined on the complete real line - if (upper_thresholds[keys[-1]] != numpy.inf) | (lower_thresholds[0] != -numpy.inf): + if (upper_thresholds[keys[-1]] != xnp.inf) | (lower_thresholds[0] != -xnp.inf): raise ValueError(f"{leaf_name} needs to be defined on the entire real line.") for interval in keys[1:]: @@ -211,19 +220,24 @@ def check_and_get_thresholds( f" threshold in the piece after." ) - if not numpy.allclose(lower_thresholds[1:], upper_thresholds[:-1]): + if not xnp.allclose(lower_thresholds[1:], upper_thresholds[:-1]): raise ValueError( f"The lower and upper thresholds of {leaf_name} have to coincide" ) thresholds = sorted([lower_thresholds[0], *upper_thresholds]) - return np.array(lower_thresholds), np.array(upper_thresholds), np.array(thresholds) + return ( + xnp.array(lower_thresholds), + xnp.array(upper_thresholds), + xnp.array(thresholds), + ) def _check_and_get_rates( leaf_name: str, func_type: FUNC_TYPES, parameter_dict: dict[int, dict[str, float]], -) -> np.ndarray: + xnp: ModuleType, +) -> numpy.ndarray: """Check and transfer raw rates data. Transfer and check raw rates data, which needs to be specified in a @@ -235,6 +249,8 @@ def _check_and_get_rates( leaf_name keys func_type + xnp : ModuleType + The numpy module to use for calculations. Returns ------- @@ -250,16 +266,17 @@ def _check_and_get_rates( raise ValueError( f"In interval {interval} of {leaf_name}, {rate_type} is missing." ) - return np.array(rates) + return xnp.array(rates) def _check_and_get_intercepts( leaf_name: str, parameter_dict: dict[int, dict[str, float]], - lower_thresholds: np.ndarray, - upper_thresholds: np.ndarray, - rates: np.ndarray, -) -> np.ndarray: + lower_thresholds: numpy.ndarray, + upper_thresholds: numpy.ndarray, + rates: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Check and transfer raw intercept data. If necessary create intercepts. Transfer and check raw rates data, which needs to be specified in a @@ -273,6 +290,8 @@ def _check_and_get_intercepts( upper_thresholds rates keys + xnp : ModuleType + The numpy module to use for calculations. Returns ------- @@ -304,17 +323,18 @@ def _check_and_get_intercepts( else: intercepts = _create_intercepts( - lower_thresholds, upper_thresholds, rates, intercepts[0] + lower_thresholds, upper_thresholds, rates, intercepts[0], xnp=xnp ) - return np.array(intercepts) + return xnp.array(intercepts) def _create_intercepts( - lower_thresholds: np.ndarray, - upper_thresholds: np.ndarray, - rates: np.ndarray, - intercept_at_lowest_threshold: np.ndarray, -) -> np.ndarray: + lower_thresholds: numpy.ndarray, + upper_thresholds: numpy.ndarray, + rates: numpy.ndarray, + intercept_at_lowest_threshold: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """Create intercepts from raw data. Parameters @@ -335,12 +355,14 @@ def _create_intercepts( fun: function handle (currently only piecewise_linear, will need to think about whether we can have a generic function with a different interface or make it specific ) + xnp : ModuleType + The numpy module to use for calculations. Returns ------- """ - intercepts = numpy.full_like(upper_thresholds, numpy.nan) + intercepts = numpy.full_like(upper_thresholds, xnp.nan) intercepts[0] = intercept_at_lowest_threshold for i, up_thr in enumerate(upper_thresholds[:-1]): intercepts[i + 1] = _calculate_one_intercept( @@ -349,16 +371,18 @@ def _create_intercepts( upper_thresholds=upper_thresholds, rates=rates, intercepts=intercepts, + xnp=xnp, ) - return np.array(intercepts) + return xnp.array(intercepts) def _calculate_one_intercept( x: float, - lower_thresholds: np.ndarray, - upper_thresholds: np.ndarray, - rates: np.ndarray, - intercepts: np.ndarray, + lower_thresholds: numpy.ndarray, + upper_thresholds: numpy.ndarray, + rates: numpy.ndarray, + intercepts: numpy.ndarray, + xnp: ModuleType, ) -> float: """Calculate the intercepts from the raw data. @@ -375,6 +399,8 @@ def _calculate_one_intercept( to the nth polynomial. intercepts : numpy.ndarray The intercepts at the lower threshold of each interval. + xnp : ModuleType + The numpy module to use for calculations. Returns ------- @@ -384,15 +410,15 @@ def _calculate_one_intercept( """ # Check if value lies within the defined range. - if (x < lower_thresholds[0]) or (x > upper_thresholds[-1]) or numpy.isnan(x): - return numpy.nan - index_interval = numpy.searchsorted(upper_thresholds, x, side="left") + if (x < lower_thresholds[0]) or (x > upper_thresholds[-1]) or xnp.isnan(x): + return xnp.nan + index_interval = xnp.searchsorted(upper_thresholds, x, side="left") intercept_interval = intercepts[index_interval] # Select threshold and calculate corresponding increment into interval lower_threshold_interval = lower_thresholds[index_interval] - if lower_threshold_interval == -numpy.inf: + if lower_threshold_interval == -xnp.inf: return intercept_interval increment_to_calc = x - lower_threshold_interval diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt_dag_elements/rounding.py index 9d26d8ac1..b09c5df4d 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt_dag_elements/rounding.py @@ -4,12 +4,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, ParamSpec, get_args -from ttsim.config import numpy_or_jax as np - -ROUNDING_DIRECTION = Literal["up", "down", "nearest"] - if TYPE_CHECKING: from collections.abc import Callable + from types import ModuleType + + +ROUNDING_DIRECTION = Literal["up", "down", "nearest"] P = ParamSpec("P") @@ -35,30 +35,35 @@ def __post_init__(self) -> None: f"Additive part must be a number, got {self.to_add_after_rounding!r}" ) - def apply_rounding(self, func: Callable[P, np.ndarray]) -> Callable[P, np.ndarray]: - """Decorator to round the output of a function. + def apply_rounding( + self, func: Callable[P, numpy.ndarray] + ) -> Callable[P, numpy.ndarray]: + """Decorator to round the output of a function. The wrapped function must accept an xnp: ModuleType argument for numpy operations. Parameters ---------- func - Function to be rounded. + Function to be rounded. Must accept xnp: ModuleType as a parameter. Returns ------- Function with rounding applied. """ - # Make sure that signature is preserved. @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> np.ndarray: - out = func(*args, **kwargs) + def wrapper( + *args: P.args, xnp: ModuleType, **kwargs: P.kwargs + ) -> numpy.ndarray: + out = func(*args, xnp=xnp, **kwargs) if self.direction == "up": - rounded_out = self.base * np.ceil(out / self.base) + rounded_out = self.base * xnp.ceil(out / self.base) elif self.direction == "down": - rounded_out = self.base * np.floor(out / self.base) + rounded_out = self.base * xnp.floor(out / self.base) elif self.direction == "nearest": - rounded_out = self.base * (np.asarray(out) / self.base).round() + rounded_out = self.base * (xnp.asarray(out) / self.base).round() + else: + raise ValueError(f"Invalid rounding direction: {self.direction}") rounded_out += self.to_add_after_rounding return rounded_out diff --git a/src/ttsim/tt_dag_elements/shared.py b/src/ttsim/tt_dag_elements/shared.py index cb98046e5..66388d267 100644 --- a/src/ttsim/tt_dag_elements/shared.py +++ b/src/ttsim/tt_dag_elements/shared.py @@ -1,28 +1,36 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + + import numpy def join( - foreign_key: np.ndarray, - primary_key: np.ndarray, - target: np.ndarray, + foreign_key: numpy.ndarray, + primary_key: numpy.ndarray, + target: numpy.ndarray, value_if_foreign_key_is_missing: float | bool, -) -> np.ndarray: + xnp: ModuleType, +) -> numpy.ndarray: """ Given a foreign key, find the corresponding primary key, and return the target at the same index as the primary key. When using Jax, does not work on String Arrays. Parameters ---------- - foreign_key : np.ndarray[Key] + foreign_key : numpy.ndarray[Key] The foreign keys. - primary_key : np.ndarray[Key] + primary_key : numpy.ndarray[Key] The primary keys. - target : np.ndarray[Out] + target : numpy.ndarray[Out] The targets in the same order as the primary keys. value_if_foreign_key_is_missing : Out The value to return if no matching primary key is found. + xnp : ModuleType + The numpy module to use for calculations. Returns ------- @@ -33,15 +41,15 @@ def join( # For each foreign key, add a column with True at the end, to later fall back to # the value for unresolved foreign keys - padded_matches_foreign_key = np.pad( + padded_matches_foreign_key = xnp.pad( matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True ) # For each foreign key, compute the index of the first matching primary key - indices = np.argmax(padded_matches_foreign_key, axis=1) + indices = xnp.argmax(padded_matches_foreign_key, axis=1) # Add the value for unresolved foreign keys at the end of the target array - padded_targets = np.pad( + padded_targets = xnp.pad( target, (0, 1), "constant", constant_values=value_if_foreign_key_is_missing ) diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt_dag_elements/vectorization.py index 36b7bc872..7cc9cd120 100644 --- a/src/ttsim/tt_dag_elements/vectorization.py +++ b/src/ttsim/tt_dag_elements/vectorization.py @@ -6,12 +6,15 @@ import textwrap import types from importlib import import_module +from types import ModuleType from typing import TYPE_CHECKING, Literal, cast import numpy -from ttsim.config import IS_JAX_INSTALLED -from ttsim.config import numpy_or_jax as np +if TYPE_CHECKING: + from types import ModuleType + + import numpy if TYPE_CHECKING: from ttsim.interface_dag_elements.typing import GenericCallable @@ -22,6 +25,8 @@ def vectorize_function( func: GenericCallable, vectorization_strategy: Literal["loop", "vectorize"], + backend: Literal["numpy", "jax"], + xnp: ModuleType, ) -> GenericCallable: vectorized: GenericCallable if vectorization_strategy == "loop": @@ -30,8 +35,7 @@ def vectorize_function( vectorized.__globals__ = func.__globals__ vectorized.__closure__ = func.__closure__ elif vectorization_strategy == "vectorize": - backend = "jax" if IS_JAX_INSTALLED else "numpy" - vectorized = _make_vectorizable(func, backend=backend) + vectorized = _make_vectorizable(func, backend=backend, xnp=xnp) else: raise ValueError( f"Vectorization strategy {vectorization_strategy} is not supported. " @@ -40,7 +44,9 @@ def vectorize_function( return vectorized -def _make_vectorizable(func: GenericCallable, backend: str) -> GenericCallable: +def _make_vectorizable( + func: GenericCallable, backend: str, xnp: ModuleType +) -> GenericCallable: """Redefine function to be vectorizable given backend. Args: @@ -58,7 +64,7 @@ def _make_vectorizable(func: GenericCallable, backend: str) -> GenericCallable: ) module = _module_from_backend(backend) - tree = _make_vectorizable_ast(func, module=module) + tree = _make_vectorizable_ast(func, module=module, xnp=xnp) # recreate scope of function, add array library scope = dict(func.__globals__) @@ -101,7 +107,9 @@ def make_vectorizable_source(func: GenericCallable, backend: str) -> str: return ast.unparse(tree) -def _make_vectorizable_ast(func: GenericCallable, module: str) -> ast.Module: +def _make_vectorizable_ast( + func: GenericCallable, module: str, xnp: ModuleType +) -> ast.Module: """Change if statement to where call in the ast of func and return new ast. Args: @@ -117,7 +125,7 @@ def _make_vectorizable_ast(func: GenericCallable, module: str) -> ast.Module: func_loc = f"{func.__module__}/{func.__name__}" # transform tree nodes - new_tree = Transformer(module, func_loc).visit(tree) + new_tree = Transformer(module, func_loc, xnp).visit(tree) return ast.fix_missing_locations(new_tree) @@ -142,14 +150,15 @@ def _remove_decorator_lines(source: str) -> str: class Transformer(ast.NodeTransformer): - def __init__(self, module: str, func_loc: str) -> None: + def __init__(self, module: str, func_loc: str, xnp: ModuleType) -> None: self.module = module self.func_loc = func_loc + self.xnp = xnp def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802 self.generic_visit(node) return _call_to_call_from_module( - node, module=self.module, func_loc=self.func_loc + node, module=self.module, func_loc=self.func_loc, xnp=self.xnp ) def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.UnaryOp | ast.Call: # noqa: N802 @@ -283,7 +292,9 @@ def _constructor(left: ast.Call | ast.expr, right: ast.Call | ast.expr) -> ast.C return cast("ast.Call", functools.reduce(_constructor, values)) -def _call_to_call_from_module(node: ast.Call, module: str, func_loc: str) -> ast.AST: +def _call_to_call_from_module( + node: ast.Call, module: str, func_loc: str, xnp: ModuleType +) -> ast.AST: """Transform built-in Calls to Calls from module.""" to_transform = ("sum", "any", "all", "max", "min") @@ -297,7 +308,7 @@ def _call_to_call_from_module(node: ast.Call, module: str, func_loc: str) -> ast args = node.args if len(args) == 1: - if type(args) not in (list, tuple, np.ndarray): + if type(args) not in (list, tuple, xnp.ndarray): raise TranslateToVectorizableError( f"Argument of function {func_id} is not a list, tuple, or valid array." f"\n\nFunction: {func_loc}\n\n" diff --git a/tests/ttsim/mettsim/group_by_ids.py b/tests/ttsim/mettsim/group_by_ids.py index 566d4aaaa..2ac6b8161 100644 --- a/tests/ttsim/mettsim/group_by_ids.py +++ b/tests/ttsim/mettsim/group_by_ids.py @@ -1,50 +1,59 @@ from __future__ import annotations -from ttsim.config import numpy_or_jax as np +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + + import numpy from ttsim.tt_dag_elements import group_creation_function @group_creation_function() def sp_id( - p_id: np.ndarray, - p_id_spouse: np.ndarray, -) -> np.ndarray: + p_id: numpy.ndarray, p_id_spouse: numpy.ndarray, xnp: ModuleType +) -> numpy.ndarray: """ Compute the spouse (sp) group ID for each person. """ - n = np.max(p_id) - p_id_spouse = np.where(p_id_spouse < 0, p_id, p_id_spouse) - sp_id = np.maximum(p_id, p_id_spouse) + np.minimum(p_id, p_id_spouse) * n + n = numpy.max(p_id) + p_id_spouse = numpy.where(p_id_spouse < 0, p_id, p_id_spouse) + sp_id = numpy.maximum(p_id, p_id_spouse) + numpy.minimum(p_id, p_id_spouse) * n - return __reorder_ids(sp_id) + return __reorder_ids(sp_id, xnp) @group_creation_function() def fam_id( - p_id_spouse: np.ndarray, - p_id: np.ndarray, - age: np.ndarray, - p_id_parent_1: np.ndarray, - p_id_parent_2: np.ndarray, -) -> np.ndarray: + p_id_spouse: numpy.ndarray, + p_id: numpy.ndarray, + age: numpy.ndarray, + p_id_parent_1: numpy.ndarray, + p_id_parent_2: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: """ Compute the family ID for each person. """ - n = np.max(p_id) + n = numpy.max(p_id) p_id_parent_1_loc = p_id_parent_1 p_id_parent_2_loc = p_id_parent_2 for i in range(p_id.shape[0]): - p_id_parent_1_loc = np.where(p_id_parent_1_loc == p_id[i], i, p_id_parent_1_loc) - p_id_parent_2_loc = np.where(p_id_parent_2_loc == p_id[i], i, p_id_parent_2_loc) + p_id_parent_1_loc = numpy.where( + p_id_parent_1_loc == p_id[i], i, p_id_parent_1_loc + ) + p_id_parent_2_loc = numpy.where( + p_id_parent_2_loc == p_id[i], i, p_id_parent_2_loc + ) - children = np.isin(p_id, p_id_parent_1) + np.isin(p_id, p_id_parent_2) - fam_id = np.where( + children = numpy.isin(p_id, p_id_parent_1) + numpy.isin(p_id, p_id_parent_2) + fam_id = numpy.where( p_id_spouse < 0, p_id + p_id * n, - np.maximum(p_id, p_id_spouse) + np.minimum(p_id, p_id_spouse) * n, + numpy.maximum(p_id, p_id_spouse) + numpy.minimum(p_id, p_id_spouse) * n, ) - fam_id = np.where( + fam_id = numpy.where( (fam_id == p_id + p_id * n) * (p_id_parent_1_loc >= 0) * (age < 25) @@ -52,7 +61,7 @@ def fam_id( fam_id[p_id_parent_1_loc], fam_id, ) - fam_id = np.where( + fam_id = numpy.where( (fam_id == p_id + p_id * n) * (p_id_parent_2_loc >= 0) * (age < 25) @@ -61,14 +70,14 @@ def fam_id( fam_id, ) - return __reorder_ids(fam_id) + return __reorder_ids(fam_id, xnp) -def __reorder_ids(ids: np.ndarray) -> np.ndarray: +def __reorder_ids(ids: numpy.ndarray, xnp: ModuleType) -> numpy.ndarray: """Make ID's consecutively numbered.""" - sorting = np.argsort(ids) + sorting = xnp.argsort(ids) ids_sorted = ids[sorting] - index_after_sort = np.arange(ids.shape[0])[sorting] - diff_to_prev = np.where(np.diff(ids_sorted) >= 1, 1, 0) - cons_ids = np.concatenate((np.asarray([0]), np.cumsum(diff_to_prev))) - return cons_ids[np.argsort(index_after_sort)] + index_after_sort = xnp.arange(ids.shape[0])[sorting] + diff_to_prev = xnp.where(xnp.diff(ids_sorted) >= 1, 1, 0) + cons_ids = xnp.concatenate((xnp.asarray([0]), xnp.cumsum(diff_to_prev))) + return cons_ids[xnp.argsort(index_after_sort)] diff --git a/tests/ttsim/test_convert_nested_data.py b/tests/ttsim/test_convert_nested_data.py index 0c55f7192..4b651c6f9 100644 --- a/tests/ttsim/test_convert_nested_data.py +++ b/tests/ttsim/test_convert_nested_data.py @@ -1,7 +1,7 @@ from __future__ import annotations import dags.tree as dt -import numpy as np +import numpy import pandas as pd import pytest @@ -54,8 +54,8 @@ def int_param_function() -> int: @pytest.fixture def minimal_data_tree(): return { - "hh_id": np.array([1, 2, 3]), - "p_id": np.array([1, 2, 3]), + "hh_id": numpy.array([1, 2, 3]), + "p_id": numpy.array([1, 2, 3]), } @@ -105,6 +105,7 @@ def test_dataframe_to_nested_data( result = dataframe_to_nested_data( mapper=inputs_tree_to_df_columns, df=df, + xnp=numpy, ) flat_result = dt.flatten_to_qual_names(result) flat_expected_output = dt.flatten_to_qual_names(expected_output) @@ -112,7 +113,7 @@ def test_dataframe_to_nested_data( assert set(flat_result.keys()) == set(flat_expected_output.keys()) for key in flat_result: pd.testing.assert_series_equal( - flat_result[key], flat_expected_output[key], check_names=False + pd.Series(flat_result[key]), flat_expected_output[key], check_names=False ) @@ -134,7 +135,7 @@ def test_dataframe_to_nested_data( "another_policy_function": "res2", }, pd.DataFrame( - {"res1": np.array([1, 1, 1]), "res2": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1]), "res2": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), @@ -147,7 +148,7 @@ def test_dataframe_to_nested_data( "some_policy_function": "res1", }, pd.DataFrame( - {"res1": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), @@ -160,7 +161,7 @@ def test_dataframe_to_nested_data( "some_param_function": "res1", }, pd.DataFrame( - {"res1": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), @@ -175,7 +176,7 @@ def test_dataframe_to_nested_data( "some_policy_function": "res2", }, pd.DataFrame( - {"res1": np.array([1, 1, 1]), "res2": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1]), "res2": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), @@ -186,7 +187,7 @@ def test_dataframe_to_nested_data( }, {"some_scalar_param": "res1"}, pd.DataFrame( - {"res1": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), @@ -201,7 +202,7 @@ def test_dataframe_to_nested_data( "some_policy_function": "res2", }, pd.DataFrame( - {"res1": np.array([1, 1, 1]), "res2": np.array([1, 1, 1])}, + {"res1": numpy.array([1, 1, 1]), "res2": numpy.array([1, 1, 1])}, index=pd.Index([1, 2, 3], name="p_id"), ), ), diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index bf5f0ac77..8012a9926 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -12,7 +12,9 @@ from mettsim.config import METTSIM_ROOT from ttsim import main -from ttsim.config import numpy_or_jax as np + +if TYPE_CHECKING: + import numpy from ttsim.interface_dag_elements.fail_if import ( ConflictingActivePeriodsError, _param_with_active_periods, @@ -63,7 +65,7 @@ leaf_name="some_consecutive_int_1d_lookup_table_param", value=ConsecutiveInt1dLookupTableParamValue( base_to_subtract=1, - values_to_look_up=np.array([1, 2, 3]), + values_to_look_up=numpy.array([1, 2, 3]), ), **_GENERIC_PARAM_SPEC, ) @@ -78,9 +80,9 @@ _SOME_PIECEWISE_POLYNOMIAL_PARAM = PiecewisePolynomialParam( leaf_name="some_piecewise_polynomial_param", value=PiecewisePolynomialParamValue( - thresholds=np.array([1, 2, 3]), - intercepts=np.array([1, 2, 3]), - rates=np.array([1, 2, 3]), + thresholds=numpy.array([1, 2, 3]), + intercepts=numpy.array([1, 2, 3]), + rates=numpy.array([1, 2, 3]), ), **_GENERIC_PARAM_SPEC, ) @@ -89,8 +91,8 @@ @pytest.fixture def minimal_data_tree(): return { - "hh_id": np.array([1, 2, 3]), - "p_id": np.array([1, 2, 3]), + "hh_id": numpy.array([1, 2, 3]), + "p_id": numpy.array([1, 2, 3]), } @@ -615,9 +617,9 @@ def test_fail_if_group_ids_are_outside_top_level_namespace(): def test_fail_if_group_variables_are_not_constant_within_groups(): data = { - "p_id": np.array([0, 1, 2]), - "foo_kin": np.array([1, 2, 2]), - "kin_id": np.array([1, 1, 2]), + "p_id": numpy.array([0, 1, 2]), + "foo_kin": numpy.array([1, 2, 2]), + "kin_id": numpy.array([1, 1, 2]), } with pytest.raises(ValueError): group_variables_are_not_constant_within_groups( diff --git a/tests/ttsim/test_mettsim.py b/tests/ttsim/test_mettsim.py index 5eaddde4c..26359ba4a 100644 --- a/tests/ttsim/test_mettsim.py +++ b/tests/ttsim/test_mettsim.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy import pytest from mettsim.config import METTSIM_ROOT @@ -15,7 +16,9 @@ TEST_DIR = Path(__file__).parent -POLICY_TEST_IDS_AND_CASES = load_policy_test_data(test_dir=TEST_DIR, policy_name="") +POLICY_TEST_IDS_AND_CASES = load_policy_test_data( + test_dir=TEST_DIR, policy_name="", xnp=numpy +) @pytest.mark.parametrize( diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index 3b9bf63a0..7bb942f77 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -15,7 +15,9 @@ merge_trees, ) from ttsim.config import IS_JAX_INSTALLED -from ttsim.config import numpy_or_jax as np + +if TYPE_CHECKING: + import numpy from ttsim.interface_dag_elements.specialized_environment import ( with_partialled_params_and_scalars, with_processed_params_and_scalars, @@ -422,7 +424,7 @@ def test_create_agg_by_group_functions( "input_data__tree": input_data__tree, "targets__tree": targets__tree, "rounding": False, - # "jit": jit, + "backend": "numpy", }, targets=["results__tree"], )["results__tree"] @@ -440,7 +442,7 @@ def test_output_is_tree(minimal_input_data): "policy_environment": policy_environment, "targets__tree": {"module": {"some_func": None}}, "rounding": False, - # "jit": jit, + "backend": "numpy", }, targets=["results__tree"], )["results__tree"] @@ -474,7 +476,7 @@ def test_params_target_is_allowed(minimal_input_data): "policy_environment": policy_environment, "targets__tree": {"some_param": None, "module": {"some_func": None}}, "rounding": False, - # "jit": jit, + "backend": "numpy", }, targets=["results__tree"], )["results__tree"] @@ -486,8 +488,8 @@ def test_params_target_is_allowed(minimal_input_data): def test_function_without_data_dependency_is_not_mistaken_for_data(minimal_input_data): @policy_function(leaf_name="a", vectorization_strategy="not_required") - def a() -> np.ndarray: - return np.array(minimal_input_data["p_id"]) + def a() -> numpy.ndarray: + return numpy.array(minimal_input_data["p_id"]) @policy_function(leaf_name="b") def b(a): @@ -503,12 +505,12 @@ def b(a): "policy_environment": policy_environment, "targets__tree": {"b": None}, "rounding": False, - # "jit": jit, + "backend": "numpy", }, targets=["results__tree"], )["results__tree"] numpy.testing.assert_array_almost_equal( - results__tree["b"], np.array(minimal_input_data["p_id"]) + results__tree["b"], numpy.array(minimal_input_data["p_id"]) ) @@ -696,7 +698,7 @@ def test_user_provided_aggregate_by_p_id_specs( ): @policy_function(leaf_name=leaf_name, vectorization_strategy="not_required") def source() -> int: - return np.array([100, 200, 300]) + return numpy.array([100, 200, 300]) policy_environment = merge_trees( agg_functions, @@ -769,10 +771,10 @@ def test_policy_environment_with_params_and_scalars_is_processed(): "identity_plus_one": identity_plus_one, }, { - "identity": np.array([1, 2, 3, 4, 5]), + "identity": numpy.array([1, 2, 3, 4, 5]), }, {"identity_plus_one": None}, - {"identity_plus_one": np.array([2, 3, 4, 5, 6])}, + {"identity_plus_one": numpy.array([2, 3, 4, 5, 6])}, ), # Overwriting parameter ( @@ -781,10 +783,10 @@ def test_policy_environment_with_params_and_scalars_is_processed(): "some_policy_function_taking_int_param": some_policy_function_taking_int_param, # noqa: E501 }, { - "some_int_param": np.array([1, 2, 3, 4, 5]), + "some_int_param": numpy.array([1, 2, 3, 4, 5]), }, {"some_policy_function_taking_int_param": None}, - {"some_policy_function_taking_int_param": np.array([1, 2, 3, 4, 5])}, + {"some_policy_function_taking_int_param": numpy.array([1, 2, 3, 4, 5])}, ), # Overwriting parameter function ( @@ -794,10 +796,14 @@ def test_policy_environment_with_params_and_scalars_is_processed(): "some_policy_func_taking_scalar_params_func": some_policy_func_taking_scalar_params_func, # noqa: E501 }, { - "some_scalar_params_func": np.array([1, 2, 3, 4, 5]), + "some_scalar_params_func": numpy.array([1, 2, 3, 4, 5]), }, {"some_policy_func_taking_scalar_params_func": None}, - {"some_policy_func_taking_scalar_params_func": np.array([1, 2, 3, 4, 5])}, + { + "some_policy_func_taking_scalar_params_func": numpy.array( + [1, 2, 3, 4, 5] + ) + }, ), ], ) diff --git a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py index 9e1075595..bcc162791 100644 --- a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py +++ b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py @@ -1,12 +1,15 @@ from __future__ import annotations import copy +from typing import TYPE_CHECKING import numpy import pytest from ttsim.config import IS_JAX_INSTALLED -from ttsim.config import numpy_or_jax as np + +if TYPE_CHECKING: + import numpy from ttsim.tt_dag_elements.aggregation import ( grouped_all, grouped_any, @@ -44,109 +47,109 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): test_grouped_specs = { "constant_column": { - "column_to_aggregate": np.array([1, 1, 1, 1, 1]), - "group_id": np.array([0, 0, 1, 1, 1]), - "expected_res_count": np.array([2, 2, 3, 3, 3]), - "expected_res_sum": np.array([2, 2, 3, 3, 3]), - "expected_res_max": np.array([1, 1, 1, 1, 1]), - "expected_res_min": np.array([1, 1, 1, 1, 1]), - "expected_res_any": np.array([True, True, True, True, True]), - "expected_res_all": np.array([True, True, True, True, True]), + "column_to_aggregate": numpy.array([1, 1, 1, 1, 1]), + "group_id": numpy.array([0, 0, 1, 1, 1]), + "expected_res_count": numpy.array([2, 2, 3, 3, 3]), + "expected_res_sum": numpy.array([2, 2, 3, 3, 3]), + "expected_res_max": numpy.array([1, 1, 1, 1, 1]), + "expected_res_min": numpy.array([1, 1, 1, 1, 1]), + "expected_res_any": numpy.array([True, True, True, True, True]), + "expected_res_all": numpy.array([True, True, True, True, True]), }, "constant_column_group_id_unsorted": { - "column_to_aggregate": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), - "group_id": np.array([0, 1, 0, 1, 0]), - "expected_res_count": np.array([3, 2, 3, 2, 3]), - "expected_res_sum": np.array([3.0, 2.0, 3.0, 2.0, 3.0]), - "expected_res_mean": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), - "expected_res_max": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), - "expected_res_min": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "column_to_aggregate": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "group_id": numpy.array([0, 1, 0, 1, 0]), + "expected_res_count": numpy.array([3, 2, 3, 2, 3]), + "expected_res_sum": numpy.array([3.0, 2.0, 3.0, 2.0, 3.0]), + "expected_res_mean": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "expected_res_max": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "expected_res_min": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), }, "basic_case": { - "column_to_aggregate": np.array([0, 1, 2, 3, 4]), - "group_id": np.array([0, 0, 1, 1, 1]), - "expected_res_sum": np.array([1, 1, 9, 9, 9]), - "expected_res_max": np.array([1, 1, 4, 4, 4]), - "expected_res_min": np.array([0, 0, 2, 2, 2]), - "expected_res_any": np.array([True, True, True, True, True]), - "expected_res_all": np.array([False, False, True, True, True]), + "column_to_aggregate": numpy.array([0, 1, 2, 3, 4]), + "group_id": numpy.array([0, 0, 1, 1, 1]), + "expected_res_sum": numpy.array([1, 1, 9, 9, 9]), + "expected_res_max": numpy.array([1, 1, 4, 4, 4]), + "expected_res_min": numpy.array([0, 0, 2, 2, 2]), + "expected_res_any": numpy.array([True, True, True, True, True]), + "expected_res_all": numpy.array([False, False, True, True, True]), }, "unique_group_ids_with_gaps": { - "column_to_aggregate": np.array([0.0, 1.0, 2.0, 3.0, 4.0]), - "group_id": np.array([0, 0, 3, 3, 3]), - "expected_res_count": np.array([2, 2, 3, 3, 3]), - "expected_res_sum": np.array([1.0, 1.0, 9.0, 9.0, 9.0]), - "expected_res_mean": np.array([0.5, 0.5, 3.0, 3.0, 3.0]), - "expected_res_max": np.array([1.0, 1.0, 4.0, 4.0, 4.0]), - "expected_res_min": np.array([0.0, 0.0, 2.0, 2.0, 2.0]), + "column_to_aggregate": numpy.array([0.0, 1.0, 2.0, 3.0, 4.0]), + "group_id": numpy.array([0, 0, 3, 3, 3]), + "expected_res_count": numpy.array([2, 2, 3, 3, 3]), + "expected_res_sum": numpy.array([1.0, 1.0, 9.0, 9.0, 9.0]), + "expected_res_mean": numpy.array([0.5, 0.5, 3.0, 3.0, 3.0]), + "expected_res_max": numpy.array([1.0, 1.0, 4.0, 4.0, 4.0]), + "expected_res_min": numpy.array([0.0, 0.0, 2.0, 2.0, 2.0]), }, "float_column": { - "column_to_aggregate": np.array([0.0, 1.5, 2.0, 3.0, 4.0]), - "group_id": np.array([0, 0, 3, 3, 3]), - "expected_res_sum": np.array([1.5, 1.5, 9.0, 9.0, 9.0]), - "expected_res_mean": np.array([0.75, 0.75, 3.0, 3.0, 3.0]), - "expected_res_max": np.array([1.5, 1.5, 4.0, 4.0, 4.0]), - "expected_res_min": np.array([0.0, 0.0, 2.0, 2.0, 2.0]), + "column_to_aggregate": numpy.array([0.0, 1.5, 2.0, 3.0, 4.0]), + "group_id": numpy.array([0, 0, 3, 3, 3]), + "expected_res_sum": numpy.array([1.5, 1.5, 9.0, 9.0, 9.0]), + "expected_res_mean": numpy.array([0.75, 0.75, 3.0, 3.0, 3.0]), + "expected_res_max": numpy.array([1.5, 1.5, 4.0, 4.0, 4.0]), + "expected_res_min": numpy.array([0.0, 0.0, 2.0, 2.0, 2.0]), }, "more_than_two_groups": { - "column_to_aggregate": np.array([0.0, 1.0, 2.0, 3.0, 4.0]), - "group_id": np.array([1, 0, 1, 1, 3]), - "expected_res_count": np.array([3, 1, 3, 3, 1]), - "expected_res_sum": np.array([5.0, 1.0, 5.0, 5.0, 4.0]), - "expected_res_mean": np.array([5.0 / 3.0, 1.0, 5.0 / 3.0, 5.0 / 3.0, 4.0]), - "expected_res_max": np.array([3.0, 1.0, 3.0, 3.0, 4.0]), - "expected_res_min": np.array([0.0, 1.0, 0.0, 0.0, 4.0]), + "column_to_aggregate": numpy.array([0.0, 1.0, 2.0, 3.0, 4.0]), + "group_id": numpy.array([1, 0, 1, 1, 3]), + "expected_res_count": numpy.array([3, 1, 3, 3, 1]), + "expected_res_sum": numpy.array([5.0, 1.0, 5.0, 5.0, 4.0]), + "expected_res_mean": numpy.array([5.0 / 3.0, 1.0, 5.0 / 3.0, 5.0 / 3.0, 4.0]), + "expected_res_max": numpy.array([3.0, 1.0, 3.0, 3.0, 4.0]), + "expected_res_min": numpy.array([0.0, 1.0, 0.0, 0.0, 4.0]), }, "basic_case_bool": { - "column_to_aggregate": np.array([True, False, True, False, False]), - "group_id": np.array([0, 0, 1, 1, 1]), - "expected_res_any": np.array([True, True, True, True, True]), - "expected_res_all": np.array([False, False, False, False, False]), - "expected_res_sum": np.array([1, 1, 1, 1, 1]), + "column_to_aggregate": numpy.array([True, False, True, False, False]), + "group_id": numpy.array([0, 0, 1, 1, 1]), + "expected_res_any": numpy.array([True, True, True, True, True]), + "expected_res_all": numpy.array([False, False, False, False, False]), + "expected_res_sum": numpy.array([1, 1, 1, 1, 1]), }, "group_id_unsorted_bool": { - "column_to_aggregate": np.array([True, False, True, True, True]), - "group_id": np.array([0, 1, 0, 1, 0]), - "expected_res_any": np.array([True, True, True, True, True]), - "expected_res_all": np.array([True, False, True, False, True]), - "expected_res_sum": np.array([3, 1, 3, 1, 3]), + "column_to_aggregate": numpy.array([True, False, True, True, True]), + "group_id": numpy.array([0, 1, 0, 1, 0]), + "expected_res_any": numpy.array([True, True, True, True, True]), + "expected_res_all": numpy.array([True, False, True, False, True]), + "expected_res_sum": numpy.array([3, 1, 3, 1, 3]), }, "unique_group_ids_with_gaps_bool": { - "column_to_aggregate": np.array([True, False, False, False, False]), - "group_id": np.array([0, 0, 3, 3, 3]), - "expected_res_any": np.array([True, True, False, False, False]), - "expected_res_all": np.array([False, False, False, False, False]), - "expected_res_sum": np.array([1, 1, 0, 0, 0]), + "column_to_aggregate": numpy.array([True, False, False, False, False]), + "group_id": numpy.array([0, 0, 3, 3, 3]), + "expected_res_any": numpy.array([True, True, False, False, False]), + "expected_res_all": numpy.array([False, False, False, False, False]), + "expected_res_sum": numpy.array([1, 1, 0, 0, 0]), }, "sum_by_p_id_float": { - "column_to_aggregate": np.array([10.0, 20.0, 30.0, 40.0, 50.0]), - "p_id_to_aggregate_by": np.array([-1, -1, 8, 8, 10]), - "p_id_to_store_by": np.array([7, 8, 9, 10, 11]), - "expected_res": np.array([0.0, 70.0, 0.0, 50.0, 0.0]), + "column_to_aggregate": numpy.array([10.0, 20.0, 30.0, 40.0, 50.0]), + "p_id_to_aggregate_by": numpy.array([-1, -1, 8, 8, 10]), + "p_id_to_store_by": numpy.array([7, 8, 9, 10, 11]), + "expected_res": numpy.array([0.0, 70.0, 0.0, 50.0, 0.0]), "expected_type": numpy.floating, }, "sum_by_p_id_int": { - "column_to_aggregate": np.array([10, 20, 30, 40, 50]), - "p_id_to_aggregate_by": np.array([-1, -1, 8, 8, 10]), - "p_id_to_store_by": np.array([7, 8, 9, 10, 11]), - "expected_res": np.array([0, 70, 0, 50, 0]), + "column_to_aggregate": numpy.array([10, 20, 30, 40, 50]), + "p_id_to_aggregate_by": numpy.array([-1, -1, 8, 8, 10]), + "p_id_to_store_by": numpy.array([7, 8, 9, 10, 11]), + "expected_res": numpy.array([0, 70, 0, 50, 0]), "expected_type": numpy.integer, }, } test_grouped_raises_specs = { "dtype_boolean": { - "column_to_aggregate": np.array([True, True, True, False, False]), - "group_id": np.array([0, 0, 1, 1, 1]), + "column_to_aggregate": numpy.array([True, True, True, False, False]), + "group_id": numpy.array([0, 0, 1, 1, 1]), "error_mean": TypeError, "error_max": TypeError, "error_min": TypeError, "exception_match": "grouped_", }, "float_group_id": { - "column_to_aggregate": np.array([0, 1, 2, 3, 4]), - "group_id": np.array([0, 0, 3.5, 3.5, 3.5]), - "p_id_to_store_by": np.array([0, 1, 2, 3, 4]), + "column_to_aggregate": numpy.array([0, 1, 2, 3, 4]), + "group_id": numpy.array([0, 0, 3.5, 3.5, 3.5]), + "p_id_to_store_by": numpy.array([0, 1, 2, 3, 4]), "error_sum": TypeError, "error_mean": TypeError, "error_max": TypeError, @@ -155,21 +158,21 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "exception_match": "The dtype of id columns must be integer.", }, "dtype_float": { - "column_to_aggregate": np.array([1.5, 2, 3.5, 4, 5]), - "group_id": np.array([0, 0, 1, 1, 1]), + "column_to_aggregate": numpy.array([1.5, 2, 3.5, 4, 5]), + "group_id": numpy.array([0, 0, 1, 1, 1]), "error_any": TypeError, "error_all": TypeError, "exception_match": "grouped_", }, "dtype_integer": { - "column_to_aggregate": np.array([1, 2, 3, 4, 5]), - "group_id": np.array([0, 0, 1, 1, 1]), + "column_to_aggregate": numpy.array([1, 2, 3, 4, 5]), + "group_id": numpy.array([0, 0, 1, 1, 1]), "error_mean": TypeError, "exception_match": "grouped_", }, "float_group_id_bool": { - "column_to_aggregate": np.array([True, True, True, False, False]), - "group_id": np.array([0, 0, 3.5, 3.5, 3.5]), + "column_to_aggregate": numpy.array([True, True, True, False, False]), + "group_id": numpy.array([0, 0, 3.5, 3.5, 3.5]), "error_any": TypeError, "error_all": TypeError, "exception_match": "The dtype of id columns must be integer.", @@ -178,39 +181,39 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): # We cannot even set up these fixtures in JAX. if not IS_JAX_INSTALLED: test_grouped_specs["datetime"] = { - "column_to_aggregate": np.array( + "column_to_aggregate": numpy.array( [ - np.datetime64("2000"), - np.datetime64("2001"), - np.datetime64("2002"), - np.datetime64("2003"), - np.datetime64("2004"), + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2002"), + numpy.datetime64("2003"), + numpy.datetime64("2004"), ] ), - "group_id": np.array([1, 0, 1, 1, 1]), - "expected_res_max": np.array( + "group_id": numpy.array([1, 0, 1, 1, 1]), + "expected_res_max": numpy.array( [ - np.datetime64("2004"), - np.datetime64("2001"), - np.datetime64("2004"), - np.datetime64("2004"), - np.datetime64("2004"), + numpy.datetime64("2004"), + numpy.datetime64("2001"), + numpy.datetime64("2004"), + numpy.datetime64("2004"), + numpy.datetime64("2004"), ] ), - "expected_res_min": np.array( + "expected_res_min": numpy.array( [ - np.datetime64("2000"), - np.datetime64("2001"), - np.datetime64("2000"), - np.datetime64("2000"), - np.datetime64("2000"), + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2000"), + numpy.datetime64("2000"), + numpy.datetime64("2000"), ] ), } test_grouped_raises_specs["dtype_string"] = { - "column_to_aggregate": np.array(["0", "1", "2", "3", "4"]), - "group_id": np.array([0, 0, 1, 1, 1]), + "column_to_aggregate": numpy.array(["0", "1", "2", "3", "4"]), + "group_id": numpy.array([0, 0, 1, 1, 1]), "error_sum": TypeError, "error_mean": TypeError, "error_max": TypeError, @@ -220,16 +223,16 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "exception_match": "grouped_", } test_grouped_raises_specs["datetime"] = { - "column_to_aggregate": np.array( + "column_to_aggregate": numpy.array( [ - np.datetime64("2000"), - np.datetime64("2001"), - np.datetime64("2002"), - np.datetime64("2003"), - np.datetime64("2004"), + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2002"), + numpy.datetime64("2003"), + numpy.datetime64("2004"), ] ), - "group_id": np.array([0, 0, 1, 1, 1]), + "group_id": numpy.array([0, 0, 1, 1, 1]), "error_sum": TypeError, "error_mean": TypeError, "error_any": TypeError, diff --git a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py index 8f8e16ff2..5b8db9741 100644 --- a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py +++ b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py @@ -4,10 +4,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy import pytest -from ttsim.config import numpy_or_jax as np +if TYPE_CHECKING: + import numpy from ttsim.tt_dag_elements.piecewise_polynomial import ( PiecewisePolynomialParamValue, get_piecewise_parameters, @@ -18,8 +21,10 @@ @pytest.fixture def parameters(): params = PiecewisePolynomialParamValue( - thresholds=np.array([-np.inf, 9168.0, 14254.0, 55960.0, 265326.0, np.inf]), - rates=np.array( + thresholds=numpy.array( + [-numpy.inf, 9168.0, 14254.0, 55960.0, 265326.0, numpy.inf] + ), + rates=numpy.array( [ [ 0.00000000e00, @@ -37,7 +42,7 @@ def parameters(): ], ] ), - intercepts=np.array([0.0, 0.0, 965.5771, 14722.3012, 102656.0212]), + intercepts=numpy.array([0.0, 0.0, 965.5771, 14722.3012, 102656.0212]), ) return params @@ -81,8 +86,8 @@ def test_get_piecewise_parameters_all_intercepts_supplied(): def test_piecewise_polynomial(parameters: PiecewisePolynomialParamValue): - x = np.array([-1_000, 1_000, 10_000, 30_000, 100_000, 1_000_000]) - expected = np.array([0.0, 0.0, 246.53, 10551.65, 66438.2, 866518.64]) + x = numpy.array([-1_000, 1_000, 10_000, 30_000, 100_000, 1_000_000]) + expected = numpy.array([0.0, 0.0, 246.53, 10551.65, 66438.2, 866518.64]) actual = piecewise_polynomial( x=x, diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index da50d5012..2dfc811c7 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -1,12 +1,12 @@ from __future__ import annotations +import numpy import pandas as pd import pytest from pandas._testing import assert_series_equal from ttsim import main from ttsim.config import IS_JAX_INSTALLED -from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.policy_environment import policy_environment from ttsim.tt_dag_elements import ( RoundingSpec, @@ -33,48 +33,48 @@ def p_id() -> int: rounding_specs_and_exp_results = [ ( RoundingSpec(base=1, direction="up"), - np.array([100.24, 100.78]), - np.array([101.0, 101.0]), + numpy.array([100.24, 100.78]), + numpy.array([101.0, 101.0]), ), ( RoundingSpec(base=1, direction="down"), - np.array([100.24, 100.78]), - np.array([100.0, 100.0]), + numpy.array([100.24, 100.78]), + numpy.array([100.0, 100.0]), ), ( RoundingSpec(base=1, direction="nearest"), - np.array([100.24, 100.78]), - np.array([100.0, 101.0]), + numpy.array([100.24, 100.78]), + numpy.array([100.0, 101.0]), ), ( RoundingSpec(base=5, direction="up"), - np.array([100.24, 100.78]), - np.array([105.0, 105.0]), + numpy.array([100.24, 100.78]), + numpy.array([105.0, 105.0]), ), ( RoundingSpec(base=0.1, direction="down"), - np.array([100.24, 100.78]), - np.array([100.2, 100.7]), + numpy.array([100.24, 100.78]), + numpy.array([100.2, 100.7]), ), ( RoundingSpec(base=0.001, direction="nearest"), - np.array([100.24, 100.78]), - np.array([100.24, 100.78]), + numpy.array([100.24, 100.78]), + numpy.array([100.24, 100.78]), ), ( RoundingSpec(base=1, direction="up", to_add_after_rounding=10), - np.array([100.24, 100.78]), - np.array([111.0, 111.0]), + numpy.array([100.24, 100.78]), + numpy.array([111.0, 111.0]), ), ( RoundingSpec(base=1, direction="down", to_add_after_rounding=10), - np.array([100.24, 100.78]), - np.array([110.0, 110.0]), + numpy.array([100.24, 100.78]), + numpy.array([110.0, 110.0]), ), ( RoundingSpec(base=1, direction="nearest", to_add_after_rounding=10), - np.array([100.24, 100.78]), - np.array([110.0, 111.0]), + numpy.array([100.24, 100.78]), + numpy.array([110.0, 111.0]), ), ] @@ -116,8 +116,8 @@ def test_func(x): return x input_data__tree = { - "p_id": np.array([1, 2]), - "namespace": {"x": np.array(input_values)}, + "p_id": numpy.array([1, 2]), + "namespace": {"x": numpy.array(input_values)}, } policy_environment = {"namespace": {"test_func": test_func, "x": x}, "p_id": p_id} @@ -146,8 +146,8 @@ def test_func_m(x): return x data = { - "p_id": np.array([1, 2]), - "x": np.array([1.2, 1.5]), + "p_id": numpy.array([1, 2]), + "x": numpy.array([1.2, 1.5]), } policy_environment = { @@ -186,8 +186,8 @@ def test_no_rounding( def test_func(x): return x - data = {"p_id": np.array([1, 2])} - data["x"] = np.array(input_values_exp_output) + data = {"p_id": numpy.array([1, 2])} + data["x"] = numpy.array(input_values_exp_output) policy_environment = { "test_func": test_func, "x": x, diff --git a/tests/ttsim/tt_dag_elements/test_shared.py b/tests/ttsim/tt_dag_elements/test_shared.py index 290ecbfaa..cf3cdbc76 100644 --- a/tests/ttsim/tt_dag_elements/test_shared.py +++ b/tests/ttsim/tt_dag_elements/test_shared.py @@ -1,8 +1,8 @@ from __future__ import annotations +import numpy import pytest -from ttsim.config import numpy_or_jax as np from ttsim.tt_dag_elements import join @@ -10,43 +10,43 @@ "foreign_key, primary_key, target, value_if_foreign_key_is_missing, expected", [ ( - np.array([1, 2, 3]), - np.array([1, 2, 3]), - np.array([1, 2, 3]), + numpy.array([1, 2, 3]), + numpy.array([1, 2, 3]), + numpy.array([1, 2, 3]), 4, - np.array([1, 2, 3]), + numpy.array([1, 2, 3]), ), ( - np.array([3, 2, 1]), - np.array([1, 2, 3]), - np.array([1, 2, 3]), + numpy.array([3, 2, 1]), + numpy.array([1, 2, 3]), + numpy.array([1, 2, 3]), 4, - np.array([3, 2, 1]), + numpy.array([3, 2, 1]), ), ( - np.array([1, 1, 1]), - np.array([1, 2, 3]), - np.array([1, 2, 3]), + numpy.array([1, 1, 1]), + numpy.array([1, 2, 3]), + numpy.array([1, 2, 3]), 4, - np.array([1, 1, 1]), + numpy.array([1, 1, 1]), ), ( - np.array([-1]), - np.array([1]), - np.array([1]), + numpy.array([-1]), + numpy.array([1]), + numpy.array([1]), 4, - np.array([4]), + numpy.array([4]), ), ], ) def test_join( - foreign_key: np.ndarray, - primary_key: np.ndarray, - target: np.ndarray, + foreign_key: numpy.ndarray, + primary_key: numpy.ndarray, + target: numpy.ndarray, value_if_foreign_key_is_missing: int, - expected: np.ndarray, + expected: numpy.ndarray, ): - assert np.array_equal( + assert numpy.array_equal( join(foreign_key, primary_key, target, value_if_foreign_key_is_missing), expected, ) diff --git a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py index 1cc838ba6..b51432eaa 100644 --- a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py +++ b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py @@ -2,9 +2,9 @@ import inspect +import numpy import pytest -from ttsim.config import numpy_or_jax as np from ttsim.tt_dag_elements import ( AggType, PolicyFunction, @@ -288,5 +288,7 @@ def aggregate_by_p_id_multiple_other_p_ids_present( def test_agg_by_p_id_sum_with_all_missing_p_ids(): aggregate_by_p_id_sum( - p_id=np.array([180]), p_id_specifier=np.array([-1]), source=np.array([False]) + p_id=numpy.array([180]), + p_id_specifier=numpy.array([-1]), + source=numpy.array([False]), ) From 279fc7a28240dab2fefadd22a708dccd5ff5d37d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 06:35:29 +0200 Subject: [PATCH 02/25] Defer vectorization until the creation of the specialised environment. Vectorisation tests pass. --- pixi.lock | 46 ++++----- pyproject.toml | 3 + .../specialized_environment.py | 15 ++- .../column_objects_param_function.py | 30 +++++- src/ttsim/tt_dag_elements/vectorization.py | 6 +- tests/ttsim/conftest.py | 24 +++++ .../tt_dag_elements/test_vectorization.py | 97 ++++++++----------- 7 files changed, 133 insertions(+), 88 deletions(-) create mode 100644 tests/ttsim/conftest.py diff --git a/pixi.lock b/pixi.lock index 89b447502..473dd6565 100644 --- a/pixi.lock +++ b/pixi.lock @@ -271,7 +271,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -507,7 +507,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -743,7 +743,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -987,7 +987,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ mypy: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -1263,7 +1263,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: . + - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1503,7 +1503,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: . + - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1743,7 +1743,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: . + - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -1991,7 +1991,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl - - pypi: . + - pypi: ./ py311: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -2263,7 +2263,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/aa/3d/52a75740d6c449073d4bb54da382f6368553f285fb5a680b27dd198dd839/optree-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2499,7 +2499,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e8/89/1267444a074b6e4402b5399b73b930a7b86cde054a41cecb9694be726a92/optree-0.15.0-cp311-cp311-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2735,7 +2735,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/98/a5/f8d6c278ce72b2ed8c1ebac968c3c652832bd2d9e65ec81fe6a21082c313/optree-0.15.0-cp311-cp311-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -2979,7 +2979,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ py312: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -3251,7 +3251,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3487,7 +3487,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3723,7 +3723,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -3967,7 +3967,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ py312-jax: channels: - url: https://conda.anaconda.org/conda-forge/ @@ -4251,7 +4251,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4499,7 +4499,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4747,7 +4747,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda @@ -4997,7 +4997,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - - pypi: . + - pypi: ./ packages: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726 @@ -6576,10 +6576,10 @@ packages: purls: [] size: 21903 timestamp: 1694400856979 -- pypi: . +- pypi: ./ name: gettsim - version: 0.7.1.dev433+g45b17d643.d20250610 - sha256: 3c87f5e1b0dbb950bd0bd46d24bb9ebc09061dc62200009645b901e24004348b + version: 0.7.1.dev437+gede66cd1.d20250611 + sha256: 9f7117cd7089efbebbc12d4aaf941403d1e5ceb8537993948ac2eded7efa2988 requires_dist: - ipywidgets - networkx diff --git a/pyproject.toml b/pyproject.toml index 73bd642ca..8fe5b7c0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,9 @@ types-pytz = "*" [tool.pixi.feature.test.tasks] tests = "pytest" +[tool.pixi.feature.jax.tasks] +tests-jax = "pytest --backend=jax" + [tool.pixi.feature.mypy.tasks] mypy = "mypy --ignore-missing-imports" diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 48188fce9..e018cb3e2 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -22,11 +22,13 @@ ColumnFunction, ColumnObject, ParamFunction, + PolicyFunction, ) from ttsim.tt_dag_elements.param_objects import ParamObject, RawParam if TYPE_CHECKING: from collections.abc import Callable + from types import ModuleType import networkx as nx @@ -64,6 +66,8 @@ def with_derived_functions_and_processed_input_nodes( targets__tree: NestedStrings, names__top_level_namespace: UnorderedQNames, names__grouping_levels: OrderedQNames, + backend: str, + xnp: ModuleType, ) -> QNameCombinedEnvironment0: """Return a flat policy environment with derived functions. @@ -77,8 +81,17 @@ def with_derived_functions_and_processed_input_nodes( policy_environment=policy_environment, names__top_level_namespace=names__top_level_namespace, ) + flat_vectorized = { + k: f.vectorize( + backend=backend, + xnp=xnp, + ) + if isinstance(f, PolicyFunction) + else f + for k, f in flat.items() + } flat_with_derived = _add_derived_functions( - qual_name_policy_environment=flat, + qual_name_policy_environment=flat_vectorized, targets=dt.qual_names(targets__tree), names__processed_data_columns=names__processed_data_columns, grouping_levels=names__grouping_levels, diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index e7f832019..bdce15b31 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -289,6 +289,8 @@ class PolicyFunction(ColumnFunction): # type: ignore[type-arg] The rounding specification. """ + vectorization_strategy: Literal["loop", "vectorize", "not_required"] = "vectorize" + def remove_tree_logic( self, tree_path: tuple[str, ...], @@ -306,6 +308,28 @@ def remove_tree_logic( end_date=self.end_date, rounding_spec=self.rounding_spec, foreign_key_type=self.foreign_key_type, + vectorization_strategy=self.vectorization_strategy, + ) + + def vectorize(self, backend: str, xnp: ModuleType) -> PolicyFunction: + func = ( + self.function + if self.vectorization_strategy == "not_required" + else vectorize_function( + self.function, + vectorization_strategy=self.vectorization_strategy, + backend=backend, + xnp=xnp, + ) + ) + return PolicyFunction( + leaf_name=self.leaf_name, + function=func, + start_date=self.start_date, + end_date=self.end_date, + rounding_spec=self.rounding_spec, + foreign_key_type=self.foreign_key_type, + vectorization_strategy="not_required", ) @@ -355,11 +379,6 @@ def policy_function( start_date, end_date = _convert_and_validate_dates(start_date, end_date) def inner(func: GenericCallable) -> PolicyFunction: - func = ( - func - if vectorization_strategy == "not_required" - else vectorize_function(func, vectorization_strategy=vectorization_strategy) - ) return PolicyFunction( leaf_name=leaf_name if leaf_name else func.__name__, function=func, @@ -367,6 +386,7 @@ def inner(func: GenericCallable) -> PolicyFunction: end_date=end_date, rounding_spec=rounding_spec, foreign_key_type=foreign_key_type, + vectorization_strategy=vectorization_strategy, ) return inner diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt_dag_elements/vectorization.py index 7cc9cd120..e736d04b6 100644 --- a/src/ttsim/tt_dag_elements/vectorization.py +++ b/src/ttsim/tt_dag_elements/vectorization.py @@ -84,7 +84,9 @@ def _make_vectorizable( return functools.wraps(func)(new_func) -def make_vectorizable_source(func: GenericCallable, backend: str) -> str: +def make_vectorizable_source( + func: GenericCallable, backend: str, xnp: ModuleType +) -> str: """Redefine function source to be vectorizable given backend. Args: @@ -103,7 +105,7 @@ def make_vectorizable_source(func: GenericCallable, backend: str) -> str: ) module = _module_from_backend(backend) - tree = _make_vectorizable_ast(func, module=module) + tree = _make_vectorizable_ast(func, module=module, xnp=xnp) return ast.unparse(tree) diff --git a/tests/ttsim/conftest.py b/tests/ttsim/conftest.py new file mode 100644 index 000000000..6b57dd18f --- /dev/null +++ b/tests/ttsim/conftest.py @@ -0,0 +1,24 @@ +import numpy +import pytest + + +# content of conftest.py +def pytest_addoption(parser): + parser.addoption( + "--backend", + action="store", + default="numpy", + help="The backend to test against (e.g., --backend=numpy --backend=jax)", + ) + + +@pytest.fixture +def backend_xnp(request): + backend = request.config.getoption("--backend") + if backend == "numpy": + xnp = numpy + else: + import jax + + xnp = jax.numpy + return backend, xnp diff --git a/tests/ttsim/tt_dag_elements/test_vectorization.py b/tests/ttsim/tt_dag_elements/test_vectorization.py index f44fa1df7..546797f4a 100644 --- a/tests/ttsim/tt_dag_elements/test_vectorization.py +++ b/tests/ttsim/tt_dag_elements/test_vectorization.py @@ -10,15 +10,6 @@ import numpy import pytest from dags import concatenate_functions - -from ttsim.config import IS_JAX_INSTALLED -from ttsim.tt_dag_elements.column_objects_param_function import ( - AggByGroupFunction, - AggByPIDFunction, -) - -if IS_JAX_INSTALLED: - import jax.numpy from mettsim.config import METTSIM_ROOT from numpy.testing import assert_array_equal @@ -33,6 +24,10 @@ PolicyInput, policy_function, ) +from ttsim.tt_dag_elements.column_objects_param_function import ( + AggByGroupFunction, + AggByPIDFunction, +) from ttsim.tt_dag_elements.vectorization import ( TranslateToVectorizableError, _is_lambda_function, @@ -45,16 +40,6 @@ from collections.abc import Callable -# ====================================================================================== -# Backend -# ====================================================================================== - -backends = ["jax", "numpy"] if IS_JAX_INSTALLED else ["numpy"] - -modules = {"numpy": numpy} -if IS_JAX_INSTALLED: - modules["jax"] = jax.numpy - # ====================================================================================== # String comparison # ====================================================================================== @@ -317,13 +302,13 @@ def f18_exp(x): def test_change_if_to_where_source(func, expected, args): # noqa: ARG001 exp = inspect.getsource(expected) exp = exp.replace("_exp", "") - got = make_vectorizable_source(func, backend="numpy") + got = make_vectorizable_source(func, backend="numpy", xnp=numpy) assert string_equal(exp, got) @pytest.mark.parametrize("func, expected, args", TEST_CASES) def test_change_if_to_where_wrapper(func, expected, args): - got_func = _make_vectorizable(func, backend="numpy") + got_func = _make_vectorizable(func, backend="numpy", xnp=numpy) got = got_func(*args) exp = expected(*args) assert_array_equal(got, exp) @@ -366,19 +351,19 @@ def g4(x): def test_notimplemented_error(): with pytest.raises(NotImplementedError): - _make_vectorizable(f1, backend="dask") + _make_vectorizable(f1, backend="dask", xnp=numpy) @pytest.mark.parametrize("func", [g1, g2, g3, g4]) def test_disallowed_operation_source(func): with pytest.raises(TranslateToVectorizableError): - make_vectorizable_source(func, backend="numpy") + make_vectorizable_source(func, backend="numpy", xnp=numpy) @pytest.mark.parametrize("func", [g1, g2, g3, g4]) def test_disallowed_operation_wrapper(func): with pytest.raises(TranslateToVectorizableError): - _make_vectorizable(func, backend="numpy") + _make_vectorizable(func, backend="numpy", xnp=numpy) # ====================================================================================== @@ -394,9 +379,7 @@ def test_disallowed_operation_wrapper(func): (funcname, pf.function) for funcname, pf in dt.flatten_to_tree_paths( _active_column_objects_and_param_functions( - orig=column_objects_and_param_functions( - root=METTSIM_ROOT / "mettsim" - ), + orig=column_objects_and_param_functions(root=METTSIM_ROOT), date=datetime.date(year=year, month=1, day=1), ) ).items() @@ -409,10 +392,10 @@ def test_disallowed_operation_wrapper(func): ) ], ) - @pytest.mark.parametrize("backend", backends) - def test_convertible(funcname, func, backend): # noqa: ARG001 + def test_convertible(funcname, func, backend_xnp): # noqa: ARG001 # Leave funcname for debugging purposes. - _make_vectorizable(func, backend=backend) + backend, xnp = backend_xnp + _make_vectorizable(func, backend=backend, xnp=xnp) # ====================================================================================== @@ -436,12 +419,10 @@ def mock__elterngeld__geschwisterbonus_m( return out -@pytest.mark.parametrize("backend", backends) -def test_geschwisterbonus_m(backend): - full = modules[backend].full - +def test_geschwisterbonus_m(backend_xnp): + backend, xnp = backend_xnp # Test original METTSIM function on scalar input - # ================================================================================== + # ============================================================================== basisbetrag_m = 3.0 geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg = True geschwisterbonus_aufschlag = 1.0 @@ -456,10 +437,10 @@ def test_geschwisterbonus_m(backend): assert exp == 3.0 # Create array inputs and assert that METTSIM functions raises error - # ================================================================================== + # ============================================================================== shape = (10, 2) - basisbetrag_m = full(shape, basisbetrag_m) - geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg = full( + basisbetrag_m = xnp.full(shape, basisbetrag_m) + geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg = xnp.full( shape, geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg ) @@ -472,9 +453,9 @@ def test_geschwisterbonus_m(backend): ) # Call converted function on array input and test result - # ================================================================================== + # ============================================================================== converted = _make_vectorizable( - mock__elterngeld__geschwisterbonus_m, backend=backend + mock__elterngeld__geschwisterbonus_m, backend=backend, xnp=xnp ) got = converted( basisbetrag_m=basisbetrag_m, @@ -482,7 +463,7 @@ def test_geschwisterbonus_m(backend): geschwisterbonus_aufschlag=geschwisterbonus_aufschlag, geschwisterbonus_minimum=geschwisterbonus_minimum, ) - assert_array_equal(got, full(shape, exp)) + assert_array_equal(got, xnp.full(shape, exp)) def mock__elterngeld__grundsätzlich_anspruchsberechtigt( @@ -502,12 +483,10 @@ def mock__elterngeld__grundsätzlich_anspruchsberechtigt( ) -@pytest.mark.parametrize("backend", backends) -def test_grundsätzlich_anspruchsberechtigt(backend): - full = modules[backend].full - +def test_grundsätzlich_anspruchsberechtigt(backend_xnp): + backend, xnp = backend_xnp # Test original METTSIM function on scalar input - # ================================================================================== + # ============================================================================== claimed = True arbeitsstunden_w = 20.0 kind_grundsätzlich_anspruchsberechtigt_fg = True @@ -527,9 +506,9 @@ def test_grundsätzlich_anspruchsberechtigt(backend): assert exp is True # Create array inputs and assert that METTSIM functions raises error - # ================================================================================== + # ============================================================================== shape = (10, 1) - arbeitsstunden_w = full(shape, arbeitsstunden_w) + arbeitsstunden_w = xnp.full(shape, arbeitsstunden_w) with pytest.raises(ValueError, match="truth value of an array with more than"): mock__elterngeld__grundsätzlich_anspruchsberechtigt( @@ -542,9 +521,11 @@ def test_grundsätzlich_anspruchsberechtigt(backend): ) # Call converted function on array input and test result - # ================================================================================== + # ============================================================================== converted = _make_vectorizable( - mock__elterngeld__grundsätzlich_anspruchsberechtigt, backend=backend + mock__elterngeld__grundsätzlich_anspruchsberechtigt, + backend=backend, + xnp=xnp, ) got = converted( claimed=claimed, @@ -554,7 +535,7 @@ def test_grundsätzlich_anspruchsberechtigt(backend): bezugsmonate_unter_grenze_fg=bezugsmonate_unter_grenze_fg, max_arbeitsstunden_w=max_arbeitsstunden_w, ) - assert_array_equal(got, full(shape, exp)) + assert_array_equal(got, xnp.full(shape, exp)) # ====================================================================================== @@ -594,12 +575,12 @@ def test_is_lambda_function_non_function_input(): def test_lambda_functions_disallowed_make_vectorizable(): with pytest.raises(TranslateToVectorizableError, match="Lambda functions are not"): - _make_vectorizable(lambda x: x, backend="numpy") + _make_vectorizable(lambda x: x, backend="numpy", xnp=numpy) def test_lambda_functions_disallowed_make_vectorizable_source(): with pytest.raises(TranslateToVectorizableError, match="Lambda functions are not"): - make_vectorizable_source(lambda x: x, backend="numpy") + make_vectorizable_source(lambda x: x, backend="numpy", xnp=numpy) # ====================================================================================== @@ -612,7 +593,7 @@ def test_make_vectorizable_policy_func(): def alter_bis_24(alter: int) -> bool: return alter <= 24 - vectorized = _make_vectorizable(alter_bis_24, backend="numpy") + vectorized = alter_bis_24.vectorize(backend="numpy", xnp=numpy) got = vectorized(numpy.array([20, 25, 30])) exp = numpy.array([True, False, False]) @@ -634,7 +615,7 @@ def f_b(a: int) -> int: def f_manual(x: int) -> int: return f_b(f_a(x)) - vectorized = _make_vectorizable(f_manual, backend="numpy") + vectorized = _make_vectorizable(f_manual, backend="numpy", xnp=numpy) got = vectorized(numpy.array([1, 2, 3])) exp = numpy.array([3, 4, 5]) assert_array_equal(got, exp) @@ -656,7 +637,7 @@ def f_b(a: int) -> int: targets=["b"], ) - vectorized = _make_vectorizable(f_dags, backend="numpy") + vectorized = _make_vectorizable(f_dags, backend="numpy", xnp=numpy) got = vectorized(numpy.array([1, 2, 3])) exp = numpy.array([3, 4, 5]) assert_array_equal(got, exp) @@ -677,7 +658,9 @@ def already_vectorized_func(x: numpy.ndarray) -> numpy.ndarray: # type: ignore[ @pytest.mark.parametrize( "vectorized_function", [ - vectorize_function(scalar_func, vectorization_strategy="loop"), + vectorize_function( + scalar_func, vectorization_strategy="loop", backend="numpy", xnp=numpy + ), already_vectorized_func, ], ) From 2527f7a669471d85b505349951204374f0635b09 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 08:38:09 +0200 Subject: [PATCH 03/25] Improve fixtures. --- src/ttsim/interface_dag.py | 4 +-- src/ttsim/interface_dag_elements/backend.py | 36 +++++++++---------- .../interface_dag_elements/processed_data.py | 4 +-- .../specialized_environment.py | 3 +- tests/ttsim/conftest.py | 24 ++++++++----- .../tt_dag_elements/test_vectorization.py | 9 ++--- 6 files changed, 43 insertions(+), 37 deletions(-) diff --git a/src/ttsim/interface_dag.py b/src/ttsim/interface_dag.py index 4908f551a..f7e2776fd 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/interface_dag.py @@ -2,7 +2,7 @@ import inspect from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import dags @@ -25,7 +25,7 @@ def main( inputs: dict[str, Any], targets: list[str] | None = None, - backend: Literal[numpy, jax] = "numpy", + backend: Literal["numpy", "jax"] = "numpy", ) -> dict[str, Any]: """ Main function that processes the inputs and returns the outputs. diff --git a/src/ttsim/interface_dag_elements/backend.py b/src/ttsim/interface_dag_elements/backend.py index 731fc6f5d..2962e79c3 100644 --- a/src/ttsim/interface_dag_elements/backend.py +++ b/src/ttsim/interface_dag_elements/backend.py @@ -1,14 +1,24 @@ from __future__ import annotations -from types import ModuleType +from typing import TYPE_CHECKING, Literal +if TYPE_CHECKING: + from types import ModuleType import numpy -from ttsim.interface_dag_elements.interface_node_objects import interface_function +from ttsim.interface_dag_elements.interface_node_objects import ( + interface_function, + interface_input, +) + + +@interface_input(in_top_level_namespace=True) +def backend() -> Literal["numpy", "jax"]: + """The computing backend to use for the taxes and transfers function.""" @interface_function(in_top_level_namespace=True) -def xnp(backend: str) -> ModuleType: +def xnp(backend: Literal["numpy", "jax"]) -> ModuleType: """ Return the backend for numerical operations (either NumPy or jax). """ @@ -16,12 +26,8 @@ def xnp(backend: str) -> ModuleType: if backend == "numpy": xnp = numpy elif backend == "jax": - try: - import jax - except ImportError: - raise ImportError( - "jax is not installed. Please install jax to use the 'jax' backend." - ) + import jax + xnp = jax.numpy else: raise ValueError(f"Unsupported backend: {backend}. Choose 'numpy' or 'jax'.") @@ -29,20 +35,14 @@ def xnp(backend: str) -> ModuleType: @interface_function(in_top_level_namespace=True) -def dnp(backend: str) -> ModuleType: +def dnp(backend: Literal["numpy", "jax"]) -> ModuleType: """ Return the backend for datetime objects (either NumPy or jax-datetime) """ - global dnp - if backend == "numpy": dnp = numpy elif backend == "jax": - try: - import jax_datetime - except ImportError: - raise ImportError( - "jax-datetime is not installed. Please install jax-datetime to use the 'jax' backend." - ) + import jax_datetime + dnp = jax_datetime return dnp diff --git a/src/ttsim/interface_dag_elements/processed_data.py b/src/ttsim/interface_dag_elements/processed_data.py index ec52acc57..243f483d9 100644 --- a/src/ttsim/interface_dag_elements/processed_data.py +++ b/src/ttsim/interface_dag_elements/processed_data.py @@ -31,11 +31,11 @@ def processed_data(input_data__flat: FlatData, xnp: ModuleType) -> QNameData: processed_input_data = {} old_p_ids = xnp.asarray(input_data__flat[("p_id",)]) - new_p_ids = reorder_ids(old_p_ids) + new_p_ids = reorder_ids(ids=old_p_ids, xnp=xnp) for path, data in input_data__flat.items(): qname = dt.qual_name_from_tree_path(path) if path[-1].endswith("_id"): - processed_input_data[qname] = reorder_ids(xnp.asarray(data)) + processed_input_data[qname] = reorder_ids(ids=xnp.asarray(data), xnp=xnp) elif path[-1].startswith("p_id_"): variable_with_new_ids = xnp.asarray(data) for i in range(new_p_ids.shape[0]): diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index e018cb3e2..277515b78 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -98,7 +98,8 @@ def with_derived_functions_and_processed_input_nodes( ) out = {} for n, f in flat_with_derived.items(): - # Put scalars into the policy environment, else skip the key + # Put scalars into the policy environment, else remove the element because it + # will be passed into the `tax_transfer_function` as an input. if n in processed_data: if isinstance(processed_data[n], int | float | bool): out[n] = processed_data[n] diff --git a/tests/ttsim/conftest.py b/tests/ttsim/conftest.py index 6b57dd18f..265ede2ff 100644 --- a/tests/ttsim/conftest.py +++ b/tests/ttsim/conftest.py @@ -1,6 +1,8 @@ -import numpy import pytest +from ttsim.interface_dag_elements.backend import dnp as ttsim_dnp +from ttsim.interface_dag_elements.backend import xnp as ttsim_xnp + # content of conftest.py def pytest_addoption(parser): @@ -13,12 +15,18 @@ def pytest_addoption(parser): @pytest.fixture -def backend_xnp(request): +def backend(request): + backend = request.config.getoption("--backend") + return backend + + +@pytest.fixture +def xnp(request): backend = request.config.getoption("--backend") - if backend == "numpy": - xnp = numpy - else: - import jax + return ttsim_xnp(backend) + - xnp = jax.numpy - return backend, xnp +@pytest.fixture +def dnp(request): + backend = request.config.getoption("--backend") + return ttsim_dnp(backend) diff --git a/tests/ttsim/tt_dag_elements/test_vectorization.py b/tests/ttsim/tt_dag_elements/test_vectorization.py index 546797f4a..68c08c52e 100644 --- a/tests/ttsim/tt_dag_elements/test_vectorization.py +++ b/tests/ttsim/tt_dag_elements/test_vectorization.py @@ -392,9 +392,8 @@ def test_disallowed_operation_wrapper(func): ) ], ) - def test_convertible(funcname, func, backend_xnp): # noqa: ARG001 + def test_convertible(funcname, func, backend, xnp): # noqa: ARG001 # Leave funcname for debugging purposes. - backend, xnp = backend_xnp _make_vectorizable(func, backend=backend, xnp=xnp) @@ -419,8 +418,7 @@ def mock__elterngeld__geschwisterbonus_m( return out -def test_geschwisterbonus_m(backend_xnp): - backend, xnp = backend_xnp +def test_geschwisterbonus_m(backend, xnp): # Test original METTSIM function on scalar input # ============================================================================== basisbetrag_m = 3.0 @@ -483,8 +481,7 @@ def mock__elterngeld__grundsätzlich_anspruchsberechtigt( ) -def test_grundsätzlich_anspruchsberechtigt(backend_xnp): - backend, xnp = backend_xnp +def test_grundsätzlich_anspruchsberechtigt(backend, xnp): # Test original METTSIM function on scalar input # ============================================================================== claimed = True From cc98a2dc855d43a6d6d24fabf0f77e5ad9b0d017 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 08:52:51 +0200 Subject: [PATCH 04/25] Convert test_piecewise_polynomial. --- .../tt_dag_elements/piecewise_polynomial.py | 45 +++++++++---------- .../test_piecewise_polynomial.py | 30 +++++++------ 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 5420923e9..66490543b 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from types import ModuleType - import numpy from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue FUNC_TYPES = Literal[ @@ -106,6 +105,7 @@ def get_piecewise_parameters( leaf_name: str, func_type: FUNC_TYPES, parameter_dict: dict[int, dict[str, float]], + xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Create the objects for piecewise polynomial. @@ -131,6 +131,7 @@ def get_piecewise_parameters( lower_thresholds, upper_thresholds, thresholds = check_and_get_thresholds( leaf_name=leaf_name, parameter_dict=parameter_dict, + xnp=xnp, ) # Create and fill rates-array @@ -138,6 +139,7 @@ def get_piecewise_parameters( parameter_dict=parameter_dict, leaf_name=leaf_name, func_type=func_type, + xnp=xnp, ) # Create and fill intercept-array intercepts = _check_and_get_intercepts( @@ -146,6 +148,7 @@ def get_piecewise_parameters( lower_thresholds=lower_thresholds, upper_thresholds=upper_thresholds, rates=rates, + xnp=xnp, ) return PiecewisePolynomialParamValue( thresholds=thresholds, @@ -195,7 +198,7 @@ def check_and_get_thresholds( upper_thresholds[keys[-1]] = parameter_dict[keys[-1]]["upper_threshold"] # Check if the function is defined on the complete real line - if (upper_thresholds[keys[-1]] != xnp.inf) | (lower_thresholds[0] != -xnp.inf): + if (upper_thresholds[keys[-1]] != numpy.inf) | (lower_thresholds[0] != -numpy.inf): raise ValueError(f"{leaf_name} needs to be defined on the entire real line.") for interval in keys[1:]: @@ -220,7 +223,7 @@ def check_and_get_thresholds( f" threshold in the piece after." ) - if not xnp.allclose(lower_thresholds[1:], upper_thresholds[:-1]): + if not numpy.allclose(lower_thresholds[1:], upper_thresholds[:-1]): raise ValueError( f"The lower and upper thresholds of {leaf_name} have to coincide" ) @@ -339,30 +342,26 @@ def _create_intercepts( Parameters ---------- - lower_thresholds : numpy.array - The lower thresholds defining the intervals + lower_thresholds: + The lower thresholds defining the intervals - upper_thresholds : numpy.array - The upper thresholds defining the intervals + upper_thresholds: + The upper thresholds defining the intervals - rates : numpy.array - The slope in the interval below the corresponding element of - *upper_thresholds*. + rates: + The slope in the interval below the corresponding element of *upper_thresholds*. - intercept_at_lowest_threshold : numpy.array - Intercept at the lowest threshold + intercept_at_lowest_threshold: + Intercept at the lowest threshold - fun: function handle (currently only piecewise_linear, will need to think about - whether we can have a generic function with a different interface or make - it specific ) - xnp : ModuleType - The numpy module to use for calculations. + xnp: ModuleType + The module to use for calculations. Returns ------- """ - intercepts = numpy.full_like(upper_thresholds, xnp.nan) + intercepts = numpy.full_like(upper_thresholds, numpy.nan) intercepts[0] = intercept_at_lowest_threshold for i, up_thr in enumerate(upper_thresholds[:-1]): intercepts[i + 1] = _calculate_one_intercept( @@ -371,7 +370,6 @@ def _create_intercepts( upper_thresholds=upper_thresholds, rates=rates, intercepts=intercepts, - xnp=xnp, ) return xnp.array(intercepts) @@ -382,7 +380,6 @@ def _calculate_one_intercept( upper_thresholds: numpy.ndarray, rates: numpy.ndarray, intercepts: numpy.ndarray, - xnp: ModuleType, ) -> float: """Calculate the intercepts from the raw data. @@ -410,15 +407,15 @@ def _calculate_one_intercept( """ # Check if value lies within the defined range. - if (x < lower_thresholds[0]) or (x > upper_thresholds[-1]) or xnp.isnan(x): - return xnp.nan - index_interval = xnp.searchsorted(upper_thresholds, x, side="left") + if (x < lower_thresholds[0]) or (x > upper_thresholds[-1]) or numpy.isnan(x): + return numpy.nan + index_interval = numpy.searchsorted(upper_thresholds, x, side="left") intercept_interval = intercepts[index_interval] # Select threshold and calculate corresponding increment into interval lower_threshold_interval = lower_thresholds[index_interval] - if lower_threshold_interval == -xnp.inf: + if lower_threshold_interval == -numpy.inf: return intercept_interval increment_to_calc = x - lower_threshold_interval diff --git a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py index 5b8db9741..8fa6c976b 100644 --- a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py +++ b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py @@ -10,7 +10,9 @@ import pytest if TYPE_CHECKING: - import numpy + from types import ModuleType + + from ttsim.tt_dag_elements.piecewise_polynomial import ( PiecewisePolynomialParamValue, get_piecewise_parameters, @@ -19,12 +21,10 @@ @pytest.fixture -def parameters(): +def parameters(xnp): params = PiecewisePolynomialParamValue( - thresholds=numpy.array( - [-numpy.inf, 9168.0, 14254.0, 55960.0, 265326.0, numpy.inf] - ), - rates=numpy.array( + thresholds=xnp.array([-xnp.inf, 9168.0, 14254.0, 55960.0, 265326.0, xnp.inf]), + rates=xnp.array( [ [ 0.00000000e00, @@ -42,12 +42,12 @@ def parameters(): ], ] ), - intercepts=numpy.array([0.0, 0.0, 965.5771, 14722.3012, 102656.0212]), + intercepts=xnp.array([0.0, 0.0, 965.5771, 14722.3012, 102656.0212]), ) return params -def test_get_piecewise_parameters_all_intercepts_supplied(): +def test_get_piecewise_parameters_all_intercepts_supplied(xnp): parameter_dict = { 0: { "lower_threshold": "-inf", @@ -79,19 +79,23 @@ def test_get_piecewise_parameters_all_intercepts_supplied(): leaf_name="test", func_type="piecewise_linear", parameter_dict=parameter_dict, + xnp=xnp, ) - expected = numpy.array([0.27, 0.5, 0.8, 1]) + expected = xnp.array([0.27, 0.5, 0.8, 1]) numpy.testing.assert_allclose(actual.intercepts, expected, atol=1e-7) -def test_piecewise_polynomial(parameters: PiecewisePolynomialParamValue): - x = numpy.array([-1_000, 1_000, 10_000, 30_000, 100_000, 1_000_000]) - expected = numpy.array([0.0, 0.0, 246.53, 10551.65, 66438.2, 866518.64]) +def test_piecewise_polynomial( + parameters: PiecewisePolynomialParamValue, xnp: ModuleType +): + x = xnp.array([-1_000, 1_000, 10_000, 30_000, 100_000, 1_000_000]) + expected = xnp.array([0.0, 0.0, 246.53, 10551.65, 66438.2, 866518.64]) actual = piecewise_polynomial( x=x, parameters=parameters, rates_multiplier=2, + xnp=xnp, ) - numpy.testing.assert_allclose(numpy.array(actual), expected, atol=0.01) + numpy.testing.assert_allclose(xnp.array(actual), expected, atol=0.01) From 4f52606faafd6bca2907747d28ca3509c5bda079 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 09:18:30 +0200 Subject: [PATCH 05/25] Change order of vectorization and removing tree logic (vectorization destroys renaming), make rounding tests pass. --- .../specialized_environment.py | 23 ++++++++++--------- src/ttsim/interface_dag_elements/typing.py | 5 ++++ src/ttsim/tt_dag_elements/rounding.py | 16 +++++++------ tests/ttsim/tt_dag_elements/test_rounding.py | 11 +++++---- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 277515b78..9d85dea2a 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -77,10 +77,6 @@ def with_derived_functions_and_processed_input_nodes( 3. Remove all functions that are overridden by data columns. """ - flat = _remove_tree_logic_from_policy_environment( - policy_environment=policy_environment, - names__top_level_namespace=names__top_level_namespace, - ) flat_vectorized = { k: f.vectorize( backend=backend, @@ -88,10 +84,14 @@ def with_derived_functions_and_processed_input_nodes( ) if isinstance(f, PolicyFunction) else f - for k, f in flat.items() + for k, f in dt.flatten_to_qual_names(policy_environment).items() } + flat_without_tree_logic = _remove_tree_logic_from_policy_environment( + policy_environment=flat_vectorized, + names__top_level_namespace=names__top_level_namespace, + ) flat_with_derived = _add_derived_functions( - qual_name_policy_environment=flat_vectorized, + qual_name_policy_environment=flat_without_tree_logic, targets=dt.qual_names(targets__tree), names__processed_data_columns=names__processed_data_columns, grouping_levels=names__grouping_levels, @@ -110,12 +110,12 @@ def with_derived_functions_and_processed_input_nodes( def _remove_tree_logic_from_policy_environment( - policy_environment: NestedPolicyEnvironment, + policy_environment: QNamePolicyEnvironment, names__top_level_namespace: UnorderedQNames, ) -> QNamePolicyEnvironment: """Map qualified names to column objects / param functions without tree logic.""" out = {} - for name, obj in dt.flatten_to_qual_names(policy_environment).items(): + for name, obj in policy_environment.items(): if isinstance(obj, ParamObject): out[name] = obj else: @@ -243,6 +243,7 @@ def with_processed_params_and_scalars( def with_partialled_params_and_scalars( with_processed_params_and_scalars: QNameCombinedEnvironment1, rounding: bool, + xnp: ModuleType, ) -> QNameCombinedEnvironment2: """Partial parameters to functions such that they disappear from the DAG. @@ -262,7 +263,7 @@ def with_partialled_params_and_scalars( processed_functions = {} for name, _func in with_processed_params_and_scalars.items(): if isinstance(_func, ColumnFunction): - func = _apply_rounding(_func) if rounding else _func + func = _apply_rounding(_func, xnp) if rounding else _func partial_params = {} for arg in [ a @@ -281,9 +282,9 @@ def with_partialled_params_and_scalars( return processed_functions -def _apply_rounding(element: Any) -> Any: +def _apply_rounding(element: Any, xnp: ModuleType) -> Any: return ( - element.rounding_spec.apply_rounding(element) + element.rounding_spec.apply_rounding(element, xnp=xnp) if getattr(element, "rounding_spec", False) else element ) diff --git a/src/ttsim/interface_dag_elements/typing.py b/src/ttsim/interface_dag_elements/typing.py index bcc345d75..52eece69e 100644 --- a/src/ttsim/interface_dag_elements/typing.py +++ b/src/ttsim/interface_dag_elements/typing.py @@ -77,6 +77,11 @@ ColumnObject | ParamFunction | ParamObject | "NestedPolicyEnvironment", ] """Tree of column objects, param functions, and param objects.""" + QNamePolicyEnvironment = dict[ + str, + ColumnObject | ParamFunction | ParamObject, + ] + """Tree of column objects, param functions, and param objects.""" QNameCombinedEnvironment0 = Mapping[ str, ColumnObject | ParamFunction | ParamObject | int | float | bool ] diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt_dag_elements/rounding.py index b09c5df4d..3a550f830 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt_dag_elements/rounding.py @@ -8,6 +8,8 @@ from collections.abc import Callable from types import ModuleType + import numpy + ROUNDING_DIRECTION = Literal["up", "down", "nearest"] @@ -36,14 +38,16 @@ def __post_init__(self) -> None: ) def apply_rounding( - self, func: Callable[P, numpy.ndarray] + self, func: Callable[P, numpy.ndarray], xnp: ModuleType ) -> Callable[P, numpy.ndarray]: - """Decorator to round the output of a function. The wrapped function must accept an xnp: ModuleType argument for numpy operations. + """Decorator to round the output of a function. Parameters ---------- func - Function to be rounded. Must accept xnp: ModuleType as a parameter. + Function to be rounded. + xnp + The computing module to use. Returns ------- @@ -51,10 +55,8 @@ def apply_rounding( """ @functools.wraps(func) - def wrapper( - *args: P.args, xnp: ModuleType, **kwargs: P.kwargs - ) -> numpy.ndarray: - out = func(*args, xnp=xnp, **kwargs) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> numpy.ndarray: + out = func(*args, **kwargs) if self.direction == "up": rounded_out = self.base * xnp.ceil(out / self.base) diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index 2dfc811c7..ec3339778 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -107,7 +107,7 @@ def test_func(): "rounding_spec, input_values, exp_output", rounding_specs_and_exp_results, ) -def test_rounding(rounding_spec, input_values, exp_output): +def test_rounding(rounding_spec, input_values, exp_output, backend): """Check if rounding is correct.""" # Define function that should be rounded @@ -127,6 +127,7 @@ def test_func(x): "policy_environment": policy_environment, "targets__tree": {"namespace": {"test_func": None}}, "rounding": True, + "backend": backend, }, targets=["results__tree"], )["results__tree"] @@ -214,13 +215,13 @@ def test_func(x): "rounding_spec, input_values, exp_output", rounding_specs_and_exp_results, ) -def test_rounding_callable(rounding_spec, input_values, exp_output): +def test_rounding_callable(rounding_spec, input_values, exp_output, xnp): """Check if callable is rounded correctly.""" def test_func(income): return income - func_with_rounding = rounding_spec.apply_rounding(test_func) + func_with_rounding = rounding_spec.apply_rounding(test_func, xnp=xnp) assert_series_equal( pd.Series(func_with_rounding(input_values)), @@ -233,13 +234,13 @@ def test_func(income): "rounding_spec, input_values, exp_output", rounding_specs_and_exp_results, ) -def test_rounding_spec(rounding_spec, input_values, exp_output): +def test_rounding_spec(rounding_spec, input_values, exp_output, xnp): """Test RoundingSpec directly.""" def test_func(income): return income - rounded_func = rounding_spec.apply_rounding(test_func) + rounded_func = rounding_spec.apply_rounding(test_func, xnp=xnp) result = rounded_func(input_values) assert_series_equal( From d0bb42ed3e203be6b952811c7beb9533163446a4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 09:24:58 +0200 Subject: [PATCH 06/25] Change order of vectorization and removing tree logic (vectorization destroys renaming), make rounding tests pass (numpy only, though). --- src/ttsim/tt_dag_elements/rounding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt_dag_elements/rounding.py index 3a550f830..677c65d65 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt_dag_elements/rounding.py @@ -64,10 +64,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> numpy.ndarray: rounded_out = self.base * xnp.floor(out / self.base) elif self.direction == "nearest": rounded_out = self.base * (xnp.asarray(out) / self.base).round() - else: - raise ValueError(f"Invalid rounding direction: {self.direction}") - rounded_out += self.to_add_after_rounding - return rounded_out + return rounded_out + self.to_add_after_rounding return wrapper From 601c4cfc1e2a26c3b46e0bf30fb0be59821de600 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 10:12:34 +0200 Subject: [PATCH 07/25] More updates, getting there. --- src/_gettsim/arbeitslosengeld_2/einkommen.py | 6 ++++ src/ttsim/interface_dag_elements/fail_if.py | 9 +++--- .../policy_environment.py | 21 ++++++++++--- .../specialized_environment.py | 6 ++-- .../column_objects_param_function.py | 4 ++- tests/{ttsim => }/conftest.py | 0 .../child_tax_credit/child_tax_credit.py | 7 +++++ tests/ttsim/test_failures.py | 22 ++++++------- tests/ttsim/test_policy_environment.py | 3 +- tests/ttsim/test_specialized_environment.py | 31 ++++++++++++------- tests/ttsim/tt_dag_elements/test_rounding.py | 11 +++++-- tests/ttsim/tt_dag_elements/test_shared.py | 14 ++++++++- 12 files changed, 94 insertions(+), 40 deletions(-) rename tests/{ttsim => }/conftest.py (100%) diff --git a/src/_gettsim/arbeitslosengeld_2/einkommen.py b/src/_gettsim/arbeitslosengeld_2/einkommen.py index 87604a1e6..e1afcd61c 100644 --- a/src/_gettsim/arbeitslosengeld_2/einkommen.py +++ b/src/_gettsim/arbeitslosengeld_2/einkommen.py @@ -14,6 +14,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import RawParam @@ -196,6 +198,7 @@ def anrechnungsfreies_einkommen_m( @param_function(start_date="2005-01-01") def parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg( raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: RawParam, + xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Parameter for calculation of income not subject to transfer withdrawal when children are not in the Bedarfsgemeinschaft.""" @@ -203,6 +206,7 @@ def parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg( leaf_name="parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg", func_type="piecewise_linear", parameter_dict=raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg, + xnp=xnp, ) @@ -210,6 +214,7 @@ def parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg( def parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg( raw_parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg: RawParam, raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: RawParam, + xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Parameter for calculation of income not subject to transfer withdrawal when children are in the Bedarfsgemeinschaft.""" @@ -221,4 +226,5 @@ def parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg( leaf_name="parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg", func_type="piecewise_linear", parameter_dict=updated_parameters, + xnp=xnp, ) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 778ed1b92..10ec70a2d 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -8,6 +8,7 @@ import dags.tree as dt import networkx as nx +import numpy import optree import pandas as pd @@ -26,8 +27,6 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - from ttsim.interface_dag_elements.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, @@ -243,7 +242,7 @@ def data_paths_are_missing_in_paths_to_column_names( @interface_function() -def input_data_tree_is_invalid(input_data__tree: NestedData) -> None: +def input_data_tree_is_invalid(input_data__tree: NestedData, xnp: ModuleType) -> None: """ Validate the basic structure of the data tree. @@ -263,7 +262,9 @@ def input_data_tree_is_invalid(input_data__tree: NestedData) -> None: """ assert_valid_ttsim_pytree( tree=input_data__tree, - leaf_checker=lambda leaf: isinstance(leaf, int | pd.Series | numpy.ndarray), + leaf_checker=lambda leaf: isinstance( + leaf, int | pd.Series | numpy.ndarray | xnp.ndarray + ), tree_name="input_data__tree", ) p_id = input_data__tree.get("p_id", None) diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index 652475759..42d823fcb 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -31,6 +31,8 @@ from ttsim.tt_dag_elements.piecewise_polynomial import get_piecewise_parameters if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import ( DashedISOString, FlatColumnObjectsParamFunctions, @@ -47,6 +49,7 @@ def policy_environment( orig_policy_objects__column_objects_and_param_functions: NestedColumnObjectsParamFunctions, # noqa: E501 orig_policy_objects__param_specs: FlatOrigParamSpecs, date: datetime.date | DashedISOString, + xnp: ModuleType, ) -> NestedPolicyEnvironment: """ Set up the policy environment for a particular date. @@ -74,6 +77,7 @@ def policy_environment( right=_active_param_objects( orig=orig_policy_objects__param_specs, date=date, + xnp=xnp, ), ) @@ -90,6 +94,7 @@ def policy_environment( note=None, reference=None, ) + a_tree["xnp"] = xnp return a_tree @@ -124,6 +129,7 @@ def _active_column_objects_and_param_functions( def _active_param_objects( orig: FlatOrigParamSpecs, date: datetime.date, + xnp: ModuleType, ) -> NestedParamObjects: """Parse the original yaml tree.""" flat_tree_with_params = {} @@ -134,6 +140,7 @@ def _active_param_objects( leaf_name=leaf_name, spec=orig_params_spec, date=date, + xnp=xnp, ) if param is not None: flat_tree_with_params[(*path_to_keep, leaf_name)] = param @@ -144,6 +151,7 @@ def _active_param_objects( leaf_name=leaf_name_jan1, spec=orig_params_spec, date=date_jan1, + xnp=xnp, ) if param is not None: flat_tree_with_params[(*path_to_keep, leaf_name_jan1)] = param @@ -154,6 +162,7 @@ def _get_one_param( # noqa: PLR0911 leaf_name: str, spec: OrigParamSpec, date: datetime.date, + xnp: ModuleType, ) -> ParamObject: """Parse the original specification found in the yaml tree to a ParamObject.""" cleaned_spec = _clean_one_param_spec(leaf_name=leaf_name, spec=spec, date=date) @@ -169,29 +178,33 @@ def _get_one_param( # noqa: PLR0911 leaf_name=leaf_name, func_type=spec["type"], parameter_dict=cleaned_spec["value"], + xnp=xnp, ) return PiecewisePolynomialParam(**cleaned_spec) elif spec["type"] == "consecutive_int_1d_lookup_table": cleaned_spec["value"] = get_consecutive_int_1d_lookup_table_param_value( - cleaned_spec["value"] + raw=cleaned_spec["value"], + xnp=xnp, ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) elif spec["type"] == "consecutive_int_2d_lookup_table": cleaned_spec["value"] = get_consecutive_int_2d_lookup_table_param_value( - cleaned_spec["value"] + raw=cleaned_spec["value"], + xnp=xnp, ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) elif spec["type"] == "month_based_phase_inout_of_age_thresholds": cleaned_spec["value"] = ( get_month_based_phase_inout_of_age_thresholds_param_value( - cleaned_spec["value"] + raw=cleaned_spec["value"], xnp=xnp ) ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) elif spec["type"] == "year_based_phase_inout_of_age_thresholds": cleaned_spec["value"] = ( get_year_based_phase_inout_of_age_thresholds_param_value( - cleaned_spec["value"] + raw=cleaned_spec["value"], + xnp=xnp, ) ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 9d85dea2a..c34d4cd0c 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -116,13 +116,13 @@ def _remove_tree_logic_from_policy_environment( """Map qualified names to column objects / param functions without tree logic.""" out = {} for name, obj in policy_environment.items(): - if isinstance(obj, ParamObject): - out[name] = obj - else: + if hasattr(obj, "remove_tree_logic"): out[name] = obj.remove_tree_logic( tree_path=dt.tree_path_from_qual_name(name), top_level_namespace=names__top_level_namespace, ) + else: + out[name] = obj return out diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index bdce15b31..d4ad95888 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -478,7 +478,9 @@ def group_creation_function( def decorator(func: GenericCallable) -> GroupCreationFunction: _leaf_name = func.__name__ if leaf_name is None else leaf_name - func_with_reorder = lambda **kwargs: reorder_ids(func(**kwargs)) + func_with_reorder = lambda **kwargs: reorder_ids( + ids=func(**kwargs), xnp=kwargs["xnp"] + ) functools.update_wrapper(func_with_reorder, func) return GroupCreationFunction( diff --git a/tests/ttsim/conftest.py b/tests/conftest.py similarity index 100% rename from tests/ttsim/conftest.py rename to tests/conftest.py diff --git a/tests/ttsim/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py b/tests/ttsim/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py index 722da8da3..7076e773d 100644 --- a/tests/ttsim/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py +++ b/tests/ttsim/mettsim/payroll_tax/child_tax_credit/child_tax_credit.py @@ -1,5 +1,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( AggType, agg_by_p_id_function, @@ -42,6 +47,7 @@ def in_same_household_as_recipient( p_id: int, kin_id: int, p_id_recipient: int, + xnp: ModuleType, ) -> bool: return ( join( @@ -49,6 +55,7 @@ def in_same_household_as_recipient( primary_key=p_id, target=kin_id, value_if_foreign_key_is_missing=-1, + xnp=xnp, ) == kin_id ) diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 8012a9926..77c128fab 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -12,9 +12,6 @@ from mettsim.config import METTSIM_ROOT from ttsim import main - -if TYPE_CHECKING: - import numpy from ttsim.interface_dag_elements.fail_if import ( ConflictingActivePeriodsError, _param_with_active_periods, @@ -629,13 +626,13 @@ def test_fail_if_group_variables_are_not_constant_within_groups(): ) -def test_fail_if_input_data_tree_is_invalid(): +def test_fail_if_input_data_tree_is_invalid(xnp): data = {"fam_id": pd.Series(data=numpy.arange(8), name="fam_id")} with pytest.raises( ValueError, match="The input data must contain the `p_id` column." ): - input_data_tree_is_invalid(input_data__tree=data) + input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) def test_fail_if_input_data_tree_is_invalid_via_main(): @@ -754,16 +751,16 @@ def test_fail_if_non_convertible_objects_in_results_tree( non_convertible_objects_in_results_tree(results__tree) -def test_fail_if_p_id_does_not_exist(): +def test_fail_if_p_id_does_not_exist(xnp): data = {"fam_id": pd.Series(data=numpy.arange(8), name="fam_id")} with pytest.raises( ValueError, match="The input data must contain the `p_id` column." ): - input_data_tree_is_invalid(input_data__tree=data) + input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) -def test_fail_if_p_id_does_not_exist_via_main(): +def test_fail_if_p_id_does_not_exist_via_main(backend): data = {"fam_id": pd.Series([1, 2, 3], name="fam_id")} with pytest.raises( ValueError, @@ -775,22 +772,22 @@ def test_fail_if_p_id_does_not_exist_via_main(): "policy_environment": {}, "targets__tree": {}, "rounding": False, - # "jit": jit, + "backend": backend, }, targets=["fail_if__input_data_tree_is_invalid"], )["fail_if__input_data_tree_is_invalid"] -def test_fail_if_p_id_is_not_unique(): +def test_fail_if_p_id_is_not_unique(xnp): data = {"p_id": pd.Series(data=numpy.arange(4).repeat(2), name="p_id")} with pytest.raises( ValueError, match="The following `p_id`s are not unique in the input data" ): - input_data_tree_is_invalid(input_data__tree=data) + input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) -def test_fail_if_p_id_is_not_unique_via_main(minimal_input_data): +def test_fail_if_p_id_is_not_unique_via_main(minimal_input_data, backend): data = copy.deepcopy(minimal_input_data) data["p_id"][:] = 1 @@ -804,6 +801,7 @@ def test_fail_if_p_id_is_not_unique_via_main(minimal_input_data): "policy_environment": {}, "targets__tree": {}, "rounding": False, + "backend": backend, }, targets=["fail_if__input_data_tree_is_invalid"], )["fail_if__input_data_tree_is_invalid"] diff --git a/tests/ttsim/test_policy_environment.py b/tests/ttsim/test_policy_environment.py index 84f734ae9..5314ba499 100644 --- a/tests/ttsim/test_policy_environment.py +++ b/tests/ttsim/test_policy_environment.py @@ -59,12 +59,13 @@ def some_int_param(): ) -def test_add_jahresanfang(): +def test_add_jahresanfang(xnp): orig = param_specs(root=Path(__file__).parent / "test_parameters") k = ("test_add_jahresanfang.yaml", "foo") _active_ttsim_tree_with_params = _active_param_objects( orig={k: orig[k]}, date=pd.to_datetime("2020-07-01").date(), + xnp=xnp, ) assert _active_ttsim_tree_with_params["foo"].value == 2 assert _active_ttsim_tree_with_params["foo_jahresanfang"].value == 1 diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index 7bb942f77..51370c2a3 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -232,15 +232,6 @@ def func_before_partial(arg_1, some_param): return arg_1 + some_param -func_after_partial = with_partialled_params_and_scalars( - with_processed_params_and_scalars={ - "some_func": func_before_partial, - "some_param": SOME_INT_PARAM.value, - }, - rounding=False, -)["some_func"] - - @pytest.fixture @policy_function(leaf_name="foo") def function_with_bool_return(x: bool) -> bool: @@ -514,12 +505,30 @@ def b(a): ) -def test_partial_params_to_functions(): +def test_partial_params_to_functions(xnp): # Partial function produces correct result + func_after_partial = with_partialled_params_and_scalars( + with_processed_params_and_scalars={ + "some_func": func_before_partial, + "some_param": SOME_INT_PARAM.value, + }, + rounding=False, + xnp=xnp, + )["some_func"] + assert func_after_partial(2) == 3 -def test_partial_params_to_functions_removes_argument(): +def test_partial_params_to_functions_removes_argument(xnp): + func_after_partial = with_partialled_params_and_scalars( + with_processed_params_and_scalars={ + "some_func": func_before_partial, + "some_param": SOME_INT_PARAM.value, + }, + rounding=False, + xnp=xnp, + )["some_func"] + # Fails if params is added to partial function with pytest.raises( TypeError, diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index ec3339778..f66e80fc5 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -138,7 +138,7 @@ def test_func(x): ) -def test_rounding_with_time_conversion(): +def test_rounding_with_time_conversion(backend, xnp): """Check if rounding is correct for time-converted functions.""" # Define function that should be rounded @@ -147,8 +147,8 @@ def test_func_m(x): return x data = { - "p_id": numpy.array([1, 2]), - "x": numpy.array([1.2, 1.5]), + "p_id": xnp.array([1, 2]), + "x": xnp.array([1.2, 1.5]), } policy_environment = { @@ -163,6 +163,7 @@ def test_func_m(x): "policy_environment": policy_environment, "targets__tree": {"test_func_y": None}, "rounding": True, + "backend": backend, }, targets=["results__tree"], )["results__tree"] @@ -170,6 +171,7 @@ def test_func_m(x): pd.Series(results__tree["test_func_y"]), pd.Series([12.0, 12.0], dtype=DTYPE), check_names=False, + check_dtype=False, ) @@ -208,6 +210,7 @@ def test_func(x): pd.Series(results__tree["test_func"]), pd.Series(input_values_exp_output, dtype=DTYPE), check_names=False, + check_dtype=False, ) @@ -227,6 +230,7 @@ def test_func(income): pd.Series(func_with_rounding(input_values)), pd.Series(exp_output), check_names=False, + check_dtype=False, ) @@ -247,6 +251,7 @@ def test_func(income): pd.Series(result), pd.Series(exp_output), check_names=False, + check_dtype=False, ) diff --git a/tests/ttsim/tt_dag_elements/test_shared.py b/tests/ttsim/tt_dag_elements/test_shared.py index cf3cdbc76..90a7a303a 100644 --- a/tests/ttsim/tt_dag_elements/test_shared.py +++ b/tests/ttsim/tt_dag_elements/test_shared.py @@ -1,5 +1,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + import numpy import pytest @@ -45,8 +50,15 @@ def test_join( target: numpy.ndarray, value_if_foreign_key_is_missing: int, expected: numpy.ndarray, + xnp: ModuleType, ): assert numpy.array_equal( - join(foreign_key, primary_key, target, value_if_foreign_key_is_missing), + join( + foreign_key=xnp.array(foreign_key), + primary_key=xnp.array(primary_key), + target=xnp.array(target), + value_if_foreign_key_is_missing=value_if_foreign_key_is_missing, + xnp=xnp, + ), expected, ) From 21821a8f26d6b39fe31f05962554d4cf86d08d59 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Wed, 11 Jun 2025 21:48:28 +0200 Subject: [PATCH 08/25] Fix tests in METTSIM, adjust calls to piecewise polynomial everywhere. Still many errors in GETTSIM. --- src/_gettsim/arbeitslosengeld_2/einkommen.py | 5 +++ .../abz\303\274ge/behinderung.py" | 7 +++- .../abz\303\274ge/vorsorge.py" | 4 ++ .../einkommensteuer/einkommensteuer.py | 8 +++- .../eink\303\274nfte/sonstige/sonstige.py" | 7 ++++ .../grundsicherung/im_alter/einkommen.py | 6 +++ src/_gettsim/interface.py | 3 ++ src/_gettsim/lohnsteuer/einkommen.py | 7 +++- src/_gettsim/lohnsteuer/lohnsteuer.py | 40 +++++++++++++++---- .../solidarit\303\244tszuschlag.py" | 7 ++++ .../arbeitslosen/arbeitslosengeld.py | 5 +++ .../rente/grundrente/grundrente.py | 7 ++++ src/_gettsim/wohngeld/einkommen.py | 6 ++- src/ttsim/plot_dag.py | 10 ++++- src/ttsim/tt_dag_elements/param_objects.py | 30 +++++++------- .../tt_dag_elements/piecewise_polynomial.py | 2 +- tests/ttsim/mettsim/payroll_tax/amount.py | 9 +++++ tests/ttsim/mettsim/property_tax/amount.py | 7 ++++ 18 files changed, 138 insertions(+), 32 deletions(-) diff --git a/src/_gettsim/arbeitslosengeld_2/einkommen.py b/src/_gettsim/arbeitslosengeld_2/einkommen.py index e1afcd61c..54eece24c 100644 --- a/src/_gettsim/arbeitslosengeld_2/einkommen.py +++ b/src/_gettsim/arbeitslosengeld_2/einkommen.py @@ -149,12 +149,14 @@ def anrechnungsfreies_einkommen_m_basierend_auf_nettoquote( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m: float, nettoquote: float, parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Share of income which remains to the individual.""" return piecewise_polynomial( x=einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m, parameters=parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg, rates_multiplier=nettoquote, + xnp=xnp, ) @@ -166,6 +168,7 @@ def anrechnungsfreies_einkommen_m( einkommensteuer__anzahl_kinderfreibeträge: int, parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: PiecewisePolynomialParamValue, parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Calculate share of income, which remains to the individual since 10/2005. @@ -186,11 +189,13 @@ def anrechnungsfreies_einkommen_m( out = piecewise_polynomial( x=eink_erwerbstätigkeit, parameters=parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg, + xnp=xnp, ) else: out = piecewise_polynomial( x=eink_erwerbstätigkeit, parameters=parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg, + xnp=xnp, ) return out diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/behinderung.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/behinderung.py" index 626f47524..08394fac4 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/behinderung.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/behinderung.py" @@ -7,15 +7,20 @@ from ttsim.tt_dag_elements import piecewise_polynomial, policy_function if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import PiecewisePolynomialParam @policy_function() def pauschbetrag_behinderung_y( - behinderungsgrad: int, parameter_behindertenpauschbetrag: PiecewisePolynomialParam + behinderungsgrad: int, + parameter_behindertenpauschbetrag: PiecewisePolynomialParam, + xnp: ModuleType, ) -> float: """Assign tax deduction allowance for handicaped to different handicap degrees.""" return piecewise_polynomial( x=behinderungsgrad, parameters=parameter_behindertenpauschbetrag, + xnp=xnp, ) diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" index 7d9dea452..81d18c5d7 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" @@ -6,6 +6,8 @@ from ttsim.tt_dag_elements.column_objects_param_function import param_function if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue @@ -201,6 +203,7 @@ def vorsorgeaufwendungen_keine_kappung_krankenversicherung_y_sn( def rate_abzugsfähige_altersvorsorgeaufwendungen( evaluationsjahr: int, parameter_einführungsfaktor_altersvorsorgeaufwendungen: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the current year as follows: @@ -215,6 +218,7 @@ def rate_abzugsfähige_altersvorsorgeaufwendungen( return piecewise_polynomial( x=evaluationsjahr, parameters=parameter_einführungsfaktor_altersvorsorgeaufwendungen, + xnp=xnp, ) diff --git a/src/_gettsim/einkommensteuer/einkommensteuer.py b/src/_gettsim/einkommensteuer/einkommensteuer.py index 27bc30123..d3f3c67f9 100644 --- a/src/_gettsim/einkommensteuer/einkommensteuer.py +++ b/src/_gettsim/einkommensteuer/einkommensteuer.py @@ -22,6 +22,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import RawParam from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue @@ -128,6 +130,7 @@ def betrag_mit_kinderfreibetrag_y_sn_ab_2002( zu_versteuerndes_einkommen_mit_kinderfreibetrag_y_sn: float, anzahl_personen_sn: int, parameter_einkommensteuertarif: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Taxes with child allowance on Steuernummer level. @@ -138,7 +141,7 @@ def betrag_mit_kinderfreibetrag_y_sn_ab_2002( zu_versteuerndes_einkommen_mit_kinderfreibetrag_y_sn / anzahl_personen_sn ) return anzahl_personen_sn * piecewise_polynomial( - x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif + x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif, xnp=xnp ) @@ -151,6 +154,7 @@ def betrag_ohne_kinderfreibetrag_y_sn( gesamteinkommen_y: float, anzahl_personen_sn: int, parameter_einkommensteuertarif: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Taxes without child allowance on Steuernummer level. Also referred to as "tarifliche ESt II". @@ -158,7 +162,7 @@ def betrag_ohne_kinderfreibetrag_y_sn( """ zu_verst_eink_per_indiv = gesamteinkommen_y / anzahl_personen_sn return anzahl_personen_sn * piecewise_polynomial( - x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif + x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif, xnp=xnp ) diff --git "a/src/_gettsim/einkommensteuer/eink\303\274nfte/sonstige/sonstige.py" "b/src/_gettsim/einkommensteuer/eink\303\274nfte/sonstige/sonstige.py" index aa57d6787..ae33362be 100644 --- "a/src/_gettsim/einkommensteuer/eink\303\274nfte/sonstige/sonstige.py" +++ "b/src/_gettsim/einkommensteuer/eink\303\274nfte/sonstige/sonstige.py" @@ -2,6 +2,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, piecewise_polynomial, @@ -35,9 +40,11 @@ def renteneinkünfte_m( def ertragsanteil_an_rente( sozialversicherung__rente__jahr_renteneintritt: int, parameter_ertragsanteil_an_rente: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Share of pensions subject to income taxation.""" return piecewise_polynomial( x=sozialversicherung__rente__jahr_renteneintritt, parameters=parameter_ertragsanteil_an_rente, + xnp=xnp, ) diff --git a/src/_gettsim/grundsicherung/im_alter/einkommen.py b/src/_gettsim/grundsicherung/im_alter/einkommen.py index 0a7fb77ab..c754433b0 100644 --- a/src/_gettsim/grundsicherung/im_alter/einkommen.py +++ b/src/_gettsim/grundsicherung/im_alter/einkommen.py @@ -7,6 +7,8 @@ from ttsim.tt_dag_elements import piecewise_polynomial, policy_function if TYPE_CHECKING: + from types import ModuleType + from _gettsim.grundsicherung.bedarfe import Regelbedarfsstufen from ttsim.tt_dag_elements import PiecewisePolynomialParam @@ -113,6 +115,7 @@ def private_rente_betrag_m( sozialversicherung__rente__private_rente_betrag_m: float, anrechnungsfreier_anteil_private_renteneinkünfte: PiecewisePolynomialParam, grundsicherung__regelbedarfsstufen: Regelbedarfsstufen, + xnp: ModuleType, ) -> float: """Calculate individual private pension benefits considered in the calculation of Grundsicherung im Alter. @@ -123,6 +126,7 @@ def private_rente_betrag_m( piecewise_polynomial( x=sozialversicherung__rente__private_rente_betrag_m, parameters=anrechnungsfreier_anteil_private_renteneinkünfte, + xnp=xnp, ) ) upper = grundsicherung__regelbedarfsstufen.rbs_1 / 2 @@ -150,6 +154,7 @@ def gesetzliche_rente_m_ab_2021( sozialversicherung__rente__grundrente__grundsätzlich_anspruchsberechtigt: bool, grundsicherung__regelbedarfsstufen: Regelbedarfsstufen, anrechnungsfreier_anteil_gesetzliche_rente: PiecewisePolynomialParam, + xnp: ModuleType, ) -> float: """Calculate individual public pension benefits which are considered in the calculation of Grundsicherung im Alter since 2021. @@ -161,6 +166,7 @@ def gesetzliche_rente_m_ab_2021( angerechnete_rente = piecewise_polynomial( x=sozialversicherung__rente__altersrente__betrag_m, parameters=anrechnungsfreier_anteil_gesetzliche_rente, + xnp=xnp, ) upper = grundsicherung__regelbedarfsstufen.rbs_1 / 2 diff --git a/src/_gettsim/interface.py b/src/_gettsim/interface.py index fd97a0cdc..b555643ca 100644 --- a/src/_gettsim/interface.py +++ b/src/_gettsim/interface.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +import numpy + from _gettsim.config import GETTSIM_ROOT from ttsim import main from ttsim.interface_dag_elements.data_converters import ( @@ -106,6 +108,7 @@ def oss( input_data__tree = dataframe_to_nested_data( mapper=inputs_tree_to_inputs_df_columns, df=inputs_df, + xnp=numpy, ) nested_result = main( inputs={ diff --git a/src/_gettsim/lohnsteuer/einkommen.py b/src/_gettsim/lohnsteuer/einkommen.py index 5c9bf062f..86f4f17a7 100644 --- a/src/_gettsim/lohnsteuer/einkommen.py +++ b/src/_gettsim/lohnsteuer/einkommen.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import ModuleType from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, @@ -142,6 +145,7 @@ def vorsorge_krankenversicherungsbeiträge_option_b_ab_2019( def einführungsfaktor_rentenversicherungsaufwendungen( evaluationsjahr: int, parameter_einführungsfaktor_rentenversicherungsaufwendungen: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> dict[str, Any]: """Calculate introductory factor for pension expense deductions which depends on the current year as follows: @@ -156,6 +160,7 @@ def einführungsfaktor_rentenversicherungsaufwendungen( return piecewise_polynomial( x=evaluationsjahr, parameters=parameter_einführungsfaktor_rentenversicherungsaufwendungen, + xnp=xnp, ) diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index 76f5a7d4d..4e9e0c611 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING +import numpy + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, param_function, @@ -12,11 +14,13 @@ ) if TYPE_CHECKING: - import numpy + from types import ModuleType def basis_für_klassen_5_6( - einkommen_y: float, parameter_einkommensteuertarif: PiecewisePolynomialParamValue + einkommen_y: float, + parameter_einkommensteuertarif: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Calculate base for Lohnsteuer for steuerklasse 5 and 6, by applying obtaining twice the difference between applying the factors 1.25 and 0.75 @@ -34,10 +38,10 @@ def basis_für_klassen_5_6( return 2 * ( piecewise_polynomial( - x=einkommen_y * 1.25, parameters=parameter_einkommensteuertarif + x=einkommen_y * 1.25, parameters=parameter_einkommensteuertarif, xnp=xnp ) - piecewise_polynomial( - x=einkommen_y * 0.75, parameters=parameter_einkommensteuertarif + x=einkommen_y * 0.75, parameters=parameter_einkommensteuertarif, xnp=xnp ) ) @@ -46,6 +50,7 @@ def basis_für_klassen_5_6( def parameter_max_lohnsteuer_klasse_5_6( einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, einkommensgrenzwerte_steuerklassen_5_6: dict[int, float], + xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Create paramter values for the piecewise polynomial that represents the maximum amount of Lohnsteuer that can be paid on incomes higher than the income thresholds for Steuerklasse 5 and 6. @@ -53,14 +58,17 @@ def parameter_max_lohnsteuer_klasse_5_6( lohnsteuer_bis_erste_grenze = basis_für_klassen_5_6( einkommensgrenzwerte_steuerklassen_5_6[1], einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) lohnsteuer_bis_zweite_grenze = basis_für_klassen_5_6( einkommensgrenzwerte_steuerklassen_5_6[2], einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) lohnsteuer_bis_dritte_grenze = basis_für_klassen_5_6( einkommensgrenzwerte_steuerklassen_5_6[3], einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) thresholds = numpy.asarray( [ @@ -95,10 +103,13 @@ def parameter_max_lohnsteuer_klasse_5_6( def basistarif( einkommen_y: float, einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Lohnsteuer in the Basistarif.""" return piecewise_polynomial( - x=einkommen_y, parameters=einkommensteuer__parameter_einkommensteuertarif + x=einkommen_y, + parameters=einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) @@ -106,10 +117,13 @@ def basistarif( def splittingtarif( einkommen_y: float, einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Lohnsteuer in the Splittingtarif.""" return 2 * piecewise_polynomial( - x=einkommen_y / 2, parameters=einkommensteuer__parameter_einkommensteuertarif + x=einkommen_y / 2, + parameters=einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) @@ -118,14 +132,15 @@ def tarif_klassen_5_und_6( einkommen_y: float, einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, parameter_max_lohnsteuer_klasse_5_6: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6.""" basis = basis_für_klassen_5_6( - einkommen_y, einkommensteuer__parameter_einkommensteuertarif + einkommen_y, einkommensteuer__parameter_einkommensteuertarif, xnp=xnp ) max_lohnsteuer = piecewise_polynomial( - x=einkommen_y, parameters=parameter_max_lohnsteuer_klasse_5_6 + x=einkommen_y, parameters=parameter_max_lohnsteuer_klasse_5_6, xnp=xnp ) min_lohnsteuer = ( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_y @@ -157,6 +172,7 @@ def basistarif_mit_kinderfreibetrag( einkommen_y: float, einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, kinderfreibetrag_soli_y: float, + xnp: ModuleType, ) -> float: """Lohnsteuer in the Basistarif deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( @@ -165,6 +181,7 @@ def basistarif_mit_kinderfreibetrag( return piecewise_polynomial( x=einkommen_abzüglich_kinderfreibetrag_soli, parameters=einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) @@ -173,6 +190,7 @@ def splittingtarif_mit_kinderfreibetrag( einkommen_y: float, einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, kinderfreibetrag_soli_y: float, + xnp: ModuleType, ) -> float: """Lohnsteuer in the Splittingtarif deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( @@ -181,6 +199,7 @@ def splittingtarif_mit_kinderfreibetrag( return 2 * piecewise_polynomial( x=einkommen_abzüglich_kinderfreibetrag_soli / 2, parameters=einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) @@ -190,6 +209,7 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, parameter_max_lohnsteuer_klasse_5_6: PiecewisePolynomialParamValue, kinderfreibetrag_soli_y: float, + xnp: ModuleType, ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6 deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( @@ -199,10 +219,12 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( basis = basis_für_klassen_5_6( einkommen_abzüglich_kinderfreibetrag_soli, einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) max_lohnsteuer = piecewise_polynomial( x=einkommen_abzüglich_kinderfreibetrag_soli, parameters=parameter_max_lohnsteuer_klasse_5_6, + xnp=xnp, ) min_lohnsteuer = ( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] @@ -237,12 +259,14 @@ def betrag_mit_kinderfreibetrag_y( def betrag_soli_y( betrag_mit_kinderfreibetrag_y: float, solidaritätszuschlag__parameter_solidaritätszuschlag: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Solidarity surcharge on Lohnsteuer (withholding tax on earnings).""" return piecewise_polynomial( x=betrag_mit_kinderfreibetrag_y, parameters=solidaritätszuschlag__parameter_solidaritätszuschlag, + xnp=xnp, ) diff --git "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" index d2050d237..ba5f4ea56 100644 --- "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" +++ "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" @@ -2,6 +2,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, piecewise_polynomial, @@ -13,12 +18,14 @@ def solidaritätszuschlagstarif( steuer_pro_person: float, einkommensteuer__anzahl_personen_sn: int, parameter_solidaritätszuschlag: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """The isolated function for Solidaritätszuschlag.""" return einkommensteuer__anzahl_personen_sn * piecewise_polynomial( x=steuer_pro_person / einkommensteuer__anzahl_personen_sn, parameters=parameter_solidaritätszuschlag, + xnp=xnp, ) diff --git a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py index f51548090..f3506c5a8 100644 --- a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py +++ b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py @@ -12,6 +12,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( ConsecutiveInt1dLookupTableParamValue, PiecewisePolynomialParamValue, @@ -121,6 +123,7 @@ def einkommen_vorjahr_proxy_m( einkommensteuer__parameter_einkommensteuertarif: PiecewisePolynomialParamValue, einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__werbungskostenpauschale: float, solidaritätszuschlag__parameter_solidaritätszuschlag: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Approximate last years income for unemployment benefit.""" # Relevant wage is capped at the contribution thresholds @@ -141,10 +144,12 @@ def einkommen_vorjahr_proxy_m( x=12 * max_wage - einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__werbungskostenpauschale, parameters=einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) prox_soli = piecewise_polynomial( x=prox_tax, parameters=solidaritätszuschlag__parameter_solidaritätszuschlag, + xnp=xnp, ) out = max_wage - prox_ssc - prox_tax / 12 - prox_soli / 12 return max(out, 0.0) diff --git a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py index 8aeaec617..384e59a21 100644 --- a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py +++ b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, RoundingSpec, @@ -7,6 +9,9 @@ policy_function, ) +if TYPE_CHECKING: + from types import ModuleType + @policy_function( rounding_spec=RoundingSpec( @@ -63,11 +68,13 @@ def _anzurechnendes_einkommen_m( einkommen_m_ehe: float, rentenwert: float, parameter_anzurechnendes_einkommen: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """The isolated function for the relevant income for the Grundrentezuschlag.""" return rentenwert * piecewise_polynomial( x=einkommen_m_ehe / rentenwert, parameters=parameter_anzurechnendes_einkommen, + xnp=xnp, ) diff --git a/src/_gettsim/wohngeld/einkommen.py b/src/_gettsim/wohngeld/einkommen.py index c358d82d6..b5c5c9292 100644 --- a/src/_gettsim/wohngeld/einkommen.py +++ b/src/_gettsim/wohngeld/einkommen.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING +import numpy + from ttsim.tt_dag_elements import ( AggType, ConsecutiveInt1dLookupTableParamValue, @@ -16,7 +18,7 @@ ) if TYPE_CHECKING: - import numpy + from types import ModuleType @agg_by_p_id_function(agg_type=AggType.SUM) @@ -225,6 +227,7 @@ def freibetrag_m_bis_2015( alleinerziehendenbonus: int, freibetrag_bei_behinderung_gestaffelt_y: PiecewisePolynomialParamValue, freibetrag_kinder_m: dict[str, float], + xnp: ModuleType, ) -> float: """Calculate housing benefit subtractions for one individual until 2015.""" @@ -232,6 +235,7 @@ def freibetrag_m_bis_2015( piecewise_polynomial( x=behinderungsgrad, parameters=freibetrag_bei_behinderung_gestaffelt_y, + xnp=xnp, ) / 12 ) diff --git a/src/ttsim/plot_dag.py b/src/ttsim/plot_dag.py index 8e0b45c8d..e4a4b25bd 100644 --- a/src/ttsim/plot_dag.py +++ b/src/ttsim/plot_dag.py @@ -30,7 +30,8 @@ def plot_tt_dag( with_params: bool, inputs_for_main: dict[str, Any], title: str, output_path: Path ) -> None: """Plot the taxes & transfers DAG, with or without parameters.""" - + if "backend" not in inputs_for_main: + inputs_for_main["backend"] = "numpy" policy_environment = main(inputs=inputs_for_main, targets=["policy_environment"])[ "policy_environment" ] @@ -67,6 +68,9 @@ def plot_tt_dag( )["specialized_environment__with_derived_functions_and_processed_input_nodes"] # Replace input nodes by PolicyInputs again env.update(policy_inputs) + for element in "backend", "xnp", "dnp": + if element in env: + del env[element] nodes = { qn: n.dummy_callable() if isinstance(n, PolicyInput | ParamObject) else n for qn, n in env.items() @@ -92,7 +96,9 @@ def plot_tt_dag( enforce_signature=False, set_annotations=False, ) - args = inspect.signature(f).parameters + args = dict(inspect.signature(f).parameters) + args.pop("xnp", None) + args.pop("dnp", None) if args: raise ValueError( "The policy environment DAG should include all root nodes but requires " diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index c76f9d8d9..dc4173c3f 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -15,8 +15,6 @@ import datetime from types import ModuleType - import numpy - @dataclass(frozen=True) class ParamObject: @@ -175,13 +173,13 @@ def get_consecutive_int_1d_lookup_table_param_value( xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Get the parameters for a 1-dimensional lookup table.""" - lookup_keys = xnp.asarray(sorted(raw)) - assert (lookup_keys - xnp.min(lookup_keys) == xnp.arange(len(lookup_keys))).all(), ( + lookup_keys = numpy.asarray(sorted(raw)) + assert (lookup_keys - min(lookup_keys) == numpy.arange(len(lookup_keys))).all(), ( "Dictionary keys must be consecutive integers." ) return ConsecutiveInt1dLookupTableParamValue( - base_to_subtract=xnp.min(lookup_keys), + base_to_subtract=min(lookup_keys), values_to_look_up=xnp.asarray([raw[k] for k in lookup_keys]), ) @@ -200,12 +198,12 @@ def get_consecutive_int_2d_lookup_table_param_value( f"{lookup_keys_cols} and {lookup_keys_this_col}" ) for lookup_keys in lookup_keys_rows, lookup_keys_cols: - assert ( - lookup_keys - xnp.min(lookup_keys) == xnp.arange(len(lookup_keys)) - ).all(), f"Dictionary keys must be consecutive integers, got: {lookup_keys}" + assert (lookup_keys - min(lookup_keys) == xnp.arange(len(lookup_keys))).all(), ( + f"Dictionary keys must be consecutive integers, got: {lookup_keys}" + ) return ConsecutiveInt2dLookupTableParamValue( - base_to_subtract_rows=xnp.min(lookup_keys_rows), - base_to_subtract_cols=xnp.min(lookup_keys_cols), + base_to_subtract_rows=min(lookup_keys_rows), + base_to_subtract_cols=min(lookup_keys_cols), values_to_look_up=xnp.array( [ raw[row][col] @@ -248,13 +246,13 @@ def _fill_phase_inout( first_m_since_ad_to_consider = _m_since_ad(y=raw.pop("first_year_to_consider"), m=1) last_m_since_ad_to_consider = _m_since_ad(y=raw.pop("last_year_to_consider"), m=12) assert all(isinstance(k, int) for k in raw) - first_year_phase_inout: int = xnp.min(raw.keys()) # type: ignore[assignment] - first_month_phase_inout: int = xnp.min(raw[first_year_phase_inout].keys()) + first_year_phase_inout: int = min(raw.keys()) # type: ignore[assignment] + first_month_phase_inout: int = min(raw[first_year_phase_inout].keys()) first_m_since_ad_phase_inout = _m_since_ad( y=first_year_phase_inout, m=first_month_phase_inout ) - last_year_phase_inout: int = xnp.max(raw.keys()) # type: ignore[assignment] - last_month_phase_inout: int = xnp.max(raw[last_year_phase_inout].keys()) + last_year_phase_inout: int = max(raw.keys()) # type: ignore[assignment] + last_month_phase_inout: int = max(raw[last_year_phase_inout].keys()) last_m_since_ad_phase_inout = _m_since_ad( y=last_year_phase_inout, m=last_month_phase_inout ) @@ -292,8 +290,8 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( first_year_to_consider = raw.pop("first_year_to_consider") last_year_to_consider = raw.pop("last_year_to_consider") assert all(isinstance(k, int) for k in raw) - first_year_phase_inout: int = xnp.min(raw.keys()) # type: ignore[assignment] - last_year_phase_inout: int = xnp.max(raw.keys()) # type: ignore[assignment] + first_year_phase_inout: int = sorted(raw)[0] # type: ignore[assignment] + last_year_phase_inout: int = sorted(raw)[-1] # type: ignore[assignment] assert first_year_to_consider <= first_year_phase_inout assert last_year_to_consider >= last_year_phase_inout before_phase_inout: dict[int, float] = { diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 66490543b..d69faab16 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -51,8 +51,8 @@ class RatesOptions: def piecewise_polynomial( x: numpy.ndarray, parameters: PiecewisePolynomialParamValue, - rates_multiplier: numpy.ndarray, xnp: ModuleType, + rates_multiplier: numpy.ndarray | float = 1.0, ) -> numpy.ndarray: """Calculate value of the piecewise function at `x`. If the first interval begins at -inf the polynomial of that interval can only have slope of 0. Requesting a diff --git a/tests/ttsim/mettsim/payroll_tax/amount.py b/tests/ttsim/mettsim/payroll_tax/amount.py index 327cffa92..f8b5fd11e 100644 --- a/tests/ttsim/mettsim/payroll_tax/amount.py +++ b/tests/ttsim/mettsim/payroll_tax/amount.py @@ -1,5 +1,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, piecewise_polynomial, @@ -27,11 +32,13 @@ def amount_y( def amount_standard_y( income__amount_y: float, tax_schedule_standard: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Payroll tax amount for the standard tax schedule.""" return piecewise_polynomial( x=income__amount_y, parameters=tax_schedule_standard, + xnp=xnp, ) @@ -39,9 +46,11 @@ def amount_standard_y( def amount_reduced_y( income__amount_y: float, tax_schedule_reduced: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Payroll tax amount for the reduced tax schedule.""" return piecewise_polynomial( x=income__amount_y, parameters=tax_schedule_reduced, + xnp=xnp, ) diff --git a/tests/ttsim/mettsim/property_tax/amount.py b/tests/ttsim/mettsim/property_tax/amount.py index 583612fb5..0938f2d7d 100644 --- a/tests/ttsim/mettsim/property_tax/amount.py +++ b/tests/ttsim/mettsim/property_tax/amount.py @@ -8,6 +8,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ( PiecewisePolynomialParamValue, piecewise_polynomial, @@ -25,9 +30,11 @@ def acre_size_in_hectares() -> float: def amount_y( acre_size_in_hectares: float, tax_schedule: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Property tax amount for the standard tax schedule.""" return piecewise_polynomial( x=acre_size_in_hectares, parameters=tax_schedule, + xnp=xnp, ) From 1f180256bdfaab1c2a46369aa6c79d75e9fa875f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 12 Jun 2025 09:34:25 +0200 Subject: [PATCH 09/25] Fix usage of numpy/xnp in Lohnsteuer module. --- src/_gettsim/lohnsteuer/lohnsteuer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index 4e9e0c611..825c9e3d5 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -145,7 +145,7 @@ def tarif_klassen_5_und_6( min_lohnsteuer = ( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_y ) - return numpy.minimum(numpy.maximum(min_lohnsteuer, basis), max_lohnsteuer) + return xnp.minimum(xnp.maximum(min_lohnsteuer, basis), max_lohnsteuer) @policy_function(start_date="2015-01-01") @@ -175,7 +175,7 @@ def basistarif_mit_kinderfreibetrag( xnp: ModuleType, ) -> float: """Lohnsteuer in the Basistarif deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) return piecewise_polynomial( @@ -193,7 +193,7 @@ def splittingtarif_mit_kinderfreibetrag( xnp: ModuleType, ) -> float: """Lohnsteuer in the Splittingtarif deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) return 2 * piecewise_polynomial( @@ -212,7 +212,7 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( xnp: ModuleType, ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6 deducting the Kindefreibetrag.""" - einkommen_abzüglich_kinderfreibetrag_soli = numpy.maximum( + einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( einkommen_y - kinderfreibetrag_soli_y, 0 ) @@ -230,7 +230,7 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_abzüglich_kinderfreibetrag_soli ) - return numpy.minimum(numpy.maximum(min_lohnsteuer, basis), max_lohnsteuer) + return xnp.minimum(xnp.maximum(min_lohnsteuer, basis), max_lohnsteuer) @policy_function(start_date="2015-01-01") From d671cf75a8b44b64a9a15446c8fc27ae24c43b8f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 12 Jun 2025 09:43:34 +0200 Subject: [PATCH 10/25] Small cleanups and docstrings. --- src/ttsim/config.py | 8 -------- .../column_objects_param_function.py | 9 +++++---- .../tt_dag_elements/piecewise_polynomial.py | 18 +++++++++--------- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/ttsim/config.py b/src/ttsim/config.py index 70f8a8a9b..2e57b874f 100644 --- a/src/ttsim/config.py +++ b/src/ttsim/config.py @@ -6,11 +6,3 @@ IS_JAX_INSTALLED = False else: IS_JAX_INSTALLED = True - - -if IS_JAX_INSTALLED: - numpy_or_jax = jax.numpy -else: - import numpy - - numpy_or_jax = numpy diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index d4ad95888..0a1f2eaf9 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -17,7 +17,6 @@ ) from ttsim.config import IS_JAX_INSTALLED -from ttsim.config import numpy_or_jax as np from ttsim.interface_dag_elements.shared import to_datetime from ttsim.tt_dag_elements.aggregation import ( AggType, @@ -42,7 +41,7 @@ if TYPE_CHECKING: from types import ModuleType - import numpy # noqa: TC004 + import numpy import pandas as pd from ttsim.interface_dag_elements.typing import ( @@ -828,7 +827,9 @@ def _convert_and_validate_dates( return start_date, end_date -def check_series_has_expected_type(series: pd.Series, internal_type: np.dtype) -> bool: +def check_series_has_expected_type( + series: pd.Series, internal_type: numpy.dtype, dnp: ModuleType +) -> bool: """Checks whether used series has already expected internal type. Currently not used, but might become useful again. @@ -849,7 +850,7 @@ def check_series_has_expected_type(series: pd.Series, internal_type: np.dtype) - (internal_type == float) & (is_float_dtype(series)) or (internal_type == int) & (is_integer_dtype(series)) or (internal_type == bool) & (is_bool_dtype(series)) - or (internal_type == numpy.datetime64) & (is_datetime64_any_dtype(series)) + or (internal_type == dnp.datetime64) & (is_datetime64_any_dtype(series)) ): out = True else: diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index d69faab16..200dcc493 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -60,23 +60,23 @@ def piecewise_polynomial( Parameters ---------- - x : numpy.ndarray + x: Array with values at which the piecewise polynomial is to be calculated. - thresholds : np.array + thresholds: A one-dimensional array containing the thresholds for all intervals. - coefficients : numpy.ndarray + coefficients: A two-dimensional array where columns are interval sections and rows correspond to the coefficient of the nth polynomial. - intercepts : numpy.ndarray + intercepts: The intercepts at the lower threshold of each interval. - rates_multiplier : numpy.ndarray - Multiplier to create individual or scaled rates. - xnp : ModuleType - The numpy module to use for calculations. + xnp: + The backend module to use for calculations. + rates_multiplier: + Multiplier to create individual or scaled rates. Returns ------- - out : numpy.ndarray + out: The value of `x` under the piecewise function. """ From 1d8afadd3e83add30014e412b86e7e5dced63eec Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 12 Jun 2025 10:44:36 +0200 Subject: [PATCH 11/25] Adjust GETTSIM s.t. tests pass. --- .../kindergeld\303\274bertrag.py" | 11 +++++--- .../arbeitslosengeld_2/regelbedarf.py | 5 +++- .../einkommensteuer/abz\303\274ge/alter.py" | 10 ++++++- .../einkommensteuer/einkommensteuer.py | 3 +++ src/_gettsim/ids.py | 21 ++++++++++++--- src/_gettsim/kindergeld/kindergeld.py | 15 +++++++---- src/_gettsim/lohnsteuer/lohnsteuer.py | 4 ++- .../solidarit\303\244tszuschlag.py" | 4 +++ .../arbeitslosen/arbeitslosengeld.py | 6 +++-- .../rente/altersrente/altersgrenzen.py | 11 +++----- .../rente/grundrente/grundrente.py | 3 +++ .../unterhaltsvorschuss.py | 14 +++++++--- src/_gettsim/wohngeld/einkommen.py | 14 ++++++---- src/_gettsim/wohngeld/miete.py | 26 +++++++++---------- src/_gettsim/wohngeld/wohngeld.py | 8 +++--- .../policy_environment.py | 2 ++ .../specialized_environment.py | 8 +++++- src/ttsim/tt_dag_elements/param_objects.py | 4 +-- 18 files changed, 116 insertions(+), 53 deletions(-) diff --git "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" index 6d64930c0..817455df0 100644 --- "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" +++ "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" @@ -7,6 +7,8 @@ from ttsim.tt_dag_elements import AggType, agg_by_p_id_function, join, policy_function if TYPE_CHECKING: + from types import ModuleType + import numpy @@ -59,6 +61,7 @@ def kindergeld_zur_bedarfsdeckung_m( kindergeld_pro_kind_m: float, kindergeld__p_id_empfänger: numpy.ndarray, # int p_id: numpy.ndarray, # int + xnp: ModuleType, ) -> numpy.ndarray: # float """Kindergeld that is used to cover the SGB II Regelbedarf of the child. @@ -71,10 +74,11 @@ def kindergeld_zur_bedarfsdeckung_m( """ return join( - kindergeld__p_id_empfänger, - p_id, - kindergeld_pro_kind_m, + foreign_key=kindergeld__p_id_empfänger, + primary_key=p_id, + target=kindergeld_pro_kind_m, value_if_foreign_key_is_missing=0.0, + xnp=xnp, ) @@ -122,6 +126,7 @@ def in_anderer_bg_als_kindergeldempfänger( p_id: numpy.ndarray, # int kindergeld__p_id_empfänger: numpy.ndarray, # int bg_id: numpy.ndarray, # int + xnp: ModuleType, # Will become necessary for Jax. # noqa: ARG001 ) -> numpy.ndarray: # bool """True if the person is in a different Bedarfsgemeinschaft than the Kindergeldempfänger of that person. diff --git a/src/_gettsim/arbeitslosengeld_2/regelbedarf.py b/src/_gettsim/arbeitslosengeld_2/regelbedarf.py index f741ab199..0ca60facd 100644 --- a/src/_gettsim/arbeitslosengeld_2/regelbedarf.py +++ b/src/_gettsim/arbeitslosengeld_2/regelbedarf.py @@ -14,6 +14,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from _gettsim.grundsicherung.bedarfe import Regelbedarfsstufen from ttsim.tt_dag_elements import RawParam @@ -450,6 +452,7 @@ def regelsatz_anteilsbasiert( def berechtigte_wohnfläche_eigentum( parameter_berechtigte_wohnfläche_eigentum: RawParam, wohngeld__max_anzahl_personen: dict[str, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Berechtigte Wohnfläche für Eigenheim.""" tmp = parameter_berechtigte_wohnfläche_eigentum.copy() @@ -457,4 +460,4 @@ def berechtigte_wohnfläche_eigentum( max_anzahl_direkt = tmp.pop("max_anzahl_direkt") for i in range(wohngeld__max_anzahl_personen["indizierung"] - max_anzahl_direkt): tmp[i] = tmp[max_anzahl_direkt] + i * je_weitere_person - return get_consecutive_int_1d_lookup_table_param_value(raw=tmp) + return get_consecutive_int_1d_lookup_table_param_value(raw=tmp, xnp=xnp) diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" index 681b384b4..76ac71f31 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" @@ -11,6 +11,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue @@ -94,6 +96,7 @@ def altersfreibetrag_y_ab_2005( @param_function(start_date="2005-01-01") def altersentlastungsquote_gestaffelt( raw_altersentlastungsquote_gestaffelt: dict[str | int, int | float], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Convert the raw parameters for the age-based tax deduction allowance to a dict.""" spec = raw_altersentlastungsquote_gestaffelt.copy() @@ -104,12 +107,14 @@ def altersentlastungsquote_gestaffelt( raw=spec_int_float, left_tail_key=first_birthyear_to_consider, right_tail_key=last_birthyear_to_consider, + xnp=xnp, ) @param_function(start_date="2005-01-01") def maximaler_altersentlastungsbetrag_gestaffelt( raw_maximaler_altersentlastungsbetrag_gestaffelt: dict[str | int, int | float], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Convert the raw parameters for the age-based tax deduction allowance to a dict.""" spec = raw_maximaler_altersentlastungsbetrag_gestaffelt.copy() @@ -120,6 +125,7 @@ def maximaler_altersentlastungsbetrag_gestaffelt( raw=spec_int_float, left_tail_key=first_birthyear_to_consider, right_tail_key=last_birthyear_to_consider, + xnp=xnp, ) @@ -127,6 +133,7 @@ def get_consecutive_int_1d_lookup_table_with_filled_up_tails( raw: dict[int, float], left_tail_key: int, right_tail_key: int, + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Create a consecutive integer lookup table with filled tails. @@ -148,5 +155,6 @@ def get_consecutive_int_1d_lookup_table_with_filled_up_tails( range(max_key_in_spec + 1, right_tail_key + 1), raw[max_key_in_spec] ) return get_consecutive_int_1d_lookup_table_param_value( - {**consecutive_dict_start, **raw, **consecutive_dict_end} + raw={**consecutive_dict_start, **raw, **consecutive_dict_end}, + xnp=xnp, ) diff --git a/src/_gettsim/einkommensteuer/einkommensteuer.py b/src/_gettsim/einkommensteuer/einkommensteuer.py index d3f3c67f9..4932d44e4 100644 --- a/src/_gettsim/einkommensteuer/einkommensteuer.py +++ b/src/_gettsim/einkommensteuer/einkommensteuer.py @@ -218,6 +218,7 @@ def relevantes_kindergeld_ohne_staffelung_m( @param_function(start_date="2002-01-01") def parameter_einkommensteuertarif( raw_parameter_einkommensteuertarif: RawParam, + xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Add the quadratic terms to tax tariff function. @@ -240,6 +241,7 @@ def parameter_einkommensteuertarif( lower_thresholds, upper_thresholds = check_and_get_thresholds( leaf_name="parameter_einkommensteuertarif", parameter_dict=expanded, + xnp=xnp, )[:2] for key in sorted(raw_parameter_einkommensteuertarif.keys()): if "rate_quadratic" not in raw_parameter_einkommensteuertarif[key]: @@ -250,4 +252,5 @@ def parameter_einkommensteuertarif( leaf_name="parameter_einkommensteuertarif", func_type="piecewise_quadratic", parameter_dict=expanded, + xnp=xnp, ) diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index 483c11b44..405318092 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -4,11 +4,12 @@ from typing import TYPE_CHECKING +from ttsim.tt_dag_elements import group_creation_function, policy_input + if TYPE_CHECKING: from types import ModuleType import numpy -from ttsim.tt_dag_elements import group_creation_function, policy_input @policy_input() @@ -80,10 +81,24 @@ def fg_id( ) out = _assign_parents_fg_id( - out, p_id, p_id_elternteil_1_loc, hh_id, alter, children, n + fg_id=out, + p_id=p_id, + p_id_elternteil_loc=p_id_elternteil_1_loc, + hh_id=hh_id, + alter=alter, + children=children, + n=n, + xnp=xnp, ) out = _assign_parents_fg_id( - out, p_id, p_id_elternteil_2_loc, hh_id, alter, children, n + fg_id=out, + p_id=p_id, + p_id_elternteil_loc=p_id_elternteil_2_loc, + hh_id=hh_id, + alter=alter, + children=children, + n=n, + xnp=xnp, ) return out diff --git a/src/_gettsim/kindergeld/kindergeld.py b/src/_gettsim/kindergeld/kindergeld.py index d03056476..60e6378ee 100644 --- a/src/_gettsim/kindergeld/kindergeld.py +++ b/src/_gettsim/kindergeld/kindergeld.py @@ -14,6 +14,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + import numpy from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue @@ -124,13 +126,15 @@ def gleiche_fg_wie_empfänger( p_id: numpy.ndarray, # int p_id_empfänger: numpy.ndarray, # int fg_id: numpy.ndarray, # int -) -> np.ndarray: # bool + xnp: ModuleType, +) -> numpy.ndarray: # bool """The child's Kindergeldempfänger is in the same Familiengemeinschaft.""" fg_id_kindergeldempfänger = join( - p_id_empfänger, - p_id, - fg_id, + foreign_key=p_id_empfänger, + primary_key=p_id, + target=fg_id, value_if_foreign_key_is_missing=-1, + xnp=xnp, ) return fg_id_kindergeldempfänger == fg_id @@ -139,6 +143,7 @@ def gleiche_fg_wie_empfänger( @param_function(end_date="2022-12-31") def satz_nach_anzahl_kinder( satz_gestaffelt: dict[int, float], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Convert the Kindergeld-Satz by child to the amount of Kindergeld by number of children.""" @@ -154,5 +159,5 @@ def satz_nach_anzahl_kinder( for k in range(max_num_children_in_spec + 1, max_num_children) } return get_consecutive_int_1d_lookup_table_param_value( - {0: 0.0, **base_spec, **extended_spec} + raw={0: 0.0, **base_spec, **extended_spec}, xnp=xnp ) diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index 825c9e3d5..092f93f97 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -93,7 +93,9 @@ def parameter_max_lohnsteuer_klasse_5_6( axis=0, ) parameter_max_lohnsteuer_klasse_5_6 = PiecewisePolynomialParamValue( - thresholds=thresholds, intercepts=intercepts, rates=rates + thresholds=xnp.asarray(thresholds), + intercepts=xnp.asarray(intercepts), + rates=xnp.asarray(rates), ) return parameter_max_lohnsteuer_klasse_5_6 diff --git "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" index ba5f4ea56..ce3e8d295 100644 --- "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" +++ "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" @@ -34,6 +34,7 @@ def betrag_y_sn_ohne_abgelt_st( einkommensteuer__betrag_mit_kinderfreibetrag_y_sn: float, einkommensteuer__anzahl_personen_sn: int, parameter_solidaritätszuschlag: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Calculate the Solidarity Surcharge on Steuernummer level. @@ -53,6 +54,7 @@ def betrag_y_sn_ohne_abgelt_st( steuer_pro_person=einkommensteuer__betrag_mit_kinderfreibetrag_y_sn, einkommensteuer__anzahl_personen_sn=einkommensteuer__anzahl_personen_sn, parameter_solidaritätszuschlag=parameter_solidaritätszuschlag, + xnp=xnp, ) @@ -62,6 +64,7 @@ def betrag_y_sn_mit_abgelt_st( einkommensteuer__anzahl_personen_sn: int, einkommensteuer__abgeltungssteuer__betrag_y_sn: float, parameter_solidaritätszuschlag: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Calculate the Solidarity Surcharge on Steuernummer level. @@ -82,6 +85,7 @@ def betrag_y_sn_mit_abgelt_st( steuer_pro_person=einkommensteuer__betrag_mit_kinderfreibetrag_y_sn, einkommensteuer__anzahl_personen_sn=einkommensteuer__anzahl_personen_sn, parameter_solidaritätszuschlag=parameter_solidaritätszuschlag, + xnp=xnp, ) + parameter_solidaritätszuschlag.rates[0, -1] * einkommensteuer__abgeltungssteuer__betrag_y_sn diff --git a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py index f3506c5a8..0f1f6dd00 100644 --- a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py +++ b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py @@ -158,6 +158,7 @@ def einkommen_vorjahr_proxy_m( @param_function(start_date="1997-03-24") def anspruchsdauer_nach_alter( raw_anspruchsdauer_nach_alter: dict[str | int, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Amount of potential months of unemployment benefit claims by age.""" tmp = raw_anspruchsdauer_nach_alter.copy() @@ -171,12 +172,13 @@ def anspruchsdauer_nach_alter( else: full_spec[a] = tmp[a] - return get_consecutive_int_1d_lookup_table_param_value(full_spec) + return get_consecutive_int_1d_lookup_table_param_value(raw=full_spec, xnp=xnp) @param_function(start_date="1997-03-24") def anspruchsdauer_nach_versicherungspflichtigen_monaten( raw_anspruchsdauer_nach_versicherungspflichtigen_monaten: dict[str | int, int], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Amount of potential months of unemployment benefit claims by age.""" tmp = raw_anspruchsdauer_nach_versicherungspflichtigen_monaten.copy() @@ -190,4 +192,4 @@ def anspruchsdauer_nach_versicherungspflichtigen_monaten( else: full_spec[a] = tmp[a] - return get_consecutive_int_1d_lookup_table_param_value(full_spec) + return get_consecutive_int_1d_lookup_table_param_value(raw=full_spec, xnp=xnp) diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py index 8b16f0c29..5f574ee38 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py @@ -171,7 +171,6 @@ def altersgrenze_vorzeitig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze_vorzeitig: float, regelaltersrente__altersgrenze: float, - xnp: ModuleType, ) -> float: """Earliest possible retirement age after checking for eligibility. @@ -196,7 +195,6 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( für_frauen__grundsätzlich_anspruchsberechtigt: bool, langjährig__grundsätzlich_anspruchsberechtigt: bool, wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt: bool, - xnp: ModuleType, ) -> bool: """Eligibility for some form ofearly retirement. @@ -218,7 +216,6 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( ) def vorzeitig_grundsätzlich_anspruchsberechtigt_vorzeitig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, - xnp: ModuleType, ) -> bool: """Eligibility for early retirement. @@ -237,7 +234,6 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, - xnp: ModuleType, ) -> float: """Reference age for deduction calculation in case of early retirement (Zugangsfaktor). @@ -251,7 +247,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( and für_frauen__grundsätzlich_anspruchsberechtigt and wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt ): - out = xnp.min( + out = min( [ für_frauen__altersgrenze, langjährig__altersgrenze, @@ -262,7 +258,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt and für_frauen__grundsätzlich_anspruchsberechtigt ): - out = xnp.min( + out = min( [ für_frauen__altersgrenze, langjährig__altersgrenze, @@ -272,7 +268,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt and wegen_arbeitslosigkeit__grundsätzlich_anspruchsberechtigt ): - out = xnp.min( + out = min( [ langjährig__altersgrenze, wegen_arbeitslosigkeit__altersgrenze, @@ -295,7 +291,6 @@ def referenzalter_abschlag_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, langjährig__altersgrenze: float, regelaltersrente__altersgrenze: float, - xnp: ModuleType, ) -> float: """Reference age for deduction calculation in case of early retirement (Zugangsfaktor). diff --git a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py index 384e59a21..92a8c79a2 100644 --- a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py +++ b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py @@ -90,6 +90,7 @@ def anzurechnendes_einkommen_m( sozialversicherung__rente__altersrente__rentenwert: float, anzurechnendes_einkommen_ohne_partner: PiecewisePolynomialParamValue, anzurechnendes_einkommen_mit_partner: PiecewisePolynomialParamValue, + xnp: ModuleType, ) -> float: """Income which is deducted from Grundrentenzuschlag. @@ -110,12 +111,14 @@ def anzurechnendes_einkommen_m( einkommen_m_ehe=einkommen_m_ehe, rentenwert=sozialversicherung__rente__altersrente__rentenwert, parameter_anzurechnendes_einkommen=anzurechnendes_einkommen_mit_partner, + xnp=xnp, ) else: out = _anzurechnendes_einkommen_m( einkommen_m_ehe=einkommen_m_ehe, rentenwert=sozialversicherung__rente__altersrente__rentenwert, parameter_anzurechnendes_einkommen=anzurechnendes_einkommen_ohne_partner, + xnp=xnp, ) return out diff --git a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py index bd97046e3..66ffa7948 100644 --- a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py +++ b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py @@ -15,6 +15,8 @@ ) if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import TTSIMArray from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue, RawParam @@ -67,6 +69,7 @@ def elternteil_alleinerziehend( kindergeld__p_id_empfänger: TTSIMArray, # int p_id: TTSIMArray, # int familie__alleinerziehend: TTSIMArray, # bool + xnp: ModuleType, ) -> TTSIMArray: # bool """Check if parent that receives Kindergeld is a single parent. @@ -77,6 +80,7 @@ def elternteil_alleinerziehend( primary_key=p_id, target=familie__alleinerziehend, value_if_foreign_key_is_missing=False, + xnp=xnp, ) @@ -266,14 +270,16 @@ def elternteil_mindesteinkommen_erreicht( kindergeld__p_id_empfänger: TTSIMArray, # int p_id: TTSIMArray, # int mindesteinkommen_erreicht: TTSIMArray, # bool -) -> TTSIMArray: # bool + xnp: ModuleType, +) -> TTSIMArray: # bool """Income of Unterhaltsvorschuss recipient above threshold (this variable is defined on child level).""" return join( - kindergeld__p_id_empfänger, - p_id, - mindesteinkommen_erreicht, + foreign_key=kindergeld__p_id_empfänger, + primary_key=p_id, + target=mindesteinkommen_erreicht, value_if_foreign_key_is_missing=False, + xnp=xnp, ) diff --git a/src/_gettsim/wohngeld/einkommen.py b/src/_gettsim/wohngeld/einkommen.py index b5c5c9292..81f18edd6 100644 --- a/src/_gettsim/wohngeld/einkommen.py +++ b/src/_gettsim/wohngeld/einkommen.py @@ -4,8 +4,6 @@ from typing import TYPE_CHECKING -import numpy - from ttsim.tt_dag_elements import ( AggType, ConsecutiveInt1dLookupTableParamValue, @@ -33,9 +31,10 @@ def alleinerziehendenbonus( @param_function() def min_einkommen_lookup_table( min_einkommen: dict[int, float], + xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Create a LookupTable for the min income thresholds.""" - return get_consecutive_int_1d_lookup_table_param_value(min_einkommen) + return get_consecutive_int_1d_lookup_table_param_value(raw=min_einkommen, xnp=xnp) def einkommen( @@ -43,6 +42,7 @@ def einkommen( einkommensfreibetrag: float, anzahl_personen: int, min_einkommen_lookup_table: ConsecutiveInt1dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Calculate final income relevant for calculation of housing benefit on household level. @@ -50,13 +50,13 @@ def einkommen( """ eink_nach_abzug_m_hh = einkommen_vor_freibetrag - einkommensfreibetrag unteres_eink = min_einkommen_lookup_table.values_to_look_up[ - numpy.minimum( + xnp.minimum( anzahl_personen, min_einkommen_lookup_table.values_to_look_up.shape[0] ) - min_einkommen_lookup_table.base_to_subtract ] - return numpy.maximum(eink_nach_abzug_m_hh, unteres_eink) + return xnp.maximum(eink_nach_abzug_m_hh, unteres_eink) @policy_function() @@ -65,6 +65,7 @@ def einkommen_m_wthh( freibetrag_m_wthh: float, einkommen_vor_freibetrag_m_wthh: float, min_einkommen_lookup_table: ConsecutiveInt1dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Income relevant for Wohngeld calculation. @@ -79,6 +80,7 @@ def einkommen_m_wthh( einkommensfreibetrag=freibetrag_m_wthh, einkommen_vor_freibetrag=einkommen_vor_freibetrag_m_wthh, min_einkommen_lookup_table=min_einkommen_lookup_table, + xnp=xnp, ) @@ -88,6 +90,7 @@ def einkommen_m_bg( freibetrag_m_bg: float, einkommen_vor_freibetrag_m_bg: float, min_einkommen_lookup_table: ConsecutiveInt1dLookupTableParamValue, + xnp: ModuleType, ) -> float: """Income relevant for Wohngeld calculation. @@ -102,6 +105,7 @@ def einkommen_m_bg( einkommensfreibetrag=freibetrag_m_bg, einkommen_vor_freibetrag=einkommen_vor_freibetrag_m_bg, min_einkommen_lookup_table=min_einkommen_lookup_table, + xnp=xnp, ) diff --git a/src/_gettsim/wohngeld/miete.py b/src/_gettsim/wohngeld/miete.py index 457d46edc..544dc304a 100644 --- a/src/_gettsim/wohngeld/miete.py +++ b/src/_gettsim/wohngeld/miete.py @@ -5,10 +5,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -if TYPE_CHECKING: - from types import ModuleType - - import numpy from ttsim.tt_dag_elements import ( ConsecutiveInt1dLookupTableParamValue, ConsecutiveInt2dLookupTableParamValue, @@ -18,6 +14,11 @@ policy_function, ) +if TYPE_CHECKING: + from types import ModuleType + + import numpy + @dataclass(frozen=True) class LookupTableBaujahr: @@ -52,7 +53,9 @@ def max_miete_m_lookup_mit_baujahr( + (n_p - max_n_p_defined) * per_additional_person[baujahr][ms] # type: ignore[operator] for ms in this_dict[max_n_p_defined] } - lookup_table = get_consecutive_int_2d_lookup_table_param_value(this_dict) + lookup_table = get_consecutive_int_2d_lookup_table_param_value( + raw=this_dict, xnp=xnp + ) values.append(lookup_table.values_to_look_up) subtract_cols.append(lookup_table.base_to_subtract_cols) subtract_rows.append(lookup_table.base_to_subtract_rows) @@ -86,7 +89,7 @@ def max_miete_m_lookup_ohne_baujahr( + (n_p - max_n_p_defined) * per_additional_person[ms] # type: ignore[operator] for ms in expanded[max_n_p_defined] } - return get_consecutive_int_2d_lookup_table_param_value(expanded) + return get_consecutive_int_2d_lookup_table_param_value(raw=expanded, xnp=xnp) @param_function(start_date="1984-01-01") @@ -107,7 +110,7 @@ def min_miete_lookup( expanded = raw_min_miete_m.copy() for n_p in range(max_n_p_normal + 1, max_anzahl_personen["indizierung"] + 1): expanded[n_p] = raw_min_miete_m[max_n_p_normal] - return get_consecutive_int_1d_lookup_table_param_value(expanded) + return get_consecutive_int_1d_lookup_table_param_value(raw=expanded, xnp=xnp) @param_function(start_date="2021-01-01") @@ -125,7 +128,7 @@ def heizkostenentlastung_m_lookup( expanded[n_p] = ( expanded[max_n_p_defined] + (n_p - max_n_p_defined) * per_additional_person # type: ignore[operator] ) - return get_consecutive_int_1d_lookup_table_param_value(expanded) + return get_consecutive_int_1d_lookup_table_param_value(raw=expanded, xnp=xnp) @param_function(start_date="2023-01-01") @@ -143,7 +146,7 @@ def dauerhafte_heizkostenkomponente_m_lookup( expanded[n_p] = ( expanded[max_n_p_defined] + (n_p - max_n_p_defined) * per_additional_person # type: ignore[operator] ) - return get_consecutive_int_1d_lookup_table_param_value(expanded) + return get_consecutive_int_1d_lookup_table_param_value(raw=expanded, xnp=xnp) @param_function(start_date="2023-01-01") @@ -161,7 +164,7 @@ def klimakomponente_m_lookup( expanded[n_p] = ( expanded[max_n_p_defined] + (n_p - max_n_p_defined) * per_additional_person # type: ignore[operator] ) - return get_consecutive_int_1d_lookup_table_param_value(expanded) + return get_consecutive_int_1d_lookup_table_param_value(raw=expanded, xnp=xnp) @policy_function() @@ -243,7 +246,6 @@ def miete_m_hh_ohne_baujahr_ohne_heizkostenentlastung( wohnen__bruttokaltmiete_m_hh: float, min_miete_m_hh: float, max_miete_m_lookup: ConsecutiveInt2dLookupTableParamValue, - xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" @@ -267,7 +269,6 @@ def miete_m_hh_mit_heizkostenentlastung( min_miete_m_hh: float, max_miete_m_lookup: ConsecutiveInt2dLookupTableParamValue, heizkostenentlastung_m_lookup: ConsecutiveInt1dLookupTableParamValue, - xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" max_miete_m = max_miete_m_lookup.values_to_look_up[ @@ -298,7 +299,6 @@ def miete_m_hh_mit_heizkostenentlastung_dauerhafte_heizkostenkomponente_klimakom heizkostenentlastung_m_lookup: ConsecutiveInt1dLookupTableParamValue, dauerhafte_heizkostenkomponente_m_lookup: ConsecutiveInt1dLookupTableParamValue, klimakomponente_m_lookup: ConsecutiveInt1dLookupTableParamValue, - xnp: ModuleType, ) -> float: """Rent considered in housing benefit since 2009.""" max_miete_m = max_miete_m_lookup.values_to_look_up[ diff --git a/src/_gettsim/wohngeld/wohngeld.py b/src/_gettsim/wohngeld/wohngeld.py index 60edea0a7..c01d59300 100644 --- a/src/_gettsim/wohngeld/wohngeld.py +++ b/src/_gettsim/wohngeld/wohngeld.py @@ -211,10 +211,10 @@ def basisformel_params( return BasisformelParamValues( skalierungsfaktor=skalierungsfaktor, - a=get_consecutive_int_1d_lookup_table_param_value(a), - b=get_consecutive_int_1d_lookup_table_param_value(b), - c=get_consecutive_int_1d_lookup_table_param_value(c), + a=get_consecutive_int_1d_lookup_table_param_value(raw=a, xnp=xnp), + b=get_consecutive_int_1d_lookup_table_param_value(raw=b, xnp=xnp), + c=get_consecutive_int_1d_lookup_table_param_value(raw=c, xnp=xnp), zusatzbetrag_nach_haushaltsgröße=get_consecutive_int_1d_lookup_table_param_value( - zusatzbetrag_nach_haushaltsgröße + raw=zusatzbetrag_nach_haushaltsgröße, xnp=xnp ), ) diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index 42d823fcb..a36ab872d 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -50,6 +50,7 @@ def policy_environment( orig_policy_objects__param_specs: FlatOrigParamSpecs, date: datetime.date | DashedISOString, xnp: ModuleType, + dnp: ModuleType, ) -> NestedPolicyEnvironment: """ Set up the policy environment for a particular date. @@ -95,6 +96,7 @@ def policy_environment( reference=None, ) a_tree["xnp"] = xnp + a_tree["dnp"] = dnp return a_tree diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index c34d4cd0c..8748d698c 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -3,6 +3,7 @@ import datetime import functools import inspect +from types import ModuleType from typing import TYPE_CHECKING, Any import dags.tree as dt @@ -28,7 +29,6 @@ if TYPE_CHECKING: from collections.abc import Callable - from types import ModuleType import networkx as nx @@ -206,6 +206,11 @@ def with_processed_params_and_scalars( for k, v in with_derived_functions_and_processed_input_nodes.items() if isinstance(v, float | int | bool) } + modules = { + k: v + for k, v in with_derived_functions_and_processed_input_nodes.items() + if isinstance(v, ModuleType) + } param_functions = { k: v for k, v in with_derived_functions_and_processed_input_nodes.items() @@ -224,6 +229,7 @@ def with_processed_params_and_scalars( processed_param_functions = process( **{k: v.value for k, v in params.items()}, **scalars, + **modules, ) processed_params = merge_trees( left={k: v.value for k, v in params.items() if not isinstance(v, RawParam)}, diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index dc4173c3f..7cfbbd924 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -274,7 +274,7 @@ def _fill_phase_inout( ) } return get_consecutive_int_1d_lookup_table_param_value( - {**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp + raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp=xnp ) @@ -307,5 +307,5 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( for b_y in range(last_year_phase_inout + 1, last_year_to_consider + 1) } return get_consecutive_int_1d_lookup_table_param_value( - {**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp + raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp=xnp ) From 471a151bfd7c011bd37897e75398c8abb8c7a4f6 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 12 Jun 2025 13:28:16 +0200 Subject: [PATCH 12/25] Almost done removing IS_JAX_INSTALLED, missing the aggregations. --- tests/conftest.py => conftest.py | 11 + .../unterhaltsvorschuss.py | 19 +- src/_gettsim_tests/test_policy.py | 9 +- .../automatically_added_functions.py | 3 - src/ttsim/interface_dag_elements/fail_if.py | 6 +- src/ttsim/interface_dag_elements/names.py | 3 - .../orig_policy_objects.py | 13 +- src/ttsim/interface_dag_elements/results.py | 5 +- .../specialized_environment.py | 26 +-- src/ttsim/interface_dag_elements/typing.py | 9 +- src/ttsim/stale_code_storage.py | 3 +- src/ttsim/testing_utils.py | 75 +++--- src/ttsim/tt_dag_elements/__init__.py | 8 - .../column_objects_param_function.py | 3 - .../test_automatically_added_functions.py | 2 + tests/ttsim/test_mettsim.py | 9 +- tests/ttsim/test_specialized_environment.py | 112 ++++----- .../test_aggregation_functions.py | 217 +++++++++++------- tests/ttsim/tt_dag_elements/test_rounding.py | 13 +- 19 files changed, 290 insertions(+), 256 deletions(-) rename tests/conftest.py => conftest.py (65%) diff --git a/tests/conftest.py b/conftest.py similarity index 65% rename from tests/conftest.py rename to conftest.py index 265ede2ff..7fc4aa00a 100644 --- a/tests/conftest.py +++ b/conftest.py @@ -14,6 +14,10 @@ def pytest_addoption(parser): ) +def pytest_configure(config): + config.addinivalue_line("markers", "skipif_jax: skip test if backend is jax") + + @pytest.fixture def backend(request): backend = request.config.getoption("--backend") @@ -30,3 +34,10 @@ def xnp(request): def dnp(request): backend = request.config.getoption("--backend") return ttsim_dnp(backend) + + +@pytest.fixture(autouse=True) +def skipif_jax(request, backend): + """Automatically skip tests marked with skipif_jax when backend is jax.""" + if request.node.get_closest_marker("skipif_jax") and backend == "jax": + pytest.skip("Cannot run this test with Jax") diff --git a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py index 66ffa7948..af58768be 100644 --- a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py +++ b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py @@ -17,7 +17,8 @@ if TYPE_CHECKING: from types import ModuleType - from ttsim.interface_dag_elements.typing import TTSIMArray + import numpy + from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue, RawParam @@ -66,11 +67,11 @@ def betrag_m( @policy_function(vectorization_strategy="not_required") def elternteil_alleinerziehend( - kindergeld__p_id_empfänger: TTSIMArray, # int - p_id: TTSIMArray, # int - familie__alleinerziehend: TTSIMArray, # bool + kindergeld__p_id_empfänger: numpy.ndarray, # int + p_id: numpy.ndarray, # int + familie__alleinerziehend: numpy.ndarray, # bool xnp: ModuleType, -) -> TTSIMArray: # bool +) -> numpy.ndarray: # bool """Check if parent that receives Kindergeld is a single parent. Only single parents receive Kindergeld. @@ -267,11 +268,11 @@ def anspruchshöhe_m_ab_2017_07( @policy_function(start_date="2017-07-01", vectorization_strategy="not_required") def elternteil_mindesteinkommen_erreicht( - kindergeld__p_id_empfänger: TTSIMArray, # int - p_id: TTSIMArray, # int - mindesteinkommen_erreicht: TTSIMArray, # bool + kindergeld__p_id_empfänger: numpy.ndarray, # int + p_id: numpy.ndarray, # int + mindesteinkommen_erreicht: numpy.ndarray, # bool xnp: ModuleType, -) -> TTSIMArray: # bool +) -> numpy.ndarray: # bool """Income of Unterhaltsvorschuss recipient above threshold (this variable is defined on child level).""" return join( diff --git a/src/_gettsim_tests/test_policy.py b/src/_gettsim_tests/test_policy.py index 1351461dd..b202aad6e 100644 --- a/src/_gettsim_tests/test_policy.py +++ b/src/_gettsim_tests/test_policy.py @@ -1,12 +1,12 @@ from __future__ import annotations from pathlib import Path +from typing import Literal import numpy import pytest from _gettsim.config import GETTSIM_ROOT -from ttsim.config import IS_JAX_INSTALLED from ttsim.testing_utils import ( PolicyTest, execute_test, @@ -25,8 +25,5 @@ POLICY_TEST_IDS_AND_CASES.values(), ids=POLICY_TEST_IDS_AND_CASES.keys(), ) -def test_policy(test: PolicyTest): - if IS_JAX_INSTALLED: - execute_test(test, root=GETTSIM_ROOT, jit=True) - else: - execute_test(test, root=GETTSIM_ROOT, jit=False) +def test_policy(test: PolicyTest, backend: Literal["numpy", "jax"]): + execute_test(test=test, root=GETTSIM_ROOT, backend=backend) diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index 2d5d6340a..d4c403dd2 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -6,7 +6,6 @@ import dags.tree as dt from dags import get_free_arguments, rename_arguments -from ttsim.config import IS_JAX_INSTALLED from ttsim.interface_dag_elements.shared import ( get_base_name_and_grouping_suffix, get_re_pattern_for_all_time_units_and_groupings, @@ -606,8 +605,6 @@ def create_agg_by_group_functions( if base_name_with_time_unit in potential_agg_by_group_sources: group_id = f"{match.group('group')}_id" mapper = {"group_id": group_id, "column": base_name_with_time_unit} - if IS_JAX_INSTALLED: - mapper["num_segments"] = f"{group_id}_num_segments" agg_func = rename_arguments( func=grouped_sum, mapper=mapper, diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 10ec70a2d..6de4996d1 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -586,11 +586,7 @@ def root_nodes_are_missing( == 0, ).nodes - missing_nodes = [ - node - for node in root_nodes - if node not in processed_data and not node.endswith("_num_segments") - ] + missing_nodes = [node for node in root_nodes if node not in processed_data] if missing_nodes: formatted = format_list_linewise( diff --git a/src/ttsim/interface_dag_elements/names.py b/src/ttsim/interface_dag_elements/names.py index 21d505b4e..8d097dc1a 100644 --- a/src/ttsim/interface_dag_elements/names.py +++ b/src/ttsim/interface_dag_elements/names.py @@ -138,9 +138,6 @@ def top_level_namespace( for g in grouping_levels: all_top_level_names.add(f"{name}_{g}") - # Add num_segments to grouping variables - for g in grouping_levels: - all_top_level_names.add(f"{g}_id_num_segments") return all_top_level_names diff --git a/src/ttsim/interface_dag_elements/orig_policy_objects.py b/src/ttsim/interface_dag_elements/orig_policy_objects.py index 12419360f..d93a4e324 100644 --- a/src/ttsim/interface_dag_elements/orig_policy_objects.py +++ b/src/ttsim/interface_dag_elements/orig_policy_objects.py @@ -14,6 +14,7 @@ from ttsim.tt_dag_elements.column_objects_param_function import ( ColumnObject, ParamFunction, + policy_input, ) if TYPE_CHECKING: @@ -48,13 +49,23 @@ def column_objects_and_param_functions( root: The resource directory to load the ColumnObjectParamFunctions tree from. """ - return { + + @policy_input() + def num_segments() -> int: + """The number of segments for segment sums in jax.""" + + out = { k: v for path in _find_files_recursively(root=root, suffix=".py") for k, v in _tree_path_to_orig_column_objects_params_functions( path=path, root=root ).items() } + # Add num_segments for segment sums in jax. + assert "num_segments" not in out + out[("num_segments",)] = num_segments + + return out @interface_function() diff --git a/src/ttsim/interface_dag_elements/results.py b/src/ttsim/interface_dag_elements/results.py index c0bf42e04..6d87dbe99 100644 --- a/src/ttsim/interface_dag_elements/results.py +++ b/src/ttsim/interface_dag_elements/results.py @@ -65,7 +65,9 @@ def df_with_mapper( @interface_function() -def df_with_nested_columns(tree: NestedData) -> pd.DataFrame: +def df_with_nested_columns( + tree: NestedData, input_data__tree: NestedData +) -> pd.DataFrame: """The results DataFrame with mapped column names. Args: @@ -81,4 +83,5 @@ def df_with_nested_columns(tree: NestedData) -> pd.DataFrame: """ return nested_data_to_df_with_nested_columns( nested_data_to_convert=tree, + data_with_p_id=input_data__tree, ) diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 8748d698c..b75910059 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -2,14 +2,12 @@ import datetime import functools -import inspect from types import ModuleType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import dags.tree as dt from dags import concatenate_functions, create_dag, get_free_arguments -from ttsim.config import IS_JAX_INSTALLED from ttsim.interface_dag_elements.automatically_added_functions import ( create_agg_by_group_functions, create_time_conversion_functions, @@ -95,6 +93,7 @@ def with_derived_functions_and_processed_input_nodes( targets=dt.qual_names(targets__tree), names__processed_data_columns=names__processed_data_columns, grouping_levels=names__grouping_levels, + backend=backend, ) out = {} for n, f in flat_with_derived.items(): @@ -105,7 +104,9 @@ def with_derived_functions_and_processed_input_nodes( out[n] = processed_data[n] else: out[n] = f - + # The number of segments for jax' segment sum. After processing the data, we know + # that the number of ids is at most the length of the data. + out["num_segments"] = len(next(iter(processed_data.values()))) return out @@ -131,6 +132,7 @@ def _add_derived_functions( targets: OrderedQNames, names__processed_data_columns: QNameDataColumns, grouping_levels: OrderedQNames, + backend: Literal["numpy", "jax"], ) -> UnorderedQNames: """Return a mapping of qualified names to functions operating on columns. @@ -181,6 +183,7 @@ def _add_derived_functions( names__processed_data_columns=names__processed_data_columns, targets=targets, grouping_levels=grouping_levels, + backend=backend, ) out = { **qual_name_policy_environment, @@ -313,7 +316,7 @@ def tax_transfer_function( tax_transfer_dag: nx.DiGraph, with_partialled_params_and_scalars: QNameCombinedEnvironment2, names__target_columns: OrderedQNames, - # backend: numpy | jax, + backend: Literal["numpy", "jax"], ) -> Callable[[QNameData], QNameData]: """Returns a function that takes a dictionary of arrays and unpacks them as keyword arguments.""" ttf_with_keyword_args = concatenate_functions( @@ -326,20 +329,9 @@ def tax_transfer_function( set_annotations=False, ) - # if backend == jax: - # if not IS_JAX_INSTALLED: - # raise ImportError( - # "JAX is not installed. Please install JAX to use JIT compilation." - # ) - if IS_JAX_INSTALLED: + if backend == "jax": import jax - static_args = { - argname: 1000 - for argname in inspect.signature(ttf_with_keyword_args).parameters - if argname.endswith("_num_segments") - } - ttf_with_keyword_args = functools.partial(ttf_with_keyword_args, **static_args) ttf_with_keyword_args = jax.jit(ttf_with_keyword_args) def wrapper(processed_data: QNameData) -> QNameData: diff --git a/src/ttsim/interface_dag_elements/typing.py b/src/ttsim/interface_dag_elements/typing.py index 52eece69e..753ecdd4c 100644 --- a/src/ttsim/interface_dag_elements/typing.py +++ b/src/ttsim/interface_dag_elements/typing.py @@ -7,6 +7,8 @@ import datetime from collections.abc import Mapping + import numpy + OrigParamSpec = ( # Header dict[str, str | None | dict[Literal["de", "en"], str | None]] @@ -32,19 +34,18 @@ ColumnObject, ParamFunction, ParamObject, - TTSIMArray, ) # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Tree-like data structures for input, processing, and output; including metadata. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - NestedData = Mapping[str, TTSIMArray | "NestedData"] + NestedData = Mapping[str, numpy.ndarray | "NestedData"] """Tree mapping TTSIM paths to 1d arrays.""" - FlatData = Mapping[str, TTSIMArray | "FlatData"] + FlatData = Mapping[str, numpy.ndarray | "FlatData"] """Flattened tree mapping TTSIM paths to 1d arrays.""" NestedInputsMapper = Mapping[str, str | bool | int | float | "NestedInputsMapper"] """Tree mapping TTSIM paths to df columns or constants.""" - QNameData = Mapping[str, TTSIMArray] + QNameData = Mapping[str, numpy.ndarray] """Mapping of qualified name paths to 1d arrays.""" # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # diff --git a/src/ttsim/stale_code_storage.py b/src/ttsim/stale_code_storage.py index 0300e5802..f9584b2ce 100644 --- a/src/ttsim/stale_code_storage.py +++ b/src/ttsim/stale_code_storage.py @@ -8,7 +8,6 @@ ColumnObject, ParamFunction, ParamObject, - TTSIMArray, policy_function, ) @@ -25,7 +24,7 @@ | int | float | bool - | TTSIMArray + | numpy.ndarray | "NestedAnyTTSIMObject", ] NestedAny = Mapping[str, Any | "NestedAnyTTSIMObject"] diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index 11f7946fc..5b49651f3 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import dags.tree as dt import optree @@ -76,53 +76,50 @@ def name(self) -> str: return self.path.relative_to(self.test_dir / "test_data").as_posix() -def execute_test(test: PolicyTest, root: Path, jit: bool = False) -> None: # noqa: ARG001 +def execute_test( + test: PolicyTest, root: Path, backend: Literal["numpy", "jax"] +) -> None: environment = cached_policy_environment(date=test.date, root=root) if test.target_structure: - nested_result = main( + result_df = main( inputs={ "input_data__tree": test.input_tree, "policy_environment": environment, "targets__tree": test.target_structure, "rounding": True, - # "jit": jit, + "backend": backend, }, - targets=["results__tree"], - )["results__tree"] - else: - nested_result = {} - - if test.expected_output_tree: - expected_df = nested_data_to_df_with_nested_columns( - nested_data_to_convert=test.expected_output_tree, - data_with_p_id=test.input_tree, - ) - result_df = nested_data_to_df_with_nested_columns( - nested_data_to_convert=nested_result, data_with_p_id=test.input_tree - ) - try: - pd.testing.assert_frame_equal( - result_df.sort_index(axis="columns"), - expected_df.sort_index(axis="columns"), - atol=test.info["precision_atol"], - check_dtype=False, + targets=["results__df_with_nested_columns"], + )["results__df_with_nested_columns"] + + if test.expected_output_tree: + expected_df = nested_data_to_df_with_nested_columns( + nested_data_to_convert=test.expected_output_tree, + data_with_p_id=test.input_tree, ) - except AssertionError as e: - assert set(result_df.columns) == set(expected_df.columns) - cols_with_differences = [] - for col in expected_df.columns: - try: - pd.testing.assert_series_equal( - result_df[col], - expected_df[col], - atol=test.info["precision_atol"], - check_dtype=False, - ) - except AssertionError: - cols_with_differences.append(col) - raise AssertionError( - f"""actual != expected in columns: {cols_with_differences}. + try: + pd.testing.assert_frame_equal( + result_df.sort_index(axis="columns"), + expected_df.sort_index(axis="columns"), + atol=test.info["precision_atol"], + check_dtype=False, + ) + except AssertionError as e: + assert set(result_df.columns) == set(expected_df.columns) + cols_with_differences = [] + for col in expected_df.columns: + try: + pd.testing.assert_series_equal( + result_df[col], + expected_df[col], + atol=test.info["precision_atol"], + check_dtype=False, + ) + except AssertionError: + cols_with_differences.append(col) + raise AssertionError( + f"""actual != expected in columns: {cols_with_differences}. actual[cols_with_differences]: @@ -132,7 +129,7 @@ def execute_test(test: PolicyTest, root: Path, jit: bool = False) -> None: # no {expected_df[cols_with_differences]} """ - ) from e + ) from e def load_policy_test_data( diff --git a/src/ttsim/tt_dag_elements/__init__.py b/src/ttsim/tt_dag_elements/__init__.py index 80aba6b2e..a69a25898 100644 --- a/src/ttsim/tt_dag_elements/__init__.py +++ b/src/ttsim/tt_dag_elements/__init__.py @@ -1,10 +1,3 @@ -from ttsim.config import IS_JAX_INSTALLED - -if IS_JAX_INSTALLED: - from jax import Array as TTSIMArray -else: - from numpy import ndarray as TTSIMArray # noqa: N812 - from ttsim.tt_dag_elements.aggregation import AggType from ttsim.tt_dag_elements.column_objects_param_function import ( AggByGroupFunction, @@ -69,7 +62,6 @@ "RawParam", "RoundingSpec", "ScalarParam", - "TTSIMArray", "TimeConversionFunction", "agg_by_group_function", "agg_by_p_id_function", diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index 0a1f2eaf9..c8b72b7bf 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -16,7 +16,6 @@ is_integer_dtype, ) -from ttsim.config import IS_JAX_INSTALLED from ttsim.interface_dag_elements.shared import to_datetime from ttsim.tt_dag_elements.aggregation import ( AggType, @@ -571,8 +570,6 @@ def inner(func: GenericCallable) -> AggByGroupFunction: else: _fail_if__other_arg_is_invalid(other_args, orig_location) mapper = {"group_id": group_id, "column": other_args.pop()} - if IS_JAX_INSTALLED: - mapper["num_segments"] = f"{group_id}_num_segments" agg_func = rename_arguments( func=agg_registry[agg_type], mapper=mapper, diff --git a/tests/ttsim/test_automatically_added_functions.py b/tests/ttsim/test_automatically_added_functions.py index 2233810e9..d12604c4f 100644 --- a/tests/ttsim/test_automatically_added_functions.py +++ b/tests/ttsim/test_automatically_added_functions.py @@ -391,6 +391,7 @@ def test_derived_aggregation_functions_are_in_correct_namespace( targets, names__processed_data_columns, expected, + backend, ): """Test that the derived aggregation functions are in the correct namespace. @@ -402,5 +403,6 @@ def test_derived_aggregation_functions_are_in_correct_namespace( names__processed_data_columns=names__processed_data_columns, targets=targets, grouping_levels=("kin",), + backend=backend, ) assert expected in result diff --git a/tests/ttsim/test_mettsim.py b/tests/ttsim/test_mettsim.py index 26359ba4a..36a1f38d3 100644 --- a/tests/ttsim/test_mettsim.py +++ b/tests/ttsim/test_mettsim.py @@ -1,12 +1,12 @@ from __future__ import annotations from pathlib import Path +from typing import Literal import numpy import pytest from mettsim.config import METTSIM_ROOT -from ttsim.config import IS_JAX_INSTALLED from ttsim.plot_dag import plot_tt_dag from ttsim.testing_utils import ( PolicyTest, @@ -26,11 +26,8 @@ POLICY_TEST_IDS_AND_CASES.values(), ids=POLICY_TEST_IDS_AND_CASES.keys(), ) -def test_mettsim(test: PolicyTest): - if IS_JAX_INSTALLED: - execute_test(test, root=METTSIM_ROOT, jit=True) - else: - execute_test(test, root=METTSIM_ROOT, jit=False) +def test_mettsim(test: PolicyTest, backend: Literal["numpy", "jax"]): + execute_test(test=test, root=METTSIM_ROOT, backend=backend) def test_mettsim_policy_environment_dag_with_params(): diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index 51370c2a3..20e8b26a2 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -10,14 +10,7 @@ import pytest from mettsim.config import METTSIM_ROOT -from ttsim import ( - main, - merge_trees, -) -from ttsim.config import IS_JAX_INSTALLED - -if TYPE_CHECKING: - import numpy +from ttsim import main, merge_trees from ttsim.interface_dag_elements.specialized_environment import ( with_partialled_params_and_scalars, with_processed_params_and_scalars, @@ -29,7 +22,6 @@ PiecewisePolynomialParamValue, RawParam, ScalarParam, - TTSIMArray, agg_by_group_function, agg_by_p_id_function, param_function, @@ -40,11 +32,6 @@ if TYPE_CHECKING: from ttsim.interface_dag_elements.typing import NestedPolicyEnvironment -if IS_JAX_INSTALLED: - jit = True -else: - jit = False - @policy_input() def p_id() -> int: @@ -337,7 +324,7 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "n1": {"x": pd.Series([1, 1, 1])}, "kin_id": pd.Series([0, 0, 0]), "p_id": pd.Series([0, 1, 2]), - "num_segments": 1, + "num_segments": 3, }, ), ( @@ -357,7 +344,7 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "n1": {"x": pd.Series([1, 1, 1])}, "kin_id": pd.Series([0, 0, 0]), "p_id": pd.Series([0, 1, 2]), - "num_segments": 1, + "num_segments": 3, }, ), ( @@ -378,7 +365,7 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "n1": {"x": pd.Series([1, 1, 1])}, "kin_id": pd.Series([0, 0, 0]), "p_id": pd.Series([0, 1, 2]), - "num_segments": 1, + "num_segments": 3, }, ), ( @@ -399,7 +386,7 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "inputs": {"x": pd.Series([1, 1, 1])}, "kin_id": pd.Series([0, 0, 0]), "p_id": pd.Series([0, 1, 2]), - "num_segments": 1, + "num_segments": 3, }, ), ], @@ -440,7 +427,7 @@ def test_output_is_tree(minimal_input_data): assert isinstance(out, dict) assert "some_func" in out["module"] - assert isinstance(out["module"]["some_func"], TTSIMArray) + assert isinstance(out["module"]["some_func"], numpy.ndarray) def test_params_target_is_allowed(minimal_input_data): @@ -540,7 +527,7 @@ def test_partial_params_to_functions_removes_argument(xnp): func_before_partial(2, 1) -def test_user_provided_aggregate_by_group_specs(): +def test_user_provided_aggregate_by_group_specs(backend): data = { "p_id": pd.Series([1, 2, 3], name="p_id"), "fam_id": pd.Series([1, 1, 2], name="fam_id"), @@ -553,33 +540,35 @@ def test_user_provided_aggregate_by_group_specs(): "module_name": {"betrag_m": betrag_m}, } - expected_res = pd.Series([200, 200, 100]) + expected = pd.Series([200, 200, 100], index=pd.Index(data["p_id"], name="p_id")) - out = main( + actual = main( inputs={ "input_data__tree": data, "policy_environment": policy_environment, "targets__tree": {"module_name": {"betrag_m_fam": None}}, "rounding": False, - # "jit": jit, + "backend": backend, }, - targets=["results__tree"], - )["results__tree"] - - numpy.testing.assert_array_almost_equal( - out["module_name"]["betrag_m_fam"], expected_res + targets=["results__df_with_nested_columns"], + )["results__df_with_nested_columns"] + + pd.testing.assert_series_equal( + actual[("module_name", "betrag_m_fam")], + expected, + check_names=False, + check_dtype=False, ) -def test_user_provided_aggregation(): +def test_user_provided_aggregation(backend): data = { "p_id": pd.Series([1, 2, 3], name="p_id"), "fam_id": pd.Series([1, 1, 2], name="fam_id"), "module_name": {"betrag_m": pd.Series([200, 100, 100], name="betrag_m")}, } - data["num_segments"] = len(data["fam_id"].unique()) # Double up, then take max fam_id - expected = pd.Series([400, 400, 200]) + expected = pd.Series([400, 400, 200], index=pd.Index(data["p_id"], name="p_id")) @policy_function(vectorization_strategy="vectorize") def betrag_m_double(betrag_m): @@ -604,17 +593,20 @@ def betrag_m_double_fam(betrag_m_double, fam_id) -> float: "policy_environment": policy_environment, "targets__tree": {"module_name": {"betrag_m_double_fam": None}}, "rounding": False, - # "jit": jit, + "backend": backend, }, - targets=["results__tree"], - )["results__tree"] - - numpy.testing.assert_array_almost_equal( - actual["module_name"]["betrag_m_double_fam"], expected + targets=["results__df_with_nested_columns"], + )["results__df_with_nested_columns"] + + pd.testing.assert_series_equal( + actual[("module_name", "betrag_m_double_fam")], + expected, + check_names=False, + check_dtype=False, ) -def test_user_provided_aggregation_with_time_conversion(): +def test_user_provided_aggregation_with_time_conversion(backend): data = { "p_id": pd.Series([1, 2, 3], name="p_id"), "fam_id": pd.Series([1, 1, 2], name="fam_id"), @@ -624,7 +616,9 @@ def test_user_provided_aggregation_with_time_conversion(): } # Double up, convert to quarter, then take max fam_id - expected = pd.Series([400 * 12, 400 * 12, 200 * 12]) + expected = pd.Series( + [400 * 12, 400 * 12, 200 * 12], index=pd.Index(data["p_id"], name="p_id") + ) @policy_function(vectorization_strategy="vectorize") def betrag_double_m(betrag_m): @@ -649,13 +643,16 @@ def max_betrag_double_m_fam(betrag_double_m, fam_id) -> float: "policy_environment": policy_environment, "targets__tree": {"module_name": {"max_betrag_double_y_fam": None}}, "rounding": False, - # "jit": jit, + "backend": backend, }, - targets=["results__tree"], - )["results__tree"] - - numpy.testing.assert_array_almost_equal( - actual["module_name"]["max_betrag_double_y_fam"], expected + targets=["results__df_with_nested_columns"], + )["results__df_with_nested_columns"] + + pd.testing.assert_series_equal( + actual[("module_name", "max_betrag_double_y_fam")], + expected, + check_names=False, + check_dtype=False, ) @@ -684,7 +681,7 @@ def sum_source_m_by_p_id_someone_else( }, "source", {"module": {"sum_source_by_p_id_someone_else": None}}, - pd.Series([200, 100, 0]), + pd.Series([200, 100, 0], index=pd.Index([0, 1, 2], name="p_id")), ), ( { @@ -694,7 +691,7 @@ def sum_source_m_by_p_id_someone_else( }, "source_m", {"module": {"sum_source_m_by_p_id_someone_else": None}}, - pd.Series([200, 100, 0]), + pd.Series([200, 100, 0], index=pd.Index([0, 1, 2], name="p_id")), ), ], ) @@ -704,10 +701,12 @@ def test_user_provided_aggregate_by_p_id_specs( target_tree, expected, minimal_input_data_shared_fam, + backend, + xnp, ): @policy_function(leaf_name=leaf_name, vectorization_strategy="not_required") def source() -> int: - return numpy.array([100, 200, 300]) + return xnp.array([100, 200, 300]) policy_environment = merge_trees( agg_functions, @@ -718,18 +717,23 @@ def source() -> int: }, ) - out = main( + actual = main( inputs={ "input_data__tree": minimal_input_data_shared_fam, "policy_environment": policy_environment, "targets__tree": target_tree, "rounding": False, - # "jit": jit, + "backend": backend, }, - targets=["results__tree"], - )["results__tree"]["module"][next(iter(target_tree["module"].keys()))] - - numpy.testing.assert_array_almost_equal(out, expected) + targets=["results__df_with_nested_columns"], + )["results__df_with_nested_columns"] + + pd.testing.assert_series_equal( + actual[("module", next(iter(target_tree["module"].keys())))], + expected, + check_names=False, + check_dtype=False, + ) def test_policy_environment_with_params_and_scalars_is_processed(): diff --git a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py index bcc162791..793edaa51 100644 --- a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py +++ b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py @@ -1,15 +1,18 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING import numpy import pytest -from ttsim.config import IS_JAX_INSTALLED +try: + import jax_datetime + + my_datetime = jax_datetime.to_datetime +except ImportError: + my_datetime = lambda x: x + -if TYPE_CHECKING: - import numpy from ttsim.tt_dag_elements.aggregation import ( grouped_all, grouped_any, @@ -178,90 +181,88 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "exception_match": "The dtype of id columns must be integer.", }, } -# We cannot even set up these fixtures in JAX. -if not IS_JAX_INSTALLED: - test_grouped_specs["datetime"] = { - "column_to_aggregate": numpy.array( - [ - numpy.datetime64("2000"), - numpy.datetime64("2001"), - numpy.datetime64("2002"), - numpy.datetime64("2003"), - numpy.datetime64("2004"), - ] - ), - "group_id": numpy.array([1, 0, 1, 1, 1]), - "expected_res_max": numpy.array( - [ - numpy.datetime64("2004"), - numpy.datetime64("2001"), - numpy.datetime64("2004"), - numpy.datetime64("2004"), - numpy.datetime64("2004"), - ] - ), - "expected_res_min": numpy.array( - [ - numpy.datetime64("2000"), - numpy.datetime64("2001"), - numpy.datetime64("2000"), - numpy.datetime64("2000"), - numpy.datetime64("2000"), - ] - ), - } - - test_grouped_raises_specs["dtype_string"] = { - "column_to_aggregate": numpy.array(["0", "1", "2", "3", "4"]), - "group_id": numpy.array([0, 0, 1, 1, 1]), - "error_sum": TypeError, - "error_mean": TypeError, - "error_max": TypeError, - "error_min": TypeError, - "error_any": TypeError, - "error_all": TypeError, - "exception_match": "grouped_", - } - test_grouped_raises_specs["datetime"] = { - "column_to_aggregate": numpy.array( - [ - numpy.datetime64("2000"), - numpy.datetime64("2001"), - numpy.datetime64("2002"), - numpy.datetime64("2003"), - numpy.datetime64("2004"), - ] - ), - "group_id": numpy.array([0, 0, 1, 1, 1]), - "error_sum": TypeError, - "error_mean": TypeError, - "error_any": TypeError, - "error_all": TypeError, - "exception_match": "grouped_", - } +test_grouped_specs["datetime"] = { + "column_to_aggregate": numpy.array( + [ + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2002"), + numpy.datetime64("2003"), + numpy.datetime64("2004"), + ] + ), + "group_id": numpy.array([1, 0, 1, 1, 1]), + "expected_res_max": numpy.array( + [ + numpy.datetime64("2004"), + numpy.datetime64("2001"), + numpy.datetime64("2004"), + numpy.datetime64("2004"), + numpy.datetime64("2004"), + ] + ), + "expected_res_min": numpy.array( + [ + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2000"), + numpy.datetime64("2000"), + numpy.datetime64("2000"), + ] + ), +} + +test_grouped_raises_specs["dtype_string"] = { + "column_to_aggregate": numpy.array(["0", "1", "2", "3", "4"]), + "group_id": numpy.array([0, 0, 1, 1, 1]), + "error_sum": TypeError, + "error_mean": TypeError, + "error_max": TypeError, + "error_min": TypeError, + "error_any": TypeError, + "error_all": TypeError, + "exception_match": "grouped_", +} +test_grouped_raises_specs["datetime"] = { + "column_to_aggregate": numpy.array( + [ + numpy.datetime64("2000"), + numpy.datetime64("2001"), + numpy.datetime64("2002"), + numpy.datetime64("2003"), + numpy.datetime64("2004"), + ] + ), + "group_id": numpy.array([0, 0, 1, 1, 1]), + "error_sum": TypeError, + "error_mean": TypeError, + "error_any": TypeError, + "error_all": TypeError, + "exception_match": "grouped_", +} @parameterize_based_on_dict( test_grouped_specs, keys_of_test_cases=["group_id", "expected_res_count"], ) -def test_grouped_count(group_id, expected_res_count): - if IS_JAX_INSTALLED: +def test_grouped_count(group_id, expected_res_count, backend): + if backend == "jax": result = grouped_count( group_id=group_id, - num_segments=group_id.max() + 1, + num_segments=len(group_id), ) else: result = grouped_count(group_id=group_id) numpy.testing.assert_array_almost_equal(result, expected_res_count) -def _run_agg_by_group(agg_func, column_to_aggregate, group_id): - if IS_JAX_INSTALLED: +def _run_agg_by_group(agg_func, column_to_aggregate, group_id, backend): + if backend == "jax": return agg_func( column=column_to_aggregate, group_id=group_id, - num_segments=group_id.max() + 1, + num_segments=len(group_id), ) else: return agg_func(column=column_to_aggregate, group_id=group_id) @@ -275,11 +276,12 @@ def _run_agg_by_group(agg_func, column_to_aggregate, group_id): "expected_res_sum", ], ) -def test_grouped_sum(column_to_aggregate, group_id, expected_res_sum): +def test_grouped_sum(column_to_aggregate, group_id, expected_res_sum, backend): result = _run_agg_by_group( agg_func=grouped_sum, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, ) numpy.testing.assert_array_almost_equal(result, expected_res_sum) @@ -292,45 +294,86 @@ def test_grouped_sum(column_to_aggregate, group_id, expected_res_sum): "expected_res_mean", ], ) -def test_grouped_mean(column_to_aggregate, group_id, expected_res_mean): +def test_grouped_mean(column_to_aggregate, group_id, expected_res_mean, backend): result = _run_agg_by_group( agg_func=grouped_mean, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, ) numpy.testing.assert_array_almost_equal(result, expected_res_mean) @parameterize_based_on_dict( - test_grouped_specs, + {k: v for k, v in test_grouped_specs.items() if "datetime" not in k}, keys_of_test_cases=[ "column_to_aggregate", "group_id", "expected_res_max", ], ) -def test_grouped_max(column_to_aggregate, group_id, expected_res_max): +def test_grouped_max(column_to_aggregate, group_id, expected_res_max, backend): result = _run_agg_by_group( agg_func=grouped_max, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, ) numpy.testing.assert_array_equal(result, expected_res_max) +@pytest.mark.skipif_jax @parameterize_based_on_dict( - test_grouped_specs, + {k: v for k, v in test_grouped_specs.items() if "datetime" in k}, + keys_of_test_cases=[ + "column_to_aggregate", + "group_id", + "expected_res_max", + ], +) +def test_grouped_max_datetime(column_to_aggregate, group_id, expected_res_max, backend): + result = _run_agg_by_group( + agg_func=grouped_max, + column_to_aggregate=my_datetime(column_to_aggregate), + group_id=group_id, + backend=backend, + ) + numpy.testing.assert_array_equal(result, expected_res_max) + + +@parameterize_based_on_dict( + {k: v for k, v in test_grouped_specs.items() if "datetime" not in k}, keys_of_test_cases=[ "column_to_aggregate", "group_id", "expected_res_min", ], ) -def test_grouped_min(column_to_aggregate, group_id, expected_res_min): +def test_grouped_min(column_to_aggregate, group_id, expected_res_min, backend): result = _run_agg_by_group( agg_func=grouped_min, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, + ) + numpy.testing.assert_array_equal(result, expected_res_min) + + +@pytest.mark.skipif_jax +@parameterize_based_on_dict( + {k: v for k, v in test_grouped_specs.items() if "datetime" in k}, + keys_of_test_cases=[ + "column_to_aggregate", + "group_id", + "expected_res_min", + ], +) +def test_grouped_min_datetime(column_to_aggregate, group_id, expected_res_min, backend): + result = _run_agg_by_group( + agg_func=grouped_min, + column_to_aggregate=my_datetime(column_to_aggregate), + group_id=group_id, + backend=backend, ) numpy.testing.assert_array_equal(result, expected_res_min) @@ -343,11 +386,12 @@ def test_grouped_min(column_to_aggregate, group_id, expected_res_min): "expected_res_any", ], ) -def test_grouped_any(column_to_aggregate, group_id, expected_res_any): +def test_grouped_any(column_to_aggregate, group_id, expected_res_any, backend): result = _run_agg_by_group( agg_func=grouped_any, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, ) numpy.testing.assert_array_almost_equal(result, expected_res_any) @@ -360,11 +404,12 @@ def test_grouped_any(column_to_aggregate, group_id, expected_res_any): "expected_res_all", ], ) -def test_grouped_all(column_to_aggregate, group_id, expected_res_all): +def test_grouped_all(column_to_aggregate, group_id, expected_res_all, backend): result = _run_agg_by_group( agg_func=grouped_all, column_to_aggregate=column_to_aggregate, group_id=group_id, + backend=backend, ) numpy.testing.assert_array_almost_equal(result, expected_res_all) @@ -378,7 +423,7 @@ def test_grouped_all(column_to_aggregate, group_id, expected_res_all): "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_sum_raises(column_to_aggregate, group_id, error_sum, exception_match): with pytest.raises( error_sum, @@ -396,7 +441,7 @@ def test_grouped_sum_raises(column_to_aggregate, group_id, error_sum, exception_ "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_mean_raises( column_to_aggregate, group_id, error_mean, exception_match ): @@ -416,7 +461,7 @@ def test_grouped_mean_raises( "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_max_raises(column_to_aggregate, group_id, error_max, exception_match): with pytest.raises( error_max, @@ -434,7 +479,7 @@ def test_grouped_max_raises(column_to_aggregate, group_id, error_max, exception_ "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_min_raises(column_to_aggregate, group_id, error_min, exception_match): with pytest.raises( error_min, @@ -452,7 +497,7 @@ def test_grouped_min_raises(column_to_aggregate, group_id, error_min, exception_ "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_any_raises(column_to_aggregate, group_id, error_any, exception_match): with pytest.raises( error_any, @@ -470,7 +515,7 @@ def test_grouped_any_raises(column_to_aggregate, group_id, error_any, exception_ "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_grouped_all_raises(column_to_aggregate, group_id, error_all, exception_match): with pytest.raises( error_all, @@ -517,7 +562,7 @@ def test_sum_by_p_id( "exception_match", ], ) -@pytest.mark.skipif(IS_JAX_INSTALLED, reason="Cannot raise errors in jitted JAX.") +@pytest.mark.skipif_jax def test_sum_by_p_id_raises( column_to_aggregate, group_id, p_id_to_store_by, error_sum_by_p_id, exception_match ): diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index f66e80fc5..d9272a0af 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -6,7 +6,6 @@ from pandas._testing import assert_series_equal from ttsim import main -from ttsim.config import IS_JAX_INSTALLED from ttsim.interface_dag_elements.policy_environment import policy_environment from ttsim.tt_dag_elements import ( RoundingSpec, @@ -14,11 +13,6 @@ policy_input, ) -if IS_JAX_INSTALLED: - DTYPE = "float32" -else: - DTYPE = "float64" - @policy_input() def x() -> int: @@ -133,8 +127,9 @@ def test_func(x): )["results__tree"] assert_series_equal( pd.Series(results__tree["namespace"]["test_func"]), - pd.Series(exp_output, dtype=DTYPE), + pd.Series(exp_output), check_names=False, + check_dtype=False, ) @@ -169,7 +164,7 @@ def test_func_m(x): )["results__tree"] assert_series_equal( pd.Series(results__tree["test_func_y"]), - pd.Series([12.0, 12.0], dtype=DTYPE), + pd.Series([12.0, 12.0]), check_names=False, check_dtype=False, ) @@ -208,7 +203,7 @@ def test_func(x): )["results__tree"] assert_series_equal( pd.Series(results__tree["test_func"]), - pd.Series(input_values_exp_output, dtype=DTYPE), + pd.Series(input_values_exp_output), check_names=False, check_dtype=False, ) From 86c8d2b1d939aa5816bf579cb25e2c89c1816a81 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Thu, 12 Jun 2025 18:06:20 +0200 Subject: [PATCH 13/25] Remove IS_JAX_INSTALLED. For some reason, tests are not working with JAX, probably some earlier version did not jit correctly... --- src/ttsim/config.py | 8 - .../automatically_added_functions.py | 4 +- .../orig_policy_objects.py | 8 + .../policy_environment.py | 4 +- .../specialized_environment.py | 6 +- src/ttsim/plot_dag.py | 1 + src/ttsim/tt_dag_elements/aggregation.py | 216 ++++++++++++++++-- src/ttsim/tt_dag_elements/aggregation_jax.py | 10 +- .../column_objects_param_function.py | 45 ++-- tests/ttsim/mettsim/group_by_ids.py | 68 ++++-- .../test_automatically_added_functions.py | 2 - tests/ttsim/test_specialized_environment.py | 35 ++- .../test_aggregation_functions.py | 102 ++++++--- .../tt_dag_elements/test_ttsim_objects.py | 12 +- 14 files changed, 400 insertions(+), 121 deletions(-) delete mode 100644 src/ttsim/config.py diff --git a/src/ttsim/config.py b/src/ttsim/config.py deleted file mode 100644 index 2e57b874f..000000000 --- a/src/ttsim/config.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -try: - import jax -except ImportError: - IS_JAX_INSTALLED = False -else: - IS_JAX_INSTALLED = True diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index d4c403dd2..aacd716ed 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -574,6 +574,7 @@ def create_agg_by_group_functions( names__processed_data_columns: QNameDataColumns, targets: OrderedQNames, grouping_levels: OrderedQNames, + # backend: Literal["numpy", "jax"], ) -> UnorderedQNames: gp = group_pattern(grouping_levels) all_functions_and_data = { @@ -592,7 +593,8 @@ def create_agg_by_group_functions( potential_agg_by_group_sources = { qn: o for qn, o in all_functions_and_data.items() if not gp.match(qn) } - # Exclude objects that have been explicitly provided. + # Exclude objects that have been explicitly provided.u + agg_by_group_function_names = { t for t in potential_agg_by_group_function_names diff --git a/src/ttsim/interface_dag_elements/orig_policy_objects.py b/src/ttsim/interface_dag_elements/orig_policy_objects.py index d93a4e324..b8f163152 100644 --- a/src/ttsim/interface_dag_elements/orig_policy_objects.py +++ b/src/ttsim/interface_dag_elements/orig_policy_objects.py @@ -50,6 +50,10 @@ def column_objects_and_param_functions( The resource directory to load the ColumnObjectParamFunctions tree from. """ + @policy_input() + def backend() -> Literal["numpy", "jax"]: + """The backend to use for computations.""" + @policy_input() def num_segments() -> int: """The number of segments for segment sums in jax.""" @@ -61,6 +65,10 @@ def num_segments() -> int: path=path, root=root ).items() } + # Add backend so we can decide between numpy and jax for aggregation functions + assert "backend" not in out + out[("backend",)] = backend + # Add num_segments for segment sums in jax. assert "num_segments" not in out out[("num_segments",)] = num_segments diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index a36ab872d..8f6d36ff0 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -2,7 +2,7 @@ import copy import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import dags.tree as dt import numpy @@ -49,6 +49,7 @@ def policy_environment( orig_policy_objects__column_objects_and_param_functions: NestedColumnObjectsParamFunctions, # noqa: E501 orig_policy_objects__param_specs: FlatOrigParamSpecs, date: datetime.date | DashedISOString, + backend: Literal["numpy", "jax"], xnp: ModuleType, dnp: ModuleType, ) -> NestedPolicyEnvironment: @@ -95,6 +96,7 @@ def policy_environment( note=None, reference=None, ) + a_tree["backend"] = backend a_tree["xnp"] = xnp a_tree["dnp"] = dnp return a_tree diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index b75910059..6982e54eb 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -93,7 +93,6 @@ def with_derived_functions_and_processed_input_nodes( targets=dt.qual_names(targets__tree), names__processed_data_columns=names__processed_data_columns, grouping_levels=names__grouping_levels, - backend=backend, ) out = {} for n, f in flat_with_derived.items(): @@ -107,6 +106,7 @@ def with_derived_functions_and_processed_input_nodes( # The number of segments for jax' segment sum. After processing the data, we know # that the number of ids is at most the length of the data. out["num_segments"] = len(next(iter(processed_data.values()))) + out["backend"] = backend return out @@ -132,7 +132,6 @@ def _add_derived_functions( targets: OrderedQNames, names__processed_data_columns: QNameDataColumns, grouping_levels: OrderedQNames, - backend: Literal["numpy", "jax"], ) -> UnorderedQNames: """Return a mapping of qualified names to functions operating on columns. @@ -183,7 +182,6 @@ def _add_derived_functions( names__processed_data_columns=names__processed_data_columns, targets=targets, grouping_levels=grouping_levels, - backend=backend, ) out = { **qual_name_policy_environment, @@ -207,7 +205,7 @@ def with_processed_params_and_scalars( scalars = { k: v for k, v in with_derived_functions_and_processed_input_nodes.items() - if isinstance(v, float | int | bool) + if isinstance(v, float | int | bool) or k == "backend" } modules = { k: v diff --git a/src/ttsim/plot_dag.py b/src/ttsim/plot_dag.py index e4a4b25bd..4167f1672 100644 --- a/src/ttsim/plot_dag.py +++ b/src/ttsim/plot_dag.py @@ -97,6 +97,7 @@ def plot_tt_dag( set_annotations=False, ) args = dict(inspect.signature(f).parameters) + args.pop("backend", None) args.pop("xnp", None) args.pop("dnp", None) if args: diff --git a/src/ttsim/tt_dag_elements/aggregation.py b/src/ttsim/tt_dag_elements/aggregation.py index f1b624b68..9f0ecd184 100644 --- a/src/ttsim/tt_dag_elements/aggregation.py +++ b/src/ttsim/tt_dag_elements/aggregation.py @@ -1,10 +1,13 @@ from __future__ import annotations from enum import StrEnum +from typing import TYPE_CHECKING, Literal -from ttsim.config import IS_JAX_INSTALLED from ttsim.tt_dag_elements import aggregation_jax, aggregation_numpy +if TYPE_CHECKING: + import numpy + class AggType(StrEnum): """ @@ -20,21 +23,200 @@ class AggType(StrEnum): ALL = "all" -aggregation_module = aggregation_jax if IS_JAX_INSTALLED else aggregation_numpy - # The signature of the functions must be the same in both modules, except that all JAX # functions have the additional `num_segments` argument. -grouped_count = aggregation_module.grouped_count -grouped_sum = aggregation_module.grouped_sum -grouped_mean = aggregation_module.grouped_mean -grouped_max = aggregation_module.grouped_max -grouped_min = aggregation_module.grouped_min -grouped_any = aggregation_module.grouped_any -grouped_all = aggregation_module.grouped_all -count_by_p_id = aggregation_module.count_by_p_id -sum_by_p_id = aggregation_module.sum_by_p_id -mean_by_p_id = aggregation_module.mean_by_p_id -max_by_p_id = aggregation_module.max_by_p_id -min_by_p_id = aggregation_module.min_by_p_id -any_by_p_id = aggregation_module.any_by_p_id -all_by_p_id = aggregation_module.all_by_p_id +def grouped_count( + group_id: numpy.ndarray, num_segments: int, backend: Literal["numpy", "jax"] +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_count(group_id) + else: + return aggregation_jax.grouped_count(group_id, num_segments) + + +def grouped_sum( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_sum(column, group_id) + else: + return aggregation_jax.grouped_sum(column, group_id, num_segments) + + +def grouped_mean( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_mean(column, group_id) + else: + return aggregation_jax.grouped_mean(column, group_id, num_segments) + + +def grouped_max( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_max(column, group_id) + else: + return aggregation_jax.grouped_max(column, group_id, num_segments) + + +def grouped_min( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_min(column, group_id) + else: + return aggregation_jax.grouped_min(column, group_id, num_segments) + + +def grouped_any( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_any(column, group_id) + else: + return aggregation_jax.grouped_any(column, group_id, num_segments) + + +def grouped_all( + column: numpy.ndarray, + group_id: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.grouped_all(column, group_id) + else: + return aggregation_jax.grouped_all(column, group_id, num_segments) + + +def count_by_p_id( + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.count_by_p_id(p_id_to_aggregate_by, p_id_to_store_by) + else: + return aggregation_jax.count_by_p_id( + p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def sum_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.sum_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.sum_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def mean_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.mean_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.mean_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def max_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.max_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.max_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def min_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.min_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.min_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def any_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.any_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.any_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) + + +def all_by_p_id( + column: numpy.ndarray, + p_id_to_aggregate_by: numpy.ndarray, + p_id_to_store_by: numpy.ndarray, + num_segments: int, + backend: Literal["numpy", "jax"], +) -> numpy.ndarray: + if backend == "numpy": + return aggregation_numpy.all_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by + ) + else: + return aggregation_jax.all_by_p_id( + column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + ) diff --git a/src/ttsim/tt_dag_elements/aggregation_jax.py b/src/ttsim/tt_dag_elements/aggregation_jax.py index d2b79d094..597974c2f 100644 --- a/src/ttsim/tt_dag_elements/aggregation_jax.py +++ b/src/ttsim/tt_dag_elements/aggregation_jax.py @@ -94,7 +94,9 @@ def grouped_all( def count_by_p_id( - p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray + p_id_to_aggregate_by: jnp.ndarray, + p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError @@ -103,6 +105,7 @@ def sum_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, # noqa: ARG001 ) -> jnp.ndarray: if column.dtype == bool: column = column.astype(int) @@ -131,6 +134,7 @@ def mean_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError @@ -139,6 +143,7 @@ def max_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError @@ -147,6 +152,7 @@ def min_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError @@ -155,6 +161,7 @@ def any_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError @@ -163,5 +170,6 @@ def all_by_p_id( column: jnp.ndarray, p_id_to_aggregate_by: jnp.ndarray, p_id_to_store_by: jnp.ndarray, + num_segments: int, ) -> jnp.ndarray: raise NotImplementedError diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index c8b72b7bf..472cb9a20 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -225,11 +225,11 @@ class ColumnFunction(ColumnObject, Generic[FunArgTypes, ReturnType]): foreign_key_type: FKType = FKType.IRRELEVANT def __post_init__(self) -> None: - self._fail_if__rounding_has_wrong_type(self.rounding_spec) + self._fail_if_rounding_has_wrong_type(self.rounding_spec) # Expose the signature of the wrapped function for dependency resolution _frozen_safe_update_wrapper(self, self.function) - def _fail_if__rounding_has_wrong_type( + def _fail_if_rounding_has_wrong_type( self, rounding_spec: RoundingSpec | None ) -> None: """Check if rounding_spec has the correct type. @@ -561,20 +561,19 @@ def inner(func: GenericCallable) -> AggByGroupFunction: orig_location = f"{func.__module__}.{func.__name__}" args = set(inspect.signature(func).parameters) group_ids = {p for p in args if p.endswith("_id")} - _fail_if__group_id_is_invalid(group_ids, orig_location) + _fail_if_group_id_is_invalid(group_ids, orig_location) group_id = group_ids.pop() - other_args = args - {group_id} + other_args = args - {group_id, "num_segments", "backend"} if agg_type == AggType.COUNT: - _fail_if__other_arg_is_present(other_args, orig_location) + _fail_if_other_arg_is_present(other_args, orig_location) mapper = {"group_id": group_id} else: - _fail_if__other_arg_is_invalid(other_args, orig_location) + _fail_if_other_arg_is_invalid(other_args, orig_location) mapper = {"group_id": group_id, "column": other_args.pop()} agg_func = rename_arguments( func=agg_registry[agg_type], mapper=mapper, ) - return AggByGroupFunction( leaf_name=leaf_name if leaf_name else func.__name__, function=agg_func, @@ -587,7 +586,7 @@ def inner(func: GenericCallable) -> AggByGroupFunction: return inner -def _fail_if__group_id_is_invalid( +def _fail_if_group_id_is_invalid( group_ids: UnorderedQNames, orig_location: str ) -> None: if len(group_ids) != 1: @@ -598,7 +597,7 @@ def _fail_if__group_id_is_invalid( ) -def _fail_if__other_arg_is_present( +def _fail_if_other_arg_is_present( other_args: UnorderedQNames, orig_location: str ) -> None: if other_args: @@ -608,13 +607,13 @@ def _fail_if__other_arg_is_present( ) -def _fail_if__other_arg_is_invalid( +def _fail_if_other_arg_is_invalid( other_args: UnorderedQNames, orig_location: str ) -> None: if len(other_args) != 1: raise ValueError( - "There must be exactly one argument besides identifiers for aggregations. " - "Got: " + "There must be exactly one argument besides identifiers, num_segments, and " + "backend for aggregations. Got: " f"{', '.join(other_args) if other_args else 'nothing'} in {orig_location}." ) @@ -693,30 +692,30 @@ def inner(func: GenericCallable) -> AggByPIDFunction: for p in args if any(e.startswith("p_id_") for e in dt.tree_path_from_qual_name(p)) } - other_args = args - {*other_p_ids, "p_id"} - _fail_if__p_id_is_not_present(args, orig_location) - _fail_if__other_p_id_is_invalid(other_p_ids, orig_location) + other_args = args - {*other_p_ids, "p_id", "num_segments", "backend"} + _fail_if_p_id_is_not_present(args, orig_location) + _fail_if_other_p_id_is_invalid(other_p_ids, orig_location) if agg_type == AggType.COUNT: - _fail_if__other_arg_is_present(other_args, orig_location) + _fail_if_other_arg_is_present(other_args, orig_location) mapper = { "p_id_to_aggregate_by": other_p_ids.pop(), "p_id_to_store_by": "p_id", + "num_segments": "num_segments", + "backend": "backend", } else: - _fail_if__other_arg_is_invalid(other_args, orig_location) + _fail_if_other_arg_is_invalid(other_args, orig_location) mapper = { "column": other_args.pop(), "p_id_to_aggregate_by": other_p_ids.pop(), "p_id_to_store_by": "p_id", + "num_segments": "num_segments", + "backend": "backend", } agg_func = rename_arguments( func=agg_registry[agg_type], mapper=mapper, ) - - functools.update_wrapper(agg_func, func) - agg_func.__signature__ = inspect.signature(func) - return AggByPIDFunction( leaf_name=leaf_name if leaf_name else func.__name__, function=agg_func, @@ -729,7 +728,7 @@ def inner(func: GenericCallable) -> AggByPIDFunction: return inner -def _fail_if__p_id_is_not_present(args: UnorderedQNames, orig_location: str) -> None: +def _fail_if_p_id_is_not_present(args: UnorderedQNames, orig_location: str) -> None: if "p_id" not in args: raise ValueError( "The function must have the argument named 'p_id' for aggregation by p_id. " @@ -737,7 +736,7 @@ def _fail_if__p_id_is_not_present(args: UnorderedQNames, orig_location: str) -> ) -def _fail_if__other_p_id_is_invalid( +def _fail_if_other_p_id_is_invalid( other_p_ids: UnorderedQNames, orig_location: str ) -> None: if len(other_p_ids) != 1: diff --git a/tests/ttsim/mettsim/group_by_ids.py b/tests/ttsim/mettsim/group_by_ids.py index 5161e632b..1b30ee58e 100644 --- a/tests/ttsim/mettsim/group_by_ids.py +++ b/tests/ttsim/mettsim/group_by_ids.py @@ -33,39 +33,63 @@ def fam_id( """ Compute the family ID for each person. """ - n = xnp.max(p_id) + n = xnp.max(p_id) + 1 + # Get the array index for all p_ids of parents p_id_parent_1_loc = p_id_parent_1 p_id_parent_2_loc = p_id_parent_2 for i in range(p_id.shape[0]): - p_id_parent_1_loc = xnp.where( - p_id_parent_1_loc == p_id[i], i, p_id_parent_1_loc - ) - p_id_parent_2_loc = xnp.where( - p_id_parent_2_loc == p_id[i], i, p_id_parent_2_loc - ) + p_id_parent_1_loc = xnp.where(p_id_parent_1 == p_id[i], i, p_id_parent_1_loc) + p_id_parent_2_loc = xnp.where(p_id_parent_2 == p_id[i], i, p_id_parent_2_loc) + + children = xnp.isin(p_id, p_id_parent_1) | xnp.isin(p_id, p_id_parent_2) - children = xnp.isin(p_id, p_id_parent_1) + xnp.isin(p_id, p_id_parent_2) + # Assign the same fam_id to everybody who has a spouse, + # otherwise create a new one from p_id out = xnp.where( p_id_spouse < 0, p_id + p_id * n, xnp.maximum(p_id, p_id_spouse) + xnp.minimum(p_id, p_id_spouse) * n, ) - out = xnp.where( - (out == p_id + p_id * n) - * (p_id_parent_1_loc >= 0) - * (age < 25) - * (1 - children), - out[p_id_parent_1_loc], - out, + + out = _assign_parents_fam_id( + fam_id=out, + p_id=p_id, + p_id_parent_loc=p_id_parent_1_loc, + age=age, + children=children, + n=n, + xnp=xnp, ) - out = xnp.where( - (out == p_id + p_id * n) - * (p_id_parent_2_loc >= 0) - * (age < 25) - * (1 - children), - out[p_id_parent_2_loc], - out, + out = _assign_parents_fam_id( + fam_id=out, + p_id=p_id, + p_id_parent_loc=p_id_parent_2_loc, + age=age, + children=children, + n=n, + xnp=xnp, ) return out + + +def _assign_parents_fam_id( + fam_id: numpy.ndarray, + p_id: numpy.ndarray, + p_id_parent_loc: numpy.ndarray, + age: numpy.ndarray, + children: numpy.ndarray, + n: numpy.ndarray, + xnp: ModuleType, +) -> numpy.ndarray: + """Return the fam_id of the child's parents.""" + + return xnp.where( + (fam_id == p_id + p_id * n) + * (p_id_parent_loc >= 0) + * (age < 25) + * (1 - children), + fam_id[p_id_parent_loc], + fam_id, + ) diff --git a/tests/ttsim/test_automatically_added_functions.py b/tests/ttsim/test_automatically_added_functions.py index d12604c4f..2233810e9 100644 --- a/tests/ttsim/test_automatically_added_functions.py +++ b/tests/ttsim/test_automatically_added_functions.py @@ -391,7 +391,6 @@ def test_derived_aggregation_functions_are_in_correct_namespace( targets, names__processed_data_columns, expected, - backend, ): """Test that the derived aggregation functions are in the correct namespace. @@ -403,6 +402,5 @@ def test_derived_aggregation_functions_are_in_correct_namespace( names__processed_data_columns=names__processed_data_columns, targets=targets, grouping_levels=("kin",), - backend=backend, ) assert expected in result diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index 20e8b26a2..cc58b22a3 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -2,7 +2,7 @@ import datetime from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import dags.tree as dt import numpy @@ -180,8 +180,8 @@ def some_policy_function_taking_int_param(some_int_param: int) -> float: def minimal_input_data(): n_individuals = 5 out = { - "p_id": pd.Series(numpy.arange(n_individuals), name="p_id"), - "fam_id": pd.Series(numpy.arange(n_individuals), name="fam_id"), + "p_id": numpy.arange(n_individuals), + "fam_id": numpy.arange(n_individuals), } return out @@ -190,9 +190,9 @@ def minimal_input_data(): def minimal_input_data_shared_fam(): n_individuals = 3 out = { - "p_id": pd.Series(numpy.arange(n_individuals), name="p_id"), - "fam_id": pd.Series([0, 0, 1], name="fam_id"), - "p_id_someone_else": pd.Series([1, 0, -1], name="p_id_someone_else"), + "p_id": numpy.arange(n_individuals), + "fam_id": numpy.array([0, 0, 1]), + "p_id_someone_else": numpy.array([1, 0, -1]), } return out @@ -395,14 +395,17 @@ def test_create_agg_by_group_functions( policy_environment, targets__tree, input_data__tree, + backend, ): + policy_environment["backend"] = backend + policy_environment["num_segments"] = len(input_data__tree["p_id"]) main( inputs={ "policy_environment": policy_environment, "input_data__tree": input_data__tree, "targets__tree": targets__tree, "rounding": False, - "backend": "numpy", + "backend": backend, }, targets=["results__tree"], )["results__tree"] @@ -585,6 +588,8 @@ def betrag_m_double_fam(betrag_m_double, fam_id) -> float: "betrag_m_double": betrag_m_double, "betrag_m_double_fam": betrag_m_double_fam, }, + "backend": backend, + "num_segments": len(data["p_id"]), } actual = main( @@ -635,6 +640,8 @@ def max_betrag_double_m_fam(betrag_double_m, fam_id) -> float: "betrag_double_m": betrag_double_m, "max_betrag_double_m_fam": max_betrag_double_m_fam, }, + "backend": backend, + "num_segments": len(data["p_id"]), } actual = main( @@ -658,14 +665,22 @@ def max_betrag_double_m_fam(betrag_double_m, fam_id) -> float: @agg_by_p_id_function(agg_type=AggType.SUM) def sum_source_by_p_id_someone_else( - source: int, p_id: int, p_id_someone_else: int + source: int, + p_id: int, + p_id_someone_else: int, + num_segments: int, + backend: Literal["numpy", "jax"], ) -> int: pass @agg_by_p_id_function(agg_type=AggType.SUM) def sum_source_m_by_p_id_someone_else( - source_m: int, p_id: int, p_id_someone_else: int + source_m: int, + p_id: int, + p_id_someone_else: int, + num_segments: int, + backend: Literal["numpy", "jax"], ) -> int: pass @@ -714,6 +729,8 @@ def source() -> int: "module": {leaf_name: source}, "p_id": p_id, "p_id_someone_else": p_id_someone_else, + "backend": backend, + "num_segments": len(minimal_input_data_shared_fam["p_id"]), }, ) diff --git a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py index 793edaa51..76edbd0a0 100644 --- a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py +++ b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py @@ -247,25 +247,21 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): keys_of_test_cases=["group_id", "expected_res_count"], ) def test_grouped_count(group_id, expected_res_count, backend): - if backend == "jax": - result = grouped_count( - group_id=group_id, - num_segments=len(group_id), - ) - else: - result = grouped_count(group_id=group_id) + result = grouped_count( + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) numpy.testing.assert_array_almost_equal(result, expected_res_count) def _run_agg_by_group(agg_func, column_to_aggregate, group_id, backend): - if backend == "jax": - return agg_func( - column=column_to_aggregate, - group_id=group_id, - num_segments=len(group_id), - ) - else: - return agg_func(column=column_to_aggregate, group_id=group_id) + return agg_func( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -424,12 +420,19 @@ def test_grouped_all(column_to_aggregate, group_id, expected_res_all, backend): ], ) @pytest.mark.skipif_jax -def test_grouped_sum_raises(column_to_aggregate, group_id, error_sum, exception_match): +def test_grouped_sum_raises( + column_to_aggregate, group_id, error_sum, exception_match, backend +): with pytest.raises( error_sum, match=exception_match, ): - grouped_sum(column_to_aggregate, group_id) + grouped_sum( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -443,13 +446,18 @@ def test_grouped_sum_raises(column_to_aggregate, group_id, error_sum, exception_ ) @pytest.mark.skipif_jax def test_grouped_mean_raises( - column_to_aggregate, group_id, error_mean, exception_match + column_to_aggregate, group_id, error_mean, exception_match, backend ): with pytest.raises( error_mean, match=exception_match, ): - grouped_mean(column_to_aggregate, group_id) + grouped_mean( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -462,12 +470,19 @@ def test_grouped_mean_raises( ], ) @pytest.mark.skipif_jax -def test_grouped_max_raises(column_to_aggregate, group_id, error_max, exception_match): +def test_grouped_max_raises( + column_to_aggregate, group_id, error_max, exception_match, backend +): with pytest.raises( error_max, match=exception_match, ): - grouped_max(column_to_aggregate, group_id) + grouped_max( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -480,12 +495,19 @@ def test_grouped_max_raises(column_to_aggregate, group_id, error_max, exception_ ], ) @pytest.mark.skipif_jax -def test_grouped_min_raises(column_to_aggregate, group_id, error_min, exception_match): +def test_grouped_min_raises( + column_to_aggregate, group_id, error_min, exception_match, backend +): with pytest.raises( error_min, match=exception_match, ): - grouped_min(column_to_aggregate, group_id) + grouped_min( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -498,12 +520,19 @@ def test_grouped_min_raises(column_to_aggregate, group_id, error_min, exception_ ], ) @pytest.mark.skipif_jax -def test_grouped_any_raises(column_to_aggregate, group_id, error_any, exception_match): +def test_grouped_any_raises( + column_to_aggregate, group_id, error_any, exception_match, backend +): with pytest.raises( error_any, match=exception_match, ): - grouped_any(column_to_aggregate, group_id) + grouped_any( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -516,12 +545,19 @@ def test_grouped_any_raises(column_to_aggregate, group_id, error_any, exception_ ], ) @pytest.mark.skipif_jax -def test_grouped_all_raises(column_to_aggregate, group_id, error_all, exception_match): +def test_grouped_all_raises( + column_to_aggregate, group_id, error_all, exception_match, backend +): with pytest.raises( error_all, match=exception_match, ): - grouped_all(column_to_aggregate, group_id) + grouped_all( + column=column_to_aggregate, + group_id=group_id, + num_segments=len(group_id), + backend=backend, + ) @parameterize_based_on_dict( @@ -540,11 +576,14 @@ def test_sum_by_p_id( p_id_to_store_by, expected_res, expected_type, + backend, ): result = sum_by_p_id( column=column_to_aggregate, p_id_to_aggregate_by=p_id_to_aggregate_by, p_id_to_store_by=p_id_to_store_by, + num_segments=len(p_id_to_aggregate_by), + backend=backend, ) numpy.testing.assert_array_almost_equal(result, expected_res) assert numpy.issubdtype(result.dtype.type, expected_type), ( @@ -564,7 +603,12 @@ def test_sum_by_p_id( ) @pytest.mark.skipif_jax def test_sum_by_p_id_raises( - column_to_aggregate, group_id, p_id_to_store_by, error_sum_by_p_id, exception_match + column_to_aggregate, + group_id, + p_id_to_store_by, + error_sum_by_p_id, + exception_match, + backend, ): with pytest.raises( error_sum_by_p_id, @@ -574,4 +618,6 @@ def test_sum_by_p_id_raises( column=column_to_aggregate, p_id_to_aggregate_by=group_id, p_id_to_store_by=p_id_to_store_by, + num_segments=len(group_id), + backend=backend, ) diff --git a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py index b51432eaa..c39a370c9 100644 --- a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py +++ b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py @@ -178,7 +178,7 @@ def test_agg_by_group_function_type(function, expected_group_id, expected_other_ def test_agg_by_group_count_other_arg_present(): - match = "There must be no argument besides identifiers for counting." + match = "There must be no argument besides identifiers" with pytest.raises(ValueError, match=match): @agg_by_group_function(agg_type=AggType.COUNT) @@ -187,7 +187,7 @@ def aggregate_by_group_count_other_arg_present(group_id, wrong_arg): def test_agg_by_group_sum_wrong_amount_of_args(): - match = "There must be exactly one argument besides identifiers for aggregations." + match = "There must be exactly one argument besides identifiers" with pytest.raises(ValueError, match=match): @agg_by_group_function(agg_type=AggType.SUM) @@ -252,7 +252,7 @@ def test_agg_by_p_id_function_type(function, expected_foreign_p_id, expected_oth def test_agg_by_p_id_count_other_arg_present(): - match = "There must be no argument besides identifiers for counting." + match = "There must be no argument besides identifiers" with pytest.raises(ValueError, match=match): @agg_by_p_id_function(agg_type=AggType.COUNT) @@ -261,7 +261,7 @@ def aggregate_by_p_id_count_other_arg_present(p_id, p_id_specifier, wrong_arg): def test_agg_by_p_id_sum_wrong_amount_of_args(): - match = "There must be exactly one argument besides identifiers for aggregations." + match = "There must be exactly one argument besides identifiers" with pytest.raises(ValueError, match=match): @agg_by_p_id_function(agg_type=AggType.SUM) @@ -286,9 +286,11 @@ def aggregate_by_p_id_multiple_other_p_ids_present( pass -def test_agg_by_p_id_sum_with_all_missing_p_ids(): +def test_agg_by_p_id_sum_with_all_missing_p_ids(backend): aggregate_by_p_id_sum( p_id=numpy.array([180]), p_id_specifier=numpy.array([-1]), source=numpy.array([False]), + num_segments=1, + backend=backend, ) From c233b4e5fe919210ca9ae3beeb9ba2577a5b6ff8 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 12 Jun 2025 20:08:19 +0200 Subject: [PATCH 14/25] Add backend to cached policy env --- src/ttsim/testing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index 5b49651f3..41225c761 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -32,12 +32,13 @@ @lru_cache(maxsize=100) def cached_policy_environment( - date: datetime.date, root: Path + date: datetime.date, root: Path, backend: Literal["numpy", "jax"] ) -> NestedPolicyEnvironment: return main( inputs={ "date": date, "orig_policy_objects__root": root, + "backend": backend, }, targets=["policy_environment"], )["policy_environment"] @@ -79,7 +80,7 @@ def name(self) -> str: def execute_test( test: PolicyTest, root: Path, backend: Literal["numpy", "jax"] ) -> None: - environment = cached_policy_environment(date=test.date, root=root) + environment = cached_policy_environment(date=test.date, root=root, backend=backend) if test.target_structure: result_df = main( From d0a147b6f45bfe99b485e55cfd2a244b08e3e7d4 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 05:53:45 +0200 Subject: [PATCH 15/25] Remove lists from possible outputs; make jax tests pass. --- src/ttsim/interface_dag_elements/fail_if.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index d4804ec1e..845485883 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -439,11 +439,11 @@ def non_convertible_objects_in_results_tree( paths_with_incorrect_types = [] paths_with_incorrect_length = [] for path, data in dt.flatten_to_tree_paths(results__tree).items(): - if isinstance(data, (xnp.ndarray, list)): - if not all(isinstance(item, _numeric_types) for item in data): - paths_with_incorrect_types.append(str(path)) - if len(data) != expected_object_length: - paths_with_incorrect_length.append(str(path)) + if isinstance(data, xnp.ndarray) and len(data) not in { + 1, + expected_object_length, + }: + paths_with_incorrect_length.append(str(path)) elif isinstance(data, _numeric_types): continue else: From 76dd10e23d7a084b9ea2fe0dd478013ad8100af1 Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Fri, 13 Jun 2025 10:00:31 +0200 Subject: [PATCH 16/25] Add test case for TypeError because of list parameter. --- tests/ttsim/test_failures.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 392663a47..bfd49c3a4 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -35,6 +35,7 @@ PiecewisePolynomialParam, PiecewisePolynomialParamValue, group_creation_function, + param_function, policy_function, ) @@ -142,10 +143,15 @@ def some_x(x): @policy_function() -def some_policy_func_returning_array_of_length_2(xnp: ModuleType) -> numpy.ndarray: +def some_policy_func_returning_array_of_length_2(xnp: ModuleType) -> numpy.ndarray: # type: ignore[type-arg] return xnp.array([1, 2]) +@param_function() +def some_param_func_returning_list_of_length_2() -> list[int]: + return [1, 2] + + @pytest.mark.parametrize( ("tree", "leaf_checker", "err_substr"), [ @@ -739,6 +745,13 @@ def test_fail_if_input_df_mapper_has_incorrect_format( {"some_consecutive_int_1d_lookup_table_param": "res1"}, "The data contains objects that cannot be cast to a pandas.DataFrame", ), + ( + { + "some_param_func_returning_list_of_length_2": some_param_func_returning_list_of_length_2, + }, + {"some_param_func_returning_list_of_length_2": "res1"}, + "The data contains objects that cannot be cast to a pandas.DataFrame", + ), ], ) def test_fail_if_non_convertible_objects_in_results_tree_because_of_object_type( From cebf3f192111302689287b610cd67af3b7053f05 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 10:43:54 +0200 Subject: [PATCH 17/25] Move to jaxtyping, except for aggregation. --- .pre-commit-config.yaml | 5 +- conftest.py | 18 ++- pixi.lock | 129 ++++++++++++++---- pyproject.toml | 10 +- .../kindergeld\303\274bertrag.py" | 16 +-- src/_gettsim/ids.py | 66 ++++----- src/_gettsim/kindergeld/kindergeld.py | 11 +- .../unterhaltsvorschuss.py | 19 ++- src/_gettsim/wohngeld/miete.py | 10 +- .../test_warn_if_repeated_execution.py | 12 -- src/gettsim/__init__.py | 36 +---- src/ttsim/interface_dag.py | 1 - .../interface_dag_elements/data_converters.py | 3 +- src/ttsim/interface_dag_elements/fail_if.py | 40 +----- .../interface_node_objects.py | 6 +- src/ttsim/interface_dag_elements/typing.py | 14 +- src/ttsim/stale_code_storage.py | 68 +++++++++ .../column_objects_param_function.py | 85 +++--------- src/ttsim/tt_dag_elements/param_objects.py | 14 +- .../tt_dag_elements/piecewise_polynomial.py | 80 ++++------- src/ttsim/tt_dag_elements/rounding.py | 8 +- src/ttsim/tt_dag_elements/shared.py | 22 +-- src/ttsim/tt_dag_elements/vectorization.py | 3 - tests/ttsim/mettsim/group_by_ids.py | 35 +++-- tests/ttsim/test_failures.py | 7 +- tests/ttsim/test_specialized_environment.py | 22 +-- tests/ttsim/test_warnings.py | 8 +- tests/ttsim/tt_dag_elements/test_shared.py | 22 +-- .../tt_dag_elements/test_vectorization.py | 38 +++--- 29 files changed, 422 insertions(+), 386 deletions(-) delete mode 100644 src/_gettsim_tests/test_warn_if_repeated_execution.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1afbc511..e600c0ff7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,16 +86,19 @@ repos: - '88' files: (docs/.|CHANGES.md|CODE_OF_CONDUCT.md) - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.16.0 hooks: - id: mypy args: - --ignore-missing-imports - --config=pyproject.toml + - --allow-redefinition-new + - --local-partial-types additional_dependencies: - types-PyYAML - types-pytz - numpy >= 2 + - jaxtyping # - dags >= 0.3 - optree >= 0.15 - repo: https://github.com/python-jsonschema/check-jsonschema diff --git a/conftest.py b/conftest.py index 7fc4aa00a..c5db1f2f9 100644 --- a/conftest.py +++ b/conftest.py @@ -1,8 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from ttsim.interface_dag_elements.backend import dnp as ttsim_dnp from ttsim.interface_dag_elements.backend import xnp as ttsim_xnp +if TYPE_CHECKING: + from types import ModuleType + from typing import Literal + # content of conftest.py def pytest_addoption(parser): @@ -14,24 +22,20 @@ def pytest_addoption(parser): ) -def pytest_configure(config): - config.addinivalue_line("markers", "skipif_jax: skip test if backend is jax") - - @pytest.fixture -def backend(request): +def backend(request) -> Literal["numpy", "jax"]: backend = request.config.getoption("--backend") return backend @pytest.fixture -def xnp(request): +def xnp(request) -> ModuleType: backend = request.config.getoption("--backend") return ttsim_xnp(backend) @pytest.fixture -def dnp(request): +def dnp(request) -> ModuleType: backend = request.config.getoption("--backend") return ttsim_dnp(backend) diff --git a/pixi.lock b/pixi.lock index 473dd6565..e316a17b6 100644 --- a/pixi.lock +++ b/pixi.lock @@ -266,11 +266,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -502,11 +504,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -738,11 +742,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda @@ -981,12 +987,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-hbeecb71_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/25/33/cd41ab38ef313874eb2000f1037ccce001dd680873713cc2d1a2ae5d0041/optree-0.15.0-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ mypy: channels: @@ -1254,15 +1262,18 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/6f/5f/b392f7b4f659f5b619ce5994c5c43caab3d80df2296ae54fa888b3d17f5a/mypy-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -1494,15 +1505,18 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/70/cf/158e5055e60ca2be23aec54a3010f89dcffd788732634b344fc9cb1e85a0/mypy-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl + - pypi: https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -1734,15 +1748,18 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/94/34/cfff7a56be1609f5d10ef386342ce3494158e4d506516890142007e6472c/mypy-1.16.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda @@ -1981,16 +1998,19 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-hbeecb71_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - - pypi: https://files.pythonhosted.org/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl + - pypi: https://files.pythonhosted.org/packages/6d/38/52f4b808b3fef7f0ef840ee8ff6ce5b5d77381e65425758d515cdd4f5bb5/mypy-1.16.0-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/25/33/cd41ab38ef313874eb2000f1037ccce001dd680873713cc2d1a2ae5d0041/optree-0.15.0-cp312-cp312-win_amd64.whl + - pypi: https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ba/e205cd11c1c7183b23c97e4bcd1de7bc0633e2e867601c32ecfc6ad42675/types_pytz-2025.2.0.20250516-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ py311: channels: @@ -2258,11 +2278,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/3d/52a75740d6c449073d4bb54da382f6368553f285fb5a680b27dd198dd839/optree-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -2494,11 +2516,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/89/1267444a074b6e4402b5399b73b930a7b86cde054a41cecb9694be726a92/optree-0.15.0-cp311-cp311-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -2730,11 +2754,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/a5/f8d6c278ce72b2ed8c1ebac968c3c652832bd2d9e65ec81fe6a21082c313/optree-0.15.0-cp311-cp311-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda @@ -2973,12 +2999,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-hbeecb71_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/86/9743be6eac8cc5ef69fa2b6585a36254aca0815714f57a0763bcfa774906/optree-0.15.0-cp311-cp311-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ py312: channels: @@ -3246,11 +3274,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -3482,11 +3512,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -3718,11 +3750,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda @@ -3961,12 +3995,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-hbeecb71_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/25/33/cd41ab38ef313874eb2000f1037ccce001dd680873713cc2d1a2ae5d0041/optree-0.15.0-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ py312-jax: channels: @@ -4245,12 +4281,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 + - pypi: git+https://github.com/google/jax-datetime.git#e79cec944828e71f9faf790a4725545edff7e3b7 + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/86/f0/a00cf9f2cf1e8d54f71116ad5eea73fc5b1177644283704535bb8e43090e/optree-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -4493,12 +4531,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.7-h8210216_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 + - pypi: git+https://github.com/google/jax-datetime.git#e79cec944828e71f9faf790a4725545edff7e3b7 + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/a5/2589d9790a6dd7c4b1dd22bd228238c575ec5384ce5bc16a30e7f43cdd99/optree-0.15.0-cp312-cp312-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/adwaita-icon-theme-48.0-unix_0.conda @@ -4741,12 +4781,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-h6491c7d_2.conda - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 + - pypi: git+https://github.com/google/jax-datetime.git#e79cec944828e71f9faf790a4725545edff7e3b7 + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a5/2c/5363abf03c8d47ad7bc3b45a735cbdf24a10f99f82e776ef2949ffce77c6/optree-0.15.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/86/2f/1f0144b14553ad32a8d0afa38b832c4b117694484c32aef2d939dc96f20a/pdbp-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ win-64: - conda: https://conda.anaconda.org/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda @@ -4986,8 +5028,9 @@ environments: - pypi: git+https://github.com/OpenSourceEconomics/dags.git?branch=allow-passing-dag-to-concatenate_functions#ccbe36e4946e8963cd760114d6a5216b9e237989 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/dc/89/99805cd801919b4535e023bfe2de651f5a3ec4f5846a867cbc08006db455/jax-0.6.1-py3-none-any.whl - - pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 + - pypi: git+https://github.com/google/jax-datetime.git#e79cec944828e71f9faf790a4725545edff7e3b7 - pypi: https://files.pythonhosted.org/packages/1b/12/2bc629d530ee1b333edc81a1d68d262bad2f813ce60fdd46e98d48cc8a20/jaxlib-0.6.1-cp312-cp312-win_amd64.whl + - pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/90/51523adbedc808e03271c7448fd71da1660cc02603eaaf10b9ab4f102146/kaleido-0.1.0.post1-py2.py3-none-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/38/bc/c4260e4a6c6bf684d0313308de1c860467275221d5e7daf69b3fcddfdd0b/ml_dtypes-0.5.1-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl @@ -4997,6 +5040,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/65/44/bb509c3d2c0b5a87e7a5af1d5917a402a32ff026f777a6d7cb6990746cbb/tabcompleter-1.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl - pypi: ./ packages: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -6578,8 +6622,8 @@ packages: timestamp: 1694400856979 - pypi: ./ name: gettsim - version: 0.7.1.dev437+gede66cd1.d20250611 - sha256: 9f7117cd7089efbebbc12d4aaf941403d1e5ceb8537993948ac2eded7efa2988 + version: 0.7.1.dev452+gd0a147b6.d20250613 + sha256: b92f9d48c586dde9df09c5f901d6c28be6eb01fb808b2a79b2b3065bc3f98f99 requires_dist: - ipywidgets - networkx @@ -7447,7 +7491,7 @@ packages: - pkg:pypi/jax?source=hash-mapping size: 1580534 timestamp: 1747653718316 -- pypi: git+https://github.com/google/jax-datetime.git#a2e9e7dfc68629915c0ff8b4d91222c452984a90 +- pypi: git+https://github.com/google/jax-datetime.git#e79cec944828e71f9faf790a4725545edff7e3b7 name: jax-datetime version: 0.1.0 requires_dist: @@ -7540,6 +7584,20 @@ packages: - pkg:pypi/jaxlib?source=hash-mapping size: 56004263 timestamp: 1747478692111 +- pypi: https://files.pythonhosted.org/packages/c9/b9/281e10e2d967ea5e481683eaec99f55ac5a61085ee60551c36942ef32bef/jaxtyping-0.3.2-py3-none-any.whl + name: jaxtyping + version: 0.3.2 + sha256: 6a020fd276226ddb5ac4f5725323843dd65e3c7e85c64fd62431e5f738c74e04 + requires_dist: + - wadler-lindig>=0.1.3 + - hippogriffe==0.2.0 ; extra == 'docs' + - mkdocs-include-exclude-files==0.1.0 ; extra == 'docs' + - mkdocs-ipynb==0.1.0 ; extra == 'docs' + - mkdocs-material==9.6.7 ; extra == 'docs' + - mkdocs==1.6.1 ; extra == 'docs' + - mkdocstrings[python]==0.28.3 ; extra == 'docs' + - pymdown-extensions==10.14.3 ; extra == 'docs' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda sha256: 92c4d217e2dc68983f724aa983cca5464dcb929c566627b26a2511159667dba8 md5: a4f4c5dc9b80bc50e0d3dc4e6e8f1bd9 @@ -10010,13 +10068,14 @@ packages: - pkg:pypi/ml-dtypes?source=hash-mapping size: 200130 timestamp: 1736539205286 -- pypi: https://files.pythonhosted.org/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl +- pypi: https://files.pythonhosted.org/packages/6d/38/52f4b808b3fef7f0ef840ee8ff6ce5b5d77381e65425758d515cdd4f5bb5/mypy-1.16.0-cp312-cp312-win_amd64.whl name: mypy - version: 1.15.0 - sha256: 171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22 + version: 1.16.0 + sha256: bd4e1ebe126152a7bbaa4daedd781c90c8f9643c79b9748caa270ad542f12bec requires_dist: - typing-extensions>=4.6.0 - mypy-extensions>=1.0.0 + - pathspec>=0.9.0 - tomli>=1.1.0 ; python_full_version < '3.11' - psutil>=4.0 ; extra == 'dmypy' - setuptools>=50 ; extra == 'mypyc' @@ -10024,13 +10083,14 @@ packages: - pip ; extra == 'install-types' - orjson ; extra == 'faster-cache' requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/6f/5f/b392f7b4f659f5b619ce5994c5c43caab3d80df2296ae54fa888b3d17f5a/mypy-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl name: mypy - version: 1.15.0 - sha256: aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd + version: 1.16.0 + sha256: b4968f14f44c62e2ec4a038c8797a87315be8df7740dc3ee8d3bfe1c6bf5dba8 requires_dist: - typing-extensions>=4.6.0 - mypy-extensions>=1.0.0 + - pathspec>=0.9.0 - tomli>=1.1.0 ; python_full_version < '3.11' - psutil>=4.0 ; extra == 'dmypy' - setuptools>=50 ; extra == 'mypyc' @@ -10038,13 +10098,14 @@ packages: - pip ; extra == 'install-types' - orjson ; extra == 'faster-cache' requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl +- pypi: https://files.pythonhosted.org/packages/70/cf/158e5055e60ca2be23aec54a3010f89dcffd788732634b344fc9cb1e85a0/mypy-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl name: mypy - version: 1.15.0 - sha256: 8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee + version: 1.16.0 + sha256: c5436d11e89a3ad16ce8afe752f0f373ae9620841c50883dc96f8b8805620b13 requires_dist: - typing-extensions>=4.6.0 - mypy-extensions>=1.0.0 + - pathspec>=0.9.0 - tomli>=1.1.0 ; python_full_version < '3.11' - psutil>=4.0 ; extra == 'dmypy' - setuptools>=50 ; extra == 'mypyc' @@ -10052,13 +10113,14 @@ packages: - pip ; extra == 'install-types' - orjson ; extra == 'faster-cache' requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl +- pypi: https://files.pythonhosted.org/packages/94/34/cfff7a56be1609f5d10ef386342ce3494158e4d506516890142007e6472c/mypy-1.16.0-cp312-cp312-macosx_11_0_arm64.whl name: mypy - version: 1.15.0 - sha256: 2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f + version: 1.16.0 + sha256: f2622af30bf01d8fc36466231bdd203d120d7a599a6d88fb22bdcb9dbff84090 requires_dist: - typing-extensions>=4.6.0 - mypy-extensions>=1.0.0 + - pathspec>=0.9.0 - tomli>=1.1.0 ; python_full_version < '3.11' - psutil>=4.0 ; extra == 'dmypy' - setuptools>=50 ; extra == 'mypyc' @@ -11448,6 +11510,11 @@ packages: - pkg:pypi/parso?source=hash-mapping size: 75295 timestamp: 1733271352153 +- pypi: https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl + name: pathspec + version: 0.12.1 + sha256: a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08 + requires_python: '>=3.8' - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.45-hc749103_0.conda sha256: 27c4014f616326240dcce17b5f3baca3953b6bc5f245ceb49c3fa1e6320571eb md5: b90bece58b4c2bf25969b70f3be42d25 @@ -13903,6 +13970,22 @@ packages: purls: [] size: 17873 timestamp: 1743195097269 +- pypi: https://files.pythonhosted.org/packages/d1/9a/937038f3efc70871fb26b0ee6148efcfcfb96643c517c2aaddd7ed07ad76/wadler_lindig-0.1.6-py3-none-any.whl + name: wadler-lindig + version: 0.1.6 + sha256: d707f63994c7d3e1e125e7fb7e196f4adb6f80f4a11beb955c6da937754026a3 + requires_dist: + - numpy ; extra == 'dev' + - pre-commit ; extra == 'dev' + - pytest ; extra == 'dev' + - hippogriffe==0.1.0 ; extra == 'docs' + - mkdocs-include-exclude-files==0.1.0 ; extra == 'docs' + - mkdocs-ipynb==0.1.0 ; extra == 'docs' + - mkdocs-material==9.6.7 ; extra == 'docs' + - mkdocs==1.6.1 ; extra == 'docs' + - mkdocstrings[python]==0.28.3 ; extra == 'docs' + - pymdown-extensions==10.14.3 ; extra == 'docs' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/linux-64/wayland-1.23.1-h3e06ad9_1.conda sha256: 73d809ec8056c2f08e077f9d779d7f4e4c2b625881cad6af303c33dc1562ea01 md5: a37843723437ba75f42c9270ffe800b1 diff --git a/pyproject.toml b/pyproject.toml index 8fe5b7c0e..7016c6854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,6 +129,7 @@ snakeviz = ">=2.2.2,<3" [tool.pixi.pypi-dependencies] gettsim = {path = ".", editable = true} dags = { git = "https://github.com/OpenSourceEconomics/dags.git", branch = "allow-passing-dag-to-concatenate_functions"} +jaxtyping = "*" pdbp = "*" [tool.pixi.target.unix.pypi-dependencies] @@ -158,7 +159,7 @@ jax = { version = ">=0.4.20", extras = ["cpu"] } jaxlib = ">=0.4.20" [tool.pixi.feature.mypy.pypi-dependencies] -mypy = "==1.15.0" +mypy = "~=1.16" types-PyYAML = "*" types-pytz = "*" @@ -208,6 +209,7 @@ extend-ignore = [ "D415", "D416", "D417", + "F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error # Others. "D404", # Do not start module docstring with "This". "RET504", # unnecessary variable assignment before return. @@ -226,7 +228,7 @@ extend-ignore = [ "DTZ001", # use of `datetime.datetime()` without `tzinfo` argument is not allowed "DTZ002", # use of `datetime.datetime.today()` is not allowed "PT012", # `pytest.raises()` block should contain a single simple statement - "PLR5501", # elif not supported by Jax converter + "PLR5501", # elif not supported by vectorization converter for Jax "TRY003", # Avoid specifying long messages outside the exception class "FIX002", # Line contains TODO -- Use stuff from TD area. "PLC2401", # Allow non-ASCII characters in variable names. @@ -252,8 +254,6 @@ extend-ignore = [ "PT011", # pytest raises without match statement "INP001", # implicit namespace packages without init. "E721", # Use `is` and `is not` for type comparisons - "TD003", # Missing issue link -- remove again once we got rid of ad-hoc TODOs. - "ERA001", # Commented out code. # Things ignored to avoid conflicts with ruff-format # ================================================== @@ -298,6 +298,7 @@ disallow_empty_bodies = false [[tool.mypy.overrides]] module = [ + "conftest", "src.ttsim.plot_dag_old", "src.ttsim.stale_code_storage", "src._gettsim_tests.test_docs", @@ -370,6 +371,7 @@ markers = [ "unit: Flag for unit tests which target mainly a single function.", "integration: Flag for integration tests which may comprise of multiple unit tests.", "end_to_end: Flag for tests that cover the whole program.", + "skipif_jax: skip test if backend is jax" ] norecursedirs = ["docs"] testpaths = [ diff --git "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" index 817455df0..ae0aeb052 100644 --- "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" +++ "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" @@ -9,7 +9,7 @@ if TYPE_CHECKING: from types import ModuleType - import numpy + from ttsim.interface_dag_elements.typing import BoolColumn, FloatColumn, IntColumn @agg_by_p_id_function(start_date="2005-01-01", agg_type=AggType.SUM) @@ -59,10 +59,10 @@ def _mean_kindergeld_per_child_ohne_staffelung_m( @policy_function(start_date="2005-01-01", vectorization_strategy="not_required") def kindergeld_zur_bedarfsdeckung_m( kindergeld_pro_kind_m: float, - kindergeld__p_id_empfänger: numpy.ndarray, # int - p_id: numpy.ndarray, # int + kindergeld__p_id_empfänger: IntColumn, + p_id: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: # float +) -> FloatColumn: """Kindergeld that is used to cover the SGB II Regelbedarf of the child. Even though the Kindergeld is paid to the parent (see function @@ -123,11 +123,11 @@ def differenz_kindergeld_kindbedarf_m( @policy_function(start_date="2005-01-01", vectorization_strategy="not_required") def in_anderer_bg_als_kindergeldempfänger( - p_id: numpy.ndarray, # int - kindergeld__p_id_empfänger: numpy.ndarray, # int - bg_id: numpy.ndarray, # int + p_id: IntColumn, + kindergeld__p_id_empfänger: IntColumn, + bg_id: IntColumn, xnp: ModuleType, # Will become necessary for Jax. # noqa: ARG001 -) -> numpy.ndarray: # bool +) -> BoolColumn: """True if the person is in a different Bedarfsgemeinschaft than the Kindergeldempfänger of that person. """ diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index 405318092..9b8a356f4 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -11,6 +11,8 @@ import numpy + from ttsim.interface_dag_elements.typing import BoolColumn, IntColumn + @policy_input() def p_id() -> int: @@ -24,8 +26,8 @@ def hh_id() -> int: @group_creation_function() def ehe_id( - p_id: numpy.ndarray, familie__p_id_ehepartner: numpy.ndarray, xnp: ModuleType -) -> numpy.ndarray: + p_id: IntColumn, familie__p_id_ehepartner: IntColumn, xnp: ModuleType +) -> IntColumn: """Couples that are either married or in a civil union.""" n = xnp.max(p_id) + 1 p_id_ehepartner_or_own_p_id = xnp.where( @@ -41,14 +43,14 @@ def ehe_id( @group_creation_function() def fg_id( - arbeitslosengeld_2__p_id_einstandspartner: numpy.ndarray, - p_id: numpy.ndarray, - hh_id: numpy.ndarray, - alter: numpy.ndarray, - familie__p_id_elternteil_1: numpy.ndarray, - familie__p_id_elternteil_2: numpy.ndarray, + arbeitslosengeld_2__p_id_einstandspartner: IntColumn, + p_id: IntColumn, + hh_id: IntColumn, + alter: IntColumn, + familie__p_id_elternteil_1: IntColumn, + familie__p_id_elternteil_2: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Familiengemeinschaft. Base unit for some transfers. Maximum of two generations, the relevant base unit for Bürgergeld / Arbeitslosengeld @@ -105,15 +107,15 @@ def fg_id( def _assign_parents_fg_id( - fg_id: numpy.ndarray, - p_id: numpy.ndarray, - p_id_elternteil_loc: numpy.ndarray, - hh_id: numpy.ndarray, - alter: numpy.ndarray, - children: numpy.ndarray, - n: numpy.ndarray, + fg_id: IntColumn, + p_id: IntColumn, + p_id_elternteil_loc: IntColumn, + hh_id: IntColumn, + alter: IntColumn, + children: IntColumn, + n: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Return the fg_id of the child's parents.""" # TODO(@MImmesberger): Remove hard-coded number @@ -132,10 +134,10 @@ def _assign_parents_fg_id( @group_creation_function() def bg_id( - fg_id: numpy.ndarray, - p_id: numpy.ndarray, - arbeitslosengeld_2__eigenbedarf_gedeckt: numpy.ndarray, - alter: numpy.ndarray, + fg_id: IntColumn, + p_id: IntColumn, + arbeitslosengeld_2__eigenbedarf_gedeckt: BoolColumn, + alter: IntColumn, xnp: ModuleType, ) -> numpy.ndarray: """Bedarfsgemeinschaft. Relevant unit for Bürgergeld / Arbeitslosengeld 2. @@ -159,10 +161,10 @@ def bg_id( @group_creation_function() def eg_id( - arbeitslosengeld_2__p_id_einstandspartner: numpy.ndarray, - p_id: numpy.ndarray, + arbeitslosengeld_2__p_id_einstandspartner: IntColumn, + p_id: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Einstandsgemeinschaft / Einstandspartner according to SGB II. A couple whose members are deemed to be responsible for each other. @@ -182,11 +184,11 @@ def eg_id( @group_creation_function() def wthh_id( - hh_id: numpy.ndarray, - vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg: numpy.ndarray, - vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: numpy.ndarray, + hh_id: IntColumn, + vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg: BoolColumn, + vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: BoolColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Wohngeldrechtlicher Teilhaushalt. The relevant unit for Wohngeld. Members of a household for whom the Wohngeld @@ -204,11 +206,11 @@ def wthh_id( @group_creation_function() def sn_id( - p_id: numpy.ndarray, - familie__p_id_ehepartner: numpy.ndarray, - einkommensteuer__gemeinsam_veranlagt: numpy.ndarray, + p_id: IntColumn, + familie__p_id_ehepartner: IntColumn, + einkommensteuer__gemeinsam_veranlagt: BoolColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Steuernummer. Spouses filing taxes jointly or individuals.""" n = xnp.max(p_id) + 1 diff --git a/src/_gettsim/kindergeld/kindergeld.py b/src/_gettsim/kindergeld/kindergeld.py index 60e6378ee..ca51e7392 100644 --- a/src/_gettsim/kindergeld/kindergeld.py +++ b/src/_gettsim/kindergeld/kindergeld.py @@ -16,8 +16,7 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - + from ttsim.interface_dag_elements.typing import BoolColumn, IntColumn from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue @@ -123,11 +122,11 @@ def kind_bis_10_mit_kindergeld( @policy_function(vectorization_strategy="not_required") def gleiche_fg_wie_empfänger( - p_id: numpy.ndarray, # int - p_id_empfänger: numpy.ndarray, # int - fg_id: numpy.ndarray, # int + p_id: IntColumn, + p_id_empfänger: IntColumn, + fg_id: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: # bool +) -> BoolColumn: """The child's Kindergeldempfänger is in the same Familiengemeinschaft.""" fg_id_kindergeldempfänger = join( foreign_key=p_id_empfänger, diff --git a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py index af58768be..8cfdf8164 100644 --- a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py +++ b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py @@ -17,8 +17,7 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - + from ttsim.interface_dag_elements.typing import BoolColumn, IntColumn from ttsim.tt_dag_elements import ConsecutiveInt1dLookupTableParamValue, RawParam @@ -67,11 +66,11 @@ def betrag_m( @policy_function(vectorization_strategy="not_required") def elternteil_alleinerziehend( - kindergeld__p_id_empfänger: numpy.ndarray, # int - p_id: numpy.ndarray, # int - familie__alleinerziehend: numpy.ndarray, # bool + kindergeld__p_id_empfänger: IntColumn, + p_id: IntColumn, + familie__alleinerziehend: BoolColumn, xnp: ModuleType, -) -> numpy.ndarray: # bool +) -> BoolColumn: """Check if parent that receives Kindergeld is a single parent. Only single parents receive Kindergeld. @@ -268,11 +267,11 @@ def anspruchshöhe_m_ab_2017_07( @policy_function(start_date="2017-07-01", vectorization_strategy="not_required") def elternteil_mindesteinkommen_erreicht( - kindergeld__p_id_empfänger: numpy.ndarray, # int - p_id: numpy.ndarray, # int - mindesteinkommen_erreicht: numpy.ndarray, # bool + kindergeld__p_id_empfänger: IntColumn, + p_id: IntColumn, + mindesteinkommen_erreicht: BoolColumn, xnp: ModuleType, -) -> numpy.ndarray: # bool +) -> BoolColumn: """Income of Unterhaltsvorschuss recipient above threshold (this variable is defined on child level).""" return join( diff --git a/src/_gettsim/wohngeld/miete.py b/src/_gettsim/wohngeld/miete.py index 544dc304a..91f3da2ae 100644 --- a/src/_gettsim/wohngeld/miete.py +++ b/src/_gettsim/wohngeld/miete.py @@ -17,15 +17,15 @@ if TYPE_CHECKING: from types import ModuleType - import numpy + from jaxtyping import Array, Float, Int @dataclass(frozen=True) class LookupTableBaujahr: - baujahre: numpy.ndarray - lookup_table: numpy.ndarray - lookup_base_to_subtract_cols: numpy.ndarray - lookup_base_to_subtract_rows: numpy.ndarray + baujahre: Int[Array, " n_baujahr_categories"] + lookup_base_to_subtract_cols: Int[Array, " n_baujahr_categories"] + lookup_base_to_subtract_rows: Int[Array, " n_baujahr_categories"] + lookup_table: Float[Array, "n_baujahr_categories max_n_p_indizierung_n_mietstufen"] @param_function( diff --git a/src/_gettsim_tests/test_warn_if_repeated_execution.py b/src/_gettsim_tests/test_warn_if_repeated_execution.py deleted file mode 100644 index 014a1ae0a..000000000 --- a/src/_gettsim_tests/test_warn_if_repeated_execution.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -import pytest - - -def test_warn_when_internal_tests_are_executed_repeatedly(): - from gettsim import test - - test("--collect-only") - - with pytest.warns(UserWarning, match="Repeated execution of the test suite"): - test("--collect-only") diff --git a/src/gettsim/__init__.py b/src/gettsim/__init__.py index 8ca08ac64..7ff33ea07 100644 --- a/src/gettsim/__init__.py +++ b/src/gettsim/__init__.py @@ -11,46 +11,18 @@ __version__ = "unknown" -import itertools -import warnings -from typing import Any +from typing import Literal import pytest from _gettsim_tests import TEST_DIR -# from ttsim.tt_dag_elements import ( -# GroupCreationFunction, -# PolicyFunction, -# group_creation_function, -# plot_dag, -# policy_environment, -# policy_function, -# ) -COUNTER_TEST_EXECUTIONS = itertools.count() - - -def test(*args: Any) -> None: - n_test_executions = next(COUNTER_TEST_EXECUTIONS) - - if n_test_executions == 0: - pytest.main([str(TEST_DIR), "--noconftest", *args]) - else: - warnings.warn( - "Repeated execution of the test suite is not possible. Start a new Python " - "session or restart the kernel in a Jupyter/IPython notebook to re-run the " - "tests.", - stacklevel=2, - ) +def test(backend: Literal["numpy", "jax"] = "numpy") -> None: + pytest.main([str(TEST_DIR), "--backend", backend]) __all__ = [ - # "GroupCreationFunction", - # "PolicyFunction", "__version__", - # "group_creation_function", - # "plot_dag", - # "policy_environment", - # "policy_function", + "test", ] diff --git a/src/ttsim/interface_dag.py b/src/ttsim/interface_dag.py index f7e2776fd..1ed380e6c 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/interface_dag.py @@ -51,7 +51,6 @@ def main( functions=functions, targets=targets, ) - # draw_dag(dag) f = dags.concatenate_functions( dag=dag, functions=functions, diff --git a/src/ttsim/interface_dag_elements/data_converters.py b/src/ttsim/interface_dag_elements/data_converters.py index 0b2627a6a..58956fc68 100644 --- a/src/ttsim/interface_dag_elements/data_converters.py +++ b/src/ttsim/interface_dag_elements/data_converters.py @@ -1,12 +1,13 @@ from __future__ import annotations -from types import ModuleType from typing import TYPE_CHECKING import dags.tree as dt import pandas as pd if TYPE_CHECKING: + from types import ModuleType + from ttsim.interface_dag_elements.typing import ( NestedData, NestedInputsMapper, diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 845485883..a22bf1dc7 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -436,8 +436,8 @@ def non_convertible_objects_in_results_tree( _numeric_types = (int, float, bool, xnp.integer, xnp.floating, xnp.bool_) expected_object_length = len(next(iter(processed_data.values()))) - paths_with_incorrect_types = [] - paths_with_incorrect_length = [] + paths_with_incorrect_types: list[str] = [] + paths_with_incorrect_length: list[str] = [] for path, data in dt.flatten_to_tree_paths(results__tree).items(): if isinstance(data, xnp.ndarray) and len(data) not in { 1, @@ -750,39 +750,3 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A ) return out - - -def fail_if__dtype_not_int( - data: numpy.ndarray, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of integer type.""" - if not xnp.issubdtype(data.dtype, xnp.integer): - raise TypeError( - f"Data in {agg_func} must be of integer type, but is {data.dtype}." - ) - - -def fail_if__dtype_not_numeric_or_datetime( - data: numpy.ndarray, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of numeric or datetime type.""" - if not xnp.issubdtype(data.dtype, (xnp.number, xnp.datetime64)): - raise TypeError( - f"Data in {agg_func} must be of numeric or datetime type, but is {data.dtype}." - ) - - -def fail_if__dtype_not_numeric_or_boolean( - data: numpy.ndarray, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of numeric or boolean type.""" - if not xnp.issubdtype(data.dtype, (xnp.number, xnp.bool_)): - raise TypeError( - f"Data in {agg_func} must be of numeric or boolean type, but is {data.dtype}." - ) diff --git a/src/ttsim/interface_dag_elements/interface_node_objects.py b/src/ttsim/interface_dag_elements/interface_node_objects.py index d08871623..66a02e000 100644 --- a/src/ttsim/interface_dag_elements/interface_node_objects.py +++ b/src/ttsim/interface_dag_elements/interface_node_objects.py @@ -2,7 +2,7 @@ import inspect from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar import dags.tree as dt @@ -54,10 +54,10 @@ def remove_tree_logic( ) -> InterfaceInput: return self - def dummy_callable(self) -> InterfaceFunction: + def dummy_callable(self) -> InterfaceFunction: # type: ignore[type-arg] """Dummy callable for the interface input. Just used for plotting.""" - def dummy() -> self.return_type: + def dummy() -> Any: pass return interface_function( diff --git a/src/ttsim/interface_dag_elements/typing.py b/src/ttsim/interface_dag_elements/typing.py index 753ecdd4c..ccf6aa520 100644 --- a/src/ttsim/interface_dag_elements/typing.py +++ b/src/ttsim/interface_dag_elements/typing.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, NewType, TypeVar +from typing import TYPE_CHECKING, Any, Literal, NewType, TypeAlias, TypeVar if TYPE_CHECKING: # Make these available for import from other modules. import datetime from collections.abc import Mapping - import numpy + from jaxtyping import Array, Bool, Float, Int OrigParamSpec = ( # Header @@ -26,6 +26,10 @@ NestedTargetDict, ) + BoolColumn: TypeAlias = Array[Bool, " n_obs"] # type: ignore[name-defined] + IntColumn: TypeAlias = Array[Int, " n_obs"] # type: ignore[name-defined] + FloatColumn: TypeAlias = Array[Float, " n_obs"] # type: ignore[name-defined] + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Possible leaves of the various trees. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # @@ -39,13 +43,13 @@ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # Tree-like data structures for input, processing, and output; including metadata. # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # - NestedData = Mapping[str, numpy.ndarray | "NestedData"] + NestedData = Mapping[str, BoolColumn | IntColumn | FloatColumn | "NestedData"] """Tree mapping TTSIM paths to 1d arrays.""" - FlatData = Mapping[str, numpy.ndarray | "FlatData"] + FlatData = Mapping[str, BoolColumn | IntColumn | FloatColumn | "FlatData"] """Flattened tree mapping TTSIM paths to 1d arrays.""" NestedInputsMapper = Mapping[str, str | bool | int | float | "NestedInputsMapper"] """Tree mapping TTSIM paths to df columns or constants.""" - QNameData = Mapping[str, numpy.ndarray] + QNameData = Mapping[str, BoolColumn | IntColumn | FloatColumn] """Mapping of qualified name paths to 1d arrays.""" # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # diff --git a/src/ttsim/stale_code_storage.py b/src/ttsim/stale_code_storage.py index f9584b2ce..dde38a004 100644 --- a/src/ttsim/stale_code_storage.py +++ b/src/ttsim/stale_code_storage.py @@ -112,3 +112,71 @@ def test_fail_if_name_of_last_branch_element_is_not_the_functions_leaf_name( ): with pytest.raises(KeyError): name_of_last_branch_element_is_not_the_functions_leaf_name(functions_tree) + + +def check_series_has_expected_type( + series: pd.Series, internal_type: numpy.dtype, dnp: ModuleType +) -> bool: + """Checks whether used series has already expected internal type. + + Currently not used, but might become useful again. + + Parameters + ---------- + series: pandas.Series or pandas.DataFrame or dict of pandas.Series + Data provided by the user. + internal_type: TypeVar + One of the types used by TTSIM. + + Returns + ------- + Bool + + """ + if ( + (internal_type == float) & (is_float_dtype(series)) + or (internal_type == int) & (is_integer_dtype(series)) + or (internal_type == bool) & (is_bool_dtype(series)) + or (internal_type == dnp.datetime64) & (is_datetime64_any_dtype(series)) + ): + out = True + else: + out = False + + return out + + +def fail_if__dtype_not_int( + data: IntColumn, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of integer type.""" + if not xnp.issubdtype(data.dtype, xnp.integer): + raise TypeError( + f"Data in {agg_func} must be of integer type, but is {data.dtype}." + ) + + +def fail_if__dtype_not_numeric_or_datetime( + data: FloatColumn | IntColumn, # | DatetimeColumn, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of numeric or datetime type.""" + if not xnp.issubdtype(data.dtype, (xnp.number, xnp.datetime64)): + raise TypeError( + f"Data in {agg_func} must be of numeric or datetime type, but is {data.dtype}." + ) + + +def fail_if__dtype_not_numeric_or_boolean( + data: FloatColumn | IntColumn | BoolColumn, + agg_func: str, + xnp: ModuleType, +) -> None: + """Check if data is of numeric or boolean type.""" + if not xnp.issubdtype(data.dtype, (xnp.number, xnp.bool_)): + raise TypeError( + f"Data in {agg_func} must be of numeric or boolean type, but is {data.dtype}." + ) diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index 472cb9a20..1352d0286 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -9,12 +9,6 @@ import dags.tree as dt from dags import rename_arguments -from pandas.api.types import ( - is_bool_dtype, - is_datetime64_any_dtype, - is_float_dtype, - is_integer_dtype, -) from ttsim.interface_dag_elements.shared import to_datetime from ttsim.tt_dag_elements.aggregation import ( @@ -40,12 +34,10 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - import pandas as pd - from ttsim.interface_dag_elements.typing import ( DashedISOString, GenericCallable, + IntColumn, UnorderedQNames, ) @@ -129,7 +121,7 @@ def remove_tree_logic( def dummy_callable(self) -> PolicyFunction: """Dummy callable for the interface input. Just used for plotting.""" - def dummy() -> self.data_type: + def dummy(): # type: ignore[no-untyped-def] pass return policy_function( @@ -225,29 +217,10 @@ class ColumnFunction(ColumnObject, Generic[FunArgTypes, ReturnType]): foreign_key_type: FKType = FKType.IRRELEVANT def __post_init__(self) -> None: - self._fail_if_rounding_has_wrong_type(self.rounding_spec) + _fail_if_rounding_has_wrong_type(self.rounding_spec) # Expose the signature of the wrapped function for dependency resolution _frozen_safe_update_wrapper(self, self.function) - def _fail_if_rounding_has_wrong_type( - self, rounding_spec: RoundingSpec | None - ) -> None: - """Check if rounding_spec has the correct type. - - Parameters - ---------- - rounding_spec - The rounding specification to check. - - Raises - ------ - AssertionError - If rounding_spec is not a RoundingSpec or None. - """ - assert isinstance(rounding_spec, RoundingSpec | None), ( - f"rounding_spec must be a RoundingSpec or None, got {rounding_spec}" - ) - def __call__( self, *args: FunArgTypes.args, **kwargs: FunArgTypes.kwargs ) -> ReturnType: @@ -268,6 +241,24 @@ def is_active(self, date: datetime.date) -> bool: return self.start_date <= date <= self.end_date +def _fail_if_rounding_has_wrong_type(rounding_spec: RoundingSpec | None) -> None: + """Check if rounding_spec has the correct type. + + Parameters + ---------- + rounding_spec + The rounding specification to check. + + Raises + ------ + AssertionError + If rounding_spec is not a RoundingSpec or None. + """ + assert isinstance(rounding_spec, RoundingSpec | None), ( + f"rounding_spec must be a RoundingSpec or None, got {rounding_spec}" + ) + + @dataclass(frozen=True) class PolicyFunction(ColumnFunction): # type: ignore[type-arg] """ @@ -390,7 +381,7 @@ def inner(func: GenericCallable) -> PolicyFunction: return inner -def reorder_ids(ids: numpy.ndarray, xnp: ModuleType) -> numpy.ndarray: +def reorder_ids(ids: IntColumn, xnp: ModuleType) -> IntColumn: """Make ID's consecutively numbered. Takes the given IDs and replaces them by consecutive numbers @@ -823,38 +814,6 @@ def _convert_and_validate_dates( return start_date, end_date -def check_series_has_expected_type( - series: pd.Series, internal_type: numpy.dtype, dnp: ModuleType -) -> bool: - """Checks whether used series has already expected internal type. - - Currently not used, but might become useful again. - - Parameters - ---------- - series: pandas.Series or pandas.DataFrame or dict of pandas.Series - Data provided by the user. - internal_type: TypeVar - One of the types used by TTSIM. - - Returns - ------- - Bool - - """ - if ( - (internal_type == float) & (is_float_dtype(series)) - or (internal_type == int) & (is_integer_dtype(series)) - or (internal_type == bool) & (is_bool_dtype(series)) - or (internal_type == dnp.datetime64) & (is_datetime64_any_dtype(series)) - ): - out = True - else: - out = False - - return out - - @dataclass(frozen=True) class ParamFunction(Generic[FunArgTypes, ReturnType]): """ diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index 7cfbbd924..4b9b4f913 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -15,6 +15,8 @@ import datetime from types import ModuleType + from jaxtyping import Array, Float + @dataclass(frozen=True) class ParamObject: @@ -46,7 +48,7 @@ class ParamObject: def dummy_callable(self) -> ParamFunction: """Dummy callable for the policy input. Just used for plotting.""" - def dummy() -> str(type(self).__name__): + def dummy(): # type: ignore[no-untyped-def] pass return param_function( @@ -146,9 +148,9 @@ def __post_init__(self) -> None: class PiecewisePolynomialParamValue: """The parameters expected by piecewise_polynomial""" - thresholds: numpy.ndarray - intercepts: numpy.ndarray - rates: numpy.ndarray + thresholds: Float[Array, " n_segments"] + intercepts: Float[Array, " n_segments"] + rates: Float[Array, " n_segments"] @dataclass(frozen=True) @@ -156,7 +158,7 @@ class ConsecutiveInt1dLookupTableParamValue: """The parameters expected by lookup_table""" base_to_subtract: int - values_to_look_up: numpy.ndarray + values_to_look_up: Float[Array, " n_values_to_look_up"] @dataclass(frozen=True) @@ -165,7 +167,7 @@ class ConsecutiveInt2dLookupTableParamValue: base_to_subtract_rows: int base_to_subtract_cols: int - values_to_look_up: numpy.ndarray + values_to_look_up: Float[Array, "n_rows n_cols"] def get_consecutive_int_1d_lookup_table_param_value( diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 200dcc493..8667978cc 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -5,10 +5,12 @@ import numpy +from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue + if TYPE_CHECKING: from types import ModuleType -from ttsim.tt_dag_elements.param_objects import PiecewisePolynomialParamValue + from jaxtyping import Array, Float FUNC_TYPES = Literal[ "piecewise_constant", @@ -49,11 +51,11 @@ class RatesOptions: def piecewise_polynomial( - x: numpy.ndarray, + x: Float[Array, " n_pp_values"], parameters: PiecewisePolynomialParamValue, xnp: ModuleType, - rates_multiplier: numpy.ndarray | float = 1.0, -) -> numpy.ndarray: + rates_multiplier: Float[Array, " n_segments"] | float = 1.0, +) -> Float[Array, " n_pp_values"]: """Calculate value of the piecewise function at `x`. If the first interval begins at -inf the polynomial of that interval can only have slope of 0. Requesting a value outside of the provided thresholds will lead to undefined behaviour. @@ -62,13 +64,8 @@ def piecewise_polynomial( ---------- x: Array with values at which the piecewise polynomial is to be calculated. - thresholds: - A one-dimensional array containing the thresholds for all intervals. - coefficients: - A two-dimensional array where columns are interval sections and rows - correspond to the coefficient of the nth polynomial. - intercepts: - The intercepts at the lower threshold of each interval. + parameters: + The parameters of the piecewise polynomial. xnp: The backend module to use for calculations. rates_multiplier: @@ -84,13 +81,13 @@ def piecewise_polynomial( # Get interval of requested value selected_bin = xnp.searchsorted(parameters.thresholds, x, side="right") - 1 coefficients = parameters.rates[:, selected_bin].T - # Calculate distance from X to lower threshold + # Calculate distance from x to lower threshold increment_to_calc = xnp.where( parameters.thresholds[selected_bin] == -xnp.inf, 0, x - parameters.thresholds[selected_bin], ) - # Evaluate polynomial at X + # Evaluate polynomial at x out = ( parameters.intercepts[selected_bin] + ( @@ -275,30 +272,15 @@ def _check_and_get_rates( def _check_and_get_intercepts( leaf_name: str, parameter_dict: dict[int, dict[str, float]], - lower_thresholds: numpy.ndarray, - upper_thresholds: numpy.ndarray, - rates: numpy.ndarray, + lower_thresholds: Float[Array, " n_segments"], + upper_thresholds: Float[Array, " n_segments"], + rates: Float[Array, " n_segments"], xnp: ModuleType, -) -> numpy.ndarray: +) -> Float[Array, " n_segments"]: """Check and transfer raw intercept data. If necessary create intercepts. Transfer and check raw rates data, which needs to be specified in a piecewise_polynomial layout in the yaml file. - - Parameters - ---------- - parameter_dict - leaf_name - lower_thresholds - upper_thresholds - rates - keys - xnp : ModuleType - The numpy module to use for calculations. - - Returns - ------- - """ keys = sorted(parameter_dict.keys()) intercepts = numpy.zeros(len(keys)) @@ -332,12 +314,12 @@ def _check_and_get_intercepts( def _create_intercepts( - lower_thresholds: numpy.ndarray, - upper_thresholds: numpy.ndarray, - rates: numpy.ndarray, - intercept_at_lowest_threshold: numpy.ndarray, + lower_thresholds: Float[Array, " n_segments"], + upper_thresholds: Float[Array, " n_segments"], + rates: Float[Array, " n_segments"], + intercept_at_lowest_threshold: float, xnp: ModuleType, -) -> numpy.ndarray: +) -> Float[Array, " n_segments"]: """Create intercepts from raw data. Parameters @@ -376,32 +358,30 @@ def _create_intercepts( def _calculate_one_intercept( x: float, - lower_thresholds: numpy.ndarray, - upper_thresholds: numpy.ndarray, - rates: numpy.ndarray, - intercepts: numpy.ndarray, + lower_thresholds: Float[Array, " n_segments"], + upper_thresholds: Float[Array, " n_segments"], + rates: Float[Array, " n_segments"], + intercepts: Float[Array, " n_segments"], ) -> float: - """Calculate the intercepts from the raw data. + """Calculate the intercept for the segment `x` lies in. Parameters ---------- - x : float + x The value that the function is applied to. - lower_thresholds : numpy.ndarray + lower_thresholds A one-dimensional array containing lower thresholds of each interval. - upper_thresholds : numpy.ndarray + upper_thresholds A one-dimensional array containing upper thresholds each interval. - rates : numpy.ndarray + rates A two-dimensional array where columns are interval sections and rows correspond to the nth polynomial. - intercepts : numpy.ndarray + intercepts The intercepts at the lower threshold of each interval. - xnp : ModuleType - The numpy module to use for calculations. Returns ------- - out : float + out The value of `x` under the piecewise function. """ diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt_dag_elements/rounding.py index 677c65d65..c8f142f3d 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt_dag_elements/rounding.py @@ -8,7 +8,7 @@ from collections.abc import Callable from types import ModuleType - import numpy + from ttsim.interface_dag_elements.typing import FloatColumn ROUNDING_DIRECTION = Literal["up", "down", "nearest"] @@ -38,8 +38,8 @@ def __post_init__(self) -> None: ) def apply_rounding( - self, func: Callable[P, numpy.ndarray], xnp: ModuleType - ) -> Callable[P, numpy.ndarray]: + self, func: Callable[P, FloatColumn], xnp: ModuleType + ) -> Callable[P, FloatColumn]: """Decorator to round the output of a function. Parameters @@ -55,7 +55,7 @@ def apply_rounding( """ @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> numpy.ndarray: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> FloatColumn: out = func(*args, **kwargs) if self.direction == "up": diff --git a/src/ttsim/tt_dag_elements/shared.py b/src/ttsim/tt_dag_elements/shared.py index 66388d267..a6de2e2d9 100644 --- a/src/ttsim/tt_dag_elements/shared.py +++ b/src/ttsim/tt_dag_elements/shared.py @@ -5,31 +5,31 @@ if TYPE_CHECKING: from types import ModuleType - import numpy + from ttsim.interface_dag_elements.typing import BoolColumn, FloatColumn, IntColumn def join( - foreign_key: numpy.ndarray, - primary_key: numpy.ndarray, - target: numpy.ndarray, + foreign_key: IntColumn, + primary_key: IntColumn, + target: BoolColumn | IntColumn | FloatColumn, value_if_foreign_key_is_missing: float | bool, xnp: ModuleType, -) -> numpy.ndarray: +) -> BoolColumn | IntColumn | FloatColumn: """ Given a foreign key, find the corresponding primary key, and return the target at the same index as the primary key. When using Jax, does not work on String Arrays. Parameters ---------- - foreign_key : numpy.ndarray[Key] + foreign_key: The foreign keys. - primary_key : numpy.ndarray[Key] + primary_key: The primary keys. - target : numpy.ndarray[Out] - The targets in the same order as the primary keys. - value_if_foreign_key_is_missing : Out + target: + The targets, in the same order as the primary keys. + value_if_foreign_key_is_missing: The value to return if no matching primary key is found. - xnp : ModuleType + xnp: The numpy module to use for calculations. Returns diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt_dag_elements/vectorization.py index e736d04b6..ad60a5b2e 100644 --- a/src/ttsim/tt_dag_elements/vectorization.py +++ b/src/ttsim/tt_dag_elements/vectorization.py @@ -14,9 +14,6 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - -if TYPE_CHECKING: from ttsim.interface_dag_elements.typing import GenericCallable BACKEND_TO_MODULE = {"jax": "jax.numpy", "numpy": "numpy"} diff --git a/tests/ttsim/mettsim/group_by_ids.py b/tests/ttsim/mettsim/group_by_ids.py index 1b30ee58e..8b458d240 100644 --- a/tests/ttsim/mettsim/group_by_ids.py +++ b/tests/ttsim/mettsim/group_by_ids.py @@ -2,17 +2,16 @@ from typing import TYPE_CHECKING +from ttsim.tt_dag_elements import group_creation_function + if TYPE_CHECKING: from types import ModuleType - import numpy -from ttsim.tt_dag_elements import group_creation_function + from ttsim.interface_dag_elements.typing import IntColumn @group_creation_function() -def sp_id( - p_id: numpy.ndarray, p_id_spouse: numpy.ndarray, xnp: ModuleType -) -> numpy.ndarray: +def sp_id(p_id: IntColumn, p_id_spouse: IntColumn, xnp: ModuleType) -> IntColumn: """ Compute the spouse (sp) group ID for each person. """ @@ -23,13 +22,13 @@ def sp_id( @group_creation_function() def fam_id( - p_id_spouse: numpy.ndarray, - p_id: numpy.ndarray, - age: numpy.ndarray, - p_id_parent_1: numpy.ndarray, - p_id_parent_2: numpy.ndarray, + p_id_spouse: IntColumn, + p_id: IntColumn, + age: IntColumn, + p_id_parent_1: IntColumn, + p_id_parent_2: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """ Compute the family ID for each person. """ @@ -75,14 +74,14 @@ def fam_id( def _assign_parents_fam_id( - fam_id: numpy.ndarray, - p_id: numpy.ndarray, - p_id_parent_loc: numpy.ndarray, - age: numpy.ndarray, - children: numpy.ndarray, - n: numpy.ndarray, + fam_id: IntColumn, + p_id: IntColumn, + p_id_parent_loc: IntColumn, + age: IntColumn, + children: IntColumn, + n: int, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Return the fam_id of the child's parents.""" return xnp.where( diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 392663a47..ee5f27751 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -44,6 +44,7 @@ from ttsim.interface_dag_elements.typing import ( FlatColumnObjectsParamFunctions, FlatOrigParamSpecs, + IntColumn, NestedPolicyEnvironment, OrigParamSpec, ) @@ -142,7 +143,7 @@ def some_x(x): @policy_function() -def some_policy_func_returning_array_of_length_2(xnp: ModuleType) -> numpy.ndarray: +def some_policy_func_returning_array_of_length_2(xnp: ModuleType) -> IntColumn: return xnp.array([1, 2]) @@ -869,7 +870,7 @@ def test_fail_if_p_id_is_not_unique_via_main(minimal_input_data, backend): )["fail_if__input_data_tree_is_invalid"] -def test_fail_if_root_nodes_are_missing_via_main(minimal_input_data): +def test_fail_if_root_nodes_are_missing_via_main(minimal_input_data, backend): def b(a): return a @@ -891,7 +892,7 @@ def c(b): "policy_environment": policy_environment, "targets__tree": {"c": None}, "rounding": False, - # "jit": jit, + "backend": backend, }, targets=["results__tree", "fail_if__root_nodes_are_missing"], ) diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index cc58b22a3..6dbda6fff 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -30,7 +30,7 @@ ) if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import NestedPolicyEnvironment + from ttsim.interface_dag_elements.typing import IntColumn, NestedPolicyEnvironment @policy_input() @@ -411,7 +411,7 @@ def test_create_agg_by_group_functions( )["results__tree"] -def test_output_is_tree(minimal_input_data): +def test_output_is_tree(minimal_input_data, backend, xnp): policy_environment = { "p_id": p_id, "module": {"some_func": some_func}, @@ -423,14 +423,14 @@ def test_output_is_tree(minimal_input_data): "policy_environment": policy_environment, "targets__tree": {"module": {"some_func": None}}, "rounding": False, - "backend": "numpy", + "backend": backend, }, targets=["results__tree"], )["results__tree"] assert isinstance(out, dict) assert "some_func" in out["module"] - assert isinstance(out["module"]["some_func"], numpy.ndarray) + assert isinstance(out["module"]["some_func"], xnp.ndarray) def test_params_target_is_allowed(minimal_input_data): @@ -467,13 +467,15 @@ def test_params_target_is_allowed(minimal_input_data): assert out["some_param"] == 1 -def test_function_without_data_dependency_is_not_mistaken_for_data(minimal_input_data): +def test_function_without_data_dependency_is_not_mistaken_for_data( + minimal_input_data, backend, xnp +): @policy_function(leaf_name="a", vectorization_strategy="not_required") - def a() -> numpy.ndarray: - return numpy.array(minimal_input_data["p_id"]) + def a() -> IntColumn: + return xnp.array(minimal_input_data["p_id"]) @policy_function(leaf_name="b") - def b(a): + def b(a: int) -> int: return a policy_environment = { @@ -486,12 +488,12 @@ def b(a): "policy_environment": policy_environment, "targets__tree": {"b": None}, "rounding": False, - "backend": "numpy", + "backend": backend, }, targets=["results__tree"], )["results__tree"] numpy.testing.assert_array_almost_equal( - results__tree["b"], numpy.array(minimal_input_data["p_id"]) + results__tree["b"], xnp.array(minimal_input_data["p_id"]) ) diff --git a/tests/ttsim/test_warnings.py b/tests/ttsim/test_warnings.py index 68e07fd14..0fc194d74 100644 --- a/tests/ttsim/test_warnings.py +++ b/tests/ttsim/test_warnings.py @@ -19,7 +19,7 @@ def another_func(some_func: int) -> int: return some_func -def test_warn_if_functions_and_data_columns_overlap(): +def test_warn_if_functions_and_data_columns_overlap(backend): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") main( @@ -34,7 +34,7 @@ def test_warn_if_functions_and_data_columns_overlap(): }, "targets__tree": {"some_target": None}, "rounding": False, - # "jit": jit, + "backend": backend, }, targets=["warn_if__functions_and_data_columns_overlap"], ) @@ -44,7 +44,7 @@ def test_warn_if_functions_and_data_columns_overlap(): assert w[0].category.__name__ == "FunctionsAndDataColumnsOverlapWarning" -def test_warn_if_functions_and_columns_overlap_no_warning_if_no_overlap(): +def test_warn_if_functions_and_columns_overlap_no_warning_if_no_overlap(backend): with warnings.catch_warnings(): warnings.filterwarnings( "error", category=warn_if.FunctionsAndDataColumnsOverlapWarning @@ -58,7 +58,7 @@ def test_warn_if_functions_and_columns_overlap_no_warning_if_no_overlap(): "policy_environment": {"some_func": some_func}, "targets__tree": {"some_func": None}, "rounding": False, - # "jit": jit, + "backend": backend, }, targets=["warn_if__functions_and_data_columns_overlap"], ) diff --git a/tests/ttsim/tt_dag_elements/test_shared.py b/tests/ttsim/tt_dag_elements/test_shared.py index 90a7a303a..136d85542 100644 --- a/tests/ttsim/tt_dag_elements/test_shared.py +++ b/tests/ttsim/tt_dag_elements/test_shared.py @@ -2,14 +2,16 @@ from typing import TYPE_CHECKING -if TYPE_CHECKING: - from types import ModuleType - import numpy import pytest from ttsim.tt_dag_elements import join +if TYPE_CHECKING: + from types import ModuleType + + from ttsim.interface_dag_elements.typing import IntColumn + @pytest.mark.parametrize( "foreign_key, primary_key, target, value_if_foreign_key_is_missing, expected", @@ -45,18 +47,18 @@ ], ) def test_join( - foreign_key: numpy.ndarray, - primary_key: numpy.ndarray, - target: numpy.ndarray, + foreign_key: IntColumn, + primary_key: IntColumn, + target: IntColumn, value_if_foreign_key_is_missing: int, - expected: numpy.ndarray, + expected: IntColumn, xnp: ModuleType, ): assert numpy.array_equal( join( - foreign_key=xnp.array(foreign_key), - primary_key=xnp.array(primary_key), - target=xnp.array(target), + foreign_key=xnp.asarray(foreign_key), + primary_key=xnp.asarray(primary_key), + target=xnp.asarray(target), value_if_foreign_key_is_missing=value_if_foreign_key_is_missing, xnp=xnp, ), diff --git a/tests/ttsim/tt_dag_elements/test_vectorization.py b/tests/ttsim/tt_dag_elements/test_vectorization.py index 68c08c52e..d84470469 100644 --- a/tests/ttsim/tt_dag_elements/test_vectorization.py +++ b/tests/ttsim/tt_dag_elements/test_vectorization.py @@ -37,8 +37,9 @@ ) if TYPE_CHECKING: - from collections.abc import Callable + from types import ModuleType + from ttsim.interface_dag_elements.typing import IntColumn # ====================================================================================== # String comparison @@ -648,20 +649,25 @@ def scalar_func(x: int) -> int: @policy_function(vectorization_strategy="not_required") -def already_vectorized_func(x: numpy.ndarray) -> numpy.ndarray: # type: ignore[type-arg] - return numpy.where(x < 0, 0, x * 2) - - -@pytest.mark.parametrize( - "vectorized_function", - [ - vectorize_function( - scalar_func, vectorization_strategy="loop", backend="numpy", xnp=numpy - ), - already_vectorized_func, - ], -) -def test_vectorize_func(vectorized_function: Callable): # type: ignore[type-arg] +def already_vectorized_func(x: IntColumn, xnp: ModuleType) -> IntColumn: + return xnp.where(x < 0, 0, x * 2) + + +def test_loop_vectorize_scalar_func(backend, xnp): + fun = vectorize_function( + scalar_func, vectorization_strategy="loop", backend=backend, xnp=numpy + ) + assert numpy.array_equal(fun(xnp.array([-1, 0, 2, 3])), xnp.array([0, 0, 4, 6])) + + +def test_vectorize_scalar_func(backend, xnp): + fun = vectorize_function( + scalar_func, vectorization_strategy="vectorize", backend=backend, xnp=numpy + ) + assert numpy.array_equal(fun(xnp.array([-1, 0, 2, 3])), xnp.array([0, 0, 4, 6])) + + +def test_already_vectorized_func(xnp): assert numpy.array_equal( - vectorized_function(numpy.array([-1, 0, 2, 3])), numpy.array([0, 0, 4, 6]) + already_vectorized_func(xnp.array([-1, 0, 2, 3]), xnp), xnp.array([0, 0, 4, 6]) ) From 1e8f55829a2928b2c7de79188dfd4b3a5d1e245d Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 10:54:20 +0200 Subject: [PATCH 18/25] Rename, we have many more groups than 'hh' now... --- src/ttsim/tt_dag_elements/aggregation_jax.py | 30 +++++++------- .../tt_dag_elements/aggregation_numpy.py | 40 +++++++++---------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/ttsim/tt_dag_elements/aggregation_jax.py b/src/ttsim/tt_dag_elements/aggregation_jax.py index 597974c2f..e3e24cd38 100644 --- a/src/ttsim/tt_dag_elements/aggregation_jax.py +++ b/src/ttsim/tt_dag_elements/aggregation_jax.py @@ -16,10 +16,10 @@ def grouped_count(group_id: jnp.ndarray, num_segments: int) -> jnp.ndarray: - out_on_hh = segment_sum( + out_grouped = segment_sum( data=jnp.ones(len(group_id)), segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_sum( @@ -28,41 +28,41 @@ def grouped_sum( if column.dtype in ["bool"]: column = column.astype(int) - out_on_hh = segment_sum( + out_grouped = segment_sum( data=column, segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_mean( column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int ) -> jnp.ndarray: - sum_on_hh = segment_sum( + sum_grouped = segment_sum( data=column, segment_ids=group_id, num_segments=num_segments ) sizes = segment_sum( data=jnp.ones(len(column)), segment_ids=group_id, num_segments=num_segments ) - mean_on_hh = sum_on_hh / sizes - return mean_on_hh[group_id] + mean_grouped = sum_grouped / sizes + return mean_grouped[group_id] def grouped_max( column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int ) -> jnp.ndarray: - out_on_hh = segment_max( + out_grouped = segment_max( data=column, segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_min( column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int ) -> jnp.ndarray: - out_on_hh = segment_min( + out_grouped = segment_min( data=column, segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_any( @@ -74,10 +74,10 @@ def grouped_any( else: my_col = column - out_on_hh = segment_max( + out_grouped = segment_max( data=my_col, segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_all( @@ -87,10 +87,10 @@ def grouped_all( if jnp.issubdtype(column.dtype, jnp.integer): column = column.astype("bool") - out_on_hh = segment_min( + out_grouped = segment_min( data=column, segment_ids=group_id, num_segments=num_segments ) - return out_on_hh[group_id] + return out_grouped[group_id] def count_by_p_id( diff --git a/src/ttsim/tt_dag_elements/aggregation_numpy.py b/src/ttsim/tt_dag_elements/aggregation_numpy.py index f8886b0d0..0de52732d 100644 --- a/src/ttsim/tt_dag_elements/aggregation_numpy.py +++ b/src/ttsim/tt_dag_elements/aggregation_numpy.py @@ -6,11 +6,11 @@ def grouped_count(group_id: numpy.ndarray) -> numpy.ndarray: fail_if__dtype_not_int(group_id, agg_func="grouped_count") - out_on_hh = npg.aggregate( + out_grouped = npg.aggregate( group_id, numpy.ones(len(group_id), dtype=int), func="sum", fill_value=0 ) - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_sum(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: @@ -18,20 +18,20 @@ def grouped_sum(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray fail_if__dtype_not_numeric_or_boolean(column, agg_func="grouped_sum") if column.dtype == bool: column = column.astype(int) - out_on_hh = npg.aggregate(group_id, column, func="sum", fill_value=0) + out_grouped = npg.aggregate(group_id, column, func="sum", fill_value=0) # Expand to individual level - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_mean(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: fail_if__dtype_not_int(group_id, agg_func="grouped_mean") fail_if__dtype_not_float(column, agg_func="grouped_mean") - out_on_hh = npg.aggregate(group_id, column, func="mean", fill_value=0) + out_grouped = npg.aggregate(group_id, column, func="mean", fill_value=0) # Expand to individual level - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_max(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: @@ -44,18 +44,18 @@ def grouped_max(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray dtype = column.dtype float_col = column.astype("datetime64[D]").astype(int) - out_on_hh_float = npg.aggregate(group_id, float_col, func="max") + out_grouped_float = npg.aggregate(group_id, float_col, func="max") - out_on_hh = out_on_hh_float.astype("datetime64[D]").astype(dtype) + out_grouped = out_grouped_float.astype("datetime64[D]").astype(dtype) # Expand to individual level - out = out_on_hh[group_id] + out = out_grouped[group_id] else: - out_on_hh = npg.aggregate(group_id, column, func="max") + out_grouped = npg.aggregate(group_id, column, func="max") # Expand to individual level - out = out_on_hh[group_id] + out = out_grouped[group_id] return out @@ -72,18 +72,18 @@ def grouped_min(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray dtype = column.dtype float_col = column.astype("datetime64[D]").astype(int) - out_on_hh_float = npg.aggregate(group_id, float_col, func="min") + out_grouped_float = npg.aggregate(group_id, float_col, func="min") - out_on_hh = out_on_hh_float.astype("datetime64[D]").astype(dtype) + out_grouped = out_grouped_float.astype("datetime64[D]").astype(dtype) # Expand to individual level - out = out_on_hh[group_id] + out = out_grouped[group_id] else: - out_on_hh = npg.aggregate(group_id, column, func="min") + out_grouped = npg.aggregate(group_id, column, func="min") # Expand to individual level - out = out_on_hh[group_id] + out = out_grouped[group_id] return out @@ -91,20 +91,20 @@ def grouped_any(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray fail_if__dtype_not_int(group_id, agg_func="grouped_any") fail_if__dtype_not_boolean_or_int(column, agg_func="grouped_any") - out_on_hh = npg.aggregate(group_id, column, func="any", fill_value=0) + out_grouped = npg.aggregate(group_id, column, func="any", fill_value=0) # Expand to individual level - return out_on_hh[group_id] + return out_grouped[group_id] def grouped_all(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: fail_if__dtype_not_int(group_id, agg_func="grouped_all") fail_if__dtype_not_boolean_or_int(column, agg_func="grouped_all") - out_on_hh = npg.aggregate(group_id, column, func="all", fill_value=0) + out_grouped = npg.aggregate(group_id, column, func="all", fill_value=0) # Expand to individual level - return out_on_hh[group_id] + return out_grouped[group_id] def count_by_p_id( From e3d933d94112a666733f3a256e5a2adaf66cffb8 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 12:25:25 +0200 Subject: [PATCH 19/25] Got rid of remaining numpy.ndarray type hints. --- pixi.lock | 4 +- pyproject.toml | 9 - src/_gettsim/ids.py | 4 +- src/ttsim/stale_code_storage.py | 41 +---- src/ttsim/tt_dag_elements/aggregation.py | 96 +++++----- src/ttsim/tt_dag_elements/aggregation_jax.py | 87 +++++---- .../tt_dag_elements/aggregation_numpy.py | 167 ++++++++++-------- .../tt_dag_elements/piecewise_polynomial.py | 8 +- tests/ttsim/test_failures.py | 20 +-- .../test_aggregation_functions.py | 13 +- .../tt_dag_elements/test_ttsim_objects.py | 13 +- 11 files changed, 218 insertions(+), 244 deletions(-) diff --git a/pixi.lock b/pixi.lock index e316a17b6..6283f1e80 100644 --- a/pixi.lock +++ b/pixi.lock @@ -6622,8 +6622,8 @@ packages: timestamp: 1694400856979 - pypi: ./ name: gettsim - version: 0.7.1.dev452+gd0a147b6.d20250613 - sha256: b92f9d48c586dde9df09c5f901d6c28be6eb01fb808b2a79b2b3065bc3f98f99 + version: 0.7.1.dev454+g1e8f5582.d20250613 + sha256: 1811ce4f8ffd372f6c4510a2c2d43eb3499ed35c31f584814c33d33774ab0b5f requires_dist: - ipywidgets - networkx diff --git a/pyproject.toml b/pyproject.toml index 7016c6854..161234343 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -314,15 +314,6 @@ disable_error_code = [ "no-untyped-def", # All tests return None, don't clutter source code. ] -[[tool.mypy.overrides]] -module = [ - "tests.ttsim.test_failures", -] -disable_error_code = [ - "misc", # Happens when constructing param dictionaries on the fly. -] - - [[tool.mypy.overrides]] module = [ "src.ttsim.tt_dag_elements.aggregation_numpy", diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index 9b8a356f4..df5a69a9e 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -9,8 +9,6 @@ if TYPE_CHECKING: from types import ModuleType - import numpy - from ttsim.interface_dag_elements.typing import BoolColumn, IntColumn @@ -139,7 +137,7 @@ def bg_id( arbeitslosengeld_2__eigenbedarf_gedeckt: BoolColumn, alter: IntColumn, xnp: ModuleType, -) -> numpy.ndarray: +) -> IntColumn: """Bedarfsgemeinschaft. Relevant unit for Bürgergeld / Arbeitslosengeld 2. Familiengemeinschaft except for children who have enough income to fend for diff --git a/src/ttsim/stale_code_storage.py b/src/ttsim/stale_code_storage.py index dde38a004..bdda345a6 100644 --- a/src/ttsim/stale_code_storage.py +++ b/src/ttsim/stale_code_storage.py @@ -24,7 +24,10 @@ | int | float | bool - | numpy.ndarray + | BoolColumn + | IntColumn + | FloatColumn + | DatetimeColumn | "NestedAnyTTSIMObject", ] NestedAny = Mapping[str, Any | "NestedAnyTTSIMObject"] @@ -144,39 +147,3 @@ def check_series_has_expected_type( out = False return out - - -def fail_if__dtype_not_int( - data: IntColumn, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of integer type.""" - if not xnp.issubdtype(data.dtype, xnp.integer): - raise TypeError( - f"Data in {agg_func} must be of integer type, but is {data.dtype}." - ) - - -def fail_if__dtype_not_numeric_or_datetime( - data: FloatColumn | IntColumn, # | DatetimeColumn, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of numeric or datetime type.""" - if not xnp.issubdtype(data.dtype, (xnp.number, xnp.datetime64)): - raise TypeError( - f"Data in {agg_func} must be of numeric or datetime type, but is {data.dtype}." - ) - - -def fail_if__dtype_not_numeric_or_boolean( - data: FloatColumn | IntColumn | BoolColumn, - agg_func: str, - xnp: ModuleType, -) -> None: - """Check if data is of numeric or boolean type.""" - if not xnp.issubdtype(data.dtype, (xnp.number, xnp.bool_)): - raise TypeError( - f"Data in {agg_func} must be of numeric or boolean type, but is {data.dtype}." - ) diff --git a/src/ttsim/tt_dag_elements/aggregation.py b/src/ttsim/tt_dag_elements/aggregation.py index 9f0ecd184..92f18bcb2 100644 --- a/src/ttsim/tt_dag_elements/aggregation.py +++ b/src/ttsim/tt_dag_elements/aggregation.py @@ -6,7 +6,7 @@ from ttsim.tt_dag_elements import aggregation_jax, aggregation_numpy if TYPE_CHECKING: - import numpy + from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn class AggType(StrEnum): @@ -26,8 +26,8 @@ class AggType(StrEnum): # The signature of the functions must be the same in both modules, except that all JAX # functions have the additional `num_segments` argument. def grouped_count( - group_id: numpy.ndarray, num_segments: int, backend: Literal["numpy", "jax"] -) -> numpy.ndarray: + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"] +) -> IntColumn: if backend == "numpy": return aggregation_numpy.grouped_count(group_id) else: @@ -35,11 +35,11 @@ def grouped_count( def grouped_sum( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_sum(column, group_id) else: @@ -47,11 +47,11 @@ def grouped_sum( def grouped_mean( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn: if backend == "numpy": return aggregation_numpy.grouped_mean(column, group_id) else: @@ -59,11 +59,11 @@ def grouped_mean( def grouped_max( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: FloatColumn | IntColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_max(column, group_id) else: @@ -71,11 +71,11 @@ def grouped_max( def grouped_min( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: FloatColumn | IntColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_min(column, group_id) else: @@ -83,11 +83,11 @@ def grouped_min( def grouped_any( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: BoolColumn | IntColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> BoolColumn: if backend == "numpy": return aggregation_numpy.grouped_any(column, group_id) else: @@ -95,11 +95,11 @@ def grouped_any( def grouped_all( - column: numpy.ndarray, - group_id: numpy.ndarray, + column: BoolColumn | IntColumn, + group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> BoolColumn: if backend == "numpy": return aggregation_numpy.grouped_all(column, group_id) else: @@ -107,11 +107,11 @@ def grouped_all( def count_by_p_id( - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> IntColumn: if backend == "numpy": return aggregation_numpy.count_by_p_id(p_id_to_aggregate_by, p_id_to_store_by) else: @@ -121,12 +121,12 @@ def count_by_p_id( def sum_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.sum_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by @@ -138,12 +138,12 @@ def sum_by_p_id( def mean_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn: if backend == "numpy": return aggregation_numpy.mean_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by @@ -155,12 +155,12 @@ def mean_by_p_id( def max_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.max_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by @@ -172,12 +172,12 @@ def max_by_p_id( def min_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.min_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by @@ -189,12 +189,12 @@ def min_by_p_id( def any_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> BoolColumn: if backend == "numpy": return aggregation_numpy.any_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by @@ -206,12 +206,12 @@ def any_by_p_id( def all_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, backend: Literal["numpy", "jax"], -) -> numpy.ndarray: +) -> BoolColumn: if backend == "numpy": return aggregation_numpy.all_by_p_id( column, p_id_to_aggregate_by, p_id_to_store_by diff --git a/src/ttsim/tt_dag_elements/aggregation_jax.py b/src/ttsim/tt_dag_elements/aggregation_jax.py index e3e24cd38..1044f488c 100644 --- a/src/ttsim/tt_dag_elements/aggregation_jax.py +++ b/src/ttsim/tt_dag_elements/aggregation_jax.py @@ -9,13 +9,10 @@ pass if TYPE_CHECKING: - try: - import jax.numpy as jnp - except ImportError: - import numpy as jnp # noqa: TC004 + from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn -def grouped_count(group_id: jnp.ndarray, num_segments: int) -> jnp.ndarray: +def grouped_count(group_id: IntColumn, num_segments: int) -> jnp.ndarray: out_grouped = segment_sum( data=jnp.ones(len(group_id)), segment_ids=group_id, num_segments=num_segments ) @@ -23,8 +20,8 @@ def grouped_count(group_id: jnp.ndarray, num_segments: int) -> jnp.ndarray: def grouped_sum( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn, num_segments: int +) -> FloatColumn | IntColumn: if column.dtype in ["bool"]: column = column.astype(int) @@ -35,8 +32,10 @@ def grouped_sum( def grouped_mean( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn, num_segments: int +) -> FloatColumn: + if column.dtype in ["bool"]: + column = column.astype(int) sum_grouped = segment_sum( data=column, segment_ids=group_id, num_segments=num_segments ) @@ -48,8 +47,8 @@ def grouped_mean( def grouped_max( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: FloatColumn | IntColumn, group_id: IntColumn, num_segments: int +) -> FloatColumn | IntColumn: out_grouped = segment_max( data=column, segment_ids=group_id, num_segments=num_segments ) @@ -57,8 +56,8 @@ def grouped_max( def grouped_min( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: FloatColumn | IntColumn, group_id: IntColumn, num_segments: int +) -> FloatColumn | IntColumn: out_grouped = segment_min( data=column, segment_ids=group_id, num_segments=num_segments ) @@ -66,8 +65,8 @@ def grouped_min( def grouped_any( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: BoolColumn | IntColumn, group_id: IntColumn, num_segments: int +) -> BoolColumn: # Convert to boolean if necessary if jnp.issubdtype(column.dtype, jnp.integer): my_col = column.astype("bool") @@ -81,8 +80,8 @@ def grouped_any( def grouped_all( - column: jnp.ndarray, group_id: jnp.ndarray, num_segments: int -) -> jnp.ndarray: + column: BoolColumn | IntColumn, group_id: IntColumn, num_segments: int +) -> BoolColumn: # Convert to boolean if necessary if jnp.issubdtype(column.dtype, jnp.integer): column = column.astype("bool") @@ -94,19 +93,19 @@ def grouped_all( def count_by_p_id( - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> IntColumn: raise NotImplementedError def sum_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, # noqa: ARG001 -) -> jnp.ndarray: +) -> FloatColumn | IntColumn: if column.dtype == bool: column = column.astype(int) @@ -131,45 +130,45 @@ def sum_by_p_id( def mean_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> FloatColumn: raise NotImplementedError def max_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> FloatColumn | IntColumn: raise NotImplementedError def min_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> FloatColumn | IntColumn: raise NotImplementedError def any_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> BoolColumn: raise NotImplementedError def all_by_p_id( - column: jnp.ndarray, - p_id_to_aggregate_by: jnp.ndarray, - p_id_to_store_by: jnp.ndarray, + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, num_segments: int, -) -> jnp.ndarray: +) -> BoolColumn: raise NotImplementedError diff --git a/src/ttsim/tt_dag_elements/aggregation_numpy.py b/src/ttsim/tt_dag_elements/aggregation_numpy.py index 0de52732d..029ad4406 100644 --- a/src/ttsim/tt_dag_elements/aggregation_numpy.py +++ b/src/ttsim/tt_dag_elements/aggregation_numpy.py @@ -1,11 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy import numpy_groupies as npg +if TYPE_CHECKING: + from ttsim.tt_dag_elements.typing import BoolColumn, FloatColumn, IntColumn + -def grouped_count(group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_count") +def grouped_count(group_id: IntColumn) -> IntColumn: + fail_if_dtype_not_int(group_id, agg_func="grouped_count") out_grouped = npg.aggregate( group_id, numpy.ones(len(group_id), dtype=int), func="sum", fill_value=0 ) @@ -13,9 +18,11 @@ def grouped_count(group_id: numpy.ndarray) -> numpy.ndarray: return out_grouped[group_id] -def grouped_sum(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_sum") - fail_if__dtype_not_numeric_or_boolean(column, agg_func="grouped_sum") +def grouped_sum( + column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_boolean(column, agg_func="grouped_sum") + fail_if_dtype_not_int(group_id, agg_func="grouped_sum") if column.dtype == bool: column = column.astype(int) out_grouped = npg.aggregate(group_id, column, func="sum", fill_value=0) @@ -24,9 +31,11 @@ def grouped_sum(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray return out_grouped[group_id] -def grouped_mean(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_mean") - fail_if__dtype_not_float(column, agg_func="grouped_mean") +def grouped_mean( + column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn +) -> FloatColumn: + fail_if_dtype_not_numeric_or_boolean(column, agg_func="grouped_mean") + fail_if_dtype_not_int(group_id, agg_func="grouped_mean") out_grouped = npg.aggregate(group_id, column, func="mean", fill_value=0) @@ -34,9 +43,11 @@ def grouped_mean(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarra return out_grouped[group_id] -def grouped_max(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_max") - fail_if__dtype_not_numeric_or_datetime(column, agg_func="grouped_max") +def grouped_max( + column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_datetime(column, agg_func="grouped_max") + fail_if_dtype_not_int(group_id, agg_func="grouped_max") # For datetime, convert to integer (as numpy_groupies can handle datetime only if # numba is installed) @@ -59,9 +70,11 @@ def grouped_max(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray return out -def grouped_min(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_min") - fail_if__dtype_not_numeric_or_datetime(column, agg_func="grouped_min") +def grouped_min( + column: FloatColumn | IntColumn, group_id: IntColumn +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_datetime(column, agg_func="grouped_min") + fail_if_dtype_not_int(group_id, agg_func="grouped_min") # For datetime, convert to integer (as numpy_groupies can handle datetime only if # numba is installed) @@ -87,9 +100,9 @@ def grouped_min(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray return out -def grouped_any(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_any") - fail_if__dtype_not_boolean_or_int(column, agg_func="grouped_any") +def grouped_any(column: BoolColumn | IntColumn, group_id: IntColumn) -> BoolColumn: + fail_if_dtype_not_boolean_or_int(column, agg_func="grouped_any") + fail_if_dtype_not_int(group_id, agg_func="grouped_any") out_grouped = npg.aggregate(group_id, column, func="any", fill_value=0) @@ -97,9 +110,9 @@ def grouped_any(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray return out_grouped[group_id] -def grouped_all(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray: - fail_if__dtype_not_int(group_id, agg_func="grouped_all") - fail_if__dtype_not_boolean_or_int(column, agg_func="grouped_all") +def grouped_all(column: BoolColumn | IntColumn, group_id: IntColumn) -> BoolColumn: + fail_if_dtype_not_boolean_or_int(column, agg_func="grouped_all") + fail_if_dtype_not_int(group_id, agg_func="grouped_all") out_grouped = npg.aggregate(group_id, column, func="all", fill_value=0) @@ -108,22 +121,22 @@ def grouped_all(column: numpy.ndarray, group_id: numpy.ndarray) -> numpy.ndarray def count_by_p_id( - p_id_to_aggregate_by: numpy.ndarray, p_id_to_store_by: numpy.ndarray -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="count_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="count_by_p_id") + p_id_to_aggregate_by: IntColumn, p_id_to_store_by: IntColumn +) -> IntColumn: + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="count_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="count_by_p_id") raise NotImplementedError def sum_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="sum_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="sum_by_p_id") - fail_if__dtype_not_numeric_or_boolean(column, agg_func="sum_by_p_id") + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_boolean(column, agg_func="sum_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="sum_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="sum_by_p_id") if column.dtype in ["bool"]: column = column.astype(int) @@ -138,61 +151,63 @@ def sum_by_p_id( def mean_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="mean_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="mean_by_p_id") - fail_if__dtype_not_float(column, agg_func="mean_by_p_id") + column: FloatColumn | IntColumn | BoolColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> FloatColumn: + fail_if_dtype_not_numeric_or_boolean(column, agg_func="mean_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="mean_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="mean_by_p_id") raise NotImplementedError def max_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="max_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="max_by_p_id") - fail_if__dtype_not_numeric_or_datetime(column, agg_func="max_by_p_id") + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_datetime(column, agg_func="max_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="max_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="max_by_p_id") raise NotImplementedError def min_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="min_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="min_by_p_id") - fail_if__dtype_not_numeric_or_datetime(column, agg_func="min_by_p_id") + column: FloatColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> FloatColumn | IntColumn: + fail_if_dtype_not_numeric_or_datetime(column, agg_func="min_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="min_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="min_by_p_id") raise NotImplementedError def any_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="any_by_p_id") - fail_if__dtype_not_int(p_id_to_store_by, agg_func="any_by_p_id") - fail_if__dtype_not_boolean_or_int(column, agg_func="any_by_p_id") + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> BoolColumn: + fail_if_dtype_not_boolean_or_int(column, agg_func="any_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="any_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="any_by_p_id") raise NotImplementedError def all_by_p_id( - column: numpy.ndarray, - p_id_to_aggregate_by: numpy.ndarray, - p_id_to_store_by: numpy.ndarray, -) -> numpy.ndarray: - fail_if__dtype_not_int(p_id_to_store_by, agg_func="all_by_p_id") - fail_if__dtype_not_int(p_id_to_aggregate_by, agg_func="all_by_p_id") - fail_if__dtype_not_boolean_or_int(column, agg_func="all_by_p_id") + column: BoolColumn | IntColumn, + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, +) -> BoolColumn: + fail_if_dtype_not_boolean_or_int(column, agg_func="all_by_p_id") + fail_if_dtype_not_int(p_id_to_store_by, agg_func="all_by_p_id") + fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="all_by_p_id") raise NotImplementedError -def fail_if__dtype_not_numeric(column: numpy.ndarray, agg_func: str) -> None: +def fail_if_dtype_not_numeric( + column: FloatColumn | IntColumn | BoolColumn, agg_func: str +) -> None: if not numpy.issubdtype(column.dtype, numpy.number): raise TypeError( f"Aggregation function {agg_func} was applied to a column " @@ -200,7 +215,9 @@ def fail_if__dtype_not_numeric(column: numpy.ndarray, agg_func: str) -> None: ) -def fail_if__dtype_not_float(column: numpy.ndarray, agg_func: str) -> None: +def fail_if_dtype_not_float( + column: FloatColumn | IntColumn | BoolColumn, agg_func: str +) -> None: if not numpy.issubdtype(column.dtype, numpy.floating): raise TypeError( f"Aggregation function {agg_func} was applied to a column " @@ -208,7 +225,7 @@ def fail_if__dtype_not_float(column: numpy.ndarray, agg_func: str) -> None: ) -def fail_if__dtype_not_int(p_id_to_aggregate_by: numpy.ndarray, agg_func: str) -> None: +def fail_if_dtype_not_int(p_id_to_aggregate_by: IntColumn, agg_func: str) -> None: if not numpy.issubdtype(p_id_to_aggregate_by.dtype, numpy.integer): raise TypeError( f"The dtype of id columns must be integer. Aggregation function {agg_func} " @@ -216,7 +233,9 @@ def fail_if__dtype_not_int(p_id_to_aggregate_by: numpy.ndarray, agg_func: str) - ) -def fail_if__dtype_not_numeric_or_boolean(column: numpy.ndarray, agg_func: str) -> None: +def fail_if_dtype_not_numeric_or_boolean( + column: FloatColumn | IntColumn | BoolColumn, agg_func: str +) -> None: if not (numpy.issubdtype(column.dtype, numpy.number) or column.dtype == "bool"): raise TypeError( f"Aggregation function {agg_func} was applied to a column with dtype " @@ -224,8 +243,8 @@ def fail_if__dtype_not_numeric_or_boolean(column: numpy.ndarray, agg_func: str) ) -def fail_if__dtype_not_numeric_or_datetime( - column: numpy.ndarray, agg_func: str +def fail_if_dtype_not_numeric_or_datetime( + column: FloatColumn | IntColumn | BoolColumn, agg_func: str ) -> None: if not ( numpy.issubdtype(column.dtype, numpy.number) @@ -237,7 +256,9 @@ def fail_if__dtype_not_numeric_or_datetime( ) -def fail_if__dtype_not_boolean_or_int(column: numpy.ndarray, agg_func: str) -> None: +def fail_if_dtype_not_boolean_or_int( + column: BoolColumn | IntColumn, agg_func: str +) -> None: if not ( numpy.issubdtype(column.dtype, numpy.integer) or numpy.issubdtype(column.dtype, numpy.bool_) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 8667978cc..7fa9d9782 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -158,7 +158,11 @@ def check_and_get_thresholds( leaf_name: str, parameter_dict: dict[int, dict[str, float]], xnp: ModuleType, -) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: +) -> tuple[ + Float[Array, " n_segments"], + Float[Array, " n_segments"], + Float[Array, " n_segments"], +]: """Check and transfer raw threshold data. Transfer and check raw threshold data, which needs to be specified in a @@ -237,7 +241,7 @@ def _check_and_get_rates( func_type: FUNC_TYPES, parameter_dict: dict[int, dict[str, float]], xnp: ModuleType, -) -> numpy.ndarray: +) -> Float[Array, " n_segments"]: """Check and transfer raw rates data. Transfer and check raw rates data, which needs to be specified in a diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index ee5f27751..ba8bc6fb4 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -197,7 +197,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): )(identity), }, { - ("c", "g"): { + ("c", "g"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 1}, } @@ -218,7 +218,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): )(identity), }, { - ("x", "c", "h"): { + ("x", "c", "h"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 2}, } @@ -237,7 +237,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): )(identity), }, { - ("x", "c", "g"): { + ("x", "c", "g"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 3}, } @@ -258,7 +258,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): )(identity), }, { - ("z", "a", "f"): { + ("z", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 4}, } @@ -268,11 +268,11 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): ( {}, { - ("x", "a", "f"): { + ("x", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 5}, }, - ("x", "b", "g"): { + ("x", "b", "g"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 6}, }, @@ -456,7 +456,7 @@ def test_fail_if_active_periods_overlap_passes( )(identity), }, { - ("c", "f"): { + ("c", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 1}, } @@ -477,7 +477,7 @@ def test_fail_if_active_periods_overlap_passes( )(identity), }, { - ("x", "a", "f"): { + ("x", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 2}, } @@ -487,11 +487,11 @@ def test_fail_if_active_periods_overlap_passes( ( {}, { - ("x", "a", "f"): { + ("x", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 3}, }, - ("x", "b", "f"): { + ("x", "b", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 4}, }, diff --git a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py index 76edbd0a0..7b4272e1e 100644 --- a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py +++ b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py @@ -68,10 +68,11 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "expected_res_max": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), "expected_res_min": numpy.array([1.0, 1.0, 1.0, 1.0, 1.0]), }, - "basic_case": { + "int_column": { "column_to_aggregate": numpy.array([0, 1, 2, 3, 4]), "group_id": numpy.array([0, 0, 1, 1, 1]), "expected_res_sum": numpy.array([1, 1, 9, 9, 9]), + "expected_res_mean": numpy.array([0.5, 0.5, 3, 3, 3]), "expected_res_max": numpy.array([1, 1, 4, 4, 4]), "expected_res_min": numpy.array([0, 0, 2, 2, 2]), "expected_res_any": numpy.array([True, True, True, True, True]), @@ -103,12 +104,13 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "expected_res_max": numpy.array([3.0, 1.0, 3.0, 3.0, 4.0]), "expected_res_min": numpy.array([0.0, 1.0, 0.0, 0.0, 4.0]), }, - "basic_case_bool": { + "bool_column": { "column_to_aggregate": numpy.array([True, False, True, False, False]), "group_id": numpy.array([0, 0, 1, 1, 1]), "expected_res_any": numpy.array([True, True, True, True, True]), "expected_res_all": numpy.array([False, False, False, False, False]), "expected_res_sum": numpy.array([1, 1, 1, 1, 1]), + "expected_res_mean": numpy.array([0.5, 0.5, 1 / 3, 1 / 3, 1 / 3]), }, "group_id_unsorted_bool": { "column_to_aggregate": numpy.array([True, False, True, True, True]), @@ -144,7 +146,6 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "dtype_boolean": { "column_to_aggregate": numpy.array([True, True, True, False, False]), "group_id": numpy.array([0, 0, 1, 1, 1]), - "error_mean": TypeError, "error_max": TypeError, "error_min": TypeError, "exception_match": "grouped_", @@ -167,12 +168,6 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): "error_all": TypeError, "exception_match": "grouped_", }, - "dtype_integer": { - "column_to_aggregate": numpy.array([1, 2, 3, 4, 5]), - "group_id": numpy.array([0, 0, 1, 1, 1]), - "error_mean": TypeError, - "exception_match": "grouped_", - }, "float_group_id_bool": { "column_to_aggregate": numpy.array([True, True, True, False, False]), "group_id": numpy.array([0, 0, 3.5, 3.5, 3.5]), diff --git a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py index c39a370c9..a4ba38a0e 100644 --- a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py +++ b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py @@ -2,7 +2,6 @@ import inspect -import numpy import pytest from ttsim.tt_dag_elements import ( @@ -229,7 +228,7 @@ def aggregate_by_p_id_count(p_id, p_id_specifier): @agg_by_p_id_function(agg_type=AggType.SUM) -def aggregate_by_p_id_sum(p_id, p_id_specifier, source): +def aggregate_by_p_id_sum(p_id, p_id_specifier, column): pass @@ -241,7 +240,7 @@ def aggregate_by_p_id_sum(p_id, p_id_specifier, source): ), [ (aggregate_by_p_id_count, "p_id", None), - (aggregate_by_p_id_sum, "p_id", "source"), + (aggregate_by_p_id_sum, "p_id", "column"), ], ) def test_agg_by_p_id_function_type(function, expected_foreign_p_id, expected_other_arg): @@ -286,11 +285,11 @@ def aggregate_by_p_id_multiple_other_p_ids_present( pass -def test_agg_by_p_id_sum_with_all_missing_p_ids(backend): +def test_agg_by_p_id_sum_with_all_missing_p_ids(backend, xnp): aggregate_by_p_id_sum( - p_id=numpy.array([180]), - p_id_specifier=numpy.array([-1]), - source=numpy.array([False]), + p_id=xnp.array([180]), + p_id_specifier=xnp.array([-1]), + column=xnp.array([0]), num_segments=1, backend=backend, ) From f13ccd5f8c3b4501d36c8204fb347c3f41d24f31 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 14:10:37 +0200 Subject: [PATCH 20/25] Cut down on ruff exceptions. --- pixi.lock | 4 +- pyproject.toml | 79 +++++++------------ src/_gettsim/arbeitslosengeld_2/einkommen.py | 6 +- .../kindergeld\303\274bertrag.py" | 2 +- .../abz\303\274ge/sonderausgaben.py" | 3 - .../abz\303\274ge/vorsorge.py" | 2 - src/_gettsim/einkommensteuer/einkommen.py | 3 +- ...aus_nichtselbstst\303\244ndiger_arbeit.py" | 1 - .../zu_versteuerndes_einkommen.py | 1 - src/_gettsim/elterngeld/elterngeld.py | 1 - src/_gettsim/erziehungsgeld/erziehungsgeld.py | 2 - .../grundsicherung/im_alter/einkommen.py | 6 +- .../grundsicherung/im_alter/im_alter.py | 2 - src/_gettsim/ids.py | 2 - src/_gettsim/interface.py | 5 +- src/_gettsim/kinderbonus/kinderbonus.py | 1 - src/_gettsim/kindergeld/kindergeld.py | 4 +- src/_gettsim/kinderzuschlag/einkommen.py | 1 - src/_gettsim/kinderzuschlag/kinderzuschlag.py | 7 -- src/_gettsim/lohnsteuer/einkommen.py | 4 - src/_gettsim/lohnsteuer/lohnsteuer.py | 5 -- .../solidarit\303\244tszuschlag.py" | 1 - .../arbeitslosen/arbeitslosengeld.py | 4 +- .../arbeitslosen/beitrag/beitrag.py | 6 +- .../kranken/beitrag/beitrag.py | 1 - .../kranken/beitrag/beitragssatz.py | 9 --- .../kranken/beitrag/einkommen.py | 1 - src/_gettsim/sozialversicherung/midijob.py | 1 - .../pflege/beitrag/beitrag.py | 8 -- .../pflege/beitrag/beitragssatz.py | 1 - .../rente/altersrente/altersgrenzen.py | 6 -- .../rente/altersrente/altersrente.py | 3 - .../besonders_langj\303\244hrig.py" | 1 - .../rente/altersrente/entgeltpunkte.py | 1 - .../f\303\274r_frauen/f\303\274r_frauen.py" | 2 - .../rente/altersrente/inputs.py | 3 +- .../regelaltersrente/regelaltersrente.py | 1 - .../wegen_arbeitslosigkeit/inputs.py | 6 +- .../wegen_arbeitslosigkeit.py | 2 - .../rente/beitrag/beitrag.py | 4 - .../erwerbsminderung/erwerbsminderung.py | 6 -- .../rente/grundrente/grundrente.py | 3 - .../sozialversicherung/rente/inputs.py | 9 ++- src/_gettsim/unterhalt/unterhalt.py | 3 +- .../unterhaltsvorschuss.py | 3 +- .../vorrangpr\303\274fungen.py" | 1 - src/_gettsim/wohngeld/einkommen.py | 1 - src/_gettsim/wohngeld/miete.py | 2 - src/_gettsim/wohngeld/voraussetzungen.py | 2 - src/ttsim/interface_dag.py | 1 - .../automatically_added_functions.py | 1 - src/ttsim/interface_dag_elements/backend.py | 1 - .../interface_dag_elements/data_converters.py | 14 ++-- src/ttsim/interface_dag_elements/fail_if.py | 12 +-- .../interface_dag_elements/input_data.py | 6 +- .../interface_node_objects.py | 9 ++- src/ttsim/interface_dag_elements/names.py | 2 - .../policy_environment.py | 1 - .../interface_dag_elements/processed_data.py | 4 +- src/ttsim/interface_dag_elements/results.py | 6 +- src/ttsim/interface_dag_elements/shared.py | 13 +-- .../specialized_environment.py | 4 +- src/ttsim/plot_dag.py | 2 - src/ttsim/testing_utils.py | 4 +- .../column_objects_param_function.py | 7 +- src/ttsim/tt_dag_elements/param_objects.py | 3 +- .../tt_dag_elements/piecewise_polynomial.py | 4 +- src/ttsim/tt_dag_elements/vectorization.py | 11 ++- tests/ttsim/mettsim/group_by_ids.py | 1 - tests/ttsim/test_failures.py | 4 +- tests/ttsim/test_orig_policy_objects.py | 3 +- tests/ttsim/test_policy_environment.py | 10 ++- tests/ttsim/test_shared.py | 2 +- tests/ttsim/tt_dag_elements/test_rounding.py | 14 ++-- .../tt_dag_elements/test_vectorization.py | 4 +- 75 files changed, 141 insertions(+), 234 deletions(-) diff --git a/pixi.lock b/pixi.lock index 6283f1e80..fa52bddfa 100644 --- a/pixi.lock +++ b/pixi.lock @@ -6622,8 +6622,8 @@ packages: timestamp: 1694400856979 - pypi: ./ name: gettsim - version: 0.7.1.dev454+g1e8f5582.d20250613 - sha256: 1811ce4f8ffd372f6c4510a2c2d43eb3499ed35c31f584814c33d33774ab0b5f + version: 0.7.1.dev457+g905f28b1.d20250613 + sha256: bdcd4b8079d89df7ef67b55be4ee15400ab06021e05e8dfbe3f8817b130725b5 requires_dist: - ipywidgets - networkx diff --git a/pyproject.toml b/pyproject.toml index 161234343..42e0d354f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,16 +210,19 @@ extend-ignore = [ "D416", "D417", "F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error + "FBT001", # Boolean-typed positional argument in function definition + "ISC001", # Avoid conflicts with ruff-format + "N999", # Allow non-ASCII characters in file names. + "PLC2401", # Allow non-ASCII characters in variable names. + "PLC2403", # Allow non-ASCII function names for imports. + "PLR0913", # Allow too many arguments in function definitions. + "FIX002", # Line contains TODO -- Use stuff from TD area. + "TRY003", # Avoid specifying long messages outside the exception class + "PLR5501", # elif not supported by vectorization converter for Jax + "EM101", # Exception must not use a string literal + "EM102", # Exception must not use an f-string literal # Others. - "D404", # Do not start module docstring with "This". - "RET504", # unnecessary variable assignment before return. - "S101", # raise errors for asserts. - "B905", # strict parameter for zip that was implemented in py310. - - "FBT", # flake8-boolean-trap - "EM", # flake8-errmsg - "ANN401", # flake8-annotate typing.Any - "PD", # pandas-vet + "E731", # do not assign a lambda expression, use a def "RET", # unnecessary elif or else statements after return, raise, continue, ... "S324", # Probable use of insecure hash function. @@ -228,46 +231,34 @@ extend-ignore = [ "DTZ001", # use of `datetime.datetime()` without `tzinfo` argument is not allowed "DTZ002", # use of `datetime.datetime.today()` is not allowed "PT012", # `pytest.raises()` block should contain a single simple statement - "PLR5501", # elif not supported by vectorization converter for Jax - "TRY003", # Avoid specifying long messages outside the exception class - "FIX002", # Line contains TODO -- Use stuff from TD area. - "PLC2401", # Allow non-ASCII characters in variable names. - "PLC2403", # Allow non-ASCII function names for imports. - "PLR0913", # Allow too many arguments in function definitions. - "N999", # Allow non-ASCII characters in file names. - "PLR0913", # Too many arguments in function definition. - - # Things we are not sure we want - # ============================== - "SIM102", # Use single if statement instead of nested if statements - "SIM108", # Use ternary operator instead of if-else block - "SIM117", # do not use nested with statements - "BLE001", # Do not catch blind exceptions (even after handling some specific ones) - "PLR2004", # Magic values used in comparison - "PT006", # Allows only lists of tuples in parametrize, even if single argument - # Things ignored during transition phase + + # Ignored during transition phase # ====================================== "D", # docstrings - "ANN", # missing annotations - "C901", # function too complex - "PT011", # pytest raises without match statement + "PLR2004", # Magic values used in comparison "INP001", # implicit namespace packages without init. - "E721", # Use `is` and `is not` for type comparisons + "PT006", # Allows only lists of tuples in parametrize, even if single argument + "S101", # use of asserts outside of tests - # Things ignored to avoid conflicts with ruff-format - # ================================================== - "ISC001", ] exclude = [] [tool.ruff.lint.per-file-ignores] +"conftest.py" = ["ANN"] "docs/**/*.ipynb" = ["T201"] -"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716"] # Vectorization can't handle x <= y <= z or x in {x,y} +# Mostly things vectorization can't handle +"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716", "E721", "SIM108"] +# All tests return None and use asserts +"src/_gettsim_tests/**/*.py" = ["ANN", "S101"] "src/ttsim/interface_dag_elements/specialized_environment.py" = ["E501"] "src/ttsim/interface_dag_elements/fail_if.py" = ["E501"] "src/ttsim/interface_dag_elements/typing.py" = ["PGH", "PLR", "SIM114"] -"tests/ttsim/mettsim/*" = ["PLR1714", "PLR1716"] # Vectorization can't handle x <= y <= z or x in {x,y} +# Mostly things vectorization can't handle +"tests/ttsim/mettsim/**/*.py" = ["PLR1714", "PLR1716", "E721", "SIM108"] +# All tests return None and use asserts +"tests/ttsim/**/*.py" = ["ANN", "S101"] +"tests/ttsim/tt_dag_elements/test_vectorization.py" = ["PLR1714", "PLR1716", "E721", "SIM108"] "tests/ttsim/test_failures.py" = ["E501"] # TODO: remove once ported nicely "src/ttsim/stale_code_storage.py" = ["ALL"] @@ -314,14 +305,6 @@ disable_error_code = [ "no-untyped-def", # All tests return None, don't clutter source code. ] -[[tool.mypy.overrides]] -module = [ - "src.ttsim.tt_dag_elements.aggregation_numpy", -] -disable_error_code = [ - "type-arg" # ndarray is not typed further. -] - [[tool.mypy.overrides]] module = [ "src._gettsim_tests.*", @@ -330,14 +313,6 @@ disable_error_code = [ "no-untyped-def", # All tests return None, don't clutter source code. ] -[[tool.mypy.overrides]] -module = [ - "tests.ttsim.tt_dag_elements.test_vectorization", # doing some numpy mixing; vectorization does not like inline comments. -] -disable_error_code = [ - "assignment" -] - [tool.check-manifest] ignore = ["src/_gettsim/_version.py"] diff --git a/src/_gettsim/arbeitslosengeld_2/einkommen.py b/src/_gettsim/arbeitslosengeld_2/einkommen.py index 54eece24c..3f7df5c77 100644 --- a/src/_gettsim/arbeitslosengeld_2/einkommen.py +++ b/src/_gettsim/arbeitslosengeld_2/einkommen.py @@ -206,7 +206,8 @@ def parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg( xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Parameter for calculation of income not subject to transfer withdrawal when - children are not in the Bedarfsgemeinschaft.""" + children are not in the Bedarfsgemeinschaft. + """ return get_piecewise_parameters( leaf_name="parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg", func_type="piecewise_linear", @@ -222,7 +223,8 @@ def parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg( xnp: ModuleType, ) -> PiecewisePolynomialParamValue: """Parameter for calculation of income not subject to transfer withdrawal when - children are in the Bedarfsgemeinschaft.""" + children are in the Bedarfsgemeinschaft. + """ updated_parameters: dict[int, dict[str, float]] = upsert_tree( base=raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg, to_upsert=raw_parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg, diff --git "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" index ae0aeb052..6eaaede17 100644 --- "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" +++ "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" @@ -132,7 +132,7 @@ def in_anderer_bg_als_kindergeldempfänger( Kindergeldempfänger of that person. """ # Create a dictionary to map p_id to bg_id - p_id_to_bg_id = dict(zip(p_id, bg_id)) + p_id_to_bg_id = dict(zip(p_id, bg_id, strict=False)) # Map each kindergeld__p_id_empfänger to its corresponding bg_id empf_bg_id = [ diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" index 04a0fddee..94e78c7db 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" @@ -28,7 +28,6 @@ def sonderausgaben_y_sn_nur_pauschale( """ - return sonderausgabenpauschbetrag * einkommensteuer__anzahl_personen_sn @@ -44,7 +43,6 @@ def sonderausgaben_y_sn_mit_kinderbetreuung( details here https://www.buzer.de/s1.htm?a=10&g=estg. """ - return max( absetzbare_kinderbetreuungskosten_y_sn, sonderausgabenpauschbetrag * einkommensteuer__anzahl_personen_sn, @@ -76,7 +74,6 @@ def absetzbare_kinderbetreuungskosten_y_sn( """ - return ( gedeckelte_kinderbetreuungskosten_y_sn * parameter_absetzbare_kinderbetreuungskosten["anteil"] diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" index 81d18c5d7..5d9601dad 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/vorsorge.py" @@ -38,7 +38,6 @@ def vorsorgeaufwendungen_y_sn_ab_2005_bis_2009( Günstigerprüfung against the regime until 2004. """ - return max( vorsorgeaufwendungen_regime_bis_2004_y_sn, vorsorgeaufwendungen_globale_kappung_y_sn, @@ -60,7 +59,6 @@ def vorsorgeaufwendungen_y_sn_ab_2010_bis_2019( Günstigerprüfung against the regime until 2004. """ - return max( vorsorgeaufwendungen_regime_bis_2004_y_sn, vorsorgeaufwendungen_keine_kappung_krankenversicherung_y_sn, diff --git a/src/_gettsim/einkommensteuer/einkommen.py b/src/_gettsim/einkommensteuer/einkommen.py index ee094f2b5..e6e746863 100644 --- a/src/_gettsim/einkommensteuer/einkommen.py +++ b/src/_gettsim/einkommensteuer/einkommen.py @@ -1,7 +1,8 @@ """Einkommen. Einkommen are Einkünfte minus Sonderausgaben, Vorsorgeaufwendungen, außergewöhnliche -Belastungen and sonstige Abzüge.""" +Belastungen and sonstige Abzüge. +""" from __future__ import annotations diff --git "a/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_nichtselbstst\303\244ndiger_arbeit/aus_nichtselbstst\303\244ndiger_arbeit.py" "b/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_nichtselbstst\303\244ndiger_arbeit/aus_nichtselbstst\303\244ndiger_arbeit.py" index 630ee4f7a..6bcd1de40 100644 --- "a/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_nichtselbstst\303\244ndiger_arbeit/aus_nichtselbstst\303\244ndiger_arbeit.py" +++ "b/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_nichtselbstst\303\244ndiger_arbeit/aus_nichtselbstst\303\244ndiger_arbeit.py" @@ -24,5 +24,4 @@ def betrag_y( @policy_function() def betrag_ohne_minijob_y(bruttolohn_y: float, werbungskostenpauschale: float) -> float: """Take gross wage and deduct Werbungskostenpauschale.""" - return max(bruttolohn_y - werbungskostenpauschale, 0.0) diff --git a/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py b/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py index 6b6c8e0bf..ee42b6a29 100644 --- a/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py +++ b/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py @@ -81,6 +81,5 @@ def zu_versteuerndes_einkommen_mit_kinderfreibetrag_y_sn( kinderfreibetrag_y_sn: float, ) -> float: """Calculate taxable income with child allowance on Steuernummer level.""" - out = gesamteinkommen_y - kinderfreibetrag_y_sn return max(out, 0.0) diff --git a/src/_gettsim/elterngeld/elterngeld.py b/src/_gettsim/elterngeld/elterngeld.py index cee38fda4..a63d3032d 100644 --- a/src/_gettsim/elterngeld/elterngeld.py +++ b/src/_gettsim/elterngeld/elterngeld.py @@ -220,7 +220,6 @@ def lohnersatzanteil( decreases above the second step until prozent_minimum. """ - # Higher replacement rate if considered income is below a threshold if ( nettoeinkommen_vorjahr_m diff --git a/src/_gettsim/erziehungsgeld/erziehungsgeld.py b/src/_gettsim/erziehungsgeld/erziehungsgeld.py index c849291c4..bd94b62e3 100644 --- a/src/_gettsim/erziehungsgeld/erziehungsgeld.py +++ b/src/_gettsim/erziehungsgeld/erziehungsgeld.py @@ -241,7 +241,6 @@ def anzurechnendes_einkommen_y( There is special rule for "Beamte, Soldaten und Richter" which is not implemented yet. """ - if kind_grundsätzlich_anspruchsberechtigt: out = ( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_vorjahr_y_fg @@ -264,7 +263,6 @@ def einkommensgrenze_y( Legal reference: Bundesgesetzblatt Jahrgang 2004 Teil I Nr. 6 (pp.208) """ - out = ( einkommensgrenze_ohne_geschwisterbonus + (arbeitslosengeld_2__anzahl_kinder_fg - 1) * aufschlag_einkommen diff --git a/src/_gettsim/grundsicherung/im_alter/einkommen.py b/src/_gettsim/grundsicherung/im_alter/einkommen.py index c754433b0..d9103bdfc 100644 --- a/src/_gettsim/grundsicherung/im_alter/einkommen.py +++ b/src/_gettsim/grundsicherung/im_alter/einkommen.py @@ -30,7 +30,6 @@ def einkommen_m( """Calculate individual income considered in the calculation of Grundsicherung im Alter. """ - # Income total_income = ( erwerbseinkommen_m @@ -64,8 +63,8 @@ def erwerbseinkommen_m( Legal reference: § 82 SGB XII Abs. 3 - Notes: - + Notes + ----- - Freibeträge for income are currently not considered - Start date is 2011 because of the reference to regelbedarfsstufen, which was introduced in 2011. @@ -162,7 +161,6 @@ def gesetzliche_rente_m_ab_2021( Starting from 2021: If eligible for Grundrente, can deduct 100€ completely and 30% of private pension above 100 (but no more than 1/2 of regelbedarf) """ - angerechnete_rente = piecewise_polynomial( x=sozialversicherung__rente__altersrente__betrag_m, parameters=anrechnungsfreier_anteil_gesetzliche_rente, diff --git a/src/_gettsim/grundsicherung/im_alter/im_alter.py b/src/_gettsim/grundsicherung/im_alter/im_alter.py index 0ca7de5a0..c2809af2a 100644 --- a/src/_gettsim/grundsicherung/im_alter/im_alter.py +++ b/src/_gettsim/grundsicherung/im_alter/im_alter.py @@ -32,7 +32,6 @@ def betrag_m_eg( # ToDo: currently not implemented for retirees. """ - # TODO(@ChristianZimpelmann): Treatment of Bedarfsgemeinschaften with both retirees # and unemployed job seekers probably incorrect # https://github.com/iza-institute-of-labor-economics/gettsim/issues/703 @@ -80,7 +79,6 @@ def mehrbedarf_schwerbehinderung_g_m( grundsicherung__regelbedarfsstufen: Regelbedarfsstufen, ) -> float: """Calculate additional allowance for individuals with disabled person's pass G.""" - mehrbedarf_single = ( grundsicherung__regelbedarfsstufen.rbs_1 ) * mehrbedarf_bei_schwerbehinderungsgrad_g diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index df5a69a9e..c3d2da17a 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -115,7 +115,6 @@ def _assign_parents_fg_id( xnp: ModuleType, ) -> IntColumn: """Return the fg_id of the child's parents.""" - # TODO(@MImmesberger): Remove hard-coded number # https://github.com/iza-institute-of-labor-economics/gettsim/issues/668 @@ -210,7 +209,6 @@ def sn_id( xnp: ModuleType, ) -> IntColumn: """Steuernummer. Spouses filing taxes jointly or individuals.""" - n = xnp.max(p_id) + 1 p_id_ehepartner_or_own_p_id = xnp.where( diff --git a/src/_gettsim/interface.py b/src/_gettsim/interface.py index b555643ca..cf9add064 100644 --- a/src/_gettsim/interface.py +++ b/src/_gettsim/interface.py @@ -40,10 +40,11 @@ def oss( A tree that has the desired targets as the path (sequence of keys) and maps them to the data columns the user would like to have. - Returns: + Returns + ------- A DataFrame with the results. - Examples: + Examples -------- >>> inputs_df = pd.DataFrame( ... { diff --git a/src/_gettsim/kinderbonus/kinderbonus.py b/src/_gettsim/kinderbonus/kinderbonus.py index b65bd4362..b09b986a0 100644 --- a/src/_gettsim/kinderbonus/kinderbonus.py +++ b/src/_gettsim/kinderbonus/kinderbonus.py @@ -12,7 +12,6 @@ def betrag_y(kindergeld__betrag_y: float, satz: float) -> float: (one-time payment, non-allowable against transfer payments) """ - if kindergeld__betrag_y > 0: out = satz else: diff --git a/src/_gettsim/kindergeld/kindergeld.py b/src/_gettsim/kindergeld/kindergeld.py index ca51e7392..bb9e1589e 100644 --- a/src/_gettsim/kindergeld/kindergeld.py +++ b/src/_gettsim/kindergeld/kindergeld.py @@ -38,7 +38,6 @@ def betrag_ohne_staffelung_m( of children. """ - return satz * anzahl_ansprüche @@ -145,7 +144,8 @@ def satz_nach_anzahl_kinder( xnp: ModuleType, ) -> ConsecutiveInt1dLookupTableParamValue: """Convert the Kindergeld-Satz by child to the amount of Kindergeld by number of - children.""" + children. + """ max_num_children = 30 max_num_children_in_spec = max(satz_gestaffelt.keys()) base_spec = { diff --git a/src/_gettsim/kinderzuschlag/einkommen.py b/src/_gettsim/kinderzuschlag/einkommen.py index a39789eac..d78c4cd76 100644 --- a/src/_gettsim/kinderzuschlag/einkommen.py +++ b/src/_gettsim/kinderzuschlag/einkommen.py @@ -287,7 +287,6 @@ def wohnbedarf_anteil_eltern_bg( Reference: § 6a Abs. 5 S. 3 BKGG """ - if familie__alleinerziehend_bg: elternbetrag = ( existenzminimum.kosten_der_unterkunft.single diff --git a/src/_gettsim/kinderzuschlag/kinderzuschlag.py b/src/_gettsim/kinderzuschlag/kinderzuschlag.py index c69e37861..0de67f687 100644 --- a/src/_gettsim/kinderzuschlag/kinderzuschlag.py +++ b/src/_gettsim/kinderzuschlag/kinderzuschlag.py @@ -27,7 +27,6 @@ def satz_mit_gestaffeltem_kindergeld( For 2023 the amount is once again explicitly specified as a parameter. """ - return max( ( existenzminimum.regelsatz.kind @@ -53,7 +52,6 @@ def satz_mit_einheitlichem_kindergeld_und_kindersofortzuschlag( Formula according to § 6a (2) BKGG. """ - current_formula = ( existenzminimum.regelsatz.kind + existenzminimum.kosten_der_unterkunft.kind @@ -103,7 +101,6 @@ def anspruchshöhe_m_bg( vermögensfreibetrag_bg: float, ) -> float: """Kinderzuschlag claim at the Bedarfsgemeinschaft level.""" - if vermögen_bg > vermögensfreibetrag_bg: out = max( basisbetrag_m_bg - (vermögen_bg - vermögensfreibetrag_bg), @@ -121,7 +118,6 @@ def vermögensfreibetrag_bg_bis_2022( arbeitslosengeld_2__vermögensfreibetrag_bg: float, ) -> float: """Wealth exemptions for Kinderzuschlag until 2022.""" - return arbeitslosengeld_2__vermögensfreibetrag_bg @@ -130,7 +126,6 @@ def vermögensfreibetrag_bg_ab_2023( arbeitslosengeld_2__vermögensfreibetrag_in_karenzzeit_bg: float, ) -> float: """Wealth exemptions for Kinderzuschlag since 2023.""" - return arbeitslosengeld_2__vermögensfreibetrag_in_karenzzeit_bg @@ -156,7 +151,6 @@ def basisbetrag_m_bg_check_maximales_netteinkommen( (arbeitslosengeld_2__anzahl_personen_bg > 1). """ - if ( nettoeinkommen_eltern_m_bg <= maximales_nettoeinkommen_m_bg ) and arbeitslosengeld_2__anzahl_personen_bg > 1: @@ -192,7 +186,6 @@ def basisbetrag_m_bg_check_mindestbruttoeinkommen_und_maximales_nettoeinkommen( (arbeitslosengeld_2__anzahl_personen_bg > 1). """ - if ( (bruttoeinkommen_eltern_m_bg >= mindestbruttoeinkommen_m_bg) and (nettoeinkommen_eltern_m_bg <= maximales_nettoeinkommen_m_bg) diff --git a/src/_gettsim/lohnsteuer/einkommen.py b/src/_gettsim/lohnsteuer/einkommen.py index 86f4f17a7..c69be4d15 100644 --- a/src/_gettsim/lohnsteuer/einkommen.py +++ b/src/_gettsim/lohnsteuer/einkommen.py @@ -66,7 +66,6 @@ def vorsorge_krankenversicherungsbeiträge_option_a( but only up to a certain threshold. """ - vorsorge_krankenversicherungsbeiträge_option_a_basis = ( vorsorgepauschale_mindestanteil * sozialversicherung__kranken__beitrag__einkommen_bis_beitragsbemessungsgrenze_y @@ -130,7 +129,6 @@ def vorsorge_krankenversicherungsbeiträge_option_b_ab_2019( a" and "Option b". This function calculates option b where the actual contributions are used. """ - return ( sozialversicherung__kranken__beitrag__einkommen_bis_beitragsbemessungsgrenze_y * ( @@ -182,7 +180,6 @@ def vorsorgepauschale_y_ab_2010_bis_2022( used when calculating Einkommensteuer. """ - rente = ( sozialversicherung__rente__beitrag__einkommen_y * sozialversicherung__rente__beitrag__beitragssatz @@ -213,7 +210,6 @@ def vorsorgepauschale_y_ab_2023( used when calculating Einkommensteuer. """ - rente = ( sozialversicherung__rente__beitrag__einkommen_y * sozialversicherung__rente__beitrag__beitragssatz diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index 092f93f97..dafc230fb 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -35,7 +35,6 @@ def basis_für_klassen_5_6( Jahresbetrags. """ - return 2 * ( piecewise_polynomial( x=einkommen_y * 1.25, parameters=parameter_einkommensteuertarif, xnp=xnp @@ -137,7 +136,6 @@ def tarif_klassen_5_und_6( xnp: ModuleType, ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6.""" - basis = basis_für_klassen_5_6( einkommen_y, einkommensteuer__parameter_einkommensteuertarif, xnp=xnp ) @@ -158,7 +156,6 @@ def betrag_y( tarif_klassen_5_und_6: float, ) -> float: """Withholding tax on earnings (Lohnsteuer)""" - if steuerklasse == 1 or steuerklasse == 2 or steuerklasse == 4: out = basistarif elif steuerklasse == 3: @@ -264,7 +261,6 @@ def betrag_soli_y( xnp: ModuleType, ) -> float: """Solidarity surcharge on Lohnsteuer (withholding tax on earnings).""" - return piecewise_polynomial( x=betrag_mit_kinderfreibetrag_y, parameters=solidaritätszuschlag__parameter_solidaritätszuschlag, @@ -283,7 +279,6 @@ def kinderfreibetrag_soli_y( benefit, Steuerklasse 4 gets the child benefit once, and Steuerklasse 5/6 gets nothing. """ - if steuerklasse == 1 or steuerklasse == 2 or steuerklasse == 3: out = 2 * einkommensteuer__kinderfreibetrag_y elif steuerklasse == 4: diff --git "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" index ce3e8d295..3bb96420f 100644 --- "a/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" +++ "b/src/_gettsim/solidarit\303\244tszuschlag/solidarit\303\244tszuschlag.py" @@ -21,7 +21,6 @@ def solidaritätszuschlagstarif( xnp: ModuleType, ) -> float: """The isolated function for Solidaritätszuschlag.""" - return einkommensteuer__anzahl_personen_sn * piecewise_polynomial( x=steuer_pro_person / einkommensteuer__anzahl_personen_sn, parameters=parameter_solidaritätszuschlag, diff --git a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py index 0f1f6dd00..ec5d52eb6 100644 --- a/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py +++ b/src/_gettsim/sozialversicherung/arbeitslosen/arbeitslosengeld.py @@ -34,7 +34,6 @@ def betrag_m( satz: dict[str, float], ) -> float: """Calculate individual unemployment benefit.""" - if einkommensteuer__anzahl_kinderfreibeträge == 0: arbeitsl_geld_satz = satz["allgemein"] else: @@ -86,7 +85,8 @@ def mindestversicherungszeit_erreicht( mindestversicherungsmonate: int, ) -> bool: """At least 12 months of unemployment contributions in the 30 months before claiming - unemployment insurance.""" + unemployment insurance. + """ return ( monate_beitragspflichtig_versichert_in_letzten_30_monaten >= mindestversicherungsmonate diff --git a/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py index 45671ac42..766c9261c 100644 --- a/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py @@ -99,7 +99,8 @@ def betrag_gesamt_in_gleitzone_m( beitragssatz: float, ) -> float: """Sum of employee's and employer's unemployment insurance contribution - for Midijobs.""" + for Midijobs. + """ return sozialversicherung__midijob_bemessungsentgelt_m * beitragssatz @@ -113,7 +114,8 @@ def betrag_arbeitgeber_in_gleitzone_m_anteil_bruttolohn( beitragssatz: float, ) -> float: """Employers' unemployment insurance contribution for Midijobs until September - 2022.""" + 2022. + """ return ( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m * beitragssatz diff --git a/src/_gettsim/sozialversicherung/kranken/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/kranken/beitrag/beitrag.py index ff4ba41ea..6e84e6df4 100644 --- a/src/_gettsim/sozialversicherung/kranken/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/kranken/beitrag/beitrag.py @@ -193,7 +193,6 @@ def betrag_rentner_m( beitragssatz_arbeitnehmer: float, ) -> float: """Health insurance contributions for pension incomes.""" - return beitragssatz_arbeitnehmer * bemessungsgrundlage_rente_m diff --git a/src/_gettsim/sozialversicherung/kranken/beitrag/beitragssatz.py b/src/_gettsim/sozialversicherung/kranken/beitrag/beitragssatz.py index 30b727901..aeaf135b1 100644 --- a/src/_gettsim/sozialversicherung/kranken/beitrag/beitragssatz.py +++ b/src/_gettsim/sozialversicherung/kranken/beitrag/beitragssatz.py @@ -11,7 +11,6 @@ def beitragssatz_arbeitnehmer(beitragssatz: float) -> float: Basic split between employees and employers. """ - return beitragssatz / 2 @@ -188,7 +187,6 @@ def beitragssatz_arbeitgeber_bis_06_2005(beitragssatz: float) -> float: Until 2008, the top-up contribution rate (Zusatzbeitrag) was not considered. """ - return beitragssatz / 2 @@ -203,7 +201,6 @@ def beitragssatz_arbeitgeber_jahresanfang_bis_06_2005( Until 2008, the top-up contribution rate (Zusatzbeitrag) was not considered. """ - return beitragssatz_jahresanfang / 2 @@ -219,7 +216,6 @@ def beitragssatz_arbeitgeber_mittlerer_kassenspezifischer( Until 2008, the top-up contribution rate (Zusatzbeitrag) was not considered. """ - return parameter_beitragssatz["mean_allgemein"] / 2 @@ -235,7 +231,6 @@ def beitragssatz_arbeitgeber_jahresanfang_mittlerer_kassenspezifischer( Until 2008, the top-up contribution rate (Zusatzbeitrag) was not considered. """ - return parameter_beitragssatz_jahresanfang["mean_allgemein"] / 2 @@ -252,7 +247,6 @@ def beitragssatz_arbeitgeber_einheitlicher_zusatzbeitrag( From 2009 until 2018, the contribution rate was uniform for all health insurers, Zusatzbeitrag irrelevant. """ - return parameter_beitragssatz["allgemein"] / 2 @@ -269,7 +263,6 @@ def beitragssatz_arbeitgeber_jahresanfang_einheitlicher_zusatzbeitrag( From 2009 until 2018, the contribution rate was uniform for all health insurers, Zusatzbeitrag irrelevant. """ - return parameter_beitragssatz_jahresanfang["allgemein"] / 2 @@ -312,7 +305,6 @@ def zusatzbeitragssatz_von_sonderbeitrag( parameter_beitragssatz: dict[str, float], ) -> float: """Health insurance top-up (Zusatzbeitrag) rate until December 2014.""" - return parameter_beitragssatz["sonderbeitrag"] @@ -324,5 +316,4 @@ def zusatzbeitragssatz_von_mean_zusatzbeitrag( parameter_beitragssatz: dict[str, float], ) -> float: """Health insurance top-up rate (Zusatzbeitrag) since January 2015.""" - return parameter_beitragssatz["mean_zusatzbeitrag"] diff --git a/src/_gettsim/sozialversicherung/kranken/beitrag/einkommen.py b/src/_gettsim/sozialversicherung/kranken/beitrag/einkommen.py index 086199f1e..794ebcec7 100644 --- a/src/_gettsim/sozialversicherung/kranken/beitrag/einkommen.py +++ b/src/_gettsim/sozialversicherung/kranken/beitrag/einkommen.py @@ -32,7 +32,6 @@ def einkommen_bis_beitragsbemessungsgrenze_m( This does not consider reduced contributions for Mini- and Midijobs. Relevant for the computation of payroll taxes. """ - return min( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m, beitragsbemessungsgrenze_m, diff --git a/src/_gettsim/sozialversicherung/midijob.py b/src/_gettsim/sozialversicherung/midijob.py index 3e74db61a..10ac21505 100644 --- a/src/_gettsim/sozialversicherung/midijob.py +++ b/src/_gettsim/sozialversicherung/midijob.py @@ -239,7 +239,6 @@ def midijob_bemessungsentgelt_m_ab_10_2022( Legal reference: Changes in § 20 SGB IV from 01.10.2022 """ - quotient1 = (midijobgrenze) / (midijobgrenze - minijobgrenze) quotient2 = (minijobgrenze) / (midijobgrenze - minijobgrenze) einkommen_diff = ( diff --git a/src/_gettsim/sozialversicherung/pflege/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/pflege/beitrag/beitrag.py index d5f2e511e..a61b2ad06 100644 --- a/src/_gettsim/sozialversicherung/pflege/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/pflege/beitrag/beitrag.py @@ -18,7 +18,6 @@ def betrag_versicherter_m_ohne_midijob( betrag_rentner_m: float, ) -> float: """Long-term care insurance contributions paid by the insured person.""" - if einkommensteuer__einkünfte__ist_selbstständig: out = betrag_selbstständig_m elif sozialversicherung__geringfügig_beschäftigt: @@ -44,7 +43,6 @@ def betrag_versicherter_m_mit_midijob( betrag_rentner_m: float, ) -> float: """Long-term care insurance contributions paid by the insured person.""" - if einkommensteuer__einkünfte__ist_selbstständig: out = betrag_selbstständig_m elif sozialversicherung__geringfügig_beschäftigt: @@ -72,7 +70,6 @@ def betrag_arbeitgeber_m_ohne_midijob( Before Midijob introduction in April 2003. """ - if ( einkommensteuer__einkünfte__ist_selbstständig or sozialversicherung__geringfügig_beschäftigt @@ -99,7 +96,6 @@ def betrag_arbeitgeber_m_mit_midijob( After Midijob introduction in April 2003. """ - if ( einkommensteuer__einkünfte__ist_selbstständig or sozialversicherung__geringfügig_beschäftigt @@ -137,7 +133,6 @@ def betrag_versicherter_regulär_beschäftigt_m( """Long-term care insurance contributions paid by the insured person if regularly employed. """ - return sozialversicherung__kranken__beitrag__einkommen_m * beitragssatz_arbeitnehmer @@ -149,7 +144,6 @@ def betrag_arbeitgeber_regulär_beschäftigt_m( """Long-term care insurance contributions paid by the employer under regular employment. """ - return sozialversicherung__kranken__beitrag__einkommen_m * beitragssatz_arbeitgeber @@ -162,7 +156,6 @@ def betrag_gesamt_in_gleitzone_m( beitragssatz_arbeitgeber: float, ) -> float: """Sum of employee and employer long-term care insurance contributions.""" - return sozialversicherung__midijob_bemessungsentgelt_m * ( beitragssatz_arbeitnehmer + beitragssatz_arbeitgeber ) @@ -209,7 +202,6 @@ def betrag_versicherter_midijob_m_mit_verringertem_beitrag_für_eltern_mit_mehre beitragssatz_nach_kinderzahl: dict[str, float], ) -> float: """Employee's long-term care insurance contribution.""" - base = ( sozialversicherung__beitragspflichtige_einnahmen_aus_midijob_arbeitnehmer_m * beitragssatz_nach_kinderzahl["standard"] diff --git a/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py b/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py index 9f57e547f..77a6cd570 100644 --- a/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py +++ b/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py @@ -29,7 +29,6 @@ def beitragssatz_arbeitnehmer_zusatz_kinderlos_dummy( Since 2005, the contribution rate is increased for childless individuals. """ - # Add additional contribution for childless individuals if zahlt_zusatzbetrag_kinderlos: out = ( diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py index 5f574ee38..9a02137f5 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py @@ -31,7 +31,6 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_ohne_besonders_langjährig( to the normal retirement age (FRA<=NRA) and depends on personal characteristics as gender, insurance duration, health/disability, employment status. """ - out = regelaltersrente__altersgrenze if für_frauen__grundsätzlich_anspruchsberechtigt: out = xnp.minimum(out, für_frauen__altersgrenze) @@ -75,7 +74,6 @@ def altersgrenze_mit_arbeitslosigkeit_frauen_besonders_langjährig( because then all potential beneficiaries of the Rente wg. Arbeitslosigkeit and Rente für Frauen have reached the normal retirement age. """ - out = regelaltersrente__altersgrenze if für_frauen__grundsätzlich_anspruchsberechtigt: out = xnp.minimum(out, für_frauen__altersgrenze) @@ -114,7 +112,6 @@ def altersgrenze_mit_besonders_langjährig_ohne_arbeitslosigkeit_frauen( to the normal retirement age (FRA<=NRA) and depends on personal characteristics as gender, insurance duration, health/disability, employment status. """ - out = regelaltersrente__altersgrenze if langjährig__grundsätzlich_anspruchsberechtigt: out = xnp.minimum(out, langjährig__altersgrenze) @@ -177,7 +174,6 @@ def altersgrenze_vorzeitig_ohne_arbeitslosigkeit_frauen( Early retirement age depends on personal characteristics as gender, insurance duration, health/disability, employment status. """ - out = regelaltersrente__altersgrenze if langjährig__grundsätzlich_anspruchsberechtigt: @@ -203,7 +199,6 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( becomes inactive in 2018 because then all potential beneficiaries of the Rente wg. Arbeitslosigkeit and Rente für Frauen have reached the normal retirement age. """ - return ( für_frauen__grundsätzlich_anspruchsberechtigt or langjährig__grundsätzlich_anspruchsberechtigt @@ -221,7 +216,6 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_vorzeitig_ohne_arbeitslosigkeit Can only be claimed if eligible for "Rente für langjährig Versicherte". """ - return langjährig__grundsätzlich_anspruchsberechtigt diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py index 838738063..52dceac82 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py @@ -60,7 +60,6 @@ def bruttorente_basisbetrag_m_nach_wohnort( - https://de.wikipedia.org/wiki/Rentenformel - https://de.wikipedia.org/wiki/Rentenanpassungsformel """ - if sozialversicherung__rente__bezieht_rente: out = ( sozialversicherung__rente__entgeltpunkte_west @@ -97,7 +96,6 @@ def bruttorente_basisbetrag_m( - https://de.wikipedia.org/wiki/Rentenformel - https://de.wikipedia.org/wiki/Rentenanpassungsformel """ - if sozialversicherung__rente__bezieht_rente: out = ( ( @@ -161,7 +159,6 @@ def zugangsfaktor( `regelaltersrente__grundsätzlich_anspruchsberechtigt` is False. """ - if regelaltersrente__grundsätzlich_anspruchsberechtigt: # Early retirement (before full retirement age): Zugangsfaktor < 1 if ( diff --git "a/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" "b/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" index 41399f1fd..adceb370a 100644 --- "a/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" +++ "b/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" @@ -34,5 +34,4 @@ def grundsätzlich_anspruchsberechtigt( """Determining the eligibility for Altersrente für besonders langjährig Versicherte (pension for very long-term insured). Wartezeit 45 years. aka "Rente mit 63". """ - return sozialversicherung__rente__wartezeit_45_jahre_erfüllt diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/entgeltpunkte.py b/src/_gettsim/sozialversicherung/rente/altersrente/entgeltpunkte.py index b8cf4d2ed..7e891fdde 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/entgeltpunkte.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/entgeltpunkte.py @@ -52,7 +52,6 @@ def neue_entgeltpunkte( umrechnung_entgeltpunkte_beitrittsgebiet: float, ) -> float: """Return earning points for the wages earned in the last year.""" - # Scale bruttolohn up if earned in eastern Germany if wohnort_ost: bruttolohn_scaled_east = ( diff --git "a/src/_gettsim/sozialversicherung/rente/altersrente/f\303\274r_frauen/f\303\274r_frauen.py" "b/src/_gettsim/sozialversicherung/rente/altersrente/f\303\274r_frauen/f\303\274r_frauen.py" index 0bc939d34..58a281f15 100644 --- "a/src/_gettsim/sozialversicherung/rente/altersrente/f\303\274r_frauen/f\303\274r_frauen.py" +++ "b/src/_gettsim/sozialversicherung/rente/altersrente/f\303\274r_frauen/f\303\274r_frauen.py" @@ -68,7 +68,6 @@ def grundsätzlich_anspruchsberechtigt_ohne_prüfung_geburtsjahr( Policy becomes inactive in 2018 because then all potential beneficiaries have reached the normal retirement age. """ - return ( weiblich and sozialversicherung__rente__wartezeit_15_jahre_erfüllt @@ -97,7 +96,6 @@ def grundsätzlich_anspruchsberechtigt_mit_prüfung_geburtsjahr( becomes inactive in 2018 because then all potential beneficiaries have reached the normal retirement age. """ - return ( weiblich and sozialversicherung__rente__wartezeit_15_jahre_erfüllt diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/inputs.py b/src/_gettsim/sozialversicherung/rente/altersrente/inputs.py index 7da518e77..4ebdce173 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/inputs.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/inputs.py @@ -9,4 +9,5 @@ def höchster_bruttolohn_letzte_15_jahre_vor_rente_y() -> float: """Highest gross income from regular employment in the last 15 years before pension benefit claiming. Relevant to determine pension benefit deductions for retirees in - early retirement.""" + early retirement. + """ diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py b/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py index 038899061..5fad8bf96 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py @@ -29,5 +29,4 @@ def grundsätzlich_anspruchsberechtigt( sozialversicherung__rente__mindestwartezeit_erfüllt: bool, ) -> bool: """Determining the eligibility for the Regelaltersrente.""" - return sozialversicherung__rente__mindestwartezeit_erfüllt diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/inputs.py b/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/inputs.py index 4ca64bbe2..e444ed34f 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/inputs.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/inputs.py @@ -18,10 +18,12 @@ def pflichtbeitragsjahre_8_von_10() -> bool: @policy_input() def vertrauensschutz_1997() -> bool: """Is covered by Vertrauensschutz rules for the Altersrente wegen Arbeitslosigkeit - implemented in 1997 (§ 237 SGB VI Abs. 4).""" + implemented in 1997 (§ 237 SGB VI Abs. 4). + """ @policy_input() def vertrauensschutz_2004() -> bool: """Is covered by Vertrauensschutz rules for the Altersrente wegen Arbeitslosigkeit - enacted in July 2004 (§ 237 SGB VI Abs. 5).""" + enacted in July 2004 (§ 237 SGB VI Abs. 5). + """ diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/wegen_arbeitslosigkeit.py b/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/wegen_arbeitslosigkeit.py index d09a801ed..4b614857a 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/wegen_arbeitslosigkeit.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/wegen_arbeitslosigkeit/wegen_arbeitslosigkeit.py @@ -234,7 +234,6 @@ def grundsätzlich_anspruchsberechtigt_ohne_prüfung_geburtsjahr( regarding voluntary unemployment this requirement may be viewed as always satisfied and is therefore not included when checking for eligibility. """ - return ( arbeitslos_für_1_jahr_nach_alter_58_ein_halb and sozialversicherung__rente__wartezeit_15_jahre_erfüllt @@ -264,7 +263,6 @@ def grundsätzlich_anspruchsberechtigt_mit_prüfung_geburtsjahr( becomes inactive in 2018 because then all potential beneficiaries have reached the Regelaltersgrenze. """ - return ( arbeitslos_für_1_jahr_nach_alter_58_ein_halb and sozialversicherung__rente__wartezeit_15_jahre_erfüllt diff --git a/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py index 4eb51dd65..51a028046 100644 --- a/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py @@ -33,7 +33,6 @@ def betrag_versicherter_m_mit_midijob( After Midijob introduction in April 2003. """ - if sozialversicherung__geringfügig_beschäftigt: out = 0.0 elif sozialversicherung__in_gleitzone: @@ -68,7 +67,6 @@ def betrag_arbeitgeber_m_ohne_arbeitgeberpauschale( Before Minijobs were subject to pension contributions. """ - if sozialversicherung__geringfügig_beschäftigt: out = 0.0 else: @@ -92,7 +90,6 @@ def betrag_arbeitgeber_m_mit_arbeitgeberpauschale( Before Midijob introduction in April 2003. """ - if sozialversicherung__geringfügig_beschäftigt: out = ( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m @@ -117,7 +114,6 @@ def betrag_arbeitgeber_m_mit_midijob( After Midijob introduction in April 2003. """ - if sozialversicherung__geringfügig_beschäftigt: out = ( einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m diff --git a/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py b/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py index c55f2d3d7..f8fe5d1ff 100644 --- a/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py +++ b/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py @@ -25,7 +25,6 @@ def betrag_m_nach_wohnort( Legal reference: SGB VI § 64: Rentenformel für Monatsbetrag der Rente """ - if grundsätzlich_anspruchsberechtigt: out = ( ( @@ -63,7 +62,6 @@ def betrag_m_einheitlich( Legal reference: SGB VI § 64: Rentenformel für Monatsbetrag der Rente """ - if grundsätzlich_anspruchsberechtigt: out = ( (entgeltpunkte_ost + entgeltpunkte_west) @@ -88,7 +86,6 @@ def grundsätzlich_anspruchsberechtigt( Legal reference: § 43 Abs. 1 SGB VI. """ - anspruch_erwerbsm_rente = ( (voll_erwerbsgemindert or teilweise_erwerbsgemindert) and sozialversicherung__rente__mindestwartezeit_erfüllt @@ -114,7 +111,6 @@ def entgeltpunkte_west( additional earning points. They receive their average earned income points for each year between their age of retirement and the "zurechnungszeitgrenze". """ - return sozialversicherung__rente__entgeltpunkte_west + ( zurechnungszeit * (1 - anteil_entgeltpunkte_ost) ) @@ -138,7 +134,6 @@ def entgeltpunkte_ost( additional earning points. They receive their average earned income points for each year between their age of retirement and the "zurechnungszeitgrenze". """ - return sozialversicherung__rente__entgeltpunkte_ost + ( zurechnungszeit * anteil_entgeltpunkte_ost ) @@ -403,7 +398,6 @@ def mean_entgeltpunkte_pro_bewertungsmonat( Legal reference: SGB VI § 72: Grundbewertung """ - belegungsfähiger_gesamtzeitraum = ( sozialversicherung__rente__alter_bei_renteneintritt - altersgrenze_grundbewertung diff --git a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py index 92a8c79a2..2830f2d3f 100644 --- a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py +++ b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py @@ -53,7 +53,6 @@ def einkommen_m( Reference: § 97a Abs. 2 S. 1 SGB VI """ - # Sum income over different income sources. return ( einkommensteuer__einkünfte__sonstige__renteneinkünfte_vorjahr_m @@ -101,7 +100,6 @@ def anzurechnendes_einkommen_m( Reference: § 97a Abs. 4 S. 2, 4 SGB VI """ - # Calculate relevant income following the crediting rules using the values for # singles and those for married subjects # Note: Thresholds are defined relativ to rentenwert which is implemented by @@ -142,7 +140,6 @@ def basisbetrag_m( The Zugangsfaktor is limited to 1 and considered Grundrentezeiten are limited to 35 years (420 months). """ - bewertungszeiten = min( bewertungszeiten_monate, berücksichtigte_wartezeit_monate["max"], diff --git a/src/_gettsim/sozialversicherung/rente/inputs.py b/src/_gettsim/sozialversicherung/rente/inputs.py index 0878a235c..f5069f2d3 100644 --- a/src/_gettsim/sozialversicherung/rente/inputs.py +++ b/src/_gettsim/sozialversicherung/rente/inputs.py @@ -29,7 +29,8 @@ def entgeltpunkte_west() -> float: @policy_input() def ersatzzeiten_monate() -> float: """Total months during military, persecution/escape, internment, and consecutive - sickness.""" + sickness. + """ @policy_input() @@ -70,7 +71,8 @@ def monate_in_arbeitslosigkeit() -> float: @policy_input() def monate_in_arbeitsunfähigkeit() -> float: """Total months of sickness, rehabilitation, measures for worklife - participation(Teilhabe).""" + participation(Teilhabe). + """ @policy_input() @@ -91,7 +93,8 @@ def monate_in_schulausbildung() -> float: @policy_input() def monate_mit_bezug_entgeltersatzleistungen_wegen_arbeitslosigkeit() -> float: """Total months of unemployment (only time of Entgeltersatzleistungen, not - ALGII),i.e. Arbeitslosengeld, Unterhaltsgeld, Übergangsgeld.""" + ALGII),i.e. Arbeitslosengeld, Unterhaltsgeld, Übergangsgeld. + """ @policy_input() diff --git a/src/_gettsim/unterhalt/unterhalt.py b/src/_gettsim/unterhalt/unterhalt.py index b193d38a5..3a6facbda 100644 --- a/src/_gettsim/unterhalt/unterhalt.py +++ b/src/_gettsim/unterhalt/unterhalt.py @@ -13,7 +13,8 @@ def kind_festgelegter_zahlbetrag_m( abzugsrate_kindergeld: dict[str, float], ) -> float: """Monthly actual child alimony payments to be received by the child after - deductions.""" + deductions. + """ if familie__kind: abzugsrate = abzugsrate_kindergeld["minderjährig"] else: diff --git a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py index 8cfdf8164..11368358b 100644 --- a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py +++ b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py @@ -273,7 +273,8 @@ def elternteil_mindesteinkommen_erreicht( xnp: ModuleType, ) -> BoolColumn: """Income of Unterhaltsvorschuss recipient above threshold (this variable is - defined on child level).""" + defined on child level). + """ return join( foreign_key=kindergeld__p_id_empfänger, primary_key=p_id, diff --git "a/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" "b/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" index 81c12adb6..1602b8064 100644 --- "a/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" +++ "b/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" @@ -58,7 +58,6 @@ def wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg( wohngeld__anspruchshöhe_m_bg: float, ) -> bool: """Check if housing and child benefit have priority.""" - return ( arbeitslosengeld_2__anzurechnendes_einkommen_m_bg + wohngeld__anspruchshöhe_m_bg diff --git a/src/_gettsim/wohngeld/einkommen.py b/src/_gettsim/wohngeld/einkommen.py index 81f18edd6..0a8f31684 100644 --- a/src/_gettsim/wohngeld/einkommen.py +++ b/src/_gettsim/wohngeld/einkommen.py @@ -234,7 +234,6 @@ def freibetrag_m_bis_2015( xnp: ModuleType, ) -> float: """Calculate housing benefit subtractions for one individual until 2015.""" - freibetrag_bei_behinderung = ( piecewise_polynomial( x=behinderungsgrad, diff --git a/src/_gettsim/wohngeld/miete.py b/src/_gettsim/wohngeld/miete.py index 91f3da2ae..20054a889 100644 --- a/src/_gettsim/wohngeld/miete.py +++ b/src/_gettsim/wohngeld/miete.py @@ -221,7 +221,6 @@ def miete_m_hh_mit_baujahr( xnp: ModuleType, ) -> float: """Rent considered in housing benefit calculation on household level until 2008.""" - selected_bin_index = xnp.searchsorted( max_miete_m_lookup.baujahre, wohnen__baujahr_immobilie_hh, @@ -248,7 +247,6 @@ def miete_m_hh_ohne_baujahr_ohne_heizkostenentlastung( max_miete_m_lookup: ConsecutiveInt2dLookupTableParamValue, ) -> float: """Rent considered in housing benefit since 2009.""" - max_miete_m = max_miete_m_lookup.values_to_look_up[ anzahl_personen_hh - max_miete_m_lookup.base_to_subtract_rows, mietstufe - max_miete_m_lookup.base_to_subtract_cols, diff --git a/src/_gettsim/wohngeld/voraussetzungen.py b/src/_gettsim/wohngeld/voraussetzungen.py index 5713d3ff2..69a2abeb0 100644 --- a/src/_gettsim/wohngeld/voraussetzungen.py +++ b/src/_gettsim/wohngeld/voraussetzungen.py @@ -92,7 +92,6 @@ def vermögensgrenze_unterschritten_bg( parameter_vermögensfreibetrag: dict[str, float], ) -> bool: """Wealth is below the eligibility threshold for housing benefits.""" - vermögensfreibetrag = parameter_vermögensfreibetrag[ "grundfreibetrag" ] + parameter_vermögensfreibetrag["je_weitere_person"] * ( @@ -160,7 +159,6 @@ def einkommen_für_mindesteinkommen_m( Kindergeld count as income for this check. """ - return ( arbeitslosengeld_2__nettoeinkommen_vor_abzug_freibetrag_m + unterhalt__tatsächlich_erhaltener_betrag_m diff --git a/src/ttsim/interface_dag.py b/src/ttsim/interface_dag.py index 1ed380e6c..bacb5f162 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/interface_dag.py @@ -30,7 +30,6 @@ def main( """ Main function that processes the inputs and returns the outputs. """ - if "backend" not in inputs: inputs["backend"] = backend diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index aacd716ed..7391545e0 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -465,7 +465,6 @@ def create_time_conversion_functions( ------- The functions dict with the new time conversion functions. """ - time_units = tuple(TIME_UNIT_LABELS) pattern_all = get_re_pattern_for_all_time_units_and_groupings( grouping_levels=grouping_levels, diff --git a/src/ttsim/interface_dag_elements/backend.py b/src/ttsim/interface_dag_elements/backend.py index 2962e79c3..69f567b89 100644 --- a/src/ttsim/interface_dag_elements/backend.py +++ b/src/ttsim/interface_dag_elements/backend.py @@ -22,7 +22,6 @@ def xnp(backend: Literal["numpy", "jax"]) -> ModuleType: """ Return the backend for numerical operations (either NumPy or jax). """ - if backend == "numpy": xnp = numpy elif backend == "jax": diff --git a/src/ttsim/interface_dag_elements/data_converters.py b/src/ttsim/interface_dag_elements/data_converters.py index 58956fc68..f9cb0b9b3 100644 --- a/src/ttsim/interface_dag_elements/data_converters.py +++ b/src/ttsim/interface_dag_elements/data_converters.py @@ -28,7 +28,8 @@ def nested_data_to_df_with_nested_columns( data_with_p_id: Some data structure with a "p_id" column. - Returns: + Returns + ------- A DataFrame. """ flat_data_to_convert = dt.flatten_to_tree_paths(nested_data_to_convert) @@ -53,7 +54,8 @@ def nested_data_to_df_with_mapped_columns( data_with_p_id: Some data structure with a "p_id" column. - Returns: + Returns + ------- A DataFrame. """ flat_data_to_convert = dt.flatten_to_tree_paths(nested_data_to_convert) @@ -84,13 +86,13 @@ def dataframe_to_nested_data( df: The pandas DataFrame containing the source data. - Returns - ------- + Returns + ------- A nested dictionary structure containing the data organized according to the mapping definition. - Examples - -------- + Examples + -------- >>> df = pd.DataFrame({ ... "a": [1, 2, 3], ... "b": [4, 5, 6], diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index a22bf1dc7..df447f2f2 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -98,7 +98,9 @@ class _ParamWithActivePeriod(ParamObject): def assert_valid_ttsim_pytree( - tree: Any, leaf_checker: GenericCallable, tree_name: str + tree: Any, # noqa: ANN401 + leaf_checker: GenericCallable, + tree_name: str, ) -> None: """ Recursively assert that a pytree meets the following conditions: @@ -121,7 +123,7 @@ def assert_valid_ttsim_pytree( If any branch or leaf does not meet the expected requirements. """ - def _assert_valid_ttsim_pytree(subtree: Any, current_key: tuple[str, ...]) -> None: + def _assert_valid_ttsim_pytree(subtree: Any, current_key: tuple[str, ...]) -> None: # noqa: ANN401 def format_key_path(key_tuple: tuple[str, ...]) -> str: return "".join(f"[{k}]" for k in key_tuple) @@ -319,7 +321,6 @@ def foreign_keys_are_invalid_in_data( We need processed_data because we cannot guarantee that `p_id` is present in the input data. """ - valid_ids = set(processed_data["p_id"].tolist()) | {-1} relevant_objects = { k: v @@ -348,6 +349,7 @@ def foreign_keys_are_invalid_in_data( for i, j in zip( processed_data[fk_name].tolist(), processed_data["p_id"].tolist(), + strict=False, ) if i == j ] @@ -432,7 +434,8 @@ def non_convertible_objects_in_results_tree( xnp: ModuleType, ) -> None: """Fail if results should be converted to a DataFrame but contain non-convertible - objects.""" + objects. + """ _numeric_types = (int, float, bool, xnp.integer, xnp.floating, xnp.bool_) expected_object_length = len(next(iter(processed_data.values()))) @@ -587,7 +590,6 @@ def root_nodes_are_missing( ValueError If root nodes are missing. """ - # Obtain root nodes root_nodes = nx.subgraph_view( specialized_environment__tax_transfer_dag, diff --git a/src/ttsim/interface_dag_elements/input_data.py b/src/ttsim/interface_dag_elements/input_data.py index 5095caab9..b54563b8f 100644 --- a/src/ttsim/interface_dag_elements/input_data.py +++ b/src/ttsim/interface_dag_elements/input_data.py @@ -51,7 +51,8 @@ def tree( df_and_mapper__mapper: A tree that maps paths (sequence of keys) to data columns names. - Returns: + Returns + ------- A nested data structure. """ return dataframe_to_nested_data( @@ -69,7 +70,8 @@ def flat(tree: NestedData) -> FlatData: tree: The input tree. - Returns: + Returns + ------- Mapping of tree paths to input data. """ return dt.flatten_to_tree_paths(tree) diff --git a/src/ttsim/interface_dag_elements/interface_node_objects.py b/src/ttsim/interface_dag_elements/interface_node_objects.py index 66a02e000..d21b47dba 100644 --- a/src/ttsim/interface_dag_elements/interface_node_objects.py +++ b/src/ttsim/interface_dag_elements/interface_node_objects.py @@ -2,7 +2,7 @@ import inspect from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar import dags.tree as dt @@ -17,7 +17,8 @@ class InterfaceNodeObject: """Base class for all objects operating on columns of data. - Examples: + Examples + -------- - PolicyInputs - PolicyFunctions - GroupCreationFunctions @@ -57,7 +58,7 @@ def remove_tree_logic( def dummy_callable(self) -> InterfaceFunction: # type: ignore[type-arg] """Dummy callable for the interface input. Just used for plotting.""" - def dummy() -> Any: + def dummy(): # type: ignore[no-untyped-def] # noqa: ANN202 pass return interface_function( @@ -67,7 +68,7 @@ def dummy() -> Any: def interface_input( - in_top_level_namespace: bool = False, + in_top_level_namespace: bool = False, # noqa: FBT002 ) -> GenericCallable[[GenericCallable], InterfaceInput]: """ Decorator that makes a (dummy) function an `InterfaceInput`. diff --git a/src/ttsim/interface_dag_elements/names.py b/src/ttsim/interface_dag_elements/names.py index 8d097dc1a..1f2a9aea1 100644 --- a/src/ttsim/interface_dag_elements/names.py +++ b/src/ttsim/interface_dag_elements/names.py @@ -103,7 +103,6 @@ def top_level_namespace( top_level_namespace: The top level namespace. """ - time_units = tuple(TIME_UNIT_LABELS) direct_top_level_names = set(policy_environment) @@ -165,7 +164,6 @@ def root_nodes( The names of the columns in `processed_data` required for the tax transfer function. """ - # Obtain root nodes root_nodes = nx.subgraph_view( specialized_environment__tax_transfer_dag, diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index 8f6d36ff0..4825dc766 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -120,7 +120,6 @@ def _active_column_objects_and_param_functions( ------- A tree of active ColumnObjectParamFunctions. """ - flat_objects_tree = { (*orig_path[:-2], obj.leaf_name): obj for orig_path, obj in orig.items() diff --git a/src/ttsim/interface_dag_elements/processed_data.py b/src/ttsim/interface_dag_elements/processed_data.py index 243f483d9..2187dca29 100644 --- a/src/ttsim/interface_dag_elements/processed_data.py +++ b/src/ttsim/interface_dag_elements/processed_data.py @@ -25,10 +25,10 @@ def processed_data(input_data__flat: FlatData, xnp: ModuleType) -> QNameData: input_data__tree: The input data provided by the user. - Returns: + Returns + ------- A DataFrame. """ - processed_input_data = {} old_p_ids = xnp.asarray(input_data__flat[("p_id",)]) new_p_ids = reorder_ids(ids=old_p_ids, xnp=xnp) diff --git a/src/ttsim/interface_dag_elements/results.py b/src/ttsim/interface_dag_elements/results.py index 6d87dbe99..d22540a8a 100644 --- a/src/ttsim/interface_dag_elements/results.py +++ b/src/ttsim/interface_dag_elements/results.py @@ -54,7 +54,8 @@ def df_with_mapper( nested_outputs_df_column_names: A tree that maps paths (sequence of keys) to data columns names. - Returns: + Returns + ------- A DataFrame. """ return nested_data_to_df_with_mapped_columns( @@ -78,7 +79,8 @@ def df_with_nested_columns( nested_outputs_df_column_names: A tree that maps paths (sequence of keys) to data columns names. - Returns: + Returns + ------- A DataFrame. """ return nested_data_to_df_with_nested_columns( diff --git a/src/ttsim/interface_dag_elements/shared.py b/src/ttsim/interface_dag_elements/shared.py index f7edd79e8..eee428340 100644 --- a/src/ttsim/interface_dag_elements/shared.py +++ b/src/ttsim/interface_dag_elements/shared.py @@ -114,7 +114,8 @@ def get_base_name_and_grouping_suffix(match: re.Match[str]) -> tuple[str, str]: def create_tree_from_path_and_value( - path: tuple[str], value: Any = None + path: tuple[str], + value: Any = None, # noqa: ANN401 ) -> dict[str, Any]: """Create a nested dict with 'path' as keys and 'value' as leaf. @@ -139,7 +140,6 @@ def create_tree_from_path_and_value( ------- The tree structure. """ - nested_dict = value for entry in reversed(path): nested_dict = {entry: nested_dict} @@ -161,7 +161,6 @@ def merge_trees(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]: ------- The merged pytree. """ - if set(optree.tree_paths(left)) & set(optree.tree_paths(right)): # type: ignore[arg-type] raise ValueError("Conflicting paths in trees to merge.") @@ -205,7 +204,9 @@ def upsert_tree(base: dict[str, Any], to_upsert: dict[str, Any]) -> dict[str, An def upsert_path_and_value( - base: dict[str, Any], path_to_upsert: tuple[str], value_to_upsert: Any = None + base: dict[str, Any], + path_to_upsert: tuple[str], + value_to_upsert: Any = None, # noqa: ANN401 ) -> dict[str, Any]: """Update tree with a path and value. @@ -220,7 +221,9 @@ def upsert_path_and_value( def insert_path_and_value( - base: dict[str, Any], path_to_insert: tuple[str], value_to_insert: Any = None + base: dict[str, Any], + path_to_insert: tuple[str], + value_to_insert: Any = None, # noqa: ANN401 ) -> dict[str, Any]: """Insert a path and value into a tree. diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 6982e54eb..3864920a9 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -3,7 +3,7 @@ import datetime import functools from types import ModuleType -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Literal import dags.tree as dt from dags import concatenate_functions, create_dag, get_free_arguments @@ -289,7 +289,7 @@ def with_partialled_params_and_scalars( return processed_functions -def _apply_rounding(element: Any, xnp: ModuleType) -> Any: +def _apply_rounding(element: ColumnFunction, xnp: ModuleType) -> ColumnFunction: return ( element.rounding_spec.apply_rounding(element, xnp=xnp) if getattr(element, "rounding_spec", False) diff --git a/src/ttsim/plot_dag.py b/src/ttsim/plot_dag.py index 4167f1672..e0071965d 100644 --- a/src/ttsim/plot_dag.py +++ b/src/ttsim/plot_dag.py @@ -109,7 +109,6 @@ def plot_tt_dag( def plot_full_interface_dag(output_path: Path) -> None: """Plot the full interface DAG.""" - nodes = { p: n.dummy_callable() if isinstance(n, InterfaceInput) else n for p, n in load_interface_functions_and_inputs().items() @@ -139,7 +138,6 @@ def plot_full_interface_dag(output_path: Path) -> None: def _plot_dag(dag: nx.DiGraph, title: str) -> go.Figure: """Plot the DAG.""" - nice_dag = nx.relabel_nodes( dag, {qn: qn.replace("__", "
") for qn in dag.nodes()} ) diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index 41225c761..4ed13d190 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -144,7 +144,6 @@ def load_policy_test_data( If policy_name is empty, all tests found in test_dir / "test_data" are loaded. """ - out = {} for path_to_yaml in (test_dir / "test_data" / policy_name).glob("**/*.yaml"): if _is_skipped(path_to_yaml): @@ -180,7 +179,8 @@ def _get_policy_test_from_raw_test_data( raw_test_data: The raw test data. path_to_yaml: The path to the YAML file. - Returns: + Returns + ------- A list of PolicyTest objects. """ test_info: NestedData = raw_test_data.get("info", {}) diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index 1352d0286..1429e0ff7 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -62,7 +62,8 @@ class FKType(StrEnum): class ColumnObject: """Base class for all objects operating on columns of data. - Examples: + Examples + -------- - PolicyInputs - PolicyFunctions - GroupCreationFunctions @@ -121,7 +122,7 @@ def remove_tree_logic( def dummy_callable(self) -> PolicyFunction: """Dummy callable for the interface input. Just used for plotting.""" - def dummy(): # type: ignore[no-untyped-def] + def dummy(): # type: ignore[no-untyped-def] # noqa: ANN202 pass return policy_function( @@ -364,7 +365,6 @@ def policy_function( ------- A decorator that returns a PolicyFunction object. """ - start_date, end_date = _convert_and_validate_dates(start_date, end_date) def inner(func: GenericCallable) -> PolicyFunction: @@ -390,7 +390,6 @@ def reorder_ids(ids: IntColumn, xnp: ModuleType) -> IntColumn: [43,44,70,50] -> [0,1,3,2] """ - sorting = xnp.argsort(ids) ids_sorted = ids[sorting] index_after_sort = xnp.arange(ids.shape[0])[sorting] diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index 4b9b4f913..ba514b429 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -48,7 +48,7 @@ class ParamObject: def dummy_callable(self) -> ParamFunction: """Dummy callable for the policy input. Just used for plotting.""" - def dummy(): # type: ignore[no-untyped-def] + def dummy(): # type: ignore[no-untyped-def] # noqa: ANN202 pass return param_function( @@ -288,7 +288,6 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( Requires all years to be given. """ - first_year_to_consider = raw.pop("first_year_to_consider") last_year_to_consider = raw.pop("last_year_to_consider") assert all(isinstance(k, int) for k in raw) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 7fa9d9782..6f0a8c269 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -116,7 +116,6 @@ def get_piecewise_parameters( ------- """ - # Check if keys are consecutive numbers and starting at 0. if sorted(parameter_dict) != list(range(len(parameter_dict))): raise ValueError( @@ -154,7 +153,7 @@ def get_piecewise_parameters( ) -def check_and_get_thresholds( +def check_and_get_thresholds( # noqa: C901 leaf_name: str, parameter_dict: dict[int, dict[str, float]], xnp: ModuleType, @@ -389,7 +388,6 @@ def _calculate_one_intercept( The value of `x` under the piecewise function. """ - # Check if value lies within the defined range. if (x < lower_thresholds[0]) or (x > upper_thresholds[-1]) or numpy.isnan(x): return numpy.nan diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt_dag_elements/vectorization.py index ad60a5b2e..b12ed3f66 100644 --- a/src/ttsim/tt_dag_elements/vectorization.py +++ b/src/ttsim/tt_dag_elements/vectorization.py @@ -51,7 +51,8 @@ def _make_vectorizable( backend: Backend library. Currently supported backends are 'jax' and 'numpy'. Array module must export function `where` that behaves as `numpy.where`. - Returns: + Returns + ------- New function with altered ast. """ if _is_lambda_function(func): @@ -68,7 +69,7 @@ def _make_vectorizable( if func.__closure__: closure_vars = func.__code__.co_freevars closure_cells = [c.cell_contents for c in func.__closure__] - scope.update(dict(zip(closure_vars, closure_cells))) + scope.update(dict(zip(closure_vars, closure_cells, strict=False))) scope[module] = import_module(module) @@ -92,7 +93,8 @@ def make_vectorizable_source( backends. Array module must export function `where` that behaves as `numpy.where`. - Returns: + Returns + ------- Source code of new function with altered ast. """ if _is_lambda_function(func): @@ -115,7 +117,8 @@ def _make_vectorizable_ast( func: Function. module: Module which exports the function `where` that behaves as `numpy.where`. - Returns: + Returns + ------- AST of new function with altered ast. """ tree = _func_to_ast(func) diff --git a/tests/ttsim/mettsim/group_by_ids.py b/tests/ttsim/mettsim/group_by_ids.py index 8b458d240..61c3da794 100644 --- a/tests/ttsim/mettsim/group_by_ids.py +++ b/tests/ttsim/mettsim/group_by_ids.py @@ -83,7 +83,6 @@ def _assign_parents_fam_id( xnp: ModuleType, ) -> IntColumn: """Return the fam_id of the child's parents.""" - return xnp.where( (fam_id == p_id + p_id * n) * (p_id_parent_loc >= 0) diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 6262ca123..0fd147da5 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -632,7 +632,9 @@ def test_fail_if_group_variables_are_not_constant_within_groups(): "foo_kin": numpy.array([1, 2, 2]), "kin_id": numpy.array([1, 1, 2]), } - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="The following data inputs do not have a unique value within" + ): group_variables_are_not_constant_within_groups( names__grouping_levels=("kin",), names__root_nodes={n for n in data if n != "p_id"}, diff --git a/tests/ttsim/test_orig_policy_objects.py b/tests/ttsim/test_orig_policy_objects.py index b481dfacf..8f8506ef8 100644 --- a/tests/ttsim/test_orig_policy_objects.py +++ b/tests/ttsim/test_orig_policy_objects.py @@ -17,6 +17,7 @@ def test_load_path(): def test_dont_load_init_py(): """Don't load __init__.py files as sources for PolicyFunctions and - AggregationSpecs.""" + AggregationSpecs. + """ all_files = _find_files_recursively(root=METTSIM_ROOT, suffix=".py") assert "__init__.py" not in [file.name for file in all_files] diff --git a/tests/ttsim/test_policy_environment.py b/tests/ttsim/test_policy_environment.py index 5314ba499..43a3ed245 100644 --- a/tests/ttsim/test_policy_environment.py +++ b/tests/ttsim/test_policy_environment.py @@ -120,7 +120,9 @@ def test_func(): ], ) def test_start_date_invalid(date_string: str): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="neither matches the format YYYY-MM-DD nor is a datetime.date" + ): @policy_function(start_date=date_string) def test_func(): @@ -158,7 +160,9 @@ def test_func(): ], ) def test_end_date_invalid(date_string: str): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="neither matches the format YYYY-MM-DD nor is a datetime.date" + ): @policy_function(end_date=date_string) def test_func(): @@ -174,7 +178,7 @@ def test_func(): def test_active_period_is_empty(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="must be before the end date"): @policy_function(start_date="2023-01-20", end_date="2023-01-19") def test_func(): diff --git a/tests/ttsim/test_shared.py b/tests/ttsim/test_shared.py index ace8ee00b..9d223fc7a 100644 --- a/tests/ttsim/test_shared.py +++ b/tests/ttsim/test_shared.py @@ -28,7 +28,7 @@ def test_leap_year_correctly_handled(): def test_fail_if_invalid_date(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="day is out of range for month"): to_datetime(date="2020-02-30") diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index d9272a0af..59cc5e2b2 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -251,17 +251,17 @@ def test_func(income): @pytest.mark.parametrize( - "base, direction, to_add_after_rounding", + "base, direction, to_add_after_rounding, match", [ - (1, "upper", 0), - ("0.1", "down", 0), - (5, "closest", 0), - (5, "up", "0"), + (1, "upper", 0, "`direction` must be one of"), + (5, "closest", 0, "`direction` must be one of"), + ("0.1", "down", 0, "base needs to be a number"), + (5, "up", "0", "Additive part must be a number"), ], ) -def test_rounding_spec_validation(base, direction, to_add_after_rounding): +def test_rounding_spec_validation(base, direction, to_add_after_rounding, match): """Test validation of RoundingSpec parameters.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=match): RoundingSpec( base=base, direction=direction, diff --git a/tests/ttsim/tt_dag_elements/test_vectorization.py b/tests/ttsim/tt_dag_elements/test_vectorization.py index d84470469..cd7d895f5 100644 --- a/tests/ttsim/tt_dag_elements/test_vectorization.py +++ b/tests/ttsim/tt_dag_elements/test_vectorization.py @@ -376,7 +376,7 @@ def test_disallowed_operation_wrapper(func): @pytest.mark.parametrize( "funcname, func", - [ + ( (funcname, pf.function) for funcname, pf in dt.flatten_to_tree_paths( _active_column_objects_and_param_functions( @@ -391,7 +391,7 @@ def test_disallowed_operation_wrapper(func): | AggByPIDFunction | PolicyInput, ) - ], + ), ) def test_convertible(funcname, func, backend, xnp): # noqa: ARG001 # Leave funcname for debugging purposes. From 141cf48a1d3c323fe21e9df941f0570e841d2860 Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 14:34:31 +0200 Subject: [PATCH 21/25] Further cut down on ruff exceptions. --- conftest.py | 9 +- docs/conf.py | 4 +- pixi.lock | 4 +- pyproject.toml | 58 +++------ .../freibetr\303\244ge_verm\303\266gen.py" | 4 +- .../kindergeld\303\274bertrag.py" | 8 +- .../arbeitslosengeld_2/regelbedarf.py | 4 +- .../abz\303\274ge/alleinerziehend.py" | 3 +- .../einkommensteuer/abz\303\274ge/alter.py" | 6 +- .../abz\303\274ge/sonderausgaben.py" | 4 +- .../einkommensteuer/einkommensteuer.py | 31 +++-- .../aus_kapitalverm\303\266gen.py" | 3 +- .../zu_versteuerndes_einkommen.py | 4 +- src/_gettsim/elterngeld/einkommen.py | 6 +- src/_gettsim/elterngeld/elterngeld.py | 6 +- src/_gettsim/erziehungsgeld/erziehungsgeld.py | 10 +- .../grundsicherung/im_alter/einkommen.py | 3 +- src/_gettsim/household_characteristics.py | 6 +- src/_gettsim/ids.py | 19 ++- src/_gettsim/individual_characteristics.py | 4 +- src/_gettsim/kindergeld/kindergeld.py | 7 +- src/_gettsim/kinderzuschlag/einkommen.py | 15 ++- src/_gettsim/kinderzuschlag/kinderzuschlag.py | 7 +- src/_gettsim/lohnsteuer/lohnsteuer.py | 29 +++-- .../arbeitslosen/beitrag/beitrag.py | 3 +- src/_gettsim/sozialversicherung/minijob.py | 11 +- .../pflege/beitrag/beitragssatz.py | 7 +- .../rente/altersrente/altersgrenzen.py | 12 +- .../rente/altersrente/altersrente.py | 11 +- .../besonders_langj\303\244hrig.py" | 3 +- .../altersrente/hinzuverdienstgrenzen.py | 12 +- .../regelaltersrente/regelaltersrente.py | 3 +- .../rente/beitrag/beitrag.py | 3 +- .../erwerbsminderung/erwerbsminderung.py | 17 ++- .../rente/grundrente/grundrente.py | 15 ++- .../unterhaltsvorschuss.py | 18 ++- .../vorrangpr\303\274fungen.py" | 6 +- src/_gettsim/wohngeld/einkommen.py | 3 +- src/_gettsim/wohngeld/miete.py | 10 +- src/_gettsim/wohngeld/voraussetzungen.py | 6 +- src/_gettsim/wohngeld/wohngeld.py | 3 +- src/_gettsim_tests/test_interface.py | 4 +- src/_gettsim_tests/test_policy.py | 4 +- src/ttsim/interface_dag.py | 5 +- .../automatically_added_functions.py | 3 +- .../interface_dag_elements/data_converters.py | 5 +- src/ttsim/interface_dag_elements/fail_if.py | 64 +++++----- .../interface_node_objects.py | 4 +- .../orig_policy_objects.py | 6 +- .../policy_environment.py | 35 +++--- .../interface_dag_elements/processed_data.py | 4 +- .../interface_dag_elements/raw_results.py | 2 +- src/ttsim/interface_dag_elements/results.py | 3 +- src/ttsim/interface_dag_elements/shared.py | 26 ++-- .../specialized_environment.py | 4 +- src/ttsim/interface_dag_elements/typing.py | 12 +- src/ttsim/interface_dag_elements/warn_if.py | 8 +- src/ttsim/plot_dag.py | 18 +-- src/ttsim/testing_utils.py | 24 ++-- src/ttsim/tt_dag_elements/aggregation.py | 118 ++++++++++-------- src/ttsim/tt_dag_elements/aggregation_jax.py | 56 ++++++--- .../tt_dag_elements/aggregation_numpy.py | 50 +++++--- .../column_objects_param_function.py | 37 +++--- src/ttsim/tt_dag_elements/param_objects.py | 17 ++- .../tt_dag_elements/piecewise_polynomial.py | 59 ++++----- src/ttsim/tt_dag_elements/rounding.py | 9 +- src/ttsim/tt_dag_elements/shared.py | 10 +- src/ttsim/tt_dag_elements/typing.py | 3 +- src/ttsim/tt_dag_elements/vectorization.py | 53 +++++--- .../test_automatically_added_functions.py | 10 +- tests/ttsim/test_convert_nested_data.py | 4 +- tests/ttsim/test_end_to_end.py | 4 +- tests/ttsim/test_failures.py | 54 ++++---- tests/ttsim/test_mettsim.py | 4 +- tests/ttsim/test_policy_environment.py | 6 +- tests/ttsim/test_shared.py | 20 ++- tests/ttsim/test_specialized_environment.py | 44 ++++--- tests/ttsim/test_warnings.py | 3 +- .../test_aggregation_functions.py | 46 +++++-- .../test_piecewise_polynomial.py | 8 +- tests/ttsim/tt_dag_elements/test_rounding.py | 7 -- .../tt_dag_elements/test_ttsim_objects.py | 19 ++- .../tt_dag_elements/test_vectorization.py | 22 +++- 83 files changed, 800 insertions(+), 491 deletions(-) diff --git a/conftest.py b/conftest.py index c5db1f2f9..1e8149ab3 100644 --- a/conftest.py +++ b/conftest.py @@ -24,20 +24,17 @@ def pytest_addoption(parser): @pytest.fixture def backend(request) -> Literal["numpy", "jax"]: - backend = request.config.getoption("--backend") - return backend + return request.config.getoption("--backend") @pytest.fixture def xnp(request) -> ModuleType: - backend = request.config.getoption("--backend") - return ttsim_xnp(backend) + return ttsim_xnp(request.config.getoption("--backend")) @pytest.fixture def dnp(request) -> ModuleType: - backend = request.config.getoption("--backend") - return ttsim_dnp(backend) + return ttsim_dnp(request.config.getoption("--backend")) @pytest.fixture(autouse=True) diff --git a/docs/conf.py b/docs/conf.py index 7973aee26..53be2a516 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ # -- Project information ----------------------------------------------------- project = "GETTSIM" -copyright = f"2019-{datetime.today().year}, GETTSIM team" # noqa: A001 +copyright = f"2019-{datetime.today().year}, GETTSIM team" # noqa: A001, DTZ002 author = "GETTSIM team" release = "0.7.0" version = ".".join(release.split(".")[:2]) @@ -122,7 +122,7 @@ "**": [ "relations.html", # needs 'show_related': True theme option to display "searchbox.html", - ] + ], } # Napoleon settings diff --git a/pixi.lock b/pixi.lock index fa52bddfa..0dcfe26e2 100644 --- a/pixi.lock +++ b/pixi.lock @@ -6622,8 +6622,8 @@ packages: timestamp: 1694400856979 - pypi: ./ name: gettsim - version: 0.7.1.dev457+g905f28b1.d20250613 - sha256: bdcd4b8079d89df7ef67b55be4ee15400ab06021e05e8dfbe3f8817b130725b5 + version: 0.7.1.dev459+g8412c6d3 + sha256: 4a06a834c54fa7149d951da87f6d6bd1386b4abceb6e2f01021dbc58d50f3fda requires_dist: - ipywidgets - networkx diff --git a/pyproject.toml b/pyproject.toml index 42e0d354f..7acd6d4fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,48 +197,28 @@ unsafe-fixes = false [tool.ruff.lint] select = ["ALL"] extend-ignore = [ - "ICN001", # numpy should be np, but different convention here. - # Docstrings - "D103", # missing docstring in public function - "D107", - "D203", - "D212", - "D213", - "D402", - "D413", - "D415", - "D416", - "D417", + "COM812", # Avoid conflicts with ruff-format + "EM101", # Exception must not use a string literal + "EM102", # Exception must not use an f-string literal "F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error "FBT001", # Boolean-typed positional argument in function definition + "FIX002", # Line contains TODO -- Use stuff from TD area. + "ICN001", # numpy should be np, but different convention here. "ISC001", # Avoid conflicts with ruff-format "N999", # Allow non-ASCII characters in file names. "PLC2401", # Allow non-ASCII characters in variable names. "PLC2403", # Allow non-ASCII function names for imports. "PLR0913", # Allow too many arguments in function definitions. - "FIX002", # Line contains TODO -- Use stuff from TD area. - "TRY003", # Avoid specifying long messages outside the exception class "PLR5501", # elif not supported by vectorization converter for Jax - "EM101", # Exception must not use a string literal - "EM102", # Exception must not use an f-string literal - # Others. - - "E731", # do not assign a lambda expression, use a def - "RET", # unnecessary elif or else statements after return, raise, continue, ... - "S324", # Probable use of insecure hash function. - "COM812", # trailing comma missing, but black takes care of that - "PT007", # wrong type in parametrize, gave false positives - "DTZ001", # use of `datetime.datetime()` without `tzinfo` argument is not allowed - "DTZ002", # use of `datetime.datetime.today()` is not allowed - "PT012", # `pytest.raises()` block should contain a single simple statement - + "TRY003", # Avoid specifying long messages outside the exception class # Ignored during transition phase # ====================================== "D", # docstrings - "PLR2004", # Magic values used in comparison "INP001", # implicit namespace packages without init. + "PLR2004", # Magic values used in comparison "PT006", # Allows only lists of tuples in parametrize, even if single argument + "PT007", # wrong type in parametrize "S101", # use of asserts outside of tests ] @@ -248,17 +228,17 @@ exclude = [] "conftest.py" = ["ANN"] "docs/**/*.ipynb" = ["T201"] # Mostly things vectorization can't handle -"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716", "E721", "SIM108"] +"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716", "E721", "SIM108", "RET"] # All tests return None and use asserts "src/_gettsim_tests/**/*.py" = ["ANN", "S101"] "src/ttsim/interface_dag_elements/specialized_environment.py" = ["E501"] "src/ttsim/interface_dag_elements/fail_if.py" = ["E501"] "src/ttsim/interface_dag_elements/typing.py" = ["PGH", "PLR", "SIM114"] # Mostly things vectorization can't handle -"tests/ttsim/mettsim/**/*.py" = ["PLR1714", "PLR1716", "E721", "SIM108"] +"tests/ttsim/mettsim/**/*.py" = ["PLR1714", "PLR1716", "E721", "SIM108", "RET"] +"tests/ttsim/tt_dag_elements/test_vectorization.py" = ["PLR1714", "PLR1716", "E721", "SIM108", "RET"] # All tests return None and use asserts "tests/ttsim/**/*.py" = ["ANN", "S101"] -"tests/ttsim/tt_dag_elements/test_vectorization.py" = ["PLR1714", "PLR1716", "E721", "SIM108"] "tests/ttsim/test_failures.py" = ["E501"] # TODO: remove once ported nicely "src/ttsim/stale_code_storage.py" = ["ALL"] @@ -298,20 +278,12 @@ disallow_untyped_defs = false ignore_errors = true [[tool.mypy.overrides]] -module = [ - "tests.*", -] -disable_error_code = [ - "no-untyped-def", # All tests return None, don't clutter source code. -] +module = ["tests.*",] +disable_error_code = ["no-untyped-def"] # All tests return None, don't clutter source code. [[tool.mypy.overrides]] -module = [ - "src._gettsim_tests.*", -] -disable_error_code = [ - "no-untyped-def", # All tests return None, don't clutter source code. -] +module = ["src._gettsim_tests.*",] +disable_error_code = ["no-untyped-def"] # All tests return None, don't clutter source code. [tool.check-manifest] ignore = ["src/_gettsim/_version.py"] diff --git "a/src/_gettsim/arbeitslosengeld_2/freibetr\303\244ge_verm\303\266gen.py" "b/src/_gettsim/arbeitslosengeld_2/freibetr\303\244ge_verm\303\266gen.py" index ceec0ff98..5212d9fbb 100644 --- "a/src/_gettsim/arbeitslosengeld_2/freibetr\303\244ge_verm\303\266gen.py" +++ "b/src/_gettsim/arbeitslosengeld_2/freibetr\303\244ge_verm\303\266gen.py" @@ -75,7 +75,9 @@ def vermögensfreibetrag_in_karenzzeit_bg( @policy_function( - start_date="2005-01-01", end_date="2022-12-31", leaf_name="vermögensfreibetrag_bg" + start_date="2005-01-01", + end_date="2022-12-31", + leaf_name="vermögensfreibetrag_bg", ) def vermögensfreibetrag_bg_bis_2022( grundfreibetrag_vermögen_bg: float, diff --git "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" index 6eaaede17..0cdfa9bb5 100644 --- "a/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" +++ "b/src/_gettsim/arbeitslosengeld_2/kindergeld\303\274bertrag.py" @@ -14,13 +14,17 @@ @agg_by_p_id_function(start_date="2005-01-01", agg_type=AggType.SUM) def kindergeldübertrag_m( - differenz_kindergeld_kindbedarf_m: float, kindergeld__p_id_empfänger: int, p_id: int + differenz_kindergeld_kindbedarf_m: float, + kindergeld__p_id_empfänger: int, + p_id: int, ) -> float: pass @policy_function( - start_date="2005-01-01", end_date="2022-12-31", leaf_name="kindergeld_pro_kind_m" + start_date="2005-01-01", + end_date="2022-12-31", + leaf_name="kindergeld_pro_kind_m", ) def _mean_kindergeld_per_child_gestaffelt_m( kindergeld__betrag_m: float, diff --git a/src/_gettsim/arbeitslosengeld_2/regelbedarf.py b/src/_gettsim/arbeitslosengeld_2/regelbedarf.py index 0ca60facd..33ce82452 100644 --- a/src/_gettsim/arbeitslosengeld_2/regelbedarf.py +++ b/src/_gettsim/arbeitslosengeld_2/regelbedarf.py @@ -76,7 +76,9 @@ def mehrbedarf_alleinerziehend_m( @policy_function( - start_date="2005-01-01", end_date="2010-12-31", leaf_name="kindersatz_m" + start_date="2005-01-01", + end_date="2010-12-31", + leaf_name="kindersatz_m", ) def kindersatz_m_anteilsbasiert( alter: int, diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/alleinerziehend.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/alleinerziehend.py" index e7ef0373b..f965941fc 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/alleinerziehend.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/alleinerziehend.py" @@ -7,7 +7,8 @@ @policy_function(end_date="2014-12-31", leaf_name="alleinerziehend_betrag_y") def alleinerziehend_betrag_y_pauschal( - einkommensteuer__alleinerziehend_sn: bool, alleinerziehendenfreibetrag_basis: float + einkommensteuer__alleinerziehend_sn: bool, + alleinerziehendenfreibetrag_basis: float, ) -> float: """Calculate tax deduction allowance for single parents until 2014""" if einkommensteuer__alleinerziehend_sn: diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" index 76ac71f31..cc1cdc816 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/alter.py" @@ -149,10 +149,12 @@ def get_consecutive_int_1d_lookup_table_with_filled_up_tails( "Dictionary keys must be consecutive integers." ) consecutive_dict_start = dict.fromkeys( - range(left_tail_key, min_key_in_spec), raw[min_key_in_spec] + range(left_tail_key, min_key_in_spec), + raw[min_key_in_spec], ) consecutive_dict_end = dict.fromkeys( - range(max_key_in_spec + 1, right_tail_key + 1), raw[max_key_in_spec] + range(max_key_in_spec + 1, right_tail_key + 1), + raw[max_key_in_spec], ) return get_consecutive_int_1d_lookup_table_param_value( raw={**consecutive_dict_start, **raw, **consecutive_dict_end}, diff --git "a/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" "b/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" index 94e78c7db..37d17874c 100644 --- "a/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" +++ "b/src/_gettsim/einkommensteuer/abz\303\274ge/sonderausgaben.py" @@ -12,7 +12,9 @@ @agg_by_p_id_function(agg_type=AggType.SUM) def kinderbetreuungskosten_elternteil_m( - kinderbetreuungskosten_m: float, p_id_kinderbetreuungskostenträger: int, p_id: int + kinderbetreuungskosten_m: float, + p_id_kinderbetreuungskostenträger: int, + p_id: int, ) -> float: pass diff --git a/src/_gettsim/einkommensteuer/einkommensteuer.py b/src/_gettsim/einkommensteuer/einkommensteuer.py index 4932d44e4..9931c9f5a 100644 --- a/src/_gettsim/einkommensteuer/einkommensteuer.py +++ b/src/_gettsim/einkommensteuer/einkommensteuer.py @@ -60,7 +60,9 @@ def anzahl_kindergeld_ansprüche_2( end_date="1996-12-31", leaf_name="betrag_y_sn", rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S. 6 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S. 6 EStG", ), ) def betrag_y_sn_kindergeld_kinderfreibetrag_parallel( @@ -76,7 +78,9 @@ def betrag_y_sn_kindergeld_kinderfreibetrag_parallel( start_date="1997-01-01", leaf_name="betrag_y_sn", rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S.6 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S.6 EStG", ), ) def betrag_y_sn_kindergeld_oder_kinderfreibetrag( @@ -112,7 +116,9 @@ def kinderfreibetrag_günstiger_sn( end_date="2001-12-31", leaf_name="betrag_mit_kinderfreibetrag_y_sn", rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S.6 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S.6 EStG", ), ) def betrag_mit_kinderfreibetrag_y_sn_bis_2001() -> float: @@ -123,7 +129,9 @@ def betrag_mit_kinderfreibetrag_y_sn_bis_2001() -> float: start_date="2002-01-01", leaf_name="betrag_mit_kinderfreibetrag_y_sn", rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S.6 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S.6 EStG", ), ) def betrag_mit_kinderfreibetrag_y_sn_ab_2002( @@ -141,13 +149,17 @@ def betrag_mit_kinderfreibetrag_y_sn_ab_2002( zu_versteuerndes_einkommen_mit_kinderfreibetrag_y_sn / anzahl_personen_sn ) return anzahl_personen_sn * piecewise_polynomial( - x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif, xnp=xnp + x=zu_verst_eink_per_indiv, + parameters=parameter_einkommensteuertarif, + xnp=xnp, ) @policy_function( rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S.6 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S.6 EStG", ), ) def betrag_ohne_kinderfreibetrag_y_sn( @@ -162,7 +174,9 @@ def betrag_ohne_kinderfreibetrag_y_sn( """ zu_verst_eink_per_indiv = gesamteinkommen_y / anzahl_personen_sn return anzahl_personen_sn * piecewise_polynomial( - x=zu_verst_eink_per_indiv, parameters=parameter_einkommensteuertarif, xnp=xnp + x=zu_verst_eink_per_indiv, + parameters=parameter_einkommensteuertarif, + xnp=xnp, ) @@ -234,7 +248,8 @@ def parameter_einkommensteuertarif( """ expanded: dict[int, dict[str, float]] = optree.tree_map( # type: ignore[assignment] - float, raw_parameter_einkommensteuertarif + float, + raw_parameter_einkommensteuertarif, ) # Check and extract lower thresholds. diff --git "a/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_kapitalverm\303\266gen/aus_kapitalverm\303\266gen.py" "b/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_kapitalverm\303\266gen/aus_kapitalverm\303\266gen.py" index 3f76d7090..f31755a03 100644 --- "a/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_kapitalverm\303\266gen/aus_kapitalverm\303\266gen.py" +++ "b/src/_gettsim/einkommensteuer/eink\303\274nfte/aus_kapitalverm\303\266gen/aus_kapitalverm\303\266gen.py" @@ -13,7 +13,8 @@ def betrag_y_mit_sparerfreibetrag_und_werbungskostenpauschbetrag( ) -> float: """Calculate taxable capital income on Steuernummer level.""" return max( - kapitalerträge_y - sparerfreibetrag + sparer_werbungskostenpauschbetrag, 0.0 + kapitalerträge_y - sparerfreibetrag + sparer_werbungskostenpauschbetrag, + 0.0, ) diff --git a/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py b/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py index ee42b6a29..feb6540ce 100644 --- a/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py +++ b/src/_gettsim/einkommensteuer/zu_versteuerndes_einkommen.py @@ -7,7 +7,9 @@ @policy_function( rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 32a Abs. 1 S.1 EStG" + base=1, + direction="down", + reference="§ 32a Abs. 1 S.1 EStG", ), start_date="2004-01-01", leaf_name="zu_versteuerndes_einkommen_y_sn", diff --git a/src/_gettsim/elterngeld/einkommen.py b/src/_gettsim/elterngeld/einkommen.py index c7f4b36ba..f3bbc2c6c 100644 --- a/src/_gettsim/elterngeld/einkommen.py +++ b/src/_gettsim/elterngeld/einkommen.py @@ -61,7 +61,8 @@ def einkommen_vorjahr_unter_bezugsgrenze_mit_unterscheidung_single_paar( familie__alleinerziehend: bool, zu_versteuerndes_einkommen_vorjahr_y_sn: float, max_zu_versteuerndes_einkommen_vorjahr_nach_alleinerziehendenstatus: dict[ - str, float + str, + float, ], ) -> bool: """Income before birth is below income threshold for Elterngeld.""" @@ -83,7 +84,8 @@ def einkommen_vorjahr_unter_bezugsgrenze_mit_unterscheidung_single_paar( @policy_function( - start_date="2024-04-01", leaf_name="einkommen_vorjahr_unter_bezugsgrenze" + start_date="2024-04-01", + leaf_name="einkommen_vorjahr_unter_bezugsgrenze", ) def einkommen_vorjahr_unter_bezugsgrenze_ohne_unterscheidung_single_paar( zu_versteuerndes_einkommen_vorjahr_y_sn: float, diff --git a/src/_gettsim/elterngeld/elterngeld.py b/src/_gettsim/elterngeld/elterngeld.py index a63d3032d..32f91768e 100644 --- a/src/_gettsim/elterngeld/elterngeld.py +++ b/src/_gettsim/elterngeld/elterngeld.py @@ -13,7 +13,8 @@ @agg_by_group_function(agg_type=AggType.ANY) def kind_grundsätzlich_anspruchsberechtigt_fg( - kind_grundsätzlich_anspruchsberechtigt: bool, fg_id: int + kind_grundsätzlich_anspruchsberechtigt: bool, + fg_id: int, ) -> bool: pass @@ -49,7 +50,8 @@ def anzahl_kinder_bis_5_fg(familie__kind_bis_5: bool, fg_id: int) -> int: @agg_by_group_function(agg_type=AggType.SUM) def anzahl_mehrlinge_jüngstes_kind_fg( - jüngstes_kind_oder_mehrling: bool, fg_id: int + jüngstes_kind_oder_mehrling: bool, + fg_id: int, ) -> int: pass diff --git a/src/_gettsim/erziehungsgeld/erziehungsgeld.py b/src/_gettsim/erziehungsgeld/erziehungsgeld.py index bd94b62e3..df0f64f1e 100644 --- a/src/_gettsim/erziehungsgeld/erziehungsgeld.py +++ b/src/_gettsim/erziehungsgeld/erziehungsgeld.py @@ -44,7 +44,9 @@ def einkommensgrenze( @agg_by_p_id_function(agg_type=AggType.SUM) def anspruchshöhe_m( - anspruchshöhe_kind_m: float, p_id_empfänger: int, p_id: int + anspruchshöhe_kind_m: float, + p_id_empfänger: int, + p_id: int, ) -> float: pass @@ -76,7 +78,7 @@ def erziehungsgeld_kind_ohne_budgetsatz_m() -> NotImplementedError: """ Erziehungsgeld is not implemented yet prior to 2004, see https://github.com/iza-institute-of-labor-economics/gettsim/issues/673 - """ + """, ) @@ -332,6 +334,8 @@ def einkommensgrenze_ohne_geschwisterbonus_kind_älter_als_reduzierungsgrenze( @agg_by_p_id_function(agg_type=AggType.SUM) def erziehungsgeld_spec_target( - erziehungsgeld_source_field: bool, p_id_field: int, p_id: int + erziehungsgeld_source_field: bool, + p_id_field: int, + p_id: int, ) -> int: pass diff --git a/src/_gettsim/grundsicherung/im_alter/einkommen.py b/src/_gettsim/grundsicherung/im_alter/einkommen.py index d9103bdfc..988a8f109 100644 --- a/src/_gettsim/grundsicherung/im_alter/einkommen.py +++ b/src/_gettsim/grundsicherung/im_alter/einkommen.py @@ -131,7 +131,8 @@ def private_rente_betrag_m( upper = grundsicherung__regelbedarfsstufen.rbs_1 / 2 return sozialversicherung__rente__private_rente_betrag_m - min( - sozialversicherung__rente__private_rente_betrag_m_amount_exempt, upper + sozialversicherung__rente__private_rente_betrag_m_amount_exempt, + upper, ) diff --git a/src/_gettsim/household_characteristics.py b/src/_gettsim/household_characteristics.py index 5160b33d7..740d74d8e 100644 --- a/src/_gettsim/household_characteristics.py +++ b/src/_gettsim/household_characteristics.py @@ -10,7 +10,8 @@ def anzahl_erwachsene_hh(familie__erwachsen: bool, hh_id: int) -> int: @agg_by_group_function(agg_type=AggType.SUM) def anzahl_rentenbezieher_hh( - sozialversicherung__rente__bezieht_rente: bool, hh_id: int + sozialversicherung__rente__bezieht_rente: bool, + hh_id: int, ) -> int: pass @@ -22,7 +23,8 @@ def anzahl_personen_hh(hh_id: int) -> int: @policy_function() def erwachsene_alle_rentenbezieher_hh( - anzahl_erwachsene_hh: int, anzahl_rentenbezieher_hh: int + anzahl_erwachsene_hh: int, + anzahl_rentenbezieher_hh: int, ) -> bool: """Calculate if all adults in the household are pensioners.""" return anzahl_erwachsene_hh == anzahl_rentenbezieher_hh diff --git a/src/_gettsim/ids.py b/src/_gettsim/ids.py index c3d2da17a..75e467351 100644 --- a/src/_gettsim/ids.py +++ b/src/_gettsim/ids.py @@ -24,12 +24,16 @@ def hh_id() -> int: @group_creation_function() def ehe_id( - p_id: IntColumn, familie__p_id_ehepartner: IntColumn, xnp: ModuleType + p_id: IntColumn, + familie__p_id_ehepartner: IntColumn, + xnp: ModuleType, ) -> IntColumn: """Couples that are either married or in a civil union.""" n = xnp.max(p_id) + 1 p_id_ehepartner_or_own_p_id = xnp.where( - familie__p_id_ehepartner < 0, p_id, familie__p_id_ehepartner + familie__p_id_ehepartner < 0, + p_id, + familie__p_id_ehepartner, ) result = ( xnp.maximum(p_id, p_id_ehepartner_or_own_p_id) @@ -61,14 +65,19 @@ def fg_id( p_id_elternteil_2_loc = familie__p_id_elternteil_2 for i in range(p_id.shape[0]): p_id_elternteil_1_loc = xnp.where( - familie__p_id_elternteil_1 == p_id[i], i, p_id_elternteil_1_loc + familie__p_id_elternteil_1 == p_id[i], + i, + p_id_elternteil_1_loc, ) p_id_elternteil_2_loc = xnp.where( - familie__p_id_elternteil_2 == p_id[i], i, p_id_elternteil_2_loc + familie__p_id_elternteil_2 == p_id[i], + i, + p_id_elternteil_2_loc, ) children = xnp.isin(p_id, familie__p_id_elternteil_1) | xnp.isin( - p_id, familie__p_id_elternteil_2 + p_id, + familie__p_id_elternteil_2, ) # Assign the same fg_id to everybody who has an Einstandspartner, diff --git a/src/_gettsim/individual_characteristics.py b/src/_gettsim/individual_characteristics.py index 534074207..1817b200b 100644 --- a/src/_gettsim/individual_characteristics.py +++ b/src/_gettsim/individual_characteristics.py @@ -15,11 +15,11 @@ def geburtsdatum( ) -> numpy.datetime64: """Create date of birth datetime variable.""" return numpy.datetime64( - datetime.datetime( + datetime.datetime( # noqa: DTZ001 geburtsjahr, geburtsmonat, geburtstag, - ) + ), ).astype("datetime64[D]") diff --git a/src/_gettsim/kindergeld/kindergeld.py b/src/_gettsim/kindergeld/kindergeld.py index bb9e1589e..41965f971 100644 --- a/src/_gettsim/kindergeld/kindergeld.py +++ b/src/_gettsim/kindergeld/kindergeld.py @@ -22,7 +22,9 @@ @agg_by_p_id_function(agg_type=AggType.SUM) def anzahl_ansprüche( - grundsätzlich_anspruchsberechtigt: bool, p_id_empfänger: int, p_id: int + grundsätzlich_anspruchsberechtigt: bool, + p_id_empfänger: int, + p_id: int, ) -> int: pass @@ -158,5 +160,6 @@ def satz_nach_anzahl_kinder( for k in range(max_num_children_in_spec + 1, max_num_children) } return get_consecutive_int_1d_lookup_table_param_value( - raw={0: 0.0, **base_spec, **extended_spec}, xnp=xnp + raw={0: 0.0, **base_spec, **extended_spec}, + xnp=xnp, ) diff --git a/src/_gettsim/kinderzuschlag/einkommen.py b/src/_gettsim/kinderzuschlag/einkommen.py index d78c4cd76..052c94c60 100644 --- a/src/_gettsim/kinderzuschlag/einkommen.py +++ b/src/_gettsim/kinderzuschlag/einkommen.py @@ -24,7 +24,8 @@ @agg_by_group_function(agg_type=AggType.SUM, start_date="2005-01-01") def arbeitslosengeld_2__anzahl_kinder_bg( - kindergeld__anzahl_ansprüche: int, bg_id: int + kindergeld__anzahl_ansprüche: int, + bg_id: int, ) -> int: pass @@ -215,7 +216,9 @@ def kosten_der_unterkunft_m_bg( @param_function( - start_date="2005-01-01", end_date="2011-12-31", leaf_name="existenzminimum" + start_date="2005-01-01", + end_date="2011-12-31", + leaf_name="existenzminimum", ) def existenzminimum_ohne_bildung_und_teilhabe( parameter_existenzminimum: RawParam, @@ -268,7 +271,7 @@ def existenzminimum_mit_bildung_und_teilhabe( kosten_der_unterkunft=kosten_der_unterkunft, heizkosten=heizkosten, bildung_und_teilhabe=ElementExistenzminimumNurKind( - kind=parameter_existenzminimum["bildung_und_teilhabe"]["kind"] + kind=parameter_existenzminimum["bildung_und_teilhabe"]["kind"], ), ) @@ -298,7 +301,8 @@ def wohnbedarf_anteil_eltern_bg( ) kinderbetrag = min( - arbeitslosengeld_2__anzahl_kinder_bg, wohnbedarf_anteil_berücksichtigte_kinder + arbeitslosengeld_2__anzahl_kinder_bg, + wohnbedarf_anteil_berücksichtigte_kinder, ) * (existenzminimum.kosten_der_unterkunft.kind + existenzminimum.heizkosten.kind) return elternbetrag / (elternbetrag + kinderbetrag) @@ -306,7 +310,8 @@ def wohnbedarf_anteil_eltern_bg( @policy_function(start_date="2005-01-01") def erwachsenenbedarf_m_bg( - arbeitslosengeld_2__regelsatz_m_bg: float, kosten_der_unterkunft_m_bg: float + arbeitslosengeld_2__regelsatz_m_bg: float, + kosten_der_unterkunft_m_bg: float, ) -> float: """Aggregate relevant income and rental costs.""" return arbeitslosengeld_2__regelsatz_m_bg + kosten_der_unterkunft_m_bg diff --git a/src/_gettsim/kinderzuschlag/kinderzuschlag.py b/src/_gettsim/kinderzuschlag/kinderzuschlag.py index 0de67f687..7f12497fe 100644 --- a/src/_gettsim/kinderzuschlag/kinderzuschlag.py +++ b/src/_gettsim/kinderzuschlag/kinderzuschlag.py @@ -59,7 +59,8 @@ def satz_mit_einheitlichem_kindergeld_und_kindersofortzuschlag( ) / 12 - kindergeld__satz satz_ohne_kindersofortzuschlag = max( - current_formula, satz_vorjahr_ohne_kindersofortzuschlag + current_formula, + satz_vorjahr_ohne_kindersofortzuschlag, ) return satz_ohne_kindersofortzuschlag + arbeitslosengeld_2__kindersofortzuschlag @@ -112,7 +113,9 @@ def anspruchshöhe_m_bg( @policy_function( - start_date="2005-01-01", end_date="2022-12-31", leaf_name="vermögensfreibetrag_bg" + start_date="2005-01-01", + end_date="2022-12-31", + leaf_name="vermögensfreibetrag_bg", ) def vermögensfreibetrag_bg_bis_2022( arbeitslosengeld_2__vermögensfreibetrag_bg: float, diff --git a/src/_gettsim/lohnsteuer/lohnsteuer.py b/src/_gettsim/lohnsteuer/lohnsteuer.py index dafc230fb..526f517c4 100644 --- a/src/_gettsim/lohnsteuer/lohnsteuer.py +++ b/src/_gettsim/lohnsteuer/lohnsteuer.py @@ -37,10 +37,14 @@ def basis_für_klassen_5_6( """ return 2 * ( piecewise_polynomial( - x=einkommen_y * 1.25, parameters=parameter_einkommensteuertarif, xnp=xnp + x=einkommen_y * 1.25, + parameters=parameter_einkommensteuertarif, + xnp=xnp, ) - piecewise_polynomial( - x=einkommen_y * 0.75, parameters=parameter_einkommensteuertarif, xnp=xnp + x=einkommen_y * 0.75, + parameters=parameter_einkommensteuertarif, + xnp=xnp, ) ) @@ -75,7 +79,7 @@ def parameter_max_lohnsteuer_klasse_5_6( einkommensgrenzwerte_steuerklassen_5_6[1], einkommensgrenzwerte_steuerklassen_5_6[2], einkommensgrenzwerte_steuerklassen_5_6[3], - ] + ], ) intercepts = numpy.asarray( [ @@ -83,7 +87,7 @@ def parameter_max_lohnsteuer_klasse_5_6( lohnsteuer_bis_erste_grenze, lohnsteuer_bis_zweite_grenze, lohnsteuer_bis_dritte_grenze, - ] + ], ) rates = numpy.expand_dims( einkommensteuer__parameter_einkommensteuertarif.rates[0][ @@ -137,10 +141,14 @@ def tarif_klassen_5_und_6( ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6.""" basis = basis_für_klassen_5_6( - einkommen_y, einkommensteuer__parameter_einkommensteuertarif, xnp=xnp + einkommen_y, + einkommensteuer__parameter_einkommensteuertarif, + xnp=xnp, ) max_lohnsteuer = piecewise_polynomial( - x=einkommen_y, parameters=parameter_max_lohnsteuer_klasse_5_6, xnp=xnp + x=einkommen_y, + parameters=parameter_max_lohnsteuer_klasse_5_6, + xnp=xnp, ) min_lohnsteuer = ( einkommensteuer__parameter_einkommensteuertarif.rates[0, 1] * einkommen_y @@ -175,7 +183,8 @@ def basistarif_mit_kinderfreibetrag( ) -> float: """Lohnsteuer in the Basistarif deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( - einkommen_y - kinderfreibetrag_soli_y, 0 + einkommen_y - kinderfreibetrag_soli_y, + 0, ) return piecewise_polynomial( x=einkommen_abzüglich_kinderfreibetrag_soli, @@ -193,7 +202,8 @@ def splittingtarif_mit_kinderfreibetrag( ) -> float: """Lohnsteuer in the Splittingtarif deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( - einkommen_y - kinderfreibetrag_soli_y, 0 + einkommen_y - kinderfreibetrag_soli_y, + 0, ) return 2 * piecewise_polynomial( x=einkommen_abzüglich_kinderfreibetrag_soli / 2, @@ -212,7 +222,8 @@ def tarif_klassen_5_und_6_mit_kinderfreibetrag( ) -> float: """Lohnsteuer for Lohnsteuerklassen 5 and 6 deducting the Kindefreibetrag.""" einkommen_abzüglich_kinderfreibetrag_soli = xnp.maximum( - einkommen_y - kinderfreibetrag_soli_y, 0 + einkommen_y - kinderfreibetrag_soli_y, + 0, ) basis = basis_für_klassen_5_6( diff --git a/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py index 766c9261c..e8979302c 100644 --- a/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/arbeitslosen/beitrag/beitrag.py @@ -148,7 +148,8 @@ def betrag_versicherter_in_gleitzone_m_als_differenz_von_gesamt_und_arbeitgeberb @policy_function( - start_date="2022-10-01", leaf_name="betrag_versicherter_in_gleitzone_m" + start_date="2022-10-01", + leaf_name="betrag_versicherter_in_gleitzone_m", ) def betrag_versicherter_in_gleitzone_m_mit_festem_beitragssatz( sozialversicherung__beitragspflichtige_einnahmen_aus_midijob_arbeitnehmer_m: float, diff --git a/src/_gettsim/sozialversicherung/minijob.py b/src/_gettsim/sozialversicherung/minijob.py index 6b8029064..e0bbf1b77 100644 --- a/src/_gettsim/sozialversicherung/minijob.py +++ b/src/_gettsim/sozialversicherung/minijob.py @@ -27,11 +27,14 @@ def geringfügig_beschäftigt( end_date="1999-12-31", leaf_name="minijobgrenze", rounding_spec=RoundingSpec( - base=1, direction="up", reference="§ 8 Abs. 1a Satz 2 SGB IV" + base=1, + direction="up", + reference="§ 8 Abs. 1a Satz 2 SGB IV", ), ) def minijobgrenze_unterscheidung_ost_west( - wohnort_ost: bool, parameter_minijobgrenze_ost_west_unterschied: dict[str, float] + wohnort_ost: bool, + parameter_minijobgrenze_ost_west_unterschied: dict[str, float], ) -> float: """Minijob income threshold depending on place of living (East or West Germany). @@ -48,7 +51,9 @@ def minijobgrenze_unterscheidung_ost_west( start_date="2022-10-01", leaf_name="minijobgrenze", rounding_spec=RoundingSpec( - base=1, direction="up", reference="§ 8 Abs. 1a Satz 2 SGB IV" + base=1, + direction="up", + reference="§ 8 Abs. 1a Satz 2 SGB IV", ), ) def minijobgrenze_abgeleitet_von_mindestlohn( diff --git a/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py b/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py index 77a6cd570..41707593e 100644 --- a/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py +++ b/src/_gettsim/sozialversicherung/pflege/beitrag/beitragssatz.py @@ -62,7 +62,8 @@ def beitragssatz_arbeitnehmer_mit_abschlag_nach_kinderzahl( add = add + beitragssatz_nach_kinderzahl["zusatz_kinderlos"] if anzahl_kinder_bis_24 >= 2: add = add - beitragssatz_nach_kinderzahl["abschlag_für_kinder_bis_24"] * min( - anzahl_kinder_bis_24 - 1, 4 + anzahl_kinder_bis_24 - 1, + 4, ) return base + add @@ -110,7 +111,9 @@ def anzahl_kinder_bis_24( @param_function( - start_date="1995-01-01", end_date="2004-12-31", leaf_name="beitragssatz_arbeitgeber" + start_date="1995-01-01", + end_date="2004-12-31", + leaf_name="beitragssatz_arbeitgeber", ) def beitragssatz_arbeitgeber_einheitliche_basis(beitragssatz: float) -> float: """Employer's long-term care insurance contribution rate.""" diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py index 9a02137f5..48e76af71 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersgrenzen.py @@ -185,7 +185,8 @@ def altersgrenze_vorzeitig_ohne_arbeitslosigkeit_frauen( @policy_function( - end_date="2017-12-31", leaf_name="vorzeitig_grundsätzlich_anspruchsberechtigt" + end_date="2017-12-31", + leaf_name="vorzeitig_grundsätzlich_anspruchsberechtigt", ) def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( für_frauen__grundsätzlich_anspruchsberechtigt: bool, @@ -207,7 +208,8 @@ def vorzeitig_grundsätzlich_anspruchsberechtigt_mit_arbeitslosigkeit_frauen( @policy_function( - start_date="2018-01-01", leaf_name="vorzeitig_grundsätzlich_anspruchsberechtigt" + start_date="2018-01-01", + leaf_name="vorzeitig_grundsätzlich_anspruchsberechtigt", ) def vorzeitig_grundsätzlich_anspruchsberechtigt_vorzeitig_ohne_arbeitslosigkeit_frauen( langjährig__grundsätzlich_anspruchsberechtigt: bool, @@ -246,7 +248,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( für_frauen__altersgrenze, langjährig__altersgrenze, wegen_arbeitslosigkeit__altersgrenze, - ] + ], ) elif ( langjährig__grundsätzlich_anspruchsberechtigt @@ -256,7 +258,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( [ für_frauen__altersgrenze, langjährig__altersgrenze, - ] + ], ) elif ( langjährig__grundsätzlich_anspruchsberechtigt @@ -266,7 +268,7 @@ def referenzalter_abschlag_mit_arbeitslosigkeit_frauen( [ langjährig__altersgrenze, wegen_arbeitslosigkeit__altersgrenze, - ] + ], ) elif langjährig__grundsätzlich_anspruchsberechtigt: out = langjährig__altersgrenze diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py b/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py index 52dceac82..335fbf62d 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/altersrente.py @@ -8,12 +8,15 @@ @policy_function( end_date="2020-12-31", rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), leaf_name="betrag_m", ) def betrag_m( - bruttorente_m: float, sozialversicherung__rente__bezieht_rente: bool + bruttorente_m: float, + sozialversicherung__rente__bezieht_rente: bool, ) -> float: return bruttorente_m if sozialversicherung__rente__bezieht_rente else 0.0 @@ -21,7 +24,9 @@ def betrag_m( @policy_function( start_date="2021-01-01", rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), leaf_name="betrag_m", ) diff --git "a/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" "b/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" index adceb370a..dda7a4b30 100644 --- "a/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" +++ "b/src/_gettsim/sozialversicherung/rente/altersrente/besonders_langj\303\244hrig/besonders_langj\303\244hrig.py" @@ -10,7 +10,8 @@ end_date="2028-12-31", ) def altersgrenze( - geburtsjahr: int, altersgrenze_gestaffelt: ConsecutiveInt1dLookupTableParamValue + geburtsjahr: int, + altersgrenze_gestaffelt: ConsecutiveInt1dLookupTableParamValue, ) -> float: """ Full retirement age (FRA) for very long term insured. diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/hinzuverdienstgrenzen.py b/src/_gettsim/sozialversicherung/rente/altersrente/hinzuverdienstgrenzen.py index 5406c869d..100644c49 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/hinzuverdienstgrenzen.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/hinzuverdienstgrenzen.py @@ -4,7 +4,9 @@ @policy_function( end_date="2016-12-31", rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), leaf_name="bruttorente_m", ) @@ -37,7 +39,9 @@ def bruttorente_m_mit_harter_hinzuverdienstgrenze( end_date="2022-12-31", leaf_name="bruttorente_m", rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), ) def bruttorente_m_mit_hinzuverdienstdeckel( @@ -143,7 +147,9 @@ def differenz_bruttolohn_hinzuverdienstdeckel_y( start_date="2023-01-01", leaf_name="bruttorente_m", rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), ) def bruttorente_m_ohne_einkommensanrechnung( diff --git a/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py b/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py index 5fad8bf96..c90339d23 100644 --- a/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py +++ b/src/_gettsim/sozialversicherung/rente/altersrente/regelaltersrente/regelaltersrente.py @@ -7,7 +7,8 @@ @policy_function(start_date="2007-04-20", end_date="2030-12-31") def altersgrenze( - geburtsjahr: int, altersgrenze_gestaffelt: ConsecutiveInt1dLookupTableParamValue + geburtsjahr: int, + altersgrenze_gestaffelt: ConsecutiveInt1dLookupTableParamValue, ) -> float: """Normal retirement age (NRA) during the phase-in period. diff --git a/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py b/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py index 51a028046..933bcb7ec 100644 --- a/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py +++ b/src/_gettsim/sozialversicherung/rente/beitrag/beitrag.py @@ -207,7 +207,8 @@ def betrag_in_gleitzone_arbeitnehmer_m_als_differenz_von_gesamt_und_arbeitgeberb @policy_function( - start_date="2022-10-01", leaf_name="betrag_in_gleitzone_arbeitnehmer_m" + start_date="2022-10-01", + leaf_name="betrag_in_gleitzone_arbeitnehmer_m", ) def betrag_in_gleitzone_arbeitnehmer_m_mit_festem_beitragssatz( sozialversicherung__beitragspflichtige_einnahmen_aus_midijob_arbeitnehmer_m: float, diff --git a/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py b/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py index f8fe5d1ff..ed3cbd2db 100644 --- a/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py +++ b/src/_gettsim/sozialversicherung/rente/erwerbsminderung/erwerbsminderung.py @@ -18,7 +18,8 @@ def betrag_m_nach_wohnort( rentenartfaktor: float, grundsätzlich_anspruchsberechtigt: bool, sozialversicherung__rente__altersrente__parameter_rentenwert_nach_wohnort: dict[ - str, float + str, + float, ], ) -> float: """Erwerbsminderungsrente (public disability insurance). @@ -140,7 +141,9 @@ def entgeltpunkte_ost( @policy_function( - start_date="2000-12-23", end_date="2014-06-30", leaf_name="zurechnungszeit" + start_date="2000-12-23", + end_date="2014-06-30", + leaf_name="zurechnungszeit", ) def zurechnungszeit_mit_gestaffelter_altersgrenze_bis_06_2014( mean_entgeltpunkte_pro_bewertungsmonat: float, @@ -171,7 +174,9 @@ def zurechnungszeit_mit_gestaffelter_altersgrenze_bis_06_2014( @policy_function( - start_date="2014-07-01", end_date="2017-07-16", leaf_name="zurechnungszeit" + start_date="2014-07-01", + end_date="2017-07-16", + leaf_name="zurechnungszeit", ) def zurechnungszeit_mit_einheitlicher_altersgrenze( mean_entgeltpunkte_pro_bewertungsmonat: float, @@ -241,7 +246,8 @@ def zugangsfaktor_ohne_gestaffelte_altersgrenze( altersgrenze: float, min_zugangsfaktor: float, sozialversicherung__rente__altersrente__zugangsfaktor_veränderung_pro_jahr: dict[ - str, float + str, + float, ], ) -> float: """Zugangsfaktor. @@ -274,7 +280,8 @@ def zugangsfaktor_mit_gestaffelter_altersgrenze( altersgrenze_langjährig_versichert: float, min_zugangsfaktor: float, sozialversicherung__rente__altersrente__zugangsfaktor_veränderung_pro_jahr: dict[ - str, float + str, + float, ], ) -> float: """Zugangsfaktor. diff --git a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py index 2830f2d3f..e40ae628a 100644 --- a/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py +++ b/src/_gettsim/sozialversicherung/rente/grundrente/grundrente.py @@ -15,7 +15,9 @@ @policy_function( rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), start_date="2021-01-01", ) @@ -79,7 +81,9 @@ def _anzurechnendes_einkommen_m( @policy_function( rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), start_date="2021-01-01", ) @@ -123,7 +127,9 @@ def anzurechnendes_einkommen_m( @policy_function( rounding_spec=RoundingSpec( - base=0.01, direction="nearest", reference="§ 123 SGB VI Abs. 1" + base=0.01, + direction="nearest", + reference="§ 123 SGB VI Abs. 1", ), start_date="2021-01-01", ) @@ -159,7 +165,8 @@ def basisbetrag_m( @policy_function(start_date="2021-01-01") def mean_entgeltpunkte_pro_bewertungsmonat( - mean_entgeltpunkte: float, bewertungszeiten_monate: int + mean_entgeltpunkte: float, + bewertungszeiten_monate: int, ) -> float: """Average number of Entgeltpunkte earned per month of Grundrentenbewertungszeiten.""" if bewertungszeiten_monate > 0: diff --git a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py index 11368358b..a47ef30a7 100644 --- a/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py +++ b/src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py @@ -23,7 +23,9 @@ @agg_by_p_id_function(agg_type=AggType.SUM) def an_elternteil_auszuzahlender_betrag_m( - betrag_m: float, kindergeld__p_id_empfänger: int, p_id: int + betrag_m: float, + kindergeld__p_id_empfänger: int, + p_id: int, ) -> float: pass @@ -31,7 +33,9 @@ def an_elternteil_auszuzahlender_betrag_m( @policy_function( start_date="2009-01-01", rounding_spec=RoundingSpec( - base=1, direction="up", reference="§ 9 Abs. 3 UhVorschG" + base=1, + direction="up", + reference="§ 9 Abs. 3 UhVorschG", ), ) def betrag_m( @@ -88,14 +92,16 @@ def elternteil_alleinerziehend( end_date="2008-12-31", leaf_name="betrag_m", rounding_spec=RoundingSpec( - base=1, direction="down", reference="§ 9 Abs. 3 UhVorschG" + base=1, + direction="down", + reference="§ 9 Abs. 3 UhVorschG", ), ) def not_implemented_m() -> float: raise NotImplementedError( """ Unterhaltsvorschuss is not implemented prior to 2009. - """ + """, ) @@ -319,7 +325,9 @@ def einkommen_m( @agg_by_p_id_function(agg_type=AggType.SUM) def unterhaltsvorschuss_spec_target( - unterhaltsvorschuss_source_field: bool, p_id_field: int, p_id: int + unterhaltsvorschuss_source_field: bool, + p_id_field: int, + p_id: int, ) -> int: pass diff --git "a/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" "b/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" index 1602b8064..bbeae1f51 100644 --- "a/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" +++ "b/src/_gettsim/vorrangpr\303\274fungen/vorrangpr\303\274fungen.py" @@ -7,14 +7,16 @@ @agg_by_group_function(agg_type=AggType.ANY) def wohngeld_vorrang_wthh( - wohngeld_vorrang_vor_arbeitslosengeld_2_bg: bool, wthh_id: int + wohngeld_vorrang_vor_arbeitslosengeld_2_bg: bool, + wthh_id: int, ) -> bool: pass @agg_by_group_function(agg_type=AggType.ANY) def wohngeld_kinderzuschlag_vorrang_wthh( - wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: bool, wthh_id: int + wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg: bool, + wthh_id: int, ) -> bool: pass diff --git a/src/_gettsim/wohngeld/einkommen.py b/src/_gettsim/wohngeld/einkommen.py index 0a8f31684..bb6b5be7f 100644 --- a/src/_gettsim/wohngeld/einkommen.py +++ b/src/_gettsim/wohngeld/einkommen.py @@ -51,7 +51,8 @@ def einkommen( eink_nach_abzug_m_hh = einkommen_vor_freibetrag - einkommensfreibetrag unteres_eink = min_einkommen_lookup_table.values_to_look_up[ xnp.minimum( - anzahl_personen, min_einkommen_lookup_table.values_to_look_up.shape[0] + anzahl_personen, + min_einkommen_lookup_table.values_to_look_up.shape[0], ) - min_einkommen_lookup_table.base_to_subtract ] diff --git a/src/_gettsim/wohngeld/miete.py b/src/_gettsim/wohngeld/miete.py index 20054a889..f31f23d30 100644 --- a/src/_gettsim/wohngeld/miete.py +++ b/src/_gettsim/wohngeld/miete.py @@ -29,7 +29,9 @@ class LookupTableBaujahr: @param_function( - start_date="1984-01-01", end_date="2008-12-31", leaf_name="max_miete_m_lookup" + start_date="1984-01-01", + end_date="2008-12-31", + leaf_name="max_miete_m_lookup", ) def max_miete_m_lookup_mit_baujahr( raw_max_miete_m_nach_baujahr: dict[int | str, dict[int, dict[int, float]]], @@ -54,7 +56,8 @@ def max_miete_m_lookup_mit_baujahr( for ms in this_dict[max_n_p_defined] } lookup_table = get_consecutive_int_2d_lookup_table_param_value( - raw=this_dict, xnp=xnp + raw=this_dict, + xnp=xnp, ) values.append(lookup_table.values_to_look_up) subtract_cols.append(lookup_table.base_to_subtract_cols) @@ -198,7 +201,8 @@ def miete_m_bg( @policy_function() def min_miete_m_hh( - anzahl_personen_hh: int, min_miete_lookup: ConsecutiveInt1dLookupTableParamValue + anzahl_personen_hh: int, + min_miete_lookup: ConsecutiveInt1dLookupTableParamValue, ) -> float: """Minimum rent considered in Wohngeld calculation.""" return min_miete_lookup.values_to_look_up[ diff --git a/src/_gettsim/wohngeld/voraussetzungen.py b/src/_gettsim/wohngeld/voraussetzungen.py index 69a2abeb0..8724d293d 100644 --- a/src/_gettsim/wohngeld/voraussetzungen.py +++ b/src/_gettsim/wohngeld/voraussetzungen.py @@ -23,7 +23,8 @@ def grundsätzlich_anspruchsberechtigt_wthh_ohne_vermögensprüfung( @policy_function( - start_date="2009-01-01", leaf_name="grundsätzlich_anspruchsberechtigt_wthh" + start_date="2009-01-01", + leaf_name="grundsätzlich_anspruchsberechtigt_wthh", ) def grundsätzlich_anspruchsberechtigt_wthh_mit_vermögensprüfung( mindesteinkommen_erreicht_wthh: bool, @@ -56,7 +57,8 @@ def grundsätzlich_anspruchsberechtigt_bg_ohne_vermögensprüfung( @policy_function( - start_date="2009-01-01", leaf_name="grundsätzlich_anspruchsberechtigt_bg" + start_date="2009-01-01", + leaf_name="grundsätzlich_anspruchsberechtigt_bg", ) def grundsätzlich_anspruchsberechtigt_bg_mit_vermögensprüfung( mindesteinkommen_erreicht_bg: bool, diff --git a/src/_gettsim/wohngeld/wohngeld.py b/src/_gettsim/wohngeld/wohngeld.py index c01d59300..f99e7c19f 100644 --- a/src/_gettsim/wohngeld/wohngeld.py +++ b/src/_gettsim/wohngeld/wohngeld.py @@ -215,6 +215,7 @@ def basisformel_params( b=get_consecutive_int_1d_lookup_table_param_value(raw=b, xnp=xnp), c=get_consecutive_int_1d_lookup_table_param_value(raw=c, xnp=xnp), zusatzbetrag_nach_haushaltsgröße=get_consecutive_int_1d_lookup_table_param_value( - raw=zusatzbetrag_nach_haushaltsgröße, xnp=xnp + raw=zusatzbetrag_nach_haushaltsgröße, + xnp=xnp, ), ) diff --git a/src/_gettsim_tests/test_interface.py b/src/_gettsim_tests/test_interface.py index 4f4760f98..e39d23ee3 100644 --- a/src/_gettsim_tests/test_interface.py +++ b/src/_gettsim_tests/test_interface.py @@ -17,7 +17,7 @@ def example_inputs_df(): "recipient_child_benefits_id": [-1, 0, 0], "is_single_parent": [True, False, False], "has_children": [True, False, False], - } + }, ) @@ -87,7 +87,7 @@ def example_inputs_tree_to_inputs_df_columns(): "kranken": { "beitrag": { "privat_versichert": False, - } + }, }, }, "wohnort_ost": False, diff --git a/src/_gettsim_tests/test_policy.py b/src/_gettsim_tests/test_policy.py index b202aad6e..716d8306c 100644 --- a/src/_gettsim_tests/test_policy.py +++ b/src/_gettsim_tests/test_policy.py @@ -16,7 +16,9 @@ TEST_DIR = Path(__file__).parent POLICY_TEST_IDS_AND_CASES = load_policy_test_data( - test_dir=TEST_DIR, policy_name="", xnp=numpy + test_dir=TEST_DIR, + policy_name="", + xnp=numpy, ) diff --git a/src/ttsim/interface_dag.py b/src/ttsim/interface_dag.py index bacb5f162..355ea9d0c 100644 --- a/src/ttsim/interface_dag.py +++ b/src/ttsim/interface_dag.py @@ -62,7 +62,8 @@ def main( def load_interface_functions_and_inputs() -> dict[ - str, InterfaceFunction | InterfaceInput + str, + InterfaceFunction | InterfaceInput, ]: """Load the collection of functions and inputs from the current directory.""" orig_functions = _load_orig_functions() @@ -134,6 +135,6 @@ def _fail_if_targets_are_not_among_interface_functions( formatted = format_list_linewise(sorted(missing_targets)) msg = format_errors_and_warnings( "The following targets have no corresponding function in the interface " - f"DAG:\n\n{formatted}" + f"DAG:\n\n{formatted}", ) raise ValueError(msg) diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index 7391545e0..f2fd8cae4 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -559,7 +559,8 @@ def _create_one_set_of_time_conversion_functions( def _create_function_for_time_unit( - source: str, converter: Callable[[float], float] + source: str, + converter: Callable[[float], float], ) -> Callable[[float], float]: @rename_arguments(mapper={"x": source}) def func(x: float) -> float: diff --git a/src/ttsim/interface_dag_elements/data_converters.py b/src/ttsim/interface_dag_elements/data_converters.py index f9cb0b9b3..d1c7518ac 100644 --- a/src/ttsim/interface_dag_elements/data_converters.py +++ b/src/ttsim/interface_dag_elements/data_converters.py @@ -35,7 +35,8 @@ def nested_data_to_df_with_nested_columns( flat_data_to_convert = dt.flatten_to_tree_paths(nested_data_to_convert) return pd.DataFrame( - flat_data_to_convert, index=pd.Index(data_with_p_id["p_id"], name="p_id") + flat_data_to_convert, + index=pd.Index(data_with_p_id["p_id"], name="p_id"), ) @@ -133,7 +134,7 @@ def dataframe_to_nested_data( pd.Series( [input_value] * len(df), index=df.index, - ) + ), ) return dt.unflatten_from_qual_names(name_to_input_array) diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index df447f2f2..876941f02 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -130,7 +130,7 @@ def format_key_path(key_tuple: tuple[str, ...]) -> str: if not isinstance(subtree, dict): path_str = format_key_path(current_key) msg = format_errors_and_warnings( - f"{tree_name}{path_str} must be a dict, got {type(subtree)}." + f"{tree_name}{path_str} must be a dict, got {type(subtree)}.", ) raise TypeError(msg) @@ -139,7 +139,7 @@ def format_key_path(key_tuple: tuple[str, ...]) -> str: if not isinstance(key, str): msg = format_errors_and_warnings( f"Key {key} in {tree_name}{format_key_path(current_key)} must be a " - f"string but got {type(key)}." + f"string but got {type(key)}.", ) raise TypeError(msg) if isinstance(value, dict): @@ -148,7 +148,7 @@ def format_key_path(key_tuple: tuple[str, ...]) -> str: if not leaf_checker(value): msg = format_errors_and_warnings( f"Leaf at {tree_name}{format_key_path(new_key_path)} is " - f"invalid: got {value} of type {type(value)}." + f"invalid: got {value} of type {type(value)}.", ) raise TypeError(msg) @@ -173,7 +173,8 @@ def active_periods_overlap( """ # Create mapping from leaf names to objects. overlap_checker: dict[ - tuple[str, ...], list[ColumnObject | ParamFunction | _ParamWithActivePeriod] + tuple[str, ...], + list[ColumnObject | ParamFunction | _ParamWithActivePeriod], ] = {} for ( orig_path, @@ -189,11 +190,12 @@ def active_periods_overlap( path = (*orig_path[:-2], orig_path[-1]) if path in overlap_checker: overlap_checker[path].extend( - _param_with_active_periods(param_spec=obj, leaf_name=orig_path[-1]) + _param_with_active_periods(param_spec=obj, leaf_name=orig_path[-1]), ) else: overlap_checker[path] = _param_with_active_periods( - param_spec=obj, leaf_name=orig_path[-1] + param_spec=obj, + leaf_name=orig_path[-1], ) # Check for overlapping start and end dates for time-dependent functions. @@ -238,7 +240,7 @@ def data_paths_are_missing_in_paths_to_column_names( msg = format_errors_and_warnings( "Converting the nested data to a DataFrame failed because the following " "paths are not mapped to a column name: " - f"{format_list_linewise(list(missing_paths))}" + f"{format_list_linewise(list(missing_paths))}", ) raise ValueError(msg) @@ -265,7 +267,8 @@ def input_data_tree_is_invalid(input_data__tree: NestedData, xnp: ModuleType) -> assert_valid_ttsim_pytree( tree=input_data__tree, leaf_checker=lambda leaf: isinstance( - leaf, int | pd.Series | numpy.ndarray | xnp.ndarray + leaf, + int | pd.Series | numpy.ndarray | xnp.ndarray, ), tree_name="input_data__tree", ) @@ -300,7 +303,8 @@ def environment_is_invalid( assert_valid_ttsim_pytree( tree=policy_environment, leaf_checker=lambda leaf: isinstance( - leaf, ColumnObject | ParamFunction | ParamObject + leaf, + ColumnObject | ParamFunction | ParamObject, ), tree_name="policy_environment", ) @@ -331,7 +335,7 @@ def foreign_keys_are_invalid_in_data( for fk_name, fk in relevant_objects.items(): if fk.foreign_key_type == FKType.IRRELEVANT: continue - elif fk_name in names__root_nodes: + if fk_name in names__root_nodes: path = dt.tree_path_from_qual_name(fk_name) # Referenced `p_id` must exist in the input data if not all(i in valid_ids for i in processed_data[fk_name].tolist()): @@ -339,7 +343,7 @@ def foreign_keys_are_invalid_in_data( f""" For {path}, the following are not a valid p_id in the input data: {[i for i in processed_data[fk_name] if i not in valid_ids]}. - """ + """, ) raise ValueError(message) @@ -358,7 +362,7 @@ def foreign_keys_are_invalid_in_data( f""" For {path}, the following are equal to the p_id in the same row: {equal_to_pid_in_same_row}. - """ + """, ) raise ValueError(message) @@ -377,7 +381,7 @@ def group_ids_are_outside_top_level_namespace( raise ValueError( "Group identifiers must live in the top-level namespace. Got:\n\n" f"{group_ids_outside_top_level_namespace}\n\n" - "To fix this error, move the group identifiers to the top-level namespace." + "To fix this error, move the group identifiers to the top-level namespace.", ) @@ -408,7 +412,7 @@ def group_variables_are_not_constant_within_groups( group_by_id_series = pd.Series(processed_data[group_by_id]) leaf_series = pd.Series(processed_data[name]) unique_counts = leaf_series.groupby(group_by_id_series).nunique( - dropna=False + dropna=False, ) if not (unique_counts == 1).all(): faulty_data_columns.append(name) @@ -422,7 +426,7 @@ def group_variables_are_not_constant_within_groups( {formatted} To fix this error, assign the same value to each group. - """ + """, ) raise ValueError(msg) @@ -457,14 +461,14 @@ def non_convertible_objects_in_results_tree( "The data contains objects that cannot be cast to a pandas.DataFrame " "column. Make sure that the requested targets return scalars or arrays of " "scalars only. The following paths contain incompatible objects: " - f"{format_list_linewise(paths_with_incorrect_types)}" + f"{format_list_linewise(paths_with_incorrect_types)}", ) raise TypeError(msg) if paths_with_incorrect_length: msg = format_errors_and_warnings( "The data contains paths that don't have the same length as the input data " "and are not scalars. The following paths are faulty: " - f"{format_list_linewise(paths_with_incorrect_length)}" + f"{format_list_linewise(paths_with_incorrect_length)}", ) raise ValueError(msg) @@ -478,7 +482,7 @@ def input_df_has_bool_or_numeric_column_names( """DataFrame column names cannot be booleans or numbers. This restriction prevents ambiguity between actual column references and values intended for broadcasting. - """ + """, ) bool_column_names = [ col for col in input_data__df_and_mapper__df.columns if isinstance(col, bool) @@ -496,7 +500,7 @@ def input_df_has_bool_or_numeric_column_names( Boolean column names: {bool_column_names}. Numeric column names: {numeric_column_names}. - """ + """, ) raise ValueError(msg) @@ -514,7 +518,7 @@ def input_df_mapper_columns_missing_in_df( if missing_columns: msg = format_errors_and_warnings( "All columns in the input mapper must be present in the input dataframe. " - f"The following columns are missing: {missing_columns}" + f"The following columns are missing: {missing_columns}", ) raise ValueError(msg) @@ -527,7 +531,7 @@ def input_df_mapper_has_incorrect_format( if not isinstance(input_data__df_and_mapper__mapper, dict): msg = format_errors_and_warnings( """The inputs tree to column mapping must be a (nested) dictionary. Call - `dags.tree.create_tree_with_input_types` to create a template.""" + `dags.tree.create_tree_with_input_types` to create a template.""", ) raise TypeError(msg) @@ -547,7 +551,7 @@ def input_df_mapper_has_incorrect_format( {format_list_linewise(non_string_paths)} Call `dags.tree.create_tree_with_input_types` to create a template. - """ + """, ) raise TypeError(msg) @@ -566,7 +570,7 @@ def input_df_mapper_has_incorrect_format( Found the following incorrect types: {formatted_incorrect_types} - """ + """, ) raise TypeError(msg) @@ -601,7 +605,7 @@ def root_nodes_are_missing( if missing_nodes: formatted = format_list_linewise( - [str(dt.tree_path_from_qual_name(mn)) for mn in missing_nodes] + [str(dt.tree_path_from_qual_name(mn)) for mn in missing_nodes], ) raise ValueError(f"The following data columns are missing.\n{formatted}") @@ -638,7 +642,7 @@ def targets_are_not_in_policy_environment_or_data( if targets_not_in_policy_environment_or_data: formatted = format_list_linewise(targets_not_in_policy_environment_or_data) msg = format_errors_and_warnings( - f"The following targets have no corresponding function:\n\n{formatted}" + f"The following targets have no corresponding function:\n\n{formatted}", ) raise ValueError(msg) @@ -680,9 +684,7 @@ def format_errors_and_warnings(text: str, width: int = 79) -> str: wrapped_paragraph = textwrap.fill(dedented_paragraph, width=width) wrapped_paragraphs.append(wrapped_paragraph) - formatted_text = "\n\n".join(wrapped_paragraphs) - - return formatted_text + return "\n\n".join(wrapped_paragraphs) def format_list_linewise(some_list: list[Any]) -> str: # type: ignore[type-arg, unused-ignore] @@ -692,7 +694,7 @@ def format_list_linewise(some_list: list[Any]) -> str: # type: ignore[type-arg, [ "{formatted_list}", ] - """ + """, ).format(formatted_list=formatted_list) @@ -736,7 +738,7 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A end_date=end_date, original_function_name=leaf_name, **params_header, - ) + ), ) start_date = None end_date = date - datetime.timedelta(days=1) @@ -748,7 +750,7 @@ def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, A start_date=start_date, end_date=end_date, **params_header, - ) + ), ) return out diff --git a/src/ttsim/interface_dag_elements/interface_node_objects.py b/src/ttsim/interface_dag_elements/interface_node_objects.py index d21b47dba..eff5e5eb4 100644 --- a/src/ttsim/interface_dag_elements/interface_node_objects.py +++ b/src/ttsim/interface_dag_elements/interface_node_objects.py @@ -132,7 +132,9 @@ def __post_init__(self) -> None: _frozen_safe_update_wrapper(self, self.function) def __call__( - self, *args: FunArgTypes.args, **kwargs: FunArgTypes.kwargs + self, + *args: FunArgTypes.args, + **kwargs: FunArgTypes.kwargs, ) -> ReturnType: return self.function(*args, **kwargs) diff --git a/src/ttsim/interface_dag_elements/orig_policy_objects.py b/src/ttsim/interface_dag_elements/orig_policy_objects.py index b8f163152..befadda1f 100644 --- a/src/ttsim/interface_dag_elements/orig_policy_objects.py +++ b/src/ttsim/interface_dag_elements/orig_policy_objects.py @@ -62,7 +62,8 @@ def num_segments() -> int: k: v for path in _find_files_recursively(root=root, suffix=".py") for k, v in _tree_path_to_orig_column_objects_params_functions( - path=path, root=root + path=path, + root=root, ).items() } # Add backend so we can decide between numpy and jax for aggregation functions @@ -120,7 +121,8 @@ def _find_files_recursively(root: Path, suffix: Literal[".py", ".yaml"]) -> list def _tree_path_to_orig_column_objects_params_functions( - path: Path, root: Path + path: Path, + root: Path, ) -> FlatColumnObjectsParamFunctions: """Extract all active PolicyFunctions and GroupByFunctions from a module. diff --git a/src/ttsim/interface_dag_elements/policy_environment.py b/src/ttsim/interface_dag_elements/policy_environment.py index 4825dc766..547438676 100644 --- a/src/ttsim/interface_dag_elements/policy_environment.py +++ b/src/ttsim/interface_dag_elements/policy_environment.py @@ -172,11 +172,11 @@ def _get_one_param( # noqa: PLR0911 if cleaned_spec is None: return None - elif spec["type"] == "scalar": + if spec["type"] == "scalar": return ScalarParam(**cleaned_spec) - elif spec["type"] == "dict": + if spec["type"] == "dict": return DictParam(**cleaned_spec) - elif spec["type"].startswith("piecewise_"): + if spec["type"].startswith("piecewise_"): cleaned_spec["value"] = get_piecewise_parameters( leaf_name=leaf_name, func_type=spec["type"], @@ -184,26 +184,27 @@ def _get_one_param( # noqa: PLR0911 xnp=xnp, ) return PiecewisePolynomialParam(**cleaned_spec) - elif spec["type"] == "consecutive_int_1d_lookup_table": + if spec["type"] == "consecutive_int_1d_lookup_table": cleaned_spec["value"] = get_consecutive_int_1d_lookup_table_param_value( raw=cleaned_spec["value"], xnp=xnp, ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) - elif spec["type"] == "consecutive_int_2d_lookup_table": + if spec["type"] == "consecutive_int_2d_lookup_table": cleaned_spec["value"] = get_consecutive_int_2d_lookup_table_param_value( raw=cleaned_spec["value"], xnp=xnp, ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) - elif spec["type"] == "month_based_phase_inout_of_age_thresholds": + if spec["type"] == "month_based_phase_inout_of_age_thresholds": cleaned_spec["value"] = ( get_month_based_phase_inout_of_age_thresholds_param_value( - raw=cleaned_spec["value"], xnp=xnp + raw=cleaned_spec["value"], + xnp=xnp, ) ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) - elif spec["type"] == "year_based_phase_inout_of_age_thresholds": + if spec["type"] == "year_based_phase_inout_of_age_thresholds": cleaned_spec["value"] = ( get_year_based_phase_inout_of_age_thresholds_param_value( raw=cleaned_spec["value"], @@ -211,14 +212,15 @@ def _get_one_param( # noqa: PLR0911 ) ) return ConsecutiveInt1dLookupTableParam(**cleaned_spec) - elif spec["type"] == "require_converter": + if spec["type"] == "require_converter": return RawParam(**cleaned_spec) - else: - raise ValueError(f"Unknown parameter type: {spec['type']} for {leaf_name}") + raise ValueError(f"Unknown parameter type: {spec['type']} for {leaf_name}") def _clean_one_param_spec( - leaf_name: str, spec: OrigParamSpec, date: datetime.date + leaf_name: str, + spec: OrigParamSpec, + date: datetime.date, ) -> dict[str, Any] | None: """Prepare the specification of one parameter for creating a ParamObject.""" policy_dates = numpy.sort([key for key in spec if isinstance(key, datetime.date)]) @@ -243,12 +245,12 @@ def _clean_one_param_spec( out["reference"] = current_spec.pop("reference", None) if len(current_spec) == 0: return None - elif len(current_spec) == 1 and "updates_previous" in current_spec: + if len(current_spec) == 1 and "updates_previous" in current_spec: raise ValueError( - f"'updates_previous' cannot be specified as the only element, found{spec}" + f"'updates_previous' cannot be specified as the only element, found{spec}", ) # Parameter ceased to exist - elif spec["type"] == "scalar": + if spec["type"] == "scalar": assert "updates_previous" not in current_spec, ( "'updates_previous' cannot be specified for scalar parameters" ) @@ -279,5 +281,4 @@ def _get_param_value( base=_get_param_value(relevant_specs=relevant_specs[:-1]), to_upsert=current_spec, ) - else: - return current_spec + return current_spec diff --git a/src/ttsim/interface_dag_elements/processed_data.py b/src/ttsim/interface_dag_elements/processed_data.py index 2187dca29..629c889d3 100644 --- a/src/ttsim/interface_dag_elements/processed_data.py +++ b/src/ttsim/interface_dag_elements/processed_data.py @@ -40,7 +40,9 @@ def processed_data(input_data__flat: FlatData, xnp: ModuleType) -> QNameData: variable_with_new_ids = xnp.asarray(data) for i in range(new_p_ids.shape[0]): variable_with_new_ids = xnp.where( - data == old_p_ids[i], new_p_ids[i], variable_with_new_ids + data == old_p_ids[i], + new_p_ids[i], + variable_with_new_ids, ) processed_input_data[qname] = variable_with_new_ids else: diff --git a/src/ttsim/interface_dag_elements/raw_results.py b/src/ttsim/interface_dag_elements/raw_results.py index 82adb128a..03014ce68 100644 --- a/src/ttsim/interface_dag_elements/raw_results.py +++ b/src/ttsim/interface_dag_elements/raw_results.py @@ -22,7 +22,7 @@ def columns( specialized_environment__tax_transfer_function: Callable[[QNameData], QNameData], ) -> QNameData: return specialized_environment__tax_transfer_function( - {k: v for k, v in processed_data.items() if k in names__root_nodes} + {k: v for k, v in processed_data.items() if k in names__root_nodes}, ) diff --git a/src/ttsim/interface_dag_elements/results.py b/src/ttsim/interface_dag_elements/results.py index d22540a8a..a16fed0eb 100644 --- a/src/ttsim/interface_dag_elements/results.py +++ b/src/ttsim/interface_dag_elements/results.py @@ -67,7 +67,8 @@ def df_with_mapper( @interface_function() def df_with_nested_columns( - tree: NestedData, input_data__tree: NestedData + tree: NestedData, + input_data__tree: NestedData, ) -> pd.DataFrame: """The results DataFrame with mapped column names. diff --git a/src/ttsim/interface_dag_elements/shared.py b/src/ttsim/interface_dag_elements/shared.py index eee428340..1f319ae13 100644 --- a/src/ttsim/interface_dag_elements/shared.py +++ b/src/ttsim/interface_dag_elements/shared.py @@ -24,14 +24,14 @@ def to_datetime(date: datetime.date | DashedISOString) -> datetime.date: return date if isinstance(date, str) and _DASHED_ISO_DATE_REGEX.fullmatch(date): return datetime.date.fromisoformat(date) - else: - raise ValueError( - f"Date {date} neither matches the format YYYY-MM-DD nor is a datetime.date." - ) + raise ValueError( + f"Date {date} neither matches the format YYYY-MM-DD nor is a datetime.date.", + ) def get_re_pattern_for_all_time_units_and_groupings( - time_units: OrderedQNames, grouping_levels: OrderedQNames + time_units: OrderedQNames, + grouping_levels: OrderedQNames, ) -> re.Pattern[str]: """Get a regex pattern for time units and grouping_levels. @@ -59,13 +59,13 @@ def get_re_pattern_for_all_time_units_and_groupings( f"(?P.*?)" f"(?:_(?P[{re_units}]))?" f"(?:_(?P{re_groupings}))?" - f"$" + f"$", ) def group_pattern(grouping_levels: OrderedQNames) -> re.Pattern[str]: return re.compile( - f"(?P.*)_(?P{'|'.join(grouping_levels)})$" + f"(?P.*)_(?P{'|'.join(grouping_levels)})$", ) @@ -102,7 +102,7 @@ def get_re_pattern_for_specific_time_units_and_groupings( f"(?P{re.escape(base_name)})" f"(?:_(?P[{re_units}]))?" f"(?:_(?P{re_groupings}))?" - f"$" + f"$", ) @@ -215,7 +215,8 @@ def upsert_path_and_value( will be updated. """ to_upsert = create_tree_from_path_and_value( - path=path_to_upsert, value=value_to_upsert + path=path_to_upsert, + value=value_to_upsert, ) return upsert_tree(base=base, to_upsert=to_upsert) @@ -231,7 +232,8 @@ def insert_path_and_value( path must not exist in base. """ to_insert = create_tree_from_path_and_value( - path=path_to_insert, value=value_to_insert + path=path_to_insert, + value=value_to_insert, ) return merge_trees(left=base, right=to_insert) @@ -262,10 +264,10 @@ def partition_tree_by_reference_tree( ref_paths = set(dt.tree_paths(reference_tree)) flat = dt.flatten_to_tree_paths(tree_to_partition) intersection = dt.unflatten_from_tree_paths( - {path: leaf for path, leaf in flat.items() if path in ref_paths} + {path: leaf for path, leaf in flat.items() if path in ref_paths}, ) difference = dt.unflatten_from_tree_paths( - {path: leaf for path, leaf in flat.items() if path not in ref_paths} + {path: leaf for path, leaf in flat.items() if path not in ref_paths}, ) return intersection, difference diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index 3864920a9..dbdde7ce4 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -183,14 +183,12 @@ def _add_derived_functions( targets=targets, grouping_levels=grouping_levels, ) - out = { + return { **qual_name_policy_environment, **time_conversion_functions, **aggregate_by_group_functions, } - return out - @interface_function() def with_processed_params_and_scalars( diff --git a/src/ttsim/interface_dag_elements/typing.py b/src/ttsim/interface_dag_elements/typing.py index ccf6aa520..bdb5f65df 100644 --- a/src/ttsim/interface_dag_elements/typing.py +++ b/src/ttsim/interface_dag_elements/typing.py @@ -15,7 +15,8 @@ | # Parameters at one point in time dict[ - datetime.date, dict[Literal["note", "reference"] | str | int, Any] # noqa: PYI051 + datetime.date, + dict[Literal["note", "reference"] | str | int, Any], # noqa: PYI051 ] ) DashedISOString = NewType("DashedISOString", str) @@ -66,11 +67,13 @@ # Tree-like data structures for policy objects # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # FlatColumnObjectsParamFunctions = Mapping[ - tuple[str, ...], ColumnObject | ParamFunction + tuple[str, ...], + ColumnObject | ParamFunction, ] """Flat mapping of paths to column objects or param functions.""" NestedColumnObjectsParamFunctions = Mapping[ - str, ColumnObject | ParamFunction | "NestedColumnObjectsParamFunctions" + str, + ColumnObject | ParamFunction | "NestedColumnObjectsParamFunctions", ] """Tree of column objects or param functions.""" FlatOrigParamSpecs = dict[tuple[str, ...], OrigParamSpec] @@ -88,7 +91,8 @@ ] """Tree of column objects, param functions, and param objects.""" QNameCombinedEnvironment0 = Mapping[ - str, ColumnObject | ParamFunction | ParamObject | int | float | bool + str, + ColumnObject | ParamFunction | ParamObject | int | float | bool, ] """Map qualified names to column objects, param functions, param objects, or scalars from processed data.""" # noqa: E501 QNameCombinedEnvironment1 = Mapping[str, ColumnObject | Any] diff --git a/src/ttsim/interface_dag_elements/warn_if.py b/src/ttsim/interface_dag_elements/warn_if.py index fe0a10fc9..07b1b7191 100644 --- a/src/ttsim/interface_dag_elements/warn_if.py +++ b/src/ttsim/interface_dag_elements/warn_if.py @@ -41,7 +41,7 @@ def __init__(self, columns_overriding_functions: OrderedQNames) -> None: data column to be calculated by hard-coded functions, remove it from the *data* you pass to TTSIM. You need to pick one option for each column that appears in the list above. - """ + """, ) else: first_part = format_errors_and_warnings("Your data provides the columns:") @@ -53,13 +53,13 @@ def __init__(self, columns_overriding_functions: OrderedQNames) -> None: want data columns to be calculated by hard-coded functions, remove them from the *data* you pass to TTSIM. You need to pick one option for each column that appears in the list above. - """ + """, ) formatted = format_list_linewise(columns_overriding_functions) how_to_ignore = format_errors_and_warnings( """ In order to not perform this check, you can ... TODO - """ + """, ) super().__init__(f"{first_part}\n{formatted}\n{second_part}\n{how_to_ignore}") @@ -75,7 +75,7 @@ def functions_and_data_columns_overlap( col for col in names__processed_data_columns if col in dt.flatten_to_qual_names(policy_environment) - } + }, ) if len(overridden_elements) > 0: warnings.warn( diff --git a/src/ttsim/plot_dag.py b/src/ttsim/plot_dag.py index e0071965d..3a60651c3 100644 --- a/src/ttsim/plot_dag.py +++ b/src/ttsim/plot_dag.py @@ -27,7 +27,10 @@ def plot_tt_dag( - with_params: bool, inputs_for_main: dict[str, Any], title: str, output_path: Path + with_params: bool, + inputs_for_main: dict[str, Any], + title: str, + output_path: Path, ) -> None: """Plot the taxes & transfers DAG, with or without parameters.""" if "backend" not in inputs_for_main: @@ -63,7 +66,7 @@ def plot_tt_dag( ), }, targets=[ - "specialized_environment__with_derived_functions_and_processed_input_nodes" + "specialized_environment__with_derived_functions_and_processed_input_nodes", ], )["specialized_environment__with_derived_functions_and_processed_input_nodes"] # Replace input nodes by PolicyInputs again @@ -79,7 +82,7 @@ def plot_tt_dag( # Only keep nodes that are column objects if not with_params: dag.remove_nodes_from( - [qn for qn, n in env.items() if not isinstance(n, ColumnObject)] + [qn for qn, n in env.items() if not isinstance(n, ColumnObject)], ) fig = _plot_dag(dag=dag, title=title) if output_path.suffix == ".html": @@ -103,7 +106,7 @@ def plot_tt_dag( if args: raise ValueError( "The policy environment DAG should include all root nodes but requires " - f"inputs:\n\n{format_list_linewise(args.keys())}" + f"inputs:\n\n{format_list_linewise(args.keys())}", ) @@ -127,7 +130,7 @@ def plot_full_interface_dag(output_path: Path) -> None: if args: raise ValueError( "The full interface DAG should include all root nodes but requires inputs:" - f"\n\n{format_list_linewise(args.keys())}" + f"\n\n{format_list_linewise(args.keys())}", ) fig = _plot_dag(dag=dag, title="Full Interface DAG") if output_path.suffix == ".html": @@ -139,7 +142,8 @@ def plot_full_interface_dag(output_path: Path) -> None: def _plot_dag(dag: nx.DiGraph, title: str) -> go.Figure: """Plot the DAG.""" nice_dag = nx.relabel_nodes( - dag, {qn: qn.replace("__", "
") for qn in dag.nodes()} + dag, + {qn: qn.replace("__", "
") for qn in dag.nodes()}, ) pos = nx.nx_agraph.pygraphviz_layout(nice_dag, prog="dot", args="-Grankdir=LR") # Create edge traces with arrows @@ -194,7 +198,7 @@ def _plot_dag(dag: nx.DiGraph, title: str) -> go.Figure: "arrowcolor": "#888", "showarrow": True, "text": "", - } + }, ) # Create node trace diff --git a/src/ttsim/testing_utils.py b/src/ttsim/testing_utils.py index 4ed13d190..93850de53 100644 --- a/src/ttsim/testing_utils.py +++ b/src/ttsim/testing_utils.py @@ -32,7 +32,9 @@ @lru_cache(maxsize=100) def cached_policy_environment( - date: datetime.date, root: Path, backend: Literal["numpy", "jax"] + date: datetime.date, + root: Path, + backend: Literal["numpy", "jax"], ) -> NestedPolicyEnvironment: return main( inputs={ @@ -68,7 +70,7 @@ def __init__( @property def target_structure(self) -> NestedInputStructureDict: flat_target_structure = dict.fromkeys( - dt.flatten_to_tree_paths(self.expected_output_tree) + dt.flatten_to_tree_paths(self.expected_output_tree), ) return dt.unflatten_from_tree_paths(flat_target_structure) @@ -78,7 +80,9 @@ def name(self) -> str: def execute_test( - test: PolicyTest, root: Path, backend: Literal["numpy", "jax"] + test: PolicyTest, + root: Path, + backend: Literal["numpy", "jax"], ) -> None: environment = cached_policy_environment(date=test.date, root=root, backend=backend) @@ -129,12 +133,14 @@ def execute_test( expected[cols_with_differences]: {expected_df[cols_with_differences]} -""" +""", ) from e def load_policy_test_data( - test_dir: Path, policy_name: str, xnp: ModuleType + test_dir: Path, + policy_name: str, + xnp: ModuleType, ) -> dict[str, PolicyTest]: """Load all tests found by recursively searching @@ -191,18 +197,18 @@ def _get_policy_test_from_raw_test_data( merge_trees( left=raw_test_data["inputs"].get("provided", {}), right=raw_test_data["inputs"].get("assumed", {}), - ) + ), ).items() - } + }, ) expected_output_tree: NestedData = dt.unflatten_from_tree_paths( { k: xnp.array(v) for k, v in dt.flatten_to_tree_paths( - raw_test_data.get("outputs", {}) + raw_test_data.get("outputs", {}), ).items() - } + }, ) date: datetime.date = to_datetime(path_to_yaml.parent.name) diff --git a/src/ttsim/tt_dag_elements/aggregation.py b/src/ttsim/tt_dag_elements/aggregation.py index 92f18bcb2..6341f2d49 100644 --- a/src/ttsim/tt_dag_elements/aggregation.py +++ b/src/ttsim/tt_dag_elements/aggregation.py @@ -26,12 +26,13 @@ class AggType(StrEnum): # The signature of the functions must be the same in both modules, except that all JAX # functions have the additional `num_segments` argument. def grouped_count( - group_id: IntColumn, num_segments: int, backend: Literal["numpy", "jax"] + group_id: IntColumn, + num_segments: int, + backend: Literal["numpy", "jax"], ) -> IntColumn: if backend == "numpy": return aggregation_numpy.grouped_count(group_id) - else: - return aggregation_jax.grouped_count(group_id, num_segments) + return aggregation_jax.grouped_count(group_id, num_segments) def grouped_sum( @@ -42,8 +43,7 @@ def grouped_sum( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_sum(column, group_id) - else: - return aggregation_jax.grouped_sum(column, group_id, num_segments) + return aggregation_jax.grouped_sum(column, group_id, num_segments) def grouped_mean( @@ -54,8 +54,7 @@ def grouped_mean( ) -> FloatColumn: if backend == "numpy": return aggregation_numpy.grouped_mean(column, group_id) - else: - return aggregation_jax.grouped_mean(column, group_id, num_segments) + return aggregation_jax.grouped_mean(column, group_id, num_segments) def grouped_max( @@ -66,8 +65,7 @@ def grouped_max( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_max(column, group_id) - else: - return aggregation_jax.grouped_max(column, group_id, num_segments) + return aggregation_jax.grouped_max(column, group_id, num_segments) def grouped_min( @@ -78,8 +76,7 @@ def grouped_min( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.grouped_min(column, group_id) - else: - return aggregation_jax.grouped_min(column, group_id, num_segments) + return aggregation_jax.grouped_min(column, group_id, num_segments) def grouped_any( @@ -90,8 +87,7 @@ def grouped_any( ) -> BoolColumn: if backend == "numpy": return aggregation_numpy.grouped_any(column, group_id) - else: - return aggregation_jax.grouped_any(column, group_id, num_segments) + return aggregation_jax.grouped_any(column, group_id, num_segments) def grouped_all( @@ -102,8 +98,7 @@ def grouped_all( ) -> BoolColumn: if backend == "numpy": return aggregation_numpy.grouped_all(column, group_id) - else: - return aggregation_jax.grouped_all(column, group_id, num_segments) + return aggregation_jax.grouped_all(column, group_id, num_segments) def count_by_p_id( @@ -114,10 +109,11 @@ def count_by_p_id( ) -> IntColumn: if backend == "numpy": return aggregation_numpy.count_by_p_id(p_id_to_aggregate_by, p_id_to_store_by) - else: - return aggregation_jax.count_by_p_id( - p_id_to_aggregate_by, p_id_to_store_by, num_segments - ) + return aggregation_jax.count_by_p_id( + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def sum_by_p_id( @@ -129,12 +125,16 @@ def sum_by_p_id( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.sum_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.sum_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.sum_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def mean_by_p_id( @@ -146,12 +146,16 @@ def mean_by_p_id( ) -> FloatColumn: if backend == "numpy": return aggregation_numpy.mean_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.mean_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.mean_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def max_by_p_id( @@ -163,12 +167,16 @@ def max_by_p_id( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.max_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.max_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.max_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def min_by_p_id( @@ -180,12 +188,16 @@ def min_by_p_id( ) -> FloatColumn | IntColumn: if backend == "numpy": return aggregation_numpy.min_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.min_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.min_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def any_by_p_id( @@ -197,12 +209,16 @@ def any_by_p_id( ) -> BoolColumn: if backend == "numpy": return aggregation_numpy.any_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.any_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.any_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) def all_by_p_id( @@ -214,9 +230,13 @@ def all_by_p_id( ) -> BoolColumn: if backend == "numpy": return aggregation_numpy.all_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by - ) - else: - return aggregation_jax.all_by_p_id( - column, p_id_to_aggregate_by, p_id_to_store_by, num_segments + column, + p_id_to_aggregate_by, + p_id_to_store_by, ) + return aggregation_jax.all_by_p_id( + column, + p_id_to_aggregate_by, + p_id_to_store_by, + num_segments, + ) diff --git a/src/ttsim/tt_dag_elements/aggregation_jax.py b/src/ttsim/tt_dag_elements/aggregation_jax.py index 1044f488c..616305401 100644 --- a/src/ttsim/tt_dag_elements/aggregation_jax.py +++ b/src/ttsim/tt_dag_elements/aggregation_jax.py @@ -14,58 +14,80 @@ def grouped_count(group_id: IntColumn, num_segments: int) -> jnp.ndarray: out_grouped = segment_sum( - data=jnp.ones(len(group_id)), segment_ids=group_id, num_segments=num_segments + data=jnp.ones(len(group_id)), + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] def grouped_sum( - column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn, num_segments: int + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, + num_segments: int, ) -> FloatColumn | IntColumn: if column.dtype in ["bool"]: column = column.astype(int) out_grouped = segment_sum( - data=column, segment_ids=group_id, num_segments=num_segments + data=column, + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] def grouped_mean( - column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn, num_segments: int + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, + num_segments: int, ) -> FloatColumn: if column.dtype in ["bool"]: column = column.astype(int) sum_grouped = segment_sum( - data=column, segment_ids=group_id, num_segments=num_segments + data=column, + segment_ids=group_id, + num_segments=num_segments, ) sizes = segment_sum( - data=jnp.ones(len(column)), segment_ids=group_id, num_segments=num_segments + data=jnp.ones(len(column)), + segment_ids=group_id, + num_segments=num_segments, ) mean_grouped = sum_grouped / sizes return mean_grouped[group_id] def grouped_max( - column: FloatColumn | IntColumn, group_id: IntColumn, num_segments: int + column: FloatColumn | IntColumn, + group_id: IntColumn, + num_segments: int, ) -> FloatColumn | IntColumn: out_grouped = segment_max( - data=column, segment_ids=group_id, num_segments=num_segments + data=column, + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] def grouped_min( - column: FloatColumn | IntColumn, group_id: IntColumn, num_segments: int + column: FloatColumn | IntColumn, + group_id: IntColumn, + num_segments: int, ) -> FloatColumn | IntColumn: out_grouped = segment_min( - data=column, segment_ids=group_id, num_segments=num_segments + data=column, + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] def grouped_any( - column: BoolColumn | IntColumn, group_id: IntColumn, num_segments: int + column: BoolColumn | IntColumn, + group_id: IntColumn, + num_segments: int, ) -> BoolColumn: # Convert to boolean if necessary if jnp.issubdtype(column.dtype, jnp.integer): @@ -74,20 +96,26 @@ def grouped_any( my_col = column out_grouped = segment_max( - data=my_col, segment_ids=group_id, num_segments=num_segments + data=my_col, + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] def grouped_all( - column: BoolColumn | IntColumn, group_id: IntColumn, num_segments: int + column: BoolColumn | IntColumn, + group_id: IntColumn, + num_segments: int, ) -> BoolColumn: # Convert to boolean if necessary if jnp.issubdtype(column.dtype, jnp.integer): column = column.astype("bool") out_grouped = segment_min( - data=column, segment_ids=group_id, num_segments=num_segments + data=column, + segment_ids=group_id, + num_segments=num_segments, ) return out_grouped[group_id] diff --git a/src/ttsim/tt_dag_elements/aggregation_numpy.py b/src/ttsim/tt_dag_elements/aggregation_numpy.py index 029ad4406..31afc2ba1 100644 --- a/src/ttsim/tt_dag_elements/aggregation_numpy.py +++ b/src/ttsim/tt_dag_elements/aggregation_numpy.py @@ -12,14 +12,18 @@ def grouped_count(group_id: IntColumn) -> IntColumn: fail_if_dtype_not_int(group_id, agg_func="grouped_count") out_grouped = npg.aggregate( - group_id, numpy.ones(len(group_id), dtype=int), func="sum", fill_value=0 + group_id, + numpy.ones(len(group_id), dtype=int), + func="sum", + fill_value=0, ) return out_grouped[group_id] def grouped_sum( - column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, ) -> FloatColumn | IntColumn: fail_if_dtype_not_numeric_or_boolean(column, agg_func="grouped_sum") fail_if_dtype_not_int(group_id, agg_func="grouped_sum") @@ -32,7 +36,8 @@ def grouped_sum( def grouped_mean( - column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, ) -> FloatColumn: fail_if_dtype_not_numeric_or_boolean(column, agg_func="grouped_mean") fail_if_dtype_not_int(group_id, agg_func="grouped_mean") @@ -44,7 +49,8 @@ def grouped_mean( def grouped_max( - column: FloatColumn | IntColumn | BoolColumn, group_id: IntColumn + column: FloatColumn | IntColumn | BoolColumn, + group_id: IntColumn, ) -> FloatColumn | IntColumn: fail_if_dtype_not_numeric_or_datetime(column, agg_func="grouped_max") fail_if_dtype_not_int(group_id, agg_func="grouped_max") @@ -71,7 +77,8 @@ def grouped_max( def grouped_min( - column: FloatColumn | IntColumn, group_id: IntColumn + column: FloatColumn | IntColumn, + group_id: IntColumn, ) -> FloatColumn | IntColumn: fail_if_dtype_not_numeric_or_datetime(column, agg_func="grouped_min") fail_if_dtype_not_int(group_id, agg_func="grouped_min") @@ -80,7 +87,8 @@ def grouped_min( # numba is installed) if numpy.issubdtype(column.dtype, numpy.datetime64) or numpy.issubdtype( - column.dtype, numpy.timedelta64 + column.dtype, + numpy.timedelta64, ): dtype = column.dtype float_col = column.astype("datetime64[D]").astype(int) @@ -121,7 +129,8 @@ def grouped_all(column: BoolColumn | IntColumn, group_id: IntColumn) -> BoolColu def count_by_p_id( - p_id_to_aggregate_by: IntColumn, p_id_to_store_by: IntColumn + p_id_to_aggregate_by: IntColumn, + p_id_to_store_by: IntColumn, ) -> IntColumn: fail_if_dtype_not_int(p_id_to_aggregate_by, agg_func="count_by_p_id") fail_if_dtype_not_int(p_id_to_store_by, agg_func="count_by_p_id") @@ -206,22 +215,24 @@ def all_by_p_id( def fail_if_dtype_not_numeric( - column: FloatColumn | IntColumn | BoolColumn, agg_func: str + column: FloatColumn | IntColumn | BoolColumn, + agg_func: str, ) -> None: if not numpy.issubdtype(column.dtype, numpy.number): raise TypeError( f"Aggregation function {agg_func} was applied to a column " - f"with dtype {column.dtype}. Allowed are only numerical dtypes." + f"with dtype {column.dtype}. Allowed are only numerical dtypes.", ) def fail_if_dtype_not_float( - column: FloatColumn | IntColumn | BoolColumn, agg_func: str + column: FloatColumn | IntColumn | BoolColumn, + agg_func: str, ) -> None: if not numpy.issubdtype(column.dtype, numpy.floating): raise TypeError( f"Aggregation function {agg_func} was applied to a column " - f"with dtype {column.dtype}. Allowed is only float." + f"with dtype {column.dtype}. Allowed is only float.", ) @@ -229,22 +240,24 @@ def fail_if_dtype_not_int(p_id_to_aggregate_by: IntColumn, agg_func: str) -> Non if not numpy.issubdtype(p_id_to_aggregate_by.dtype, numpy.integer): raise TypeError( f"The dtype of id columns must be integer. Aggregation function {agg_func} " - f"was applied to a id columns that has dtype {p_id_to_aggregate_by.dtype}." + f"was applied to a id columns that has dtype {p_id_to_aggregate_by.dtype}.", ) def fail_if_dtype_not_numeric_or_boolean( - column: FloatColumn | IntColumn | BoolColumn, agg_func: str + column: FloatColumn | IntColumn | BoolColumn, + agg_func: str, ) -> None: if not (numpy.issubdtype(column.dtype, numpy.number) or column.dtype == "bool"): raise TypeError( f"Aggregation function {agg_func} was applied to a column with dtype " - f"{column.dtype}. Allowed are only numerical or Boolean dtypes." + f"{column.dtype}. Allowed are only numerical or Boolean dtypes.", ) def fail_if_dtype_not_numeric_or_datetime( - column: FloatColumn | IntColumn | BoolColumn, agg_func: str + column: FloatColumn | IntColumn | BoolColumn, + agg_func: str, ) -> None: if not ( numpy.issubdtype(column.dtype, numpy.number) @@ -252,12 +265,13 @@ def fail_if_dtype_not_numeric_or_datetime( ): raise TypeError( f"Aggregation function {agg_func} was applied to a column with dtype " - f"{column.dtype}. Allowed are only numerical or datetime dtypes." + f"{column.dtype}. Allowed are only numerical or datetime dtypes.", ) def fail_if_dtype_not_boolean_or_int( - column: BoolColumn | IntColumn, agg_func: str + column: BoolColumn | IntColumn, + agg_func: str, ) -> None: if not ( numpy.issubdtype(column.dtype, numpy.integer) @@ -265,5 +279,5 @@ def fail_if_dtype_not_boolean_or_int( ): raise TypeError( f"Aggregation function {agg_func} was applied to a column with dtype " - f"{column.dtype}. Allowed are only Boolean and int dtypes." + f"{column.dtype}. Allowed are only Boolean and int dtypes.", ) diff --git a/src/ttsim/tt_dag_elements/column_objects_param_function.py b/src/ttsim/tt_dag_elements/column_objects_param_function.py index 1429e0ff7..b6f7ae776 100644 --- a/src/ttsim/tt_dag_elements/column_objects_param_function.py +++ b/src/ttsim/tt_dag_elements/column_objects_param_function.py @@ -223,7 +223,9 @@ def __post_init__(self) -> None: _frozen_safe_update_wrapper(self, self.function) def __call__( - self, *args: FunArgTypes.args, **kwargs: FunArgTypes.kwargs + self, + *args: FunArgTypes.args, + **kwargs: FunArgTypes.kwargs, ) -> ReturnType: return self.function(*args, **kwargs) @@ -466,8 +468,9 @@ def group_creation_function( def decorator(func: GenericCallable) -> GroupCreationFunction: _leaf_name = func.__name__ if leaf_name is None else leaf_name - func_with_reorder = lambda **kwargs: reorder_ids( - ids=func(**kwargs), xnp=kwargs["xnp"] + func_with_reorder = lambda **kwargs: reorder_ids( # noqa: E731 + ids=func(**kwargs), + xnp=kwargs["xnp"], ) functools.update_wrapper(func_with_reorder, func) @@ -577,34 +580,37 @@ def inner(func: GenericCallable) -> AggByGroupFunction: def _fail_if_group_id_is_invalid( - group_ids: UnorderedQNames, orig_location: str + group_ids: UnorderedQNames, + orig_location: str, ) -> None: if len(group_ids) != 1: raise ValueError( "Require exactly one group identifier ending with '_id' for " "aggregation by group. Got " - f"{', '.join(group_ids) if group_ids else 'nothing'} in {orig_location}." + f"{', '.join(group_ids) if group_ids else 'nothing'} in {orig_location}.", ) def _fail_if_other_arg_is_present( - other_args: UnorderedQNames, orig_location: str + other_args: UnorderedQNames, + orig_location: str, ) -> None: if other_args: raise ValueError( "There must be no argument besides identifiers for counting. Got: " - f"{', '.join(other_args) if other_args else 'nothing'} in {orig_location}." + f"{', '.join(other_args) if other_args else 'nothing'} in {orig_location}.", ) def _fail_if_other_arg_is_invalid( - other_args: UnorderedQNames, orig_location: str + other_args: UnorderedQNames, + orig_location: str, ) -> None: if len(other_args) != 1: raise ValueError( "There must be exactly one argument besides identifiers, num_segments, and " "backend for aggregations. Got: " - f"{', '.join(other_args) if other_args else 'nothing'} in {orig_location}." + f"{', '.join(other_args) if other_args else 'nothing'} in {orig_location}.", ) @@ -722,18 +728,19 @@ def _fail_if_p_id_is_not_present(args: UnorderedQNames, orig_location: str) -> N if "p_id" not in args: raise ValueError( "The function must have the argument named 'p_id' for aggregation by p_id. " - f"Got {', '.join(args) if args else 'nothing'} in {orig_location}." + f"Got {', '.join(args) if args else 'nothing'} in {orig_location}.", ) def _fail_if_other_p_id_is_invalid( - other_p_ids: UnorderedQNames, orig_location: str + other_p_ids: UnorderedQNames, + orig_location: str, ) -> None: if len(other_p_ids) != 1: raise ValueError( "Require exactly one identifier starting with 'p_id_' for " "aggregation by p_id. Got: " - f"{', '.join(other_p_ids) if other_p_ids else 'nothing'} in {orig_location}." # noqa: E501 + f"{', '.join(other_p_ids) if other_p_ids else 'nothing'} in {orig_location}.", # noqa: E501 ) @@ -807,7 +814,7 @@ def _convert_and_validate_dates( if start_date > end_date: raise ValueError( - f"The start date {start_date} must be before the end date {end_date}." + f"The start date {start_date} must be before the end date {end_date}.", ) return start_date, end_date @@ -840,7 +847,9 @@ def __post_init__(self) -> None: _frozen_safe_update_wrapper(self, self.function) def __call__( - self, *args: FunArgTypes.args, **kwargs: FunArgTypes.kwargs + self, + *args: FunArgTypes.args, + **kwargs: FunArgTypes.kwargs, ) -> ReturnType: return self.function(*args, **kwargs) diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index ba514b429..7e204d831 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -210,7 +210,7 @@ def get_consecutive_int_2d_lookup_table_param_value( [ raw[row][col] for row, col in itertools.product(lookup_keys_rows, lookup_keys_cols) - ] + ], ).reshape(len(lookup_keys_rows), len(lookup_keys_cols)), ) @@ -251,12 +251,14 @@ def _fill_phase_inout( first_year_phase_inout: int = min(raw.keys()) # type: ignore[assignment] first_month_phase_inout: int = min(raw[first_year_phase_inout].keys()) first_m_since_ad_phase_inout = _m_since_ad( - y=first_year_phase_inout, m=first_month_phase_inout + y=first_year_phase_inout, + m=first_month_phase_inout, ) last_year_phase_inout: int = max(raw.keys()) # type: ignore[assignment] last_month_phase_inout: int = max(raw[last_year_phase_inout].keys()) last_m_since_ad_phase_inout = _m_since_ad( - y=last_year_phase_inout, m=last_month_phase_inout + y=last_year_phase_inout, + m=last_month_phase_inout, ) assert first_m_since_ad_to_consider <= first_m_since_ad_phase_inout assert last_m_since_ad_to_consider >= last_m_since_ad_phase_inout @@ -272,11 +274,13 @@ def _fill_phase_inout( after_phase_inout: dict[int, float] = { b_m: _year_fraction(raw[last_year_phase_inout][last_month_phase_inout]) for b_m in range( - last_m_since_ad_phase_inout + 1, last_m_since_ad_to_consider + 1 + last_m_since_ad_phase_inout + 1, + last_m_since_ad_to_consider + 1, ) } return get_consecutive_int_1d_lookup_table_param_value( - raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp=xnp + raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, + xnp=xnp, ) @@ -308,5 +312,6 @@ def get_year_based_phase_inout_of_age_thresholds_param_value( for b_y in range(last_year_phase_inout + 1, last_year_to_consider + 1) } return get_consecutive_int_1d_lookup_table_param_value( - raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, xnp=xnp + raw={**before_phase_inout, **during_phase_inout, **after_phase_inout}, + xnp=xnp, ) diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 6f0a8c269..3f315631c 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -120,7 +120,7 @@ def get_piecewise_parameters( if sorted(parameter_dict) != list(range(len(parameter_dict))): raise ValueError( f"The keys of {leaf_name} do not start with 0 or are not consecutive" - f" numbers." + f" numbers.", ) # Extract lower thresholds. @@ -186,14 +186,14 @@ def check_and_get_thresholds( # noqa: C901 # Check if lowest threshold exists. if "lower_threshold" not in parameter_dict[0]: raise ValueError( - f"The first piece of {leaf_name} needs to contain a lower_threshold value." + f"The first piece of {leaf_name} needs to contain a lower_threshold value.", ) lower_thresholds[0] = parameter_dict[0]["lower_threshold"] # Check if highest upper_threshold exists. if "upper_threshold" not in parameter_dict[keys[-1]]: raise ValueError( - f"The last piece of {leaf_name} needs to contain an upper_threshold value." + f"The last piece of {leaf_name} needs to contain an upper_threshold value.", ) upper_thresholds[keys[-1]] = parameter_dict[keys[-1]]["upper_threshold"] @@ -209,7 +209,7 @@ def check_and_get_thresholds( # noqa: C901 else: raise ValueError( f"In {interval} of {leaf_name} is no lower upper threshold or an upper" - f" in the piece before." + f" in the piece before.", ) for interval in keys[:-1]: @@ -220,12 +220,12 @@ def check_and_get_thresholds( # noqa: C901 else: raise ValueError( f"In {interval} of {leaf_name} is no upper threshold or a lower" - f" threshold in the piece after." + f" threshold in the piece after.", ) if not numpy.allclose(lower_thresholds[1:], upper_thresholds[:-1]): raise ValueError( - f"The lower and upper thresholds of {leaf_name} have to coincide" + f"The lower and upper thresholds of {leaf_name} have to coincide", ) thresholds = sorted([lower_thresholds[0], *upper_thresholds]) return ( @@ -267,7 +267,7 @@ def _check_and_get_rates( rates[i, interval] = parameter_dict[interval][rate_type] else: raise ValueError( - f"In interval {interval} of {leaf_name}, {rate_type} is missing." + f"In interval {interval} of {leaf_name}, {rate_type} is missing.", ) return xnp.array(rates) @@ -291,28 +291,31 @@ def _check_and_get_intercepts( if "intercept_at_lower_threshold" not in parameter_dict[0]: raise ValueError(f"The first piece of {leaf_name} needs an intercept.") - else: - intercepts[0] = parameter_dict[0]["intercept_at_lower_threshold"] - # Check if all intercepts are supplied. - for interval in keys[1:]: - if "intercept_at_lower_threshold" in parameter_dict[interval]: - count_intercepts_supplied += 1 - intercepts[interval] = parameter_dict[interval][ - "intercept_at_lower_threshold" - ] - if (count_intercepts_supplied > 1) & (count_intercepts_supplied != len(keys)): - raise ValueError( - "More than one, but not all intercepts are supplied. " - "The dictionaries should contain either only the lowest intercept " - "or all intercepts." - ) - elif count_intercepts_supplied == len(keys): - pass + intercepts[0] = parameter_dict[0]["intercept_at_lower_threshold"] + # Check if all intercepts are supplied. + for interval in keys[1:]: + if "intercept_at_lower_threshold" in parameter_dict[interval]: + count_intercepts_supplied += 1 + intercepts[interval] = parameter_dict[interval][ + "intercept_at_lower_threshold" + ] + if (count_intercepts_supplied > 1) & (count_intercepts_supplied != len(keys)): + raise ValueError( + "More than one, but not all intercepts are supplied. " + "The dictionaries should contain either only the lowest intercept " + "or all intercepts.", + ) + if count_intercepts_supplied == len(keys): + pass - else: - intercepts = _create_intercepts( - lower_thresholds, upper_thresholds, rates, intercepts[0], xnp=xnp - ) + else: + intercepts = _create_intercepts( + lower_thresholds, + upper_thresholds, + rates, + intercepts[0], + xnp=xnp, + ) return xnp.array(intercepts) diff --git a/src/ttsim/tt_dag_elements/rounding.py b/src/ttsim/tt_dag_elements/rounding.py index c8f142f3d..1a9e6cd80 100644 --- a/src/ttsim/tt_dag_elements/rounding.py +++ b/src/ttsim/tt_dag_elements/rounding.py @@ -30,15 +30,18 @@ def __post_init__(self) -> None: valid_directions = get_args(ROUNDING_DIRECTION) if self.direction not in valid_directions: raise ValueError( - f"`direction` must be one of {valid_directions}, got {self.direction!r}" + f"`direction` must be one of {valid_directions}, " + f"got {self.direction!r}", ) if type(self.to_add_after_rounding) not in [int, float]: raise ValueError( - f"Additive part must be a number, got {self.to_add_after_rounding!r}" + f"Additive part must be a number, got {self.to_add_after_rounding!r}", ) def apply_rounding( - self, func: Callable[P, FloatColumn], xnp: ModuleType + self, + func: Callable[P, FloatColumn], + xnp: ModuleType, ) -> Callable[P, FloatColumn]: """Decorator to round the output of a function. diff --git a/src/ttsim/tt_dag_elements/shared.py b/src/ttsim/tt_dag_elements/shared.py index a6de2e2d9..cafdbcfeb 100644 --- a/src/ttsim/tt_dag_elements/shared.py +++ b/src/ttsim/tt_dag_elements/shared.py @@ -42,7 +42,10 @@ def join( # For each foreign key, add a column with True at the end, to later fall back to # the value for unresolved foreign keys padded_matches_foreign_key = xnp.pad( - matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True + matches_foreign_key, + ((0, 0), (0, 1)), + "constant", + constant_values=True, ) # For each foreign key, compute the index of the first matching primary key @@ -50,7 +53,10 @@ def join( # Add the value for unresolved foreign keys at the end of the target array padded_targets = xnp.pad( - target, (0, 1), "constant", constant_values=value_if_foreign_key_is_missing + target, + (0, 1), + "constant", + constant_values=value_if_foreign_key_is_missing, ) # Return the target at the index of the first matching primary key diff --git a/src/ttsim/tt_dag_elements/typing.py b/src/ttsim/tt_dag_elements/typing.py index f399a2633..2a2688ec3 100644 --- a/src/ttsim/tt_dag_elements/typing.py +++ b/src/ttsim/tt_dag_elements/typing.py @@ -11,7 +11,8 @@ | # Parameters at one point in time dict[ - datetime.date, dict[Literal["note", "reference"] | str | int, Any] # noqa: PYI051 + datetime.date, + dict[Literal["note", "reference"] | str | int, Any], # noqa: PYI051 ] ) DashedISOString = NewType("DashedISOString", str) diff --git a/src/ttsim/tt_dag_elements/vectorization.py b/src/ttsim/tt_dag_elements/vectorization.py index b12ed3f66..2916470ed 100644 --- a/src/ttsim/tt_dag_elements/vectorization.py +++ b/src/ttsim/tt_dag_elements/vectorization.py @@ -36,13 +36,15 @@ def vectorize_function( else: raise ValueError( f"Vectorization strategy {vectorization_strategy} is not supported. " - "Use 'loop' or 'vectorize'." + "Use 'loop' or 'vectorize'.", ) return vectorized def _make_vectorizable( - func: GenericCallable, backend: str, xnp: ModuleType + func: GenericCallable, + backend: str, + xnp: ModuleType, ) -> GenericCallable: """Redefine function to be vectorizable given backend. @@ -58,7 +60,7 @@ def _make_vectorizable( if _is_lambda_function(func): raise TranslateToVectorizableError( "Lambda functions are not supported for vectorization. Please define a " - "named function and use that." + "named function and use that.", ) module = _module_from_backend(backend) @@ -83,7 +85,9 @@ def _make_vectorizable( def make_vectorizable_source( - func: GenericCallable, backend: str, xnp: ModuleType + func: GenericCallable, + backend: str, + xnp: ModuleType, ) -> str: """Redefine function source to be vectorizable given backend. @@ -100,7 +104,7 @@ def make_vectorizable_source( if _is_lambda_function(func): raise TranslateToVectorizableError( "Lambda functions are not supported for vectorization. Please define a " - "named function and use that." + "named function and use that.", ) module = _module_from_backend(backend) @@ -109,7 +113,9 @@ def make_vectorizable_source( def _make_vectorizable_ast( - func: GenericCallable, module: str, xnp: ModuleType + func: GenericCallable, + module: str, + xnp: ModuleType, ) -> ast.Module: """Change if statement to where call in the ast of func and return new ast. @@ -142,8 +148,7 @@ def _remove_decorator_lines(source: str) -> str: """Removes leading decorator lines from function source code.""" if source.startswith("def "): return source - else: - return "def " + source.split("\ndef ")[1] + return "def " + source.split("\ndef ")[1] # ====================================================================================== @@ -160,21 +165,24 @@ def __init__(self, module: str, func_loc: str, xnp: ModuleType) -> None: def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802 self.generic_visit(node) return _call_to_call_from_module( - node, module=self.module, func_loc=self.func_loc, xnp=self.xnp + node, + module=self.module, + func_loc=self.func_loc, + xnp=self.xnp, ) def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.UnaryOp | ast.Call: # noqa: N802 if isinstance(node.op, ast.Not): return _not_to_call(node, module=self.module) - else: - return node + return node def visit_BoolOp(self, node: ast.BoolOp) -> ast.Call: # noqa: N802 self.generic_visit(node) return _boolop_to_call(node, module=self.module) def visit_If( # noqa: N802 - self, node: ast.If + self, + node: ast.If, ) -> ast.Call | ast.Return | ast.Assign | ast.AugAssign: self.generic_visit(node) call = _if_to_call(node, module=self.module, func_loc=self.func_loc) @@ -218,11 +226,11 @@ def _if_to_call(node: ast.If, module: str, func_loc: str) -> ast.Call: if len(node.orelse) > 1 or len(node.body) > 1: msg = _too_many_operations_error_message(node, func_loc=func_loc) raise TranslateToVectorizableError(msg) - elif node.orelse == []: + if node.orelse == []: if isinstance(node.body[0], ast.Return): msg = _return_and_no_else_error_message(node.body[0], func_loc=func_loc) raise TranslateToVectorizableError(msg) - elif hasattr(node.body[0], "targets"): + if hasattr(node.body[0], "targets"): name = ast.Name(id=node.body[0].targets[0].id, ctx=ast.Load()) else: name = ast.Name(id=node.body[0].target.id, ctx=ast.Load()) # type: ignore[attr-defined] @@ -244,7 +252,9 @@ def _if_to_call(node: ast.If, module: str, func_loc: str) -> ast.Call: return ast.Call( func=ast.Attribute( - value=ast.Name(id=module, ctx=ast.Load()), attr="where", ctx=ast.Load() + value=ast.Name(id=module, ctx=ast.Load()), + attr="where", + ctx=ast.Load(), ), args=args, keywords=[], @@ -263,7 +273,9 @@ def _ifexp_to_call(node: ast.IfExp, module: str) -> ast.Call: return ast.Call( func=ast.Attribute( - value=ast.Name(id=module, ctx=ast.Load()), attr="where", ctx=ast.Load() + value=ast.Name(id=module, ctx=ast.Load()), + attr="where", + ctx=ast.Load(), ), args=args, keywords=[], @@ -295,7 +307,10 @@ def _constructor(left: ast.Call | ast.expr, right: ast.Call | ast.expr) -> ast.C def _call_to_call_from_module( - node: ast.Call, module: str, func_loc: str, xnp: ModuleType + node: ast.Call, + module: str, + func_loc: str, + xnp: ModuleType, ) -> ast.AST: """Transform built-in Calls to Calls from module.""" to_transform = ("sum", "any", "all", "max", "min") @@ -314,7 +329,7 @@ def _call_to_call_from_module( raise TranslateToVectorizableError( f"Argument of function {func_id} is not a list, tuple, or valid array." f"\n\nFunction: {func_loc}\n\n" - f"Problematic source code: \n\n{_node_to_formatted_source(node)}\n" + f"Problematic source code: \n\n{_node_to_formatted_source(node)}\n", ) call.func = ast.Attribute( @@ -413,5 +428,5 @@ def _module_from_backend(backend: str) -> str: return BACKEND_TO_MODULE[backend] raise NotImplementedError( - f"Argument 'backend' is {backend} but must be in {BACKEND_TO_MODULE.keys()}." + f"Argument 'backend' is {backend} but must be in {BACKEND_TO_MODULE.keys()}.", ) diff --git a/tests/ttsim/test_automatically_added_functions.py b/tests/ttsim/test_automatically_added_functions.py index 2233810e9..dedb29d36 100644 --- a/tests/ttsim/test_automatically_added_functions.py +++ b/tests/ttsim/test_automatically_added_functions.py @@ -286,11 +286,13 @@ class TestCreateFunctionsForTimeUnits: ], ) def test_should_create_functions_for_other_time_units_for_functions( - self, name: str, expected: list[str] + self, + name: str, + expected: list[str], ) -> None: time_conversion_functions = create_time_conversion_functions( qual_name_policy_environment={ - name: policy_function(leaf_name=name)(return_one) + name: policy_function(leaf_name=name)(return_one), }, processed_data_columns=set(), grouping_levels=("sn", "kin"), @@ -302,7 +304,7 @@ def test_should_create_functions_for_other_time_units_for_functions( def test_should_not_create_functions_automatically_that_exist_already(self) -> None: time_conversion_functions = create_time_conversion_functions( qual_name_policy_environment={ - "test1_d": policy_function(leaf_name="test1_d")(return_one) + "test1_d": policy_function(leaf_name="test1_d")(return_one), }, processed_data_columns={"test2_y"}, grouping_levels=("sn", "kin"), @@ -316,7 +318,7 @@ def test_should_overwrite_functions_with_data_cols_that_only_differ_in_time_peri ) -> None: time_conversion_functions = create_time_conversion_functions( qual_name_policy_environment={ - "test_d": policy_function(leaf_name="test_d")(return_one) + "test_d": policy_function(leaf_name="test_d")(return_one), }, processed_data_columns={"test_y"}, grouping_levels=("sn", "kin"), diff --git a/tests/ttsim/test_convert_nested_data.py b/tests/ttsim/test_convert_nested_data.py index 4b651c6f9..d5d930f42 100644 --- a/tests/ttsim/test_convert_nested_data.py +++ b/tests/ttsim/test_convert_nested_data.py @@ -113,7 +113,9 @@ def test_dataframe_to_nested_data( assert set(flat_result.keys()) == set(flat_expected_output.keys()) for key in flat_result: pd.testing.assert_series_equal( - pd.Series(flat_result[key]), flat_expected_output[key], check_names=False + pd.Series(flat_result[key]), + flat_expected_output[key], + check_names=False, ) diff --git a/tests/ttsim/test_end_to_end.py b/tests/ttsim/test_end_to_end.py index 88d1acbf4..81e41a90b 100644 --- a/tests/ttsim/test_end_to_end.py +++ b/tests/ttsim/test_end_to_end.py @@ -15,7 +15,7 @@ "child_tax_credit_recipient": [-1, -1, 0], "gross_wage_y": [10000, 0, 0], "wealth": [0.0, 0.0, 0.0], - } + }, ) @@ -45,7 +45,7 @@ "child_tax_credit": { "amount_m": "payroll_tax_child_tax_credit_amount_m", }, - } + }, } diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 0fd147da5..66d3ce519 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -121,11 +121,10 @@ def fam_id() -> int: @pytest.fixture(scope="module") def minimal_input_data(): n_individuals = 5 - out = { + return { "p_id": pd.Series(numpy.arange(n_individuals), name="p_id"), "fam_id": pd.Series(numpy.arange(n_individuals), name="fam_id"), } - return out @pytest.fixture(scope="module") @@ -181,7 +180,9 @@ def some_param_func_returning_list_of_length_2() -> list[int]: def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): with pytest.raises(TypeError, match=re.escape(err_substr)): assert_valid_ttsim_pytree( - tree=tree, leaf_checker=leaf_checker, tree_name="tree" + tree=tree, + leaf_checker=leaf_checker, + tree_name="tree", ) @@ -206,7 +207,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): ("c", "g"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 1}, - } + }, }, ), # Same submodule, overlapping periods, different leaf names so no name clashes. @@ -227,7 +228,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): ("x", "c", "h"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 2}, - } + }, }, ), # Different submodules, no overlapping periods, no name clashes. @@ -246,7 +247,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): ("x", "c", "g"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 3}, - } + }, }, ), # Different paths, overlapping periods, same names but no clashes. @@ -267,7 +268,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): ("z", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 4}, - } + }, }, ), # Different yaml files, no name clashes because of different names. @@ -318,7 +319,7 @@ def test_assert_valid_ttsim_pytree(tree, leaf_checker, err_substr): "value": 13, "note": "Complex didn't last long.", }, - } + }, }, ), # Different periods specified in different files. @@ -465,7 +466,7 @@ def test_fail_if_active_periods_overlap_passes( ("c", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 1}, - } + }, }, ), # Same paths, no overlap in functions, name clashes leaf name / yaml. @@ -486,7 +487,7 @@ def test_fail_if_active_periods_overlap_passes( ("x", "a", "f"): { # type: ignore[misc] **_GENERIC_PARAM_HEADER, datetime.date(2023, 1, 1): {"value": 2}, - } + }, }, ), # Same paths, name clashes within params from different yaml files. @@ -621,7 +622,8 @@ def test_fail_if_foreign_keys_are_invalid_in_data_when_foreign_key_points_to_sam def test_fail_if_group_ids_are_outside_top_level_namespace(): with pytest.raises( - ValueError, match="Group identifiers must live in the top-level namespace. Got:" + ValueError, + match="Group identifiers must live in the top-level namespace. Got:", ): group_ids_are_outside_top_level_namespace({"n1": {"fam_id": fam_id}}) @@ -633,7 +635,8 @@ def test_fail_if_group_variables_are_not_constant_within_groups(): "kin_id": numpy.array([1, 1, 2]), } with pytest.raises( - ValueError, match="The following data inputs do not have a unique value within" + ValueError, + match="The following data inputs do not have a unique value within", ): group_variables_are_not_constant_within_groups( names__grouping_levels=("kin",), @@ -646,7 +649,8 @@ def test_fail_if_input_data_tree_is_invalid(xnp): data = {"fam_id": pd.Series(data=numpy.arange(8), name="fam_id")} with pytest.raises( - ValueError, match="The input data must contain the `p_id` column." + ValueError, + match="The input data must contain the `p_id` column.", ): input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) @@ -677,7 +681,8 @@ def test_fail_if_input_data_tree_is_invalid_via_main(): ) def test_fail_if_input_df_has_bool_or_numeric_column_names(df): with pytest.raises( - ValueError, match="DataFrame column names cannot be booleans or numbers." + ValueError, + match="DataFrame column names cannot be booleans or numbers.", ): input_df_has_bool_or_numeric_column_names(df) @@ -719,7 +724,8 @@ def test_fail_if_input_df_has_bool_or_numeric_column_names(df): ], ) def test_fail_if_input_df_mapper_has_incorrect_format( - input_data__df_and_mapper__mapper, expected_error_message + input_data__df_and_mapper__mapper, + expected_error_message, ): with pytest.raises(TypeError, match=expected_error_message): input_df_mapper_has_incorrect_format(input_data__df_and_mapper__mapper) @@ -833,7 +839,8 @@ def test_fail_if_p_id_does_not_exist(xnp): data = {"fam_id": pd.Series(data=numpy.arange(8), name="fam_id")} with pytest.raises( - ValueError, match="The input data must contain the `p_id` column." + ValueError, + match="The input data must contain the `p_id` column.", ): input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) @@ -860,7 +867,8 @@ def test_fail_if_p_id_is_not_unique(xnp): data = {"p_id": pd.Series(data=numpy.arange(4).repeat(2), name="p_id")} with pytest.raises( - ValueError, match="The following `p_id`s are not unique in the input data" + ValueError, + match="The following `p_id`s are not unique in the input data", ): input_data_tree_is_invalid(input_data__tree=data, xnp=xnp) @@ -923,10 +931,14 @@ def c(b): ], ) def test_fail_if_targets_are_not_in_policy_environment_or_data( - policy_environment, targets, names__processed_data_columns, expected_error_match + policy_environment, + targets, + names__processed_data_columns, + expected_error_match, ): with pytest.raises( - ValueError, match="The following targets have no corresponding function" + ValueError, + match="The following targets have no corresponding function", ) as e: targets_are_not_in_policy_environment_or_data( policy_environment=policy_environment, @@ -986,7 +998,7 @@ def test_fail_if_targets_are_not_in_policy_environment_or_data_via_main( start_date=datetime.date(1984, 1, 1), end_date=datetime.date(2099, 12, 31), **_GENERIC_PARAM_HEADER, - ) + ), ], ), ( @@ -1007,7 +1019,7 @@ def test_fail_if_targets_are_not_in_policy_environment_or_data_via_main( start_date=datetime.date(1984, 1, 1), end_date=datetime.date(1984, 12, 31), **_GENERIC_PARAM_HEADER, - ) + ), ], ), ( diff --git a/tests/ttsim/test_mettsim.py b/tests/ttsim/test_mettsim.py index 36a1f38d3..f604dde4c 100644 --- a/tests/ttsim/test_mettsim.py +++ b/tests/ttsim/test_mettsim.py @@ -17,7 +17,9 @@ TEST_DIR = Path(__file__).parent POLICY_TEST_IDS_AND_CASES = load_policy_test_data( - test_dir=TEST_DIR, policy_name="", xnp=numpy + test_dir=TEST_DIR, + policy_name="", + xnp=numpy, ) diff --git a/tests/ttsim/test_policy_environment.py b/tests/ttsim/test_policy_environment.py index 43a3ed245..23c186fe6 100644 --- a/tests/ttsim/test_policy_environment.py +++ b/tests/ttsim/test_policy_environment.py @@ -121,7 +121,8 @@ def test_func(): ) def test_start_date_invalid(date_string: str): with pytest.raises( - ValueError, match="neither matches the format YYYY-MM-DD nor is a datetime.date" + ValueError, + match="neither matches the format YYYY-MM-DD nor is a datetime.date", ): @policy_function(start_date=date_string) @@ -161,7 +162,8 @@ def test_func(): ) def test_end_date_invalid(date_string: str): with pytest.raises( - ValueError, match="neither matches the format YYYY-MM-DD nor is a datetime.date" + ValueError, + match="neither matches the format YYYY-MM-DD nor is a datetime.date", ): @policy_function(end_date=date_string) diff --git a/tests/ttsim/test_shared.py b/tests/ttsim/test_shared.py index 9d223fc7a..4565b9167 100644 --- a/tests/ttsim/test_shared.py +++ b/tests/ttsim/test_shared.py @@ -43,7 +43,9 @@ def test_fail_if_invalid_date(): ) def test_upsert_path_and_value(base, path_to_upsert, value_to_upsert, expected): result = upsert_path_and_value( - base=base, path_to_upsert=path_to_upsert, value_to_upsert=value_to_upsert + base=base, + path_to_upsert=path_to_upsert, + value_to_upsert=value_to_upsert, ) assert result == expected @@ -57,7 +59,9 @@ def test_upsert_path_and_value(base, path_to_upsert, value_to_upsert, expected): ) def test_insert_path_and_value(base, path_to_insert, value_to_insert, expected): result = insert_path_and_value( - base=base, path_to_insert=path_to_insert, value_to_insert=value_to_insert + base=base, + path_to_insert=path_to_insert, + value_to_insert=value_to_insert, ) assert result == expected @@ -71,7 +75,9 @@ def test_insert_path_and_value(base, path_to_insert, value_to_insert, expected): def test_insert_path_and_value_invalid(base, path_to_insert, value_to_insert): with pytest.raises(ValueError, match="Conflicting paths in trees to merge."): insert_path_and_value( - base=base, path_to_insert=path_to_insert, value_to_insert=value_to_insert + base=base, + path_to_insert=path_to_insert, + value_to_insert=value_to_insert, ) @@ -195,7 +201,8 @@ def test_upsert_tree(base_dict, update_dict, expected): ) def test_partition_tree_by_reference_tree(tree_to_partition, reference_tree, expected): in_reference_tree, not_in_reference_tree = partition_tree_by_reference_tree( - tree_to_partition=tree_to_partition, reference_tree=reference_tree + tree_to_partition=tree_to_partition, + reference_tree=reference_tree, ) assert in_reference_tree == expected[0] @@ -272,7 +279,10 @@ def test_get_re_pattern_for_time_units_and_groupings( ], ) def test_get_re_pattern_for_some_base_name( - base_name, time_units, grouping_levels, expected_match + base_name, + time_units, + grouping_levels, + expected_match, ): re_pattern = get_re_pattern_for_specific_time_units_and_groupings( base_name=base_name, diff --git a/tests/ttsim/test_specialized_environment.py b/tests/ttsim/test_specialized_environment.py index 6dbda6fff..85350783f 100644 --- a/tests/ttsim/test_specialized_environment.py +++ b/tests/ttsim/test_specialized_environment.py @@ -103,7 +103,9 @@ def some_converting_params_func( @param_function() def some_param_function_taking_scalar( - some_int_scalar: int, some_float_scalar: float, some_bool_scalar: bool + some_int_scalar: int, + some_float_scalar: float, + some_bool_scalar: bool, ) -> float: return some_int_scalar + some_float_scalar + int(some_bool_scalar) @@ -179,22 +181,20 @@ def some_policy_function_taking_int_param(some_int_param: int) -> float: @pytest.fixture(scope="module") def minimal_input_data(): n_individuals = 5 - out = { + return { "p_id": numpy.arange(n_individuals), "fam_id": numpy.arange(n_individuals), } - return out @pytest.fixture(scope="module") def minimal_input_data_shared_fam(): n_individuals = 3 - out = { + return { "p_id": numpy.arange(n_individuals), "fam_id": numpy.array([0, 0, 1]), "p_id_someone_else": numpy.array([1, 0, -1]), } - return out @agg_by_group_function(agg_type=AggType.SUM) @@ -297,7 +297,8 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "p_id": p_id, "n1": { "f": policy_function( - leaf_name="f", vectorization_strategy="vectorize" + leaf_name="f", + vectorization_strategy="vectorize", )(return_n1__x_kin), "x": x, }, @@ -334,7 +335,8 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "p_id": p_id, "n1": { "f": policy_function( - leaf_name="f", vectorization_strategy="vectorize" + leaf_name="f", + vectorization_strategy="vectorize", )(some_x), "x": x, }, @@ -354,7 +356,8 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "p_id": p_id, "n1": { "f": policy_function( - leaf_name="f", vectorization_strategy="vectorize" + leaf_name="f", + vectorization_strategy="vectorize", )(some_x), "x": x, }, @@ -375,7 +378,8 @@ def return_n1__x_kin(n1__x_kin: int) -> int: "p_id": p_id, "n1": { "f": policy_function( - leaf_name="f", vectorization_strategy="vectorize" + leaf_name="f", + vectorization_strategy="vectorize", )(return_y_kin), "y_kin": y_kin_namespaced_input, }, @@ -468,7 +472,9 @@ def test_params_target_is_allowed(minimal_input_data): def test_function_without_data_dependency_is_not_mistaken_for_data( - minimal_input_data, backend, xnp + minimal_input_data, + backend, + xnp, ): @policy_function(leaf_name="a", vectorization_strategy="not_required") def a() -> IntColumn: @@ -493,7 +499,8 @@ def b(a: int) -> int: targets=["results__tree"], )["results__tree"] numpy.testing.assert_array_almost_equal( - results__tree["b"], xnp.array(minimal_input_data["p_id"]) + results__tree["b"], + xnp.array(minimal_input_data["p_id"]), ) @@ -624,7 +631,8 @@ def test_user_provided_aggregation_with_time_conversion(backend): # Double up, convert to quarter, then take max fam_id expected = pd.Series( - [400 * 12, 400 * 12, 200 * 12], index=pd.Index(data["p_id"], name="p_id") + [400 * 12, 400 * 12, 200 * 12], + index=pd.Index(data["p_id"], name="p_id"), ) @policy_function(vectorization_strategy="vectorize") @@ -693,8 +701,8 @@ def sum_source_m_by_p_id_someone_else( ( { "module": { - "sum_source_by_p_id_someone_else": sum_source_by_p_id_someone_else - } + "sum_source_by_p_id_someone_else": sum_source_by_p_id_someone_else, + }, }, "source", {"module": {"sum_source_by_p_id_someone_else": None}}, @@ -703,8 +711,8 @@ def sum_source_m_by_p_id_someone_else( ( { "module": { - "sum_source_m_by_p_id_someone_else": sum_source_m_by_p_id_someone_else # noqa: E501 - } + "sum_source_m_by_p_id_someone_else": sum_source_m_by_p_id_someone_else, # noqa: E501 + }, }, "source_m", {"module": {"sum_source_m_by_p_id_someone_else": None}}, @@ -833,8 +841,8 @@ def test_policy_environment_with_params_and_scalars_is_processed(): {"some_policy_func_taking_scalar_params_func": None}, { "some_policy_func_taking_scalar_params_func": numpy.array( - [1, 2, 3, 4, 5] - ) + [1, 2, 3, 4, 5], + ), }, ), ], diff --git a/tests/ttsim/test_warnings.py b/tests/ttsim/test_warnings.py index 0fc194d74..62b5f78f2 100644 --- a/tests/ttsim/test_warnings.py +++ b/tests/ttsim/test_warnings.py @@ -47,7 +47,8 @@ def test_warn_if_functions_and_data_columns_overlap(backend): def test_warn_if_functions_and_columns_overlap_no_warning_if_no_overlap(backend): with warnings.catch_warnings(): warnings.filterwarnings( - "error", category=warn_if.FunctionsAndDataColumnsOverlapWarning + "error", + category=warn_if.FunctionsAndDataColumnsOverlapWarning, ) main( inputs={ diff --git a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py index 7b4272e1e..870185dcb 100644 --- a/tests/ttsim/tt_dag_elements/test_aggregation_functions.py +++ b/tests/ttsim/tt_dag_elements/test_aggregation_functions.py @@ -10,7 +10,7 @@ my_datetime = jax_datetime.to_datetime except ImportError: - my_datetime = lambda x: x + my_datetime = lambda x: x # noqa: E731 from ttsim.tt_dag_elements.aggregation import ( @@ -184,7 +184,7 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): numpy.datetime64("2002"), numpy.datetime64("2003"), numpy.datetime64("2004"), - ] + ], ), "group_id": numpy.array([1, 0, 1, 1, 1]), "expected_res_max": numpy.array( @@ -194,7 +194,7 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): numpy.datetime64("2004"), numpy.datetime64("2004"), numpy.datetime64("2004"), - ] + ], ), "expected_res_min": numpy.array( [ @@ -203,7 +203,7 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): numpy.datetime64("2000"), numpy.datetime64("2000"), numpy.datetime64("2000"), - ] + ], ), } @@ -226,7 +226,7 @@ def parameterize_based_on_dict(test_cases, keys_of_test_cases=None): numpy.datetime64("2002"), numpy.datetime64("2003"), numpy.datetime64("2004"), - ] + ], ), "group_id": numpy.array([0, 0, 1, 1, 1]), "error_sum": TypeError, @@ -416,7 +416,11 @@ def test_grouped_all(column_to_aggregate, group_id, expected_res_all, backend): ) @pytest.mark.skipif_jax def test_grouped_sum_raises( - column_to_aggregate, group_id, error_sum, exception_match, backend + column_to_aggregate, + group_id, + error_sum, + exception_match, + backend, ): with pytest.raises( error_sum, @@ -441,7 +445,11 @@ def test_grouped_sum_raises( ) @pytest.mark.skipif_jax def test_grouped_mean_raises( - column_to_aggregate, group_id, error_mean, exception_match, backend + column_to_aggregate, + group_id, + error_mean, + exception_match, + backend, ): with pytest.raises( error_mean, @@ -466,7 +474,11 @@ def test_grouped_mean_raises( ) @pytest.mark.skipif_jax def test_grouped_max_raises( - column_to_aggregate, group_id, error_max, exception_match, backend + column_to_aggregate, + group_id, + error_max, + exception_match, + backend, ): with pytest.raises( error_max, @@ -491,7 +503,11 @@ def test_grouped_max_raises( ) @pytest.mark.skipif_jax def test_grouped_min_raises( - column_to_aggregate, group_id, error_min, exception_match, backend + column_to_aggregate, + group_id, + error_min, + exception_match, + backend, ): with pytest.raises( error_min, @@ -516,7 +532,11 @@ def test_grouped_min_raises( ) @pytest.mark.skipif_jax def test_grouped_any_raises( - column_to_aggregate, group_id, error_any, exception_match, backend + column_to_aggregate, + group_id, + error_any, + exception_match, + backend, ): with pytest.raises( error_any, @@ -541,7 +561,11 @@ def test_grouped_any_raises( ) @pytest.mark.skipif_jax def test_grouped_all_raises( - column_to_aggregate, group_id, error_all, exception_match, backend + column_to_aggregate, + group_id, + error_all, + exception_match, + backend, ): with pytest.raises( error_all, diff --git a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py index 8fa6c976b..acf3b8781 100644 --- a/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py +++ b/tests/ttsim/tt_dag_elements/test_piecewise_polynomial.py @@ -22,7 +22,7 @@ @pytest.fixture def parameters(xnp): - params = PiecewisePolynomialParamValue( + return PiecewisePolynomialParamValue( thresholds=xnp.array([-xnp.inf, 9168.0, 14254.0, 55960.0, 265326.0, xnp.inf]), rates=xnp.array( [ @@ -40,11 +40,10 @@ def parameters(xnp): 0.00000000e00, 0.00000000e00, ], - ] + ], ), intercepts=xnp.array([0.0, 0.0, 965.5771, 14722.3012, 102656.0212]), ) - return params def test_get_piecewise_parameters_all_intercepts_supplied(xnp): @@ -87,7 +86,8 @@ def test_get_piecewise_parameters_all_intercepts_supplied(xnp): def test_piecewise_polynomial( - parameters: PiecewisePolynomialParamValue, xnp: ModuleType + parameters: PiecewisePolynomialParamValue, + xnp: ModuleType, ): x = xnp.array([-1_000, 1_000, 10_000, 30_000, 100_000, 1_000_000]) expected = xnp.array([0.0, 0.0, 246.53, 10551.65, 66438.2, 866518.64]) diff --git a/tests/ttsim/tt_dag_elements/test_rounding.py b/tests/ttsim/tt_dag_elements/test_rounding.py index 59cc5e2b2..da3064b15 100644 --- a/tests/ttsim/tt_dag_elements/test_rounding.py +++ b/tests/ttsim/tt_dag_elements/test_rounding.py @@ -6,7 +6,6 @@ from pandas._testing import assert_series_equal from ttsim import main -from ttsim.interface_dag_elements.policy_environment import policy_environment from ttsim.tt_dag_elements import ( RoundingSpec, policy_function, @@ -90,12 +89,6 @@ def test_malformed_rounding_specs(): def test_func(): return 0 - policy_environment( - active_tree_with_column_objects_and_param_functions={ - "x.py": {"test_func": test_func} - }, - ) - @pytest.mark.parametrize( "rounding_spec, input_values, exp_output", diff --git a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py index a4ba38a0e..f1b2f2406 100644 --- a/tests/ttsim/tt_dag_elements/test_ttsim_objects.py +++ b/tests/ttsim/tt_dag_elements/test_ttsim_objects.py @@ -206,7 +206,8 @@ def test_wrong_number_of_group_ids_present(): @agg_by_group_function(agg_type=AggType.COUNT) def aggregate_by_group_count_multiple_group_ids_present( - group_id, another_group_id + group_id, + another_group_id, ): pass @@ -259,7 +260,7 @@ def aggregate_by_p_id_count_other_arg_present(p_id, p_id_specifier, wrong_arg): pass -def test_agg_by_p_id_sum_wrong_amount_of_args(): +def test_agg_by_p_id_sum_no_arg_present(): match = "There must be exactly one argument besides identifiers" with pytest.raises(ValueError, match=match): @@ -267,9 +268,17 @@ def test_agg_by_p_id_sum_wrong_amount_of_args(): def aggregate_by_p_id_sum_no_arg_present(p_id, p_id_specifier): pass + +def test_agg_by_p_id_sum_multiple_args_present(): + match = "There must be exactly one argument besides identifiers" + with pytest.raises(ValueError, match=match): + @agg_by_p_id_function(agg_type=AggType.SUM) def aggregate_by_p_id_sum_multiple_args_present( - p_id, p_id_specifier, arg, another_arg + p_id, + p_id_specifier, + arg, + another_arg, ): pass @@ -280,7 +289,9 @@ def test_agg_by_p_id_multiple_other_p_ids_present(): @agg_by_p_id_function(agg_type=AggType.SUM) def aggregate_by_p_id_multiple_other_p_ids_present( - p_id, p_id_specifier_one, p_id_specifier_two + p_id, + p_id_specifier_one, + p_id_specifier_two, ): pass diff --git a/tests/ttsim/tt_dag_elements/test_vectorization.py b/tests/ttsim/tt_dag_elements/test_vectorization.py index cd7d895f5..a16e26732 100644 --- a/tests/ttsim/tt_dag_elements/test_vectorization.py +++ b/tests/ttsim/tt_dag_elements/test_vectorization.py @@ -382,7 +382,7 @@ def test_disallowed_operation_wrapper(func): _active_column_objects_and_param_functions( orig=column_objects_and_param_functions(root=METTSIM_ROOT), date=datetime.date(year=year, month=1, day=1), - ) + ), ).items() if not isinstance( pf, @@ -440,7 +440,8 @@ def test_geschwisterbonus_m(backend, xnp): shape = (10, 2) basisbetrag_m = xnp.full(shape, basisbetrag_m) geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg = xnp.full( - shape, geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg + shape, + geschwisterbonus_grundsätzlich_anspruchsberechtigt_fg, ) with pytest.raises(ValueError, match="truth value of an array with more than"): @@ -454,7 +455,9 @@ def test_geschwisterbonus_m(backend, xnp): # Call converted function on array input and test result # ============================================================================== converted = _make_vectorizable( - mock__elterngeld__geschwisterbonus_m, backend=backend, xnp=xnp + mock__elterngeld__geschwisterbonus_m, + backend=backend, + xnp=xnp, ) got = converted( basisbetrag_m=basisbetrag_m, @@ -655,19 +658,26 @@ def already_vectorized_func(x: IntColumn, xnp: ModuleType) -> IntColumn: def test_loop_vectorize_scalar_func(backend, xnp): fun = vectorize_function( - scalar_func, vectorization_strategy="loop", backend=backend, xnp=numpy + scalar_func, + vectorization_strategy="loop", + backend=backend, + xnp=numpy, ) assert numpy.array_equal(fun(xnp.array([-1, 0, 2, 3])), xnp.array([0, 0, 4, 6])) def test_vectorize_scalar_func(backend, xnp): fun = vectorize_function( - scalar_func, vectorization_strategy="vectorize", backend=backend, xnp=numpy + scalar_func, + vectorization_strategy="vectorize", + backend=backend, + xnp=numpy, ) assert numpy.array_equal(fun(xnp.array([-1, 0, 2, 3])), xnp.array([0, 0, 4, 6])) def test_already_vectorized_func(xnp): assert numpy.array_equal( - already_vectorized_func(xnp.array([-1, 0, 2, 3]), xnp), xnp.array([0, 0, 4, 6]) + already_vectorized_func(xnp.array([-1, 0, 2, 3]), xnp), + xnp.array([0, 0, 4, 6]), ) From 8f9ce346c27da8a3129f679963bd68feedd574ec Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 20:22:55 +0200 Subject: [PATCH 22/25] Apply naming suggestions made by @MImmesberger in #951. --- src/ttsim/interface_dag_elements/names.py | 16 ++++++++-------- src/ttsim/interface_dag_elements/raw_results.py | 8 ++++---- .../specialized_environment.py | 8 ++++---- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ttsim/interface_dag_elements/names.py b/src/ttsim/interface_dag_elements/names.py index 1f2a9aea1..f3b4e49f6 100644 --- a/src/ttsim/interface_dag_elements/names.py +++ b/src/ttsim/interface_dag_elements/names.py @@ -35,7 +35,7 @@ def fail_if_multiple_time_units_for_same_base_name_and_group( @interface_function() -def target_columns( +def column_targets( specialized_environment__with_partialled_params_and_scalars: UnorderedQNames, targets__qname: OrderedQNames, ) -> OrderedQNames: @@ -48,12 +48,12 @@ def target_columns( @interface_function() -def target_params( +def param_targets( specialized_environment__with_derived_functions_and_processed_input_nodes: QNamePolicyEnvironment, # noqa: E501 targets__qname: OrderedQNames, - target_columns: OrderedQNames, + column_targets: OrderedQNames, ) -> OrderedQNames: - possible_targets = set(targets__qname) - set(target_columns) + possible_targets = set(targets__qname) - set(column_targets) return [ t for t in targets__qname @@ -64,12 +64,12 @@ def target_params( @interface_function() -def targets_from_input_data( +def input_data_targets( targets__qname: OrderedQNames, - target_columns: OrderedQNames, - target_params: OrderedQNames, + column_targets: OrderedQNames, + param_targets: OrderedQNames, ) -> OrderedQNames: - possible_targets = set(targets__qname) - set(target_columns) - set(target_params) + possible_targets = set(targets__qname) - set(column_targets) - set(param_targets) return [t for t in targets__qname if t in possible_targets] diff --git a/src/ttsim/interface_dag_elements/raw_results.py b/src/ttsim/interface_dag_elements/raw_results.py index 03014ce68..908251b92 100644 --- a/src/ttsim/interface_dag_elements/raw_results.py +++ b/src/ttsim/interface_dag_elements/raw_results.py @@ -28,21 +28,21 @@ def columns( @interface_function() def params( - names__target_params: OrderedQNames, + names__param_targets: OrderedQNames, specialized_environment__with_processed_params_and_scalars: QNameCombinedEnvironment1, # noqa: E501 ) -> QNameData: return { pt: specialized_environment__with_processed_params_and_scalars[pt] - for pt in names__target_params + for pt in names__param_targets } @interface_function() def from_input_data( - names__targets_from_input_data: OrderedQNames, + names__input_data_targets: OrderedQNames, processed_data: QNameData, ) -> QNameData: - return {ot: processed_data[ot] for ot in names__targets_from_input_data} + return {ot: processed_data[ot] for ot in names__input_data_targets} @interface_function() diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index dbdde7ce4..e5229da3b 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -298,12 +298,12 @@ def _apply_rounding(element: ColumnFunction, xnp: ModuleType) -> ColumnFunction: @interface_function() def tax_transfer_dag( with_partialled_params_and_scalars: QNameCombinedEnvironment2, - names__target_columns: OrderedQNames, + names__column_targets: OrderedQNames, ) -> nx.DiGraph: """Thin wrapper around `create_dag`.""" return create_dag( functions=with_partialled_params_and_scalars, - targets=names__target_columns, + targets=names__column_targets, ) @@ -311,14 +311,14 @@ def tax_transfer_dag( def tax_transfer_function( tax_transfer_dag: nx.DiGraph, with_partialled_params_and_scalars: QNameCombinedEnvironment2, - names__target_columns: OrderedQNames, + names__column_targets: OrderedQNames, backend: Literal["numpy", "jax"], ) -> Callable[[QNameData], QNameData]: """Returns a function that takes a dictionary of arrays and unpacks them as keyword arguments.""" ttf_with_keyword_args = concatenate_functions( dag=tax_transfer_dag, functions=with_partialled_params_and_scalars, - targets=list(names__target_columns), + targets=list(names__column_targets), return_type="dict", aggregator=None, enforce_signature=True, From c721a375f47b61dcb6afb5775162e0810346644f Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Fri, 13 Jun 2025 20:32:07 +0200 Subject: [PATCH 23/25] Rename 'names' -> 'labels'. Sort-of another naming suggestions made by @MImmesberger in #951. --- .../automatically_added_functions.py | 4 +- src/ttsim/interface_dag_elements/fail_if.py | 22 +-- src/ttsim/interface_dag_elements/names.py | 175 ------------------ .../interface_dag_elements/raw_results.py | 12 +- .../specialized_environment.py | 32 ++-- src/ttsim/interface_dag_elements/warn_if.py | 4 +- .../test_automatically_added_functions.py | 6 +- tests/ttsim/test_failures.py | 18 +- tests/ttsim/test_names.py | 2 +- tests/ttsim/test_policy_environment.py | 8 +- 10 files changed, 54 insertions(+), 229 deletions(-) delete mode 100644 src/ttsim/interface_dag_elements/names.py diff --git a/src/ttsim/interface_dag_elements/automatically_added_functions.py b/src/ttsim/interface_dag_elements/automatically_added_functions.py index f2fd8cae4..4d1b3971a 100644 --- a/src/ttsim/interface_dag_elements/automatically_added_functions.py +++ b/src/ttsim/interface_dag_elements/automatically_added_functions.py @@ -571,7 +571,7 @@ def func(x: float) -> float: def create_agg_by_group_functions( column_functions: UnorderedQNames, - names__processed_data_columns: QNameDataColumns, + labels__processed_data_columns: QNameDataColumns, targets: OrderedQNames, grouping_levels: OrderedQNames, # backend: Literal["numpy", "jax"], @@ -579,7 +579,7 @@ def create_agg_by_group_functions( gp = group_pattern(grouping_levels) all_functions_and_data = { **column_functions, - **dict.fromkeys(names__processed_data_columns), + **dict.fromkeys(labels__processed_data_columns), } potential_agg_by_group_function_names = { # Targets that end with a grouping suffix are potential aggregation targets. diff --git a/src/ttsim/interface_dag_elements/fail_if.py b/src/ttsim/interface_dag_elements/fail_if.py index 876941f02..c2b67f252 100644 --- a/src/ttsim/interface_dag_elements/fail_if.py +++ b/src/ttsim/interface_dag_elements/fail_if.py @@ -216,14 +216,14 @@ def any_paths_are_invalid( policy_environment: NestedPolicyEnvironment, input_data__tree: NestedData, targets__tree: NestedTargetDict, - names__top_level_namespace: UnorderedQNames, + labels__top_level_namespace: UnorderedQNames, ) -> None: """Thin wrapper around `dt.fail_if_paths_are_invalid`.""" return dt.fail_if_paths_are_invalid( functions=policy_environment, input_data__tree=input_data__tree, targets=targets__tree, - names__top_level_namespace=names__top_level_namespace, + labels__top_level_namespace=labels__top_level_namespace, ) @@ -312,7 +312,7 @@ def environment_is_invalid( @interface_function() def foreign_keys_are_invalid_in_data( - names__root_nodes: UnorderedQNames, + labels__root_nodes: UnorderedQNames, processed_data: QNameData, specialized_environment__with_derived_functions_and_processed_input_nodes: QNamePolicyEnvironment, ) -> None: @@ -335,7 +335,7 @@ def foreign_keys_are_invalid_in_data( for fk_name, fk in relevant_objects.items(): if fk.foreign_key_type == FKType.IRRELEVANT: continue - if fk_name in names__root_nodes: + if fk_name in labels__root_nodes: path = dt.tree_path_from_qual_name(fk_name) # Referenced `p_id` must exist in the input data if not all(i in valid_ids for i in processed_data[fk_name].tolist()): @@ -387,8 +387,8 @@ def group_ids_are_outside_top_level_namespace( @interface_function() def group_variables_are_not_constant_within_groups( - names__grouping_levels: OrderedQNames, - names__root_nodes: UnorderedQNames, + labels__grouping_levels: OrderedQNames, + labels__root_nodes: UnorderedQNames, processed_data: QNameData, ) -> None: """ @@ -403,10 +403,10 @@ def group_variables_are_not_constant_within_groups( """ faulty_data_columns = [] - for name in names__root_nodes: + for name in labels__root_nodes: group_by_id = get_name_of_group_by_id( target_name=name, - grouping_levels=names__grouping_levels, + grouping_levels=labels__grouping_levels, ) if group_by_id in processed_data: group_by_id_series = pd.Series(processed_data[group_by_id]) @@ -613,7 +613,7 @@ def root_nodes_are_missing( @interface_function() def targets_are_not_in_policy_environment_or_data( policy_environment: QNamePolicyEnvironment, - names__processed_data_columns: QNameDataColumns, + labels__processed_data_columns: QNameDataColumns, targets__qname: OrderedQNames, ) -> None: """Fail if some target is not among functions. @@ -622,7 +622,7 @@ def targets_are_not_in_policy_environment_or_data( ---------- functions Dictionary containing functions to build the DAG. - names__processed_data_columns + labels__processed_data_columns The columns which are available in the data tree. targets The targets which should be computed. They limit the DAG in the way that only @@ -637,7 +637,7 @@ def targets_are_not_in_policy_environment_or_data( targets_not_in_policy_environment_or_data = [ str(dt.tree_path_from_qual_name(n)) for n in targets__qname - if n not in policy_environment and n not in names__processed_data_columns + if n not in policy_environment and n not in labels__processed_data_columns ] if targets_not_in_policy_environment_or_data: formatted = format_list_linewise(targets_not_in_policy_environment_or_data) diff --git a/src/ttsim/interface_dag_elements/names.py b/src/ttsim/interface_dag_elements/names.py deleted file mode 100644 index f3b4e49f6..000000000 --- a/src/ttsim/interface_dag_elements/names.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import networkx as nx - -from ttsim.interface_dag_elements.automatically_added_functions import ( - TIME_UNIT_LABELS, -) -from ttsim.interface_dag_elements.interface_node_objects import interface_function -from ttsim.interface_dag_elements.shared import ( - get_base_name_and_grouping_suffix, - get_re_pattern_for_all_time_units_and_groupings, - group_pattern, -) - -if TYPE_CHECKING: - from ttsim.interface_dag_elements.typing import ( - NestedPolicyEnvironment, - OrderedQNames, - QNameData, - QNamePolicyEnvironment, - UnorderedQNames, - ) - - -def fail_if_multiple_time_units_for_same_base_name_and_group( - base_names_and_groups_to_variations: dict[tuple[str, str], list[str]], -) -> None: - invalid = { - b: q for b, q in base_names_and_groups_to_variations.items() if len(q) > 1 - } - if invalid: - raise ValueError(f"Multiple time units for base names: {invalid}") - - -@interface_function() -def column_targets( - specialized_environment__with_partialled_params_and_scalars: UnorderedQNames, - targets__qname: OrderedQNames, -) -> OrderedQNames: - """All targets that are column functions.""" - return [ - t - for t in targets__qname - if t in specialized_environment__with_partialled_params_and_scalars - ] - - -@interface_function() -def param_targets( - specialized_environment__with_derived_functions_and_processed_input_nodes: QNamePolicyEnvironment, # noqa: E501 - targets__qname: OrderedQNames, - column_targets: OrderedQNames, -) -> OrderedQNames: - possible_targets = set(targets__qname) - set(column_targets) - return [ - t - for t in targets__qname - if t in possible_targets - and t - in specialized_environment__with_derived_functions_and_processed_input_nodes - ] - - -@interface_function() -def input_data_targets( - targets__qname: OrderedQNames, - column_targets: OrderedQNames, - param_targets: OrderedQNames, -) -> OrderedQNames: - possible_targets = set(targets__qname) - set(column_targets) - set(param_targets) - return [t for t in targets__qname if t in possible_targets] - - -@interface_function() -def grouping_levels( - policy_environment: QNamePolicyEnvironment, -) -> OrderedQNames: - """The grouping levels of the policy environment.""" - return tuple( - name.rsplit("_", 1)[0] - for name in policy_environment - if name.endswith("_id") and name != "p_id" - ) - - -@interface_function() -def top_level_namespace( - policy_environment: NestedPolicyEnvironment, - grouping_levels: OrderedQNames, -) -> UnorderedQNames: - """Get the top level namespace. - - Parameters - ---------- - policy_environment: - The policy environment. - - - Returns - ------- - top_level_namespace: - The top level namespace. - """ - time_units = tuple(TIME_UNIT_LABELS) - direct_top_level_names = set(policy_environment) - - # Do not create variations for lower-level namespaces. - top_level_objects_for_variations = direct_top_level_names - { - k for k, v in policy_environment.items() if isinstance(v, dict) - } - - pattern_all = get_re_pattern_for_all_time_units_and_groupings( - time_units=time_units, - grouping_levels=grouping_levels, - ) - bngs_to_variations = {} - all_top_level_names = direct_top_level_names.copy() - for name in top_level_objects_for_variations: - match = pattern_all.fullmatch(name) - # We must not find multiple time units for the same base name and group. - bngs = get_base_name_and_grouping_suffix(match) - if match.group("time_unit"): - if bngs not in bngs_to_variations: - bngs_to_variations[bngs] = [name] - else: - bngs_to_variations[bngs].append(name) - for time_unit in time_units: - all_top_level_names.add(f"{bngs[0]}_{time_unit}{bngs[1]}") - fail_if_multiple_time_units_for_same_base_name_and_group(bngs_to_variations) - - gp = group_pattern(grouping_levels) - potential_base_names = {n for n in all_top_level_names if not gp.match(n)} - - for name in potential_base_names: - for g in grouping_levels: - all_top_level_names.add(f"{name}_{g}") - - return all_top_level_names - - -@interface_function() -def processed_data_columns(processed_data: QNameData) -> UnorderedQNames: - return set(processed_data.keys()) - - -@interface_function() -def root_nodes( - specialized_environment__tax_transfer_dag: nx.DiGraph, - processed_data: QNameData, -) -> UnorderedQNames: - """Names of the columns in `processed_data` required for the tax transfer function. - - Parameters - ---------- - specialized_environment__tax_transfer_dag: - The tax transfer DAG. - processed_data: - The processed data. - - Returns - ------- - The names of the columns in `processed_data` required for the tax transfer function. - - """ - # Obtain root nodes - root_nodes = nx.subgraph_view( - specialized_environment__tax_transfer_dag, - filter_node=lambda n: specialized_environment__tax_transfer_dag.in_degree(n) - == 0, - ).nodes - - # Restrict the passed data to the subset that is actually used. - return {k for k in processed_data if k in root_nodes} diff --git a/src/ttsim/interface_dag_elements/raw_results.py b/src/ttsim/interface_dag_elements/raw_results.py index 908251b92..f81fcdb30 100644 --- a/src/ttsim/interface_dag_elements/raw_results.py +++ b/src/ttsim/interface_dag_elements/raw_results.py @@ -17,32 +17,32 @@ @interface_function() def columns( - names__root_nodes: UnorderedQNames, + labels__root_nodes: UnorderedQNames, processed_data: QNameData, specialized_environment__tax_transfer_function: Callable[[QNameData], QNameData], ) -> QNameData: return specialized_environment__tax_transfer_function( - {k: v for k, v in processed_data.items() if k in names__root_nodes}, + {k: v for k, v in processed_data.items() if k in labels__root_nodes}, ) @interface_function() def params( - names__param_targets: OrderedQNames, + labels__param_targets: OrderedQNames, specialized_environment__with_processed_params_and_scalars: QNameCombinedEnvironment1, # noqa: E501 ) -> QNameData: return { pt: specialized_environment__with_processed_params_and_scalars[pt] - for pt in names__param_targets + for pt in labels__param_targets } @interface_function() def from_input_data( - names__input_data_targets: OrderedQNames, + labels__input_data_targets: OrderedQNames, processed_data: QNameData, ) -> QNameData: - return {ot: processed_data[ot] for ot in names__input_data_targets} + return {ot: processed_data[ot] for ot in labels__input_data_targets} @interface_function() diff --git a/src/ttsim/interface_dag_elements/specialized_environment.py b/src/ttsim/interface_dag_elements/specialized_environment.py index e5229da3b..c5a816d99 100644 --- a/src/ttsim/interface_dag_elements/specialized_environment.py +++ b/src/ttsim/interface_dag_elements/specialized_environment.py @@ -60,10 +60,10 @@ def rounding() -> bool: def with_derived_functions_and_processed_input_nodes( policy_environment: NestedPolicyEnvironment, processed_data: QNameData, - names__processed_data_columns: QNameDataColumns, + labels__processed_data_columns: QNameDataColumns, targets__tree: NestedStrings, - names__top_level_namespace: UnorderedQNames, - names__grouping_levels: OrderedQNames, + labels__top_level_namespace: UnorderedQNames, + labels__grouping_levels: OrderedQNames, backend: str, xnp: ModuleType, ) -> QNameCombinedEnvironment0: @@ -86,13 +86,13 @@ def with_derived_functions_and_processed_input_nodes( } flat_without_tree_logic = _remove_tree_logic_from_policy_environment( policy_environment=flat_vectorized, - names__top_level_namespace=names__top_level_namespace, + labels__top_level_namespace=labels__top_level_namespace, ) flat_with_derived = _add_derived_functions( qual_name_policy_environment=flat_without_tree_logic, targets=dt.qual_names(targets__tree), - names__processed_data_columns=names__processed_data_columns, - grouping_levels=names__grouping_levels, + labels__processed_data_columns=labels__processed_data_columns, + grouping_levels=labels__grouping_levels, ) out = {} for n, f in flat_with_derived.items(): @@ -112,7 +112,7 @@ def with_derived_functions_and_processed_input_nodes( def _remove_tree_logic_from_policy_environment( policy_environment: QNamePolicyEnvironment, - names__top_level_namespace: UnorderedQNames, + labels__top_level_namespace: UnorderedQNames, ) -> QNamePolicyEnvironment: """Map qualified names to column objects / param functions without tree logic.""" out = {} @@ -120,7 +120,7 @@ def _remove_tree_logic_from_policy_environment( if hasattr(obj, "remove_tree_logic"): out[name] = obj.remove_tree_logic( tree_path=dt.tree_path_from_qual_name(name), - top_level_namespace=names__top_level_namespace, + top_level_namespace=labels__top_level_namespace, ) else: out[name] = obj @@ -130,7 +130,7 @@ def _remove_tree_logic_from_policy_environment( def _add_derived_functions( qual_name_policy_environment: QNamePolicyEnvironment, targets: OrderedQNames, - names__processed_data_columns: QNameDataColumns, + labels__processed_data_columns: QNameDataColumns, grouping_levels: OrderedQNames, ) -> UnorderedQNames: """Return a mapping of qualified names to functions operating on columns. @@ -153,7 +153,7 @@ def _add_derived_functions( The list of targets with qualified names. data Dict with qualified data names as keys and arrays as values. - names__top_level_namespace + labels__top_level_namespace Set of top-level namespaces. Returns @@ -164,7 +164,7 @@ def _add_derived_functions( # Create functions for different time units time_conversion_functions = create_time_conversion_functions( qual_name_policy_environment=qual_name_policy_environment, - processed_data_columns=names__processed_data_columns, + processed_data_columns=labels__processed_data_columns, grouping_levels=grouping_levels, ) column_functions = { @@ -179,7 +179,7 @@ def _add_derived_functions( # Create aggregation functions by group. aggregate_by_group_functions = create_agg_by_group_functions( column_functions=column_functions, - names__processed_data_columns=names__processed_data_columns, + labels__processed_data_columns=labels__processed_data_columns, targets=targets, grouping_levels=grouping_levels, ) @@ -298,12 +298,12 @@ def _apply_rounding(element: ColumnFunction, xnp: ModuleType) -> ColumnFunction: @interface_function() def tax_transfer_dag( with_partialled_params_and_scalars: QNameCombinedEnvironment2, - names__column_targets: OrderedQNames, + labels__column_targets: OrderedQNames, ) -> nx.DiGraph: """Thin wrapper around `create_dag`.""" return create_dag( functions=with_partialled_params_and_scalars, - targets=names__column_targets, + targets=labels__column_targets, ) @@ -311,14 +311,14 @@ def tax_transfer_dag( def tax_transfer_function( tax_transfer_dag: nx.DiGraph, with_partialled_params_and_scalars: QNameCombinedEnvironment2, - names__column_targets: OrderedQNames, + labels__column_targets: OrderedQNames, backend: Literal["numpy", "jax"], ) -> Callable[[QNameData], QNameData]: """Returns a function that takes a dictionary of arrays and unpacks them as keyword arguments.""" ttf_with_keyword_args = concatenate_functions( dag=tax_transfer_dag, functions=with_partialled_params_and_scalars, - targets=list(names__column_targets), + targets=list(labels__column_targets), return_type="dict", aggregator=None, enforce_signature=True, diff --git a/src/ttsim/interface_dag_elements/warn_if.py b/src/ttsim/interface_dag_elements/warn_if.py index 07b1b7191..3b8c79f7f 100644 --- a/src/ttsim/interface_dag_elements/warn_if.py +++ b/src/ttsim/interface_dag_elements/warn_if.py @@ -67,13 +67,13 @@ def __init__(self, columns_overriding_functions: OrderedQNames) -> None: @interface_function() def functions_and_data_columns_overlap( policy_environment: NestedPolicyEnvironment, - names__processed_data_columns: QNameDataColumns, + labels__processed_data_columns: QNameDataColumns, ) -> None: """Warn if functions are overridden by data.""" overridden_elements = sorted( { col - for col in names__processed_data_columns + for col in labels__processed_data_columns if col in dt.flatten_to_qual_names(policy_environment) }, ) diff --git a/tests/ttsim/test_automatically_added_functions.py b/tests/ttsim/test_automatically_added_functions.py index dedb29d36..17e64a6fc 100644 --- a/tests/ttsim/test_automatically_added_functions.py +++ b/tests/ttsim/test_automatically_added_functions.py @@ -364,7 +364,7 @@ def x(test_m: int) -> int: ( "column_functions", "targets", - "names__processed_data_columns", + "labels__processed_data_columns", "expected", ), [ @@ -391,7 +391,7 @@ def x(test_m: int) -> int: def test_derived_aggregation_functions_are_in_correct_namespace( column_functions, targets, - names__processed_data_columns, + labels__processed_data_columns, expected, ): """Test that the derived aggregation functions are in the correct namespace. @@ -401,7 +401,7 @@ def test_derived_aggregation_functions_are_in_correct_namespace( """ result = create_agg_by_group_functions( column_functions=column_functions, - names__processed_data_columns=names__processed_data_columns, + labels__processed_data_columns=labels__processed_data_columns, targets=targets, grouping_levels=("kin",), ) diff --git a/tests/ttsim/test_failures.py b/tests/ttsim/test_failures.py index 66d3ce519..ca3a2502d 100644 --- a/tests/ttsim/test_failures.py +++ b/tests/ttsim/test_failures.py @@ -565,7 +565,7 @@ def test_fail_if_foreign_keys_are_invalid_in_data_allow_minus_one_as_foreign_key } foreign_keys_are_invalid_in_data( - names__root_nodes={n for n in data if n != "p_id"}, + labels__root_nodes={n for n in data if n != "p_id"}, processed_data=data, specialized_environment__with_derived_functions_and_processed_input_nodes=flat_objects_tree, ) @@ -582,7 +582,7 @@ def test_fail_if_foreign_keys_are_invalid_in_data_when_foreign_key_points_to_non with pytest.raises(ValueError, match=r"not a valid p_id in the\sinput data"): foreign_keys_are_invalid_in_data( - names__root_nodes={n for n in data if n != "p_id"}, + labels__root_nodes={n for n in data if n != "p_id"}, processed_data=data, specialized_environment__with_derived_functions_and_processed_input_nodes=flat_objects_tree, ) @@ -598,7 +598,7 @@ def test_fail_if_foreign_keys_are_invalid_in_data_when_foreign_key_points_to_sam } foreign_keys_are_invalid_in_data( - names__root_nodes={n for n in data if n != "p_id"}, + labels__root_nodes={n for n in data if n != "p_id"}, processed_data=data, specialized_environment__with_derived_functions_and_processed_input_nodes=flat_objects_tree, ) @@ -614,7 +614,7 @@ def test_fail_if_foreign_keys_are_invalid_in_data_when_foreign_key_points_to_sam } foreign_keys_are_invalid_in_data( - names__root_nodes={n for n in data if n != "p_id"}, + labels__root_nodes={n for n in data if n != "p_id"}, processed_data=data, specialized_environment__with_derived_functions_and_processed_input_nodes=flat_objects_tree, ) @@ -639,8 +639,8 @@ def test_fail_if_group_variables_are_not_constant_within_groups(): match="The following data inputs do not have a unique value within", ): group_variables_are_not_constant_within_groups( - names__grouping_levels=("kin",), - names__root_nodes={n for n in data if n != "p_id"}, + labels__grouping_levels=("kin",), + labels__root_nodes={n for n in data if n != "p_id"}, processed_data=data, ) @@ -922,7 +922,7 @@ def c(b): @pytest.mark.parametrize( - "policy_environment, targets, names__processed_data_columns, expected_error_match", + "policy_environment, targets, labels__processed_data_columns, expected_error_match", [ ({"foo": some_x}, {"bar": None}, set(), "('bar',)"), ({"foo__baz": some_x}, {"foo__bar": None}, set(), "('foo', 'bar')"), @@ -933,7 +933,7 @@ def c(b): def test_fail_if_targets_are_not_in_policy_environment_or_data( policy_environment, targets, - names__processed_data_columns, + labels__processed_data_columns, expected_error_match, ): with pytest.raises( @@ -943,7 +943,7 @@ def test_fail_if_targets_are_not_in_policy_environment_or_data( targets_are_not_in_policy_environment_or_data( policy_environment=policy_environment, targets__qname=targets, - names__processed_data_columns=names__processed_data_columns, + labels__processed_data_columns=labels__processed_data_columns, ) assert expected_error_match in str(e.value) diff --git a/tests/ttsim/test_names.py b/tests/ttsim/test_names.py index 2f1a17ec7..4dd629929 100644 --- a/tests/ttsim/test_names.py +++ b/tests/ttsim/test_names.py @@ -2,7 +2,7 @@ import pytest -from ttsim.interface_dag_elements.names import grouping_levels, top_level_namespace +from ttsim.interface_dag_elements.labels import grouping_levels, top_level_namespace from ttsim.tt_dag_elements import policy_function, policy_input diff --git a/tests/ttsim/test_policy_environment.py b/tests/ttsim/test_policy_environment.py index 23c186fe6..c60aaa4b2 100644 --- a/tests/ttsim/test_policy_environment.py +++ b/tests/ttsim/test_policy_environment.py @@ -79,8 +79,8 @@ def test_input_is_recognized_as_potential_group_id(): "orig_policy_objects__root": METTSIM_ROOT, "date": datetime.date(2020, 1, 1), }, - targets=["names__grouping_levels"], - )["names__grouping_levels"] + targets=["labels__grouping_levels"], + )["labels__grouping_levels"] ) @@ -92,8 +92,8 @@ def test_p_id_not_recognized_as_potential_group_id(): "orig_policy_objects__root": METTSIM_ROOT, "date": datetime.date(2020, 1, 1), }, - targets=["names__grouping_levels"], - )["names__grouping_levels"] + targets=["labels__grouping_levels"], + )["labels__grouping_levels"] ) From 9868e9fa0c3d3bc47ebf27e3cd2cb1eba35c7f8e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 14 Jun 2025 06:58:37 +0200 Subject: [PATCH 24/25] Perils of git commit -a ... --- src/ttsim/interface_dag_elements/labels.py | 175 +++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 src/ttsim/interface_dag_elements/labels.py diff --git a/src/ttsim/interface_dag_elements/labels.py b/src/ttsim/interface_dag_elements/labels.py new file mode 100644 index 000000000..f3b4e49f6 --- /dev/null +++ b/src/ttsim/interface_dag_elements/labels.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import networkx as nx + +from ttsim.interface_dag_elements.automatically_added_functions import ( + TIME_UNIT_LABELS, +) +from ttsim.interface_dag_elements.interface_node_objects import interface_function +from ttsim.interface_dag_elements.shared import ( + get_base_name_and_grouping_suffix, + get_re_pattern_for_all_time_units_and_groupings, + group_pattern, +) + +if TYPE_CHECKING: + from ttsim.interface_dag_elements.typing import ( + NestedPolicyEnvironment, + OrderedQNames, + QNameData, + QNamePolicyEnvironment, + UnorderedQNames, + ) + + +def fail_if_multiple_time_units_for_same_base_name_and_group( + base_names_and_groups_to_variations: dict[tuple[str, str], list[str]], +) -> None: + invalid = { + b: q for b, q in base_names_and_groups_to_variations.items() if len(q) > 1 + } + if invalid: + raise ValueError(f"Multiple time units for base names: {invalid}") + + +@interface_function() +def column_targets( + specialized_environment__with_partialled_params_and_scalars: UnorderedQNames, + targets__qname: OrderedQNames, +) -> OrderedQNames: + """All targets that are column functions.""" + return [ + t + for t in targets__qname + if t in specialized_environment__with_partialled_params_and_scalars + ] + + +@interface_function() +def param_targets( + specialized_environment__with_derived_functions_and_processed_input_nodes: QNamePolicyEnvironment, # noqa: E501 + targets__qname: OrderedQNames, + column_targets: OrderedQNames, +) -> OrderedQNames: + possible_targets = set(targets__qname) - set(column_targets) + return [ + t + for t in targets__qname + if t in possible_targets + and t + in specialized_environment__with_derived_functions_and_processed_input_nodes + ] + + +@interface_function() +def input_data_targets( + targets__qname: OrderedQNames, + column_targets: OrderedQNames, + param_targets: OrderedQNames, +) -> OrderedQNames: + possible_targets = set(targets__qname) - set(column_targets) - set(param_targets) + return [t for t in targets__qname if t in possible_targets] + + +@interface_function() +def grouping_levels( + policy_environment: QNamePolicyEnvironment, +) -> OrderedQNames: + """The grouping levels of the policy environment.""" + return tuple( + name.rsplit("_", 1)[0] + for name in policy_environment + if name.endswith("_id") and name != "p_id" + ) + + +@interface_function() +def top_level_namespace( + policy_environment: NestedPolicyEnvironment, + grouping_levels: OrderedQNames, +) -> UnorderedQNames: + """Get the top level namespace. + + Parameters + ---------- + policy_environment: + The policy environment. + + + Returns + ------- + top_level_namespace: + The top level namespace. + """ + time_units = tuple(TIME_UNIT_LABELS) + direct_top_level_names = set(policy_environment) + + # Do not create variations for lower-level namespaces. + top_level_objects_for_variations = direct_top_level_names - { + k for k, v in policy_environment.items() if isinstance(v, dict) + } + + pattern_all = get_re_pattern_for_all_time_units_and_groupings( + time_units=time_units, + grouping_levels=grouping_levels, + ) + bngs_to_variations = {} + all_top_level_names = direct_top_level_names.copy() + for name in top_level_objects_for_variations: + match = pattern_all.fullmatch(name) + # We must not find multiple time units for the same base name and group. + bngs = get_base_name_and_grouping_suffix(match) + if match.group("time_unit"): + if bngs not in bngs_to_variations: + bngs_to_variations[bngs] = [name] + else: + bngs_to_variations[bngs].append(name) + for time_unit in time_units: + all_top_level_names.add(f"{bngs[0]}_{time_unit}{bngs[1]}") + fail_if_multiple_time_units_for_same_base_name_and_group(bngs_to_variations) + + gp = group_pattern(grouping_levels) + potential_base_names = {n for n in all_top_level_names if not gp.match(n)} + + for name in potential_base_names: + for g in grouping_levels: + all_top_level_names.add(f"{name}_{g}") + + return all_top_level_names + + +@interface_function() +def processed_data_columns(processed_data: QNameData) -> UnorderedQNames: + return set(processed_data.keys()) + + +@interface_function() +def root_nodes( + specialized_environment__tax_transfer_dag: nx.DiGraph, + processed_data: QNameData, +) -> UnorderedQNames: + """Names of the columns in `processed_data` required for the tax transfer function. + + Parameters + ---------- + specialized_environment__tax_transfer_dag: + The tax transfer DAG. + processed_data: + The processed data. + + Returns + ------- + The names of the columns in `processed_data` required for the tax transfer function. + + """ + # Obtain root nodes + root_nodes = nx.subgraph_view( + specialized_environment__tax_transfer_dag, + filter_node=lambda n: specialized_environment__tax_transfer_dag.in_degree(n) + == 0, + ).nodes + + # Restrict the passed data to the subset that is actually used. + return {k for k in processed_data if k in root_nodes} From 7719e0f5439f8f2c9bf709990f090a57819d3c5e Mon Sep 17 00:00:00 2001 From: Hans-Martin von Gaudecker Date: Sat, 14 Jun 2025 07:16:04 +0200 Subject: [PATCH 25/25] Docstrings, as suggested by @mj023. --- src/ttsim/tt_dag_elements/param_objects.py | 10 +++++++++- src/ttsim/tt_dag_elements/piecewise_polynomial.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/ttsim/tt_dag_elements/param_objects.py b/src/ttsim/tt_dag_elements/param_objects.py index 7e204d831..4f631335e 100644 --- a/src/ttsim/tt_dag_elements/param_objects.py +++ b/src/ttsim/tt_dag_elements/param_objects.py @@ -146,7 +146,15 @@ def __post_init__(self) -> None: @dataclass(frozen=True) class PiecewisePolynomialParamValue: - """The parameters expected by piecewise_polynomial""" + """The parameters expected by `piecewise_polynomial`. + + thresholds: + Thresholds defining the pieces / different segments on the real line. + intercepts: + Intercepts of the polynomial on each segment. + rates: + Slope and higher-order coefficients of the polynomial on each segment. + """ thresholds: Float[Array, " n_segments"] intercepts: Float[Array, " n_segments"] diff --git a/src/ttsim/tt_dag_elements/piecewise_polynomial.py b/src/ttsim/tt_dag_elements/piecewise_polynomial.py index 3f315631c..335fa2eda 100644 --- a/src/ttsim/tt_dag_elements/piecewise_polynomial.py +++ b/src/ttsim/tt_dag_elements/piecewise_polynomial.py @@ -65,7 +65,7 @@ def piecewise_polynomial( x: Array with values at which the piecewise polynomial is to be calculated. parameters: - The parameters of the piecewise polynomial. + Thresholds defining the pieces and coefficients on each piece. xnp: The backend module to use for calculations. rates_multiplier: