Skip to content

Conversation

@hmgaudecker
Copy link
Collaborator

@hmgaudecker hmgaudecker commented Apr 22, 2025

Similar to test_jax_jit_kindergeld.py, we want to be able to jit METTSIM to a large extent. This PR attempts to do so.

  1. Put in all the basic infrastructure and fix the most obvious errors.
  2. Fix errors in test_compute_taxes_and_transfers and test_combine_functions, which only happen when jit=True in compute_taxes_and_transfers (see 1929585 and 7daffb7).
  3. Check failures in METTSIM's sp_id (and likely fam_id) creation and decide on a strategy to deal with them
    1. Fix (only if similar functions in _gettsim/ids.py are just as easily converted)
    2. Fail gracefully when there is an attempt to use group_creation_functions with Jax' jitting turned on and use Numpy explicitly.

@timmens @mj023 -- could you have a brief look into items 2. and 3.? I'd hope that 2. is not too difficult but I don't quite see what is going on. My hunch is that on 3., it will probably be option ii., but I want to be sure here.

@codecov
Copy link

codecov bot commented Apr 22, 2025

Codecov Report

Attention: Patch coverage is 78.51852% with 29 lines in your changes missing coverage. Please review.

Project coverage is 83.57%. Comparing base (da62dfd) to head (757ba3c).
Report is 1 commits behind head on collect-components-of-namespaces.

Files with missing lines Patch % Lines
tests/ttsim/utils.py 60.86% 9 Missing ⚠️
src/ttsim/compute_taxes_and_transfers.py 45.45% 6 Missing ⚠️
src/ttsim/typing.py 50.00% 3 Missing ⚠️
tests/ttsim/test_compute_taxes_and_transfers.py 84.21% 3 Missing ⚠️
src/ttsim/automatically_added_functions.py 50.00% 2 Missing ⚠️
tests/ttsim/test_aggregation_functions.py 92.00% 2 Missing ⚠️
src/ttsim/ttsim_objects.py 83.33% 1 Missing ⚠️
tests/ttsim/test_combine_functions.py 75.00% 1 Missing ⚠️
tests/ttsim/test_mettsim.py 75.00% 1 Missing ⚠️
tests/ttsim/test_rounding.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@                         Coverage Diff                          @@
##           collect-components-of-namespaces     #879      +/-   ##
====================================================================
+ Coverage                             83.17%   83.57%   +0.39%     
====================================================================
  Files                                   148      147       -1     
  Lines                                  5749     5704      -45     
====================================================================
- Hits                                   4782     4767      -15     
+ Misses                                  967      937      -30     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@timmens
Copy link
Collaborator

timmens commented Apr 23, 2025

The problem seems to be that JAX's segment functions cannot handle arbitrary inputs when being jitted. The problem here is that the resulting dimension depends on the IDs, and the array dimensions cannot be dynamic in JAX when using jit.

There are a few remedies:

  1. We could compute num_segments beforehand (if it is known) and add that as an argument to the segment calls
  2. We make the ID argument static

I cannot judge which one is preferred, or if this is feasible at all. That depends on whether the ID argument is actually static for a given computation, or whether it changes, which would result in re-compilation.

@hmgaudecker
Copy link
Collaborator Author

Thanks!!! We are talking about item 2., right?

In that case yes, I suppose we want to make the argument static, which would pretty much be case 3.ii above. That is, as soon as we want to jit a function that comes out of dags, we require that no group_creation_functions are in there. Which implies that any [x]_id elements must be inputs and, hence, can be marked as static. These inputs could of course be pre-computed in a first step, not using jitting.

Thinking about this makes me wonder whether we'll want to treat the params_[y] arguments as static. For a new policy environment, we'll always want to recompile, anyhow. So maybe we could include all of these (checking for group_by_functions and params, partialling these into functions and marking the arguments as static) in one go?

@mj023
Copy link
Collaborator

mj023 commented Apr 23, 2025

Some of the errors in 2. seem to be caused by policy functions that are vectorized with np.vectorize, i.e. vectorization_strategy is set to 'loop'. The policy functions in the tests in test_combine_functions either need to be updated, to explicitly give vectorization_strategy=vectorize, or we could change the default vectorization_strategy in the policy function decorator, depending on whether Jax is installed, because np.vectorize won't work anyways.

The functions in 3. can be written to be jax compatible, they would need to be completely reworked though. Using all ID's as inputs seems therefore reasonable. Marking all ID's as static will result in recompilation, I think if we calculate num_segments, we would have to do that outside of the jit, but then we would have to mark it as static anyways, no? So I'm not surre if there is a way to do this without recompilation.

@hmgaudecker I think the parameters are not changeable without calling compute_taxes_and_transfers right? Then they are basically static anyways. They way the function is currently written, every call to compute_taxes_and_transfers will recompile everything. We will have to return the jitted function, if we want to actually make use of the jitting.

@hmgaudecker
Copy link
Collaborator Author

hmgaudecker commented Apr 23, 2025

Thanks! I thought I had changed everything in METTSIM to (true) vectorize, but cannot check right now.

Yes, returning the jitted function would be the idea. Interface is in flux, anyhow!

@timmens
Copy link
Collaborator

timmens commented Apr 29, 2025

I've made a bunch of changes to prepare compute_taxes_and_transfers to be jittable (please review carefully). Mainly I introduced a new input data argument num_segments ---assumed to be required when jitting for the moment, but will most likely be computed at function creation time using, e.g., fam_id.

A lot of things already work, but there are two classes of errors popping up:

  1. The num_segments input data argument does not get passed through and is missing in the input data in compute_taxes_and_transfers before jitting. Re-create by calling
    pixi run tests tests/ttsim/test_compute_taxes_and_transfers.py::test_user_provided_aggregation
  2. For some tests the jitting works, however, the test is not successful anymore. Maybe you have an idea. Re-create by calling
    pixi run tests tests/ttsim/test_compute_taxes_and_transfers.py::test_user_provided_aggregate_by_group_specs

@hmgaudecker
Copy link
Collaborator Author

For some tests the jitting works, however, the test is not successful anymore. Maybe you have an idea. Re-create by calling

So this was because we were using the number of the unique values of group_id for num_segments. However, it needs to be group_id.max() + 1.

Copy link
Collaborator

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! No show stoppers from me, only comment is that IMO we should be a bit more elaborate when jitting the tax_and_transfers_function. I fear that otherwise these assumption and connections might be hidden to future maintainers.

Proposal (rough draft):

...

    if jit:
        if not IS_JAX_INSTALLED:
            raise ImportError(
                "JAX is not installed. Please install JAX to use JIT compilation."
            )
        tax_transfer_function = _jit_tax_and_transfer_function(
            tax_transfer_function,
            data=data,
        )
        
...
        
def _jit_tax_and_transfer_function(
    tax_transfer_function: Callable,
    data: QualNameDataDict,
) -> Callable:
    """Create a JIT-compiled version of the tax transfer function.
    
    This function prepares a tax transfer function for JIT compilation by:

    1. Extracting the number of segments for each grouping variable. This is necessary
       because JAX requires information about the number of segments for each grouping
       variable, in order to be able to compile the function.
    2. Partialing this information into the function. This also marks these arguments
       as static arguments for JAX.
    3. Applying JAX JIT compilation.
    
    Parameters
    ----------
    tax_transfer_function : callable
        The tax transfer function to be JIT-compiled
    data : dict
        The flattened user-provided data.
        
    Returns
    -------
    callable
        The JIT-compiled tax transfer function
    """
    import jax
    
    segment_counts = _get_segment_counts_for_grouping_variables(
        function=tax_transfer_function,
        data=data,
    )
    
    tax_function_with_static_args = functools.partial(
        tax_transfer_function,
        **segment_counts,
    )
    
    return jax.jit(tax_function_with_static_args)


def _get_segment_counts_for_grouping_variables(
    function: Callable,
    data: QualNameDataDict,
) -> dict[str, int]:
    segment_counts = {}

    for argname in inspect.signature(function).parameters:

        # In `create_agg_by_group_functions`, we re-define the argument `num_segments`
        # of the aggregation functions as `{group_level}_num_segments`, depending on the
        # grouping level.
        if argname.endswith("_num_segments"):

            grouping_level = argname.removesuffix("_num_segments")
            
            num_segments = _calculate_num_segments(data[grouping_level])

            segment_counts[argname] = num_segments
            
    return segment_counts


def _calculate_num_segments(group_level_ids: QualNameDataDict) -> int:
    """Calculate the number of segments for a grouping variable.
    
    This assumes that the grouping level IDs start at 0 and are contiguous.
    
    """
    return int(group_level_ids.max()) + 1

@hmgaudecker
Copy link
Collaborator Author

IMO we should be a bit more elaborate when jitting the tax_and_transfers_function

Completely agree! Will push that off for now though, will become part of the interface redesign. I opened #890 with your suggestion and some more thoughts.

Copy link
Collaborator

@MImmesberger MImmesberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't act like I understand every line of this 😅 (but the draft by Tim really helps!)

I think having _num_segments in the input data is quite confusing but if I understood it correctly this will change anyhow.

Just so I can follow: next steps (in a different PR) would be to make vectorization_strategy="vectorize" work with every TTIM function?

"exception_match": "The dtype of id columns must be integer.",
},
}
# We cannot even set up these fixtures in JAX.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this then? Cannot come up with an example why we'd ever need such operations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? First person in the household to retire, ... I'd leave it there in order to delimit the boundaries of what can be done.

@hmgaudecker
Copy link
Collaborator Author

Just so I can follow: next steps (in a different PR) would be to make vectorization_strategy="vectorize" work with every TTIM function?

"with most GETTSIM functions", yes!

@hmgaudecker hmgaudecker merged commit 7980a57 into collect-components-of-namespaces May 1, 2025
7 checks passed
@hmgaudecker hmgaudecker deleted the vectorize-mettsim branch May 1, 2025 04:57
hmgaudecker added a commit that referenced this pull request May 1, 2025
### What problem do you want to solve?

1. A brief shot at trying to do the same for GETTSIM what #879 did for METTSIM. 
2. Change the default of `vectorization_strategy` to `vectorize`, so that looking for `"loop"` allows to find all non-vectorisable cases quickly.

- Good news: `test_full_taxes_and_transfers` runs nicely when jitted!
- Bad news: Many tests fail because the logic in dividing up the taxes and transfers function is not elaborate enough. In the first pass, we are trying to build all ids. However, in many cases we are missing the required input data (made-up example: Einkommensteuer tests may require calculation of `sn_id`, but won't have all inputs required for `bg_id`).

Solving this should be doable (first set up the entire graph, then check which ids are needed), but the required functions are buried inside of `compute_taxes_and_transfers` so we should not waste time on that before implementing the new interface.
@timmens timmens mentioned this pull request May 16, 2025
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants