Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
66e1d71
Got started setting xnp as the backend at runtime.
hmgaudecker Jun 10, 2025
1d914e5
Merge branch 'hierarchical-interface' into set-backend
hmgaudecker Jun 11, 2025
279fc7a
Defer vectorization until the creation of the specialised environment…
hmgaudecker Jun 11, 2025
2527f7a
Improve fixtures.
hmgaudecker Jun 11, 2025
cc98a2d
Convert test_piecewise_polynomial.
hmgaudecker Jun 11, 2025
4f52606
Change order of vectorization and removing tree logic (vectorization …
hmgaudecker Jun 11, 2025
d0bb42e
Change order of vectorization and removing tree logic (vectorization …
hmgaudecker Jun 11, 2025
601c4cf
More updates, getting there.
hmgaudecker Jun 11, 2025
21821a8
Fix tests in METTSIM, adjust calls to piecewise polynomial everywhere…
hmgaudecker Jun 11, 2025
1f18025
Fix usage of numpy/xnp in Lohnsteuer module.
hmgaudecker Jun 12, 2025
d671cf7
Small cleanups and docstrings.
hmgaudecker Jun 12, 2025
1d8afad
Adjust GETTSIM s.t. tests pass.
hmgaudecker Jun 12, 2025
471a151
Almost done removing IS_JAX_INSTALLED, missing the aggregations.
hmgaudecker Jun 12, 2025
86c8d2b
Remove IS_JAX_INSTALLED. For some reason, tests are not working with …
hmgaudecker Jun 12, 2025
c233b4e
Add backend to cached policy env
mj023 Jun 12, 2025
8afc429
Merge branch 'hierarchical-interface' into set-backend
hmgaudecker Jun 13, 2025
d0a147b
Remove lists from possible outputs; make jax tests pass.
hmgaudecker Jun 13, 2025
76dd10e
Add test case for TypeError because of list parameter.
MImmesberger Jun 13, 2025
cebf3f1
Move to jaxtyping, except for aggregation.
hmgaudecker Jun 13, 2025
1e8f558
Rename, we have many more groups than 'hh' now...
hmgaudecker Jun 13, 2025
e3d933d
Got rid of remaining numpy.ndarray type hints.
hmgaudecker Jun 13, 2025
905f28b
Merge branch 'set-backend' of github.com:iza-institute-of-labor-econo…
hmgaudecker Jun 13, 2025
f13ccd5
Cut down on ruff exceptions.
hmgaudecker Jun 13, 2025
141cf48
Further cut down on ruff exceptions.
hmgaudecker Jun 13, 2025
8f9ce34
Apply naming suggestions made by @MImmesberger in #951.
hmgaudecker Jun 13, 2025
c721a37
Rename 'names' -> 'labels'. Sort-of another naming suggestions made b…
hmgaudecker Jun 13, 2025
9868e9f
Perils of git commit -a ...
hmgaudecker Jun 14, 2025
7719e0f
Docstrings, as suggested by @mj023.
hmgaudecker Jun 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,19 @@ repos:
- '88'
files: (docs/.|CHANGES.md|CODE_OF_CONDUCT.md)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.16.0
hooks:
- id: mypy
args:
- --ignore-missing-imports
- --config=pyproject.toml
- --allow-redefinition-new
- --local-partial-types
additional_dependencies:
- types-PyYAML
- types-pytz
- numpy >= 2
- jaxtyping
# - dags >= 0.3
- optree >= 0.15
- repo: https://github.com/python-jsonschema/check-jsonschema
Expand Down
44 changes: 44 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from ttsim.interface_dag_elements.backend import dnp as ttsim_dnp
from ttsim.interface_dag_elements.backend import xnp as ttsim_xnp

if TYPE_CHECKING:
from types import ModuleType
from typing import Literal


# content of conftest.py
def pytest_addoption(parser):
parser.addoption(
"--backend",
action="store",
default="numpy",
help="The backend to test against (e.g., --backend=numpy --backend=jax)",
)


@pytest.fixture
def backend(request) -> Literal["numpy", "jax"]:
return request.config.getoption("--backend")


@pytest.fixture
def xnp(request) -> ModuleType:
return ttsim_xnp(request.config.getoption("--backend"))


@pytest.fixture
def dnp(request) -> ModuleType:
return ttsim_dnp(request.config.getoption("--backend"))


@pytest.fixture(autouse=True)
def skipif_jax(request, backend):
"""Automatically skip tests marked with skipif_jax when backend is jax."""
if request.node.get_closest_marker("skipif_jax") and backend == "jax":
pytest.skip("Cannot run this test with Jax")
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# -- Project information -----------------------------------------------------

project = "GETTSIM"
copyright = f"2019-{datetime.today().year}, GETTSIM team" # noqa: A001
copyright = f"2019-{datetime.today().year}, GETTSIM team" # noqa: A001, DTZ002
author = "GETTSIM team"
release = "0.7.0"
version = ".".join(release.split(".")[:2])
Expand Down Expand Up @@ -122,7 +122,7 @@
"**": [
"relations.html", # needs 'show_related': True theme option to display
"searchbox.html",
]
],
}

# Napoleon settings
Expand Down
133 changes: 115 additions & 18 deletions pixi.lock

Large diffs are not rendered by default.

132 changes: 39 additions & 93 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ snakeviz = ">=2.2.2,<3"
[tool.pixi.pypi-dependencies]
gettsim = {path = ".", editable = true}
dags = { git = "https://github.com/OpenSourceEconomics/dags.git", branch = "allow-passing-dag-to-concatenate_functions"}
jaxtyping = "*"
pdbp = "*"

[tool.pixi.target.unix.pypi-dependencies]
Expand All @@ -150,12 +151,15 @@ python = "3.12.*"
jax = ">=0.4.20"
jaxlib = ">=0.4.20"

[tool.pixi.feature.jax.pypi-dependencies]
jax-datetime = { git = "https://github.com/google/jax-datetime.git" }

[tool.pixi.feature.jax.target.win-64.pypi-dependencies]
jax = { version = ">=0.4.20", extras = ["cpu"] }
jaxlib = ">=0.4.20"

[tool.pixi.feature.mypy.pypi-dependencies]
mypy = "==1.15.0"
mypy = "~=1.16"
types-PyYAML = "*"
types-pytz = "*"

Expand All @@ -165,6 +169,9 @@ types-pytz = "*"
[tool.pixi.feature.test.tasks]
tests = "pytest"

[tool.pixi.feature.jax.tasks]
tests-jax = "pytest --backend=jax"

[tool.pixi.feature.mypy.tasks]
mypy = "mypy --ignore-missing-imports"

Expand All @@ -190,78 +197,48 @@ unsafe-fixes = false
[tool.ruff.lint]
select = ["ALL"]
extend-ignore = [
"ICN001", # numpy should be np, but different convention here.
# Docstrings
"D103", # missing docstring in public function
"D107",
"D203",
"D212",
"D213",
"D402",
"D413",
"D415",
"D416",
"D417",
# Others.
"D404", # Do not start module docstring with "This".
"RET504", # unnecessary variable assignment before return.
"S101", # raise errors for asserts.
"B905", # strict parameter for zip that was implemented in py310.

"FBT", # flake8-boolean-trap
"EM", # flake8-errmsg
"ANN401", # flake8-annotate typing.Any
"PD", # pandas-vet
"E731", # do not assign a lambda expression, use a def
"RET", # unnecessary elif or else statements after return, raise, continue, ...
"S324", # Probable use of insecure hash function.
"COM812", # trailing comma missing, but black takes care of that
"PT007", # wrong type in parametrize, gave false positives
"DTZ001", # use of `datetime.datetime()` without `tzinfo` argument is not allowed
"DTZ002", # use of `datetime.datetime.today()` is not allowed
"PT012", # `pytest.raises()` block should contain a single simple statement
"PLR5501", # elif not supported by Jax converter
"TRY003", # Avoid specifying long messages outside the exception class
"COM812", # Avoid conflicts with ruff-format
"EM101", # Exception must not use a string literal
"EM102", # Exception must not use an f-string literal
"F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
"FBT001", # Boolean-typed positional argument in function definition
"FIX002", # Line contains TODO -- Use stuff from TD area.
"ICN001", # numpy should be np, but different convention here.
"ISC001", # Avoid conflicts with ruff-format
"N999", # Allow non-ASCII characters in file names.
"PLC2401", # Allow non-ASCII characters in variable names.
"PLC2403", # Allow non-ASCII function names for imports.
"PLR0913", # Allow too many arguments in function definitions.
"N999", # Allow non-ASCII characters in file names.
"PLR0913", # Too many arguments in function definition.

# Things we are not sure we want
# ==============================
"SIM102", # Use single if statement instead of nested if statements
"SIM108", # Use ternary operator instead of if-else block
"SIM117", # do not use nested with statements
"BLE001", # Do not catch blind exceptions (even after handling some specific ones)
"PLR2004", # Magic values used in comparison
"PT006", # Allows only lists of tuples in parametrize, even if single argument
"PLR5501", # elif not supported by vectorization converter for Jax
"TRY003", # Avoid specifying long messages outside the exception class

# Things ignored during transition phase
# Ignored during transition phase
# ======================================
"D", # docstrings
"ANN", # missing annotations
"C901", # function too complex
"PT011", # pytest raises without match statement
"INP001", # implicit namespace packages without init.
"E721", # Use `is` and `is not` for type comparisons
"TD003", # Missing issue link -- remove again once we got rid of ad-hoc TODOs.
"ERA001", # Commented out code.
"PLR2004", # Magic values used in comparison
"PT006", # Allows only lists of tuples in parametrize, even if single argument
"PT007", # wrong type in parametrize
"S101", # use of asserts outside of tests

# Things ignored to avoid conflicts with ruff-format
# ==================================================
"ISC001",
]
exclude = []

[tool.ruff.lint.per-file-ignores]
"conftest.py" = ["ANN"]
"docs/**/*.ipynb" = ["T201"]
"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716"] # Vectorization can't handle x <= y <= z or x in {x,y}
# Mostly things vectorization can't handle
"src/_gettsim/*" = ["E501", "PLR1714", "PLR1716", "E721", "SIM108", "RET"]
# All tests return None and use asserts
"src/_gettsim_tests/**/*.py" = ["ANN", "S101"]
"src/ttsim/interface_dag_elements/specialized_environment.py" = ["E501"]
"src/ttsim/interface_dag_elements/fail_if.py" = ["E501"]
"src/ttsim/interface_dag_elements/typing.py" = ["PGH", "PLR", "SIM114"]
"tests/ttsim/mettsim/*" = ["PLR1714", "PLR1716"] # Vectorization can't handle x <= y <= z or x in {x,y}
# Mostly things vectorization can't handle
"tests/ttsim/mettsim/**/*.py" = ["PLR1714", "PLR1716", "E721", "SIM108", "RET"]
"tests/ttsim/tt_dag_elements/test_vectorization.py" = ["PLR1714", "PLR1716", "E721", "SIM108", "RET"]
# All tests return None and use asserts
"tests/ttsim/**/*.py" = ["ANN", "S101"]
"tests/ttsim/test_failures.py" = ["E501"]
# TODO: remove once ported nicely
"src/ttsim/stale_code_storage.py" = ["ALL"]
Expand Down Expand Up @@ -292,6 +269,7 @@ disallow_empty_bodies = false

[[tool.mypy.overrides]]
module = [
"conftest",
"src.ttsim.plot_dag_old",
"src.ttsim.stale_code_storage",
"src._gettsim_tests.test_docs",
Expand All @@ -300,45 +278,12 @@ disallow_untyped_defs = false
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"tests.*",
]
disable_error_code = [
"no-untyped-def", # All tests return None, don't clutter source code.
]

[[tool.mypy.overrides]]
module = [
"tests.ttsim.test_failures",
]
disable_error_code = [
"misc", # Happens when constructing param dictionaries on the fly.
]


[[tool.mypy.overrides]]
module = [
"src.ttsim.tt_dag_elements.aggregation_numpy",
]
disable_error_code = [
"type-arg" # ndarray is not typed further.
]
module = ["tests.*",]
disable_error_code = ["no-untyped-def"] # All tests return None, don't clutter source code.

[[tool.mypy.overrides]]
module = [
"src._gettsim_tests.*",
]
disable_error_code = [
"no-untyped-def", # All tests return None, don't clutter source code.
]

[[tool.mypy.overrides]]
module = [
"tests.ttsim.tt_dag_elements.test_vectorization", # doing some numpy mixing; vectorization does not like inline comments.
]
disable_error_code = [
"assignment"
]
module = ["src._gettsim_tests.*",]
disable_error_code = ["no-untyped-def"] # All tests return None, don't clutter source code.

[tool.check-manifest]
ignore = ["src/_gettsim/_version.py"]
Expand All @@ -364,6 +309,7 @@ markers = [
"unit: Flag for unit tests which target mainly a single function.",
"integration: Flag for integration tests which may comprise of multiple unit tests.",
"end_to_end: Flag for tests that cover the whole program.",
"skipif_jax: skip test if backend is jax"
]
norecursedirs = ["docs"]
testpaths = [
Expand Down
17 changes: 15 additions & 2 deletions src/_gettsim/arbeitslosengeld_2/einkommen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)

if TYPE_CHECKING:
from types import ModuleType

from ttsim.interface_dag_elements.typing import RawParam


Expand Down Expand Up @@ -147,12 +149,14 @@ def anrechnungsfreies_einkommen_m_basierend_auf_nettoquote(
einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m: float,
nettoquote: float,
parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: PiecewisePolynomialParamValue,
xnp: ModuleType,
) -> float:
"""Share of income which remains to the individual."""
return piecewise_polynomial(
x=einkommensteuer__einkünfte__aus_nichtselbstständiger_arbeit__bruttolohn_m,
parameters=parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg,
rates_multiplier=nettoquote,
xnp=xnp,
)


Expand All @@ -164,6 +168,7 @@ def anrechnungsfreies_einkommen_m(
einkommensteuer__anzahl_kinderfreibeträge: int,
parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: PiecewisePolynomialParamValue,
parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg: PiecewisePolynomialParamValue,
xnp: ModuleType,
) -> float:
"""Calculate share of income, which remains to the individual since 10/2005.

Expand All @@ -184,35 +189,42 @@ def anrechnungsfreies_einkommen_m(
out = piecewise_polynomial(
x=eink_erwerbstätigkeit,
parameters=parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg,
xnp=xnp,
)
else:
out = piecewise_polynomial(
x=eink_erwerbstätigkeit,
parameters=parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg,
xnp=xnp,
)
return out


@param_function(start_date="2005-01-01")
def parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg(
raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: RawParam,
xnp: ModuleType,
) -> PiecewisePolynomialParamValue:
"""Parameter for calculation of income not subject to transfer withdrawal when
children are not in the Bedarfsgemeinschaft."""
children are not in the Bedarfsgemeinschaft.
"""
return get_piecewise_parameters(
leaf_name="parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg",
func_type="piecewise_linear",
parameter_dict=raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg,
xnp=xnp,
)


@param_function(start_date="2005-10-01")
def parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg(
raw_parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg: RawParam,
raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg: RawParam,
xnp: ModuleType,
) -> PiecewisePolynomialParamValue:
"""Parameter for calculation of income not subject to transfer withdrawal when
children are in the Bedarfsgemeinschaft."""
children are in the Bedarfsgemeinschaft.
"""
updated_parameters: dict[int, dict[str, float]] = upsert_tree(
base=raw_parameter_anrechnungsfreies_einkommen_ohne_kinder_in_bg,
to_upsert=raw_parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg,
Expand All @@ -221,4 +233,5 @@ def parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg(
leaf_name="parameter_anrechnungsfreies_einkommen_mit_kindern_in_bg",
func_type="piecewise_linear",
parameter_dict=updated_parameters,
xnp=xnp,
)
4 changes: 3 additions & 1 deletion src/_gettsim/arbeitslosengeld_2/freibeträge_vermögen.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def vermögensfreibetrag_in_karenzzeit_bg(


@policy_function(
start_date="2005-01-01", end_date="2022-12-31", leaf_name="vermögensfreibetrag_bg"
start_date="2005-01-01",
end_date="2022-12-31",
leaf_name="vermögensfreibetrag_bg",
)
def vermögensfreibetrag_bg_bis_2022(
grundfreibetrag_vermögen_bg: float,
Expand Down
Loading
Loading