-
Notifications
You must be signed in to change notification settings - Fork 31
Allow setting backend at runtime #960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…. Vectorisation tests pass.
…destroys renaming), make rounding tests pass.
…destroys renaming), make rounding tests pass (numpy only, though).
…. Still many errors in GETTSIM.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
…JAX, probably some earlier version did not jit correctly...
|
Btw, both of you: Forget about the mypy errors, I will tackle them later on but would love to have some feedback! |
|
The reason seems to be, that the variable |
Thanks! Maybe we'll need to defer this line in the func_with_reorder = lambda **kwargs: reorder_ids(
ids=func(**kwargs), xnp=kwargs["xnp"]
)to a later point, similar to what I changed with the vectorisation? |
|
Should be fixed now, it was just a problem with the test execution. The cached policy environment did not receive a backend, so it defaulted to numpy. I hope it was okay for me to just quickly commit this change to your branch. |
Excellent, thanks!!! I might have searched forever for that one... No worries at all re committing to the branch. |
|
@MImmesberger, please double-check my change in d0a147. The previous version did not work in jax (iteration there is over are 1-element arrays rather than scalars) and I removed the checks re numeric contents in arrays (superfluous) and the possibility to return lists (cannot happen except for the case where a parameter happens to return a list of length n obs, right?). |
Yes, that's correct, and the |
| """ | ||
| return { | ||
|
|
||
| @policy_input() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have backend and num_segments would it make sense to use a different
class for them? I'm sure some users will look at intermediate versions of the function
tree and these showing up might be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. On the one hand, I completely agree. On the other hand, I really don't want to bloat the types. Let's postpone until we have some experience, will open an issue momentarily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| xnp=xnp, | ||
| ), | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, we could move "Evaluationsjahr" to the regular TT files, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, still missing the dates handling.
graph TD;
date_str --> date;
evaluation_date_str --> evaluation_date;
date --> evaluation_date;
date --> policy_date;
policy_date_str --> policy_date;
All these would live in the interface DAG, the evaluation date would also go into the taxes & transfers DAG. The str variables would be inputs, the others functions. Raise an error if both date and either specific date are passed.
| orig_policy_objects__column_objects_and_param_functions: NestedColumnObjectsParamFunctions, # noqa: E501 | ||
| orig_policy_objects__param_specs: FlatOrigParamSpecs, | ||
| date: datetime.date | DashedISOString, | ||
| backend: Literal["numpy", "jax"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to ignore this one, just wanted to express my confusion about needing
backend and xnp (I thought the former is just how the user interacts with the
latter). As far as I can tell, this is only because of _make_vectorizable_ast and I
didn't bother trying to understand this. I'll just assume that this is necessary for some
technical reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backendis the string.xnpis the actual module
Technically we could do without either one by doing imports conditional on backend (certainly not) or by conditioning on xnp.__name__ (no idea whether that exists) instead of conditioning on backend.
But users only need to pass backend and the rest is handled automatically, so at least it should not be a user-facing issue.
| } | ||
| # Add backend so we can decide between numpy and jax for aggregation functions | ||
| assert "backend" not in out | ||
| out[("backend",)] = backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get why this is here. Shouldn't be relevant if functions
are only collected, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are only the placeholders, right? And they should be filled up here. But probably we could just add to policy_environment, too.
…mics/gettsim into set-backend
We currently use global settings for numpy/jax based on whether jax is installed. This is unsatisfactory from many reasons.
This PR allows setting
backend: Literal["numpy", "jax"]as part of the inputs of the interface DAG. It defaults to"numpy". All operations on columns that require a module-level call tonp/jnptake a parameterxnp, which can be either one depending on thebackendsetting.config.numpy_or_jaxIS_JAX_INSTALLEDand adjust testing infrastructure to be based on command-line option everywherejaxtypinginstead ofnumpy.ndarrayfor type hints.