-
Notifications
You must be signed in to change notification settings - Fork 31
Vectorize METTSIM #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize METTSIM #879
Conversation
…n.to_source. In limited set of experiments, it produced exactly the same result.
…g tests. Fix typing and some isinstance checks.
…d_transfers.py'. Fails.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## collect-components-of-namespaces #879 +/- ##
====================================================================
+ Coverage 83.17% 83.57% +0.39%
====================================================================
Files 148 147 -1
Lines 5749 5704 -45
====================================================================
- Hits 4782 4767 -15
+ Misses 967 937 -30 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
The problem seems to be that JAX's segment functions cannot handle arbitrary inputs when being jitted. The problem here is that the resulting dimension depends on the IDs, and the array dimensions cannot be dynamic in JAX when using jit. There are a few remedies:
I cannot judge which one is preferred, or if this is feasible at all. That depends on whether the ID argument is actually static for a given computation, or whether it changes, which would result in re-compilation. |
|
Thanks!!! We are talking about item 2., right? In that case yes, I suppose we want to make the argument static, which would pretty much be case 3.ii above. That is, as soon as we want to jit a function that comes out of dags, we require that no Thinking about this makes me wonder whether we'll want to treat the |
|
Some of the errors in 2. seem to be caused by policy functions that are vectorized with np.vectorize, i.e. vectorization_strategy is set to 'loop'. The policy functions in the tests in The functions in 3. can be written to be jax compatible, they would need to be completely reworked though. Using all ID's as inputs seems therefore reasonable. Marking all ID's as static will result in recompilation, I think if we calculate num_segments, we would have to do that outside of the jit, but then we would have to mark it as static anyways, no? So I'm not surre if there is a way to do this without recompilation. @hmgaudecker I think the parameters are not changeable without calling |
|
Thanks! I thought I had changed everything in METTSIM to (true) vectorize, but cannot check right now. Yes, returning the jitted function would be the idea. Interface is in flux, anyhow! |
|
I've made a bunch of changes to prepare A lot of things already work, but there are two classes of errors popping up:
|
So this was because we were using the number of the unique values of |
… will do in the office where I had merged things but not pushed.
…-economics/gettsim into vectorize-mettsim
timmens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! No show stoppers from me, only comment is that IMO we should be a bit more elaborate when jitting the tax_and_transfers_function. I fear that otherwise these assumption and connections might be hidden to future maintainers.
Proposal (rough draft):
...
if jit:
if not IS_JAX_INSTALLED:
raise ImportError(
"JAX is not installed. Please install JAX to use JIT compilation."
)
tax_transfer_function = _jit_tax_and_transfer_function(
tax_transfer_function,
data=data,
)
...
def _jit_tax_and_transfer_function(
tax_transfer_function: Callable,
data: QualNameDataDict,
) -> Callable:
"""Create a JIT-compiled version of the tax transfer function.
This function prepares a tax transfer function for JIT compilation by:
1. Extracting the number of segments for each grouping variable. This is necessary
because JAX requires information about the number of segments for each grouping
variable, in order to be able to compile the function.
2. Partialing this information into the function. This also marks these arguments
as static arguments for JAX.
3. Applying JAX JIT compilation.
Parameters
----------
tax_transfer_function : callable
The tax transfer function to be JIT-compiled
data : dict
The flattened user-provided data.
Returns
-------
callable
The JIT-compiled tax transfer function
"""
import jax
segment_counts = _get_segment_counts_for_grouping_variables(
function=tax_transfer_function,
data=data,
)
tax_function_with_static_args = functools.partial(
tax_transfer_function,
**segment_counts,
)
return jax.jit(tax_function_with_static_args)
def _get_segment_counts_for_grouping_variables(
function: Callable,
data: QualNameDataDict,
) -> dict[str, int]:
segment_counts = {}
for argname in inspect.signature(function).parameters:
# In `create_agg_by_group_functions`, we re-define the argument `num_segments`
# of the aggregation functions as `{group_level}_num_segments`, depending on the
# grouping level.
if argname.endswith("_num_segments"):
grouping_level = argname.removesuffix("_num_segments")
num_segments = _calculate_num_segments(data[grouping_level])
segment_counts[argname] = num_segments
return segment_counts
def _calculate_num_segments(group_level_ids: QualNameDataDict) -> int:
"""Calculate the number of segments for a grouping variable.
This assumes that the grouping level IDs start at 0 and are contiguous.
"""
return int(group_level_ids.max()) + 1
Completely agree! Will push that off for now though, will become part of the interface redesign. I opened #890 with your suggestion and some more thoughts. |
MImmesberger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I won't act like I understand every line of this 😅 (but the draft by Tim really helps!)
I think having _num_segments in the input data is quite confusing but if I understood it correctly this will change anyhow.
Just so I can follow: next steps (in a different PR) would be to make vectorization_strategy="vectorize" work with every TTIM function?
| "exception_match": "The dtype of id columns must be integer.", | ||
| }, | ||
| } | ||
| # We cannot even set up these fixtures in JAX. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove this then? Cannot come up with an example why we'd ever need such operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? First person in the household to retire, ... I'd leave it there in order to delimit the boundaries of what can be done.
"with most GETTSIM functions", yes! |
### What problem do you want to solve? 1. A brief shot at trying to do the same for GETTSIM what #879 did for METTSIM. 2. Change the default of `vectorization_strategy` to `vectorize`, so that looking for `"loop"` allows to find all non-vectorisable cases quickly. - Good news: `test_full_taxes_and_transfers` runs nicely when jitted! - Bad news: Many tests fail because the logic in dividing up the taxes and transfers function is not elaborate enough. In the first pass, we are trying to build all ids. However, in many cases we are missing the required input data (made-up example: Einkommensteuer tests may require calculation of `sn_id`, but won't have all inputs required for `bg_id`). Solving this should be doable (first set up the entire graph, then check which ids are needed), but the required functions are buried inside of `compute_taxes_and_transfers` so we should not waste time on that before implementing the new interface.
Similar to
test_jax_jit_kindergeld.py, we want to be able to jit METTSIM to a large extent. This PR attempts to do so.test_compute_taxes_and_transfersandtest_combine_functions, which only happen whenjit=Trueincompute_taxes_and_transfers(see 1929585 and 7daffb7).sp_id(and likelyfam_id) creation and decide on a strategy to deal with them_gettsim/ids.pyare just as easily converted)group_creation_functions with Jax' jitting turned on and use Numpy explicitly.@timmens @mj023 -- could you have a brief look into items 2. and 3.? I'd hope that 2. is not too difficult but I don't quite see what is going on. My hunch is that on 3., it will probably be option ii., but I want to be sure here.