Skip to content

Interface for using Jax / jitting #890

@hmgaudecker

Description

@hmgaudecker

#879 adds support for jitting most of METTSIM's tax-transfer function; support for large parts of GETTSIM should not be far away, either.

The PR deliberately left some rough edges to try things out. Most importantly, we'll need a clean distinction between what can be jitted an what cannot be. Currently, we just take out the group_id functions in the tests; probably we'll need a keyword, maybe can use vectorization_strategy for that.

Beyond that, some details of the code are very imperative and littered with complex if-statements. Here's a suggestion by @timmens re compute_taxes_and_transfers:

...

    if jit:
        if not IS_JAX_INSTALLED:
            raise ImportError(
                "JAX is not installed. Please install JAX to use JIT compilation."
            )
        tax_transfer_function = _jit_tax_and_transfer_function(
            tax_transfer_function,
            data=data,
        )
        
...
        
def _jit_tax_and_transfer_function(
    tax_transfer_function: Callable,
    data: QualNameDataDict,
) -> Callable:
    """Create a JIT-compiled version of the tax transfer function.
    
    This function prepares a tax transfer function for JIT compilation by:

    1. Extracting the number of segments for each grouping variable. This is necessary
       because JAX requires information about the number of segments for each grouping
       variable, in order to be able to compile the function.
    2. Partialing this information into the function. This also marks these arguments
       as static arguments for JAX.
    3. Applying JAX JIT compilation.
    
    Parameters
    ----------
    tax_transfer_function : callable
        The tax transfer function to be JIT-compiled
    data : dict
        The flattened user-provided data.
        
    Returns
    -------
    callable
        The JIT-compiled tax transfer function
    """
    import jax
    
    segment_counts = _get_segment_counts_for_grouping_variables(
        function=tax_transfer_function,
        data=data,
    )
    
    tax_function_with_static_args = functools.partial(
        tax_transfer_function,
        **segment_counts,
    )
    
    return jax.jit(tax_function_with_static_args)


def _get_segment_counts_for_grouping_variables(
    function: Callable,
    data: QualNameDataDict,
) -> dict[str, int]:
    segment_counts = {}

    for argname in inspect.signature(function).parameters:

        # In `create_agg_by_group_functions`, we re-define the argument `num_segments`
        # of the aggregation functions as `{group_level}_num_segments`, depending on the
        # grouping level.
        if argname.endswith("_num_segments"):

            grouping_level = argname.removesuffix("_num_segments")
            
            num_segments = _calculate_num_segments(data[grouping_level])

            segment_counts[argname] = num_segments
            
    return segment_counts


def _calculate_num_segments(group_level_ids: QualNameDataDict) -> int:
    """Calculate the number of segments for a grouping variable.
    
    This assumes that the grouping level IDs start at 0 and are contiguous.
    
    """
    return int(group_level_ids.max()) + 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions