Skip to content

Conversation

@hmgaudecker
Copy link
Collaborator

@hmgaudecker hmgaudecker commented Jun 12, 2025

We currently use global settings for numpy/jax based on whether jax is installed. This is unsatisfactory from many reasons.

This PR allows setting backend: Literal["numpy", "jax"] as part of the inputs of the interface DAG. It defaults to "numpy". All operations on columns that require a module-level call to np / jnp take a parameter xnp, which can be either one depending on the backend setting.

  • Remove all imports of config.numpy_or_jax
  • Defer vectorization until the creation of the specialised environment (backend is unknown when decorator is applied)
  • Add command-line option to tests for using jax
  • Make sure ttsim tests pass
  • Make sure GETTSIM test pass
  • Remove imports of IS_JAX_INSTALLED and adjust testing infrastructure to be based on command-line option everywhere
  • Use jaxtyping instead of numpy.ndarray for type hints.

@hmgaudecker hmgaudecker changed the base branch from main to hierarchical-interface June 12, 2025 07:46
@hmgaudecker
Copy link
Collaborator Author

@mj023, could you do me a favour and check the errors on this branch when running

pixi run tests-jax tests

?

I even copied the code from GETTSIM, which I thought should jit fine after #905, but I get complaints... Also asking for a review so you'll know the updated logic. Thanks!

@hmgaudecker
Copy link
Collaborator Author

Btw, both of you: Forget about the mypy errors, I will tackle them later on but would love to have some feedback!

@mj023
Copy link
Collaborator

mj023 commented Jun 12, 2025

The reason seems to be, that the variable xnp that is passed to these functions is still normal numpy instead of jax.numpy. Calling a numpy method on a traced Jax Array then throws the error. I still need to look exactly where this changes, because the new xnp function actually returns jax.numpy.

@hmgaudecker
Copy link
Collaborator Author

The reason seems to be, that the variable xnp that is passed to these functions is still normal numpy instead of jax.numpy. Calling a numpy method on a traced Jax Array then throws the error. I still need to look exactly where this changes, because the new xnp function actually returns jax.numpy.

Thanks!

Maybe we'll need to defer this line in the group_creation_function decorator:

        func_with_reorder = lambda **kwargs: reorder_ids(
            ids=func(**kwargs), xnp=kwargs["xnp"]
        )

to a later point, similar to what I changed with the vectorisation?

@mj023
Copy link
Collaborator

mj023 commented Jun 12, 2025

Should be fixed now, it was just a problem with the test execution. The cached policy environment did not receive a backend, so it defaulted to numpy. I hope it was okay for me to just quickly commit this change to your branch.

@hmgaudecker
Copy link
Collaborator Author

Should be fixed now, it was just a problem with the test execution. The cached policy environment did not receive a backend, so it defaulted to numpy. I hope it was okay for me to just quickly commit this change to your branch.

Excellent, thanks!!! I might have searched forever for that one... No worries at all re committing to the branch.

@hmgaudecker
Copy link
Collaborator Author

@MImmesberger, please double-check my change in d0a147. The previous version did not work in jax (iteration there is over are 1-element arrays rather than scalars) and I removed the checks re numeric contents in arrays (superfluous) and the possibility to return lists (cannot happen except for the case where a parameter happens to return a list of length n obs, right?).

@MImmesberger
Copy link
Collaborator

cannot happen except for the case where a parameter happens to return a list of length n obs, right?

Yes, that's correct, and the TypeError it raises now is more telling than the ValueError.

"""
return {

@policy_input()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have backend and num_segments would it make sense to use a different
class for them? I'm sure some users will look at intermediate versions of the function
tree and these showing up might be confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. On the one hand, I completely agree. On the other hand, I really don't want to bloat the types. Let's postpone until we have some experience, will open an issue momentarily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xnp=xnp,
),
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, we could move "Evaluationsjahr" to the regular TT files, correct?

Copy link
Collaborator Author

@hmgaudecker hmgaudecker Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, still missing the dates handling.

graph TD;
    date_str --> date;
    evaluation_date_str --> evaluation_date;
    date --> evaluation_date;
    date --> policy_date;
    policy_date_str --> policy_date;
Loading

All these would live in the interface DAG, the evaluation date would also go into the taxes & transfers DAG. The str variables would be inputs, the others functions. Raise an error if both date and either specific date are passed.

orig_policy_objects__column_objects_and_param_functions: NestedColumnObjectsParamFunctions, # noqa: E501
orig_policy_objects__param_specs: FlatOrigParamSpecs,
date: datetime.date | DashedISOString,
backend: Literal["numpy", "jax"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore this one, just wanted to express my confusion about needing
backend and xnp (I thought the former is just how the user interacts with the
latter). As far as I can tell, this is only because of _make_vectorizable_ast and I
didn't bother trying to understand this. I'll just assume that this is necessary for some
technical reason.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • backend is the string.
  • xnp is the actual module

Technically we could do without either one by doing imports conditional on backend (certainly not) or by conditioning on xnp.__name__ (no idea whether that exists) instead of conditioning on backend.

But users only need to pass backend and the rest is handled automatically, so at least it should not be a user-facing issue.

}
# Add backend so we can decide between numpy and jax for aggregation functions
assert "backend" not in out
out[("backend",)] = backend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get why this is here. Shouldn't be relevant if functions
are only collected, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only the placeholders, right? And they should be filled up here. But probably we could just add to policy_environment, too.

@hmgaudecker hmgaudecker merged commit a0dc9cf into hierarchical-interface Jun 16, 2025
9 checks passed
@hmgaudecker hmgaudecker deleted the set-backend branch June 16, 2025 05:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants