Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Caution
This PR is only open to showcase the required changes to enable JAX. When #804 is merged, we will start implementing the changes --that might be here or in a fresh PR.
Enable JAX backend
In this PR, we ensure that
gettsimbecomes fully operational with thejaxbackend and is tested accordingly such that future changes in the codebase that are not JAX-compatible trigger test failures.Closes #515
Issues
To make GETTSIM fully JAX-operable, the tax and transfers function, returned by
dags, needs to JAX-jittable. This requires all functions defined in GETTSIM and called internally to be JAX-jittable. In particular, they cannot use functions fromnumpyor array methods that only work onnumpyarrays.ToDo's
_vectorize_func()is actually used and at correct position in codemake_vectorizablecan handle lambda functionsTesting
jax.jitthe compute taxes and transfers functions)jax.jitindividual functions and call them on JAX arrays)Numpy replacement
numpy.function(...)bynumpy_or_jax.function(...)*HMG: Make sure to find these cases by looking for the whole word "numpy", mostly we are importingconfig.numpy_or_jax as npalready.asarrayallclosesearchsortedarrayzerosinffull_likeminmaxjoinjoin_numpyis implemented. Either (1) implementjoin_jaxand call if required, or (2) rewritejoin_numpyusing Array API standard.fg_idcreation sensitive to order of adults #801 in the processDatetime replacement/refactoring
numpy.datetime64is used before runningdags(We want to allow the user to pass in data frames with DateTime-objects. However, these should be transformed internally to day, month, and year ints or floating point time spans, whatever is relevant.)
Piecewise Functions
Either create (1) one version of the code for
numpyand one of JAX, or (2) try to solve this problem using the Array API standard.