Skip to content

Conversation

@mj023
Copy link
Collaborator

@mj023 mj023 commented May 12, 2025

What problem do you want to solve?

Make the ID creation functions in Gettsim and Mettsim jittable.

Todo

  • Make Gettsim ID's jittable
  • Make METTSIM ID's jittable.
  • Change tests to the new ID values

@mj023
Copy link
Collaborator Author

mj023 commented May 13, 2025

Some of the tests are failing, because I still have to adjust the expected values to the new IDs, but others, like test_groupings[groupings/2023-01-01/mehrere_haushalte_durchmischt.yaml], are failing, because they do not use ordered, continuously numbered p_id values. I can probably come up with some kind of workaround, but it will be computationally expensive and requiring them to be sorted and continuously numbered does not seem unreasonable to me, what do you think @hmgaudecker ?

@hmgaudecker
Copy link
Collaborator

Thanks! We cannot ask that of users (think of doing some query on the data and then remembering that you have to fix p_id and all its derivatives. But I'd be open to doing that ourselves -- sounds like it's computationally worth it. The ugly thing would be that we'll have to remember that both are around and users will find it difficult to retrieve offending rows in case of errors. So it may be worth doing that inside the functions / just adding it to the graph? ehe_id would then depend on sanitized_p_id and familie_sanitized_p_id_ehepartner or so?

@codecov
Copy link

codecov bot commented May 14, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@mj023
Copy link
Collaborator Author

mj023 commented May 14, 2025

It was only problematic for the fg_id function, so I decided to just lookup the correct indices for familie__p_id_elternteil. This is O(n²), if it turns out to be problematic (when the number of p_ids is huge), we can switch to a hash table, but it's probably overkill for now. The only restriction is now that p_id values should not be higher than 2³² or we get an overflow with int64.

In 4fbb227 I encountered some test where parents were married to their children, I assumed that was a mistake and fixed them.

@mj023 mj023 requested a review from hmgaudecker May 14, 2025 16:00
@hmgaudecker
Copy link
Collaborator

Thanks a lot!!

It was only problematic for the fg_id function, so I decided to just lookup the correct indices for familie__p_id_elternteil. This is O(n²), if it turns out to be problematic (when the number of p_ids is huge), we can switch to a hash table, but it's probably overkill for now. The only restriction is now that p_id values should not be higher than 2³² or we get an overflow with int64.

Great, thanks! And agreed.

However, we should check these things upfront. Should we add this to the group_creation_function decorator, which can then set up a mechanism to perform that check? If there is no performance penalty to large num_segments, we could even use these maximum values for that and get rid of the computation-before-jitting-stuff, which in its current naive form is causing problems when we only need a subset of group ids (lots of those GETTSIM test failures, but also the METTSIM tess in #897). What do you think?

In 4fbb227 I encountered some test where parents were married to their children, I assumed that was a mistake and fixed them.

For sure. @MImmesberger, do we still have that open issue for sanity checks? If so, could you please add this to the list?

@hmgaudecker hmgaudecker changed the base branch from collect-components-of-namespaces to collect-unify-parsing-of-params May 16, 2025 06:24
@hmgaudecker
Copy link
Collaborator

It was only problematic for the fg_id function, so I decided to just lookup the correct indices for familie__p_id_elternteil. This is O(n²), if it turns out to be problematic (when the number of p_ids is huge), we can switch to a hash table, but it's probably overkill for now. The only restriction is now that p_id values should not be higher than 2³² or we get an overflow with int64.

Great, thanks! And agreed.

However, we should check these things upfront. Should we add this to the group_creation_function decorator, which can then set up a mechanism to perform that check? If there is no performance penalty to large num_segments, we could even use these maximum values for that and get rid of the computation-before-jitting-stuff, which in its current naive form is causing problems when we only need a subset of group ids (lots of those GETTSIM test failures, but also the METTSIM tess in #897). What do you think?

Apologies, this was a bit cryptic I guess. I changed the base to the current "HEAD" so the issue becomes apparent. In the tests, we are currently pre-computing the ids so we have the num_segments parameter. Search for def execute_tests and then see what happens under if IS_JAX_INSTALLED:. This is done for all groupings, which is problematic when we only need a subset of targets that don't require all groupings. In those cases, the input data will often not include the dependencies of the group_creation_functions. Happens all the time in GETTSIM, now also in a couple of METTSIM test cases.

It would be best if we could get rid of this altogether. But we need the num_segments.

How about changing the group ids upon return, so that they are consecutive numbers? We could then simply pass the number of observations as num_segments. There might be a small performance penalty, but we could do so in a way that users may override it.

Other ideas welcome!

Copy link
Collaborator

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Just a few small suggestions.

@mj023
Copy link
Collaborator Author

mj023 commented May 16, 2025

Thanks for the clarification, I understand the problem now.
With number of observations you mean the number of unique p_ids or number of fg_ids? Number of p_ids might work, if the num_segments argument is allowed to be bigger than the actual number of segments in the data, I will try that out first. Otherwise we might be able to avoid num_segments altogether by writing our own segment_max, this would take a lot more time though.

@hmgaudecker
Copy link
Collaborator

The thing is timing. The best value would be the maximum of the group ids. But we only know that once we have computed them.

To avoid that, I'd use the length of the data, which is an upper bound when redefining the contents as suggested above.

@mj023
Copy link
Collaborator Author

mj023 commented May 16, 2025

Your idea is probably the best then, I did not realize that segment_sum was implemented so that it returns an array the size of the highest segment_id. I will change all ID's to be consecutive numbers starting at 0. Using the length of the inputs as num_segments will be fine, I don't think using a little bit too much memory there will be a problem.

@hmgaudecker
Copy link
Collaborator

I did not realize that segment_sum was implemented so that it returns an array the size of the highest segment_id.

Me neither! Did you, @timmens ?

I will change all ID's to be consecutive numbers starting at 0. Using the length of the inputs as num_segments will be fine, I don't think using a little bit too much memory there will be a problem.

Excellent!

@timmens
Copy link
Collaborator

timmens commented May 16, 2025

This is my last update.

@hmgaudecker
Copy link
Collaborator

This is my last update.

Sure, this was just meant FYI!

@hmgaudecker
Copy link
Collaborator

hmgaudecker commented May 19, 2025

Thanks!

Changing the two num_segments=... in test_aggregation_functions to num_segments=len(group_id) made me realise that this solution is great when we calculate stuff ourselves, but it is not helpful when users pass [group_id].

I think there are two ways out:

  • creating internal [group]_id columns, which basically call reorder_ids on the existing ones and sticking with one global num_segments=len(p_id)
  • continuing with [group_id]_num_segments and making it depend on whether the [group_id] is user-supplied ([group_id].max() + 1) or calculated by ttsim (len(group_id)).

What do you think? Other ideas welcome, ofc!

@mj023
Copy link
Collaborator Author

mj023 commented May 20, 2025

I think the first solution is good. Might make it easier for people to write their own ID functions that depend on other ID's, it's much easier if you know they are consecutively numbered starting at 0. It's also probably just better if num_segments does not get too big.

I tried removing the precomputed ID's from the tests, but for some reason I still get many KeyErrors, but strangely only in GETTSIM not METTSIM. I forgot to set Jit to true when running the tests.

There are now nearly all tests on GETTSIM failing, on the first look unrelated to ID's, should I try to fix this or should this be tackled in a different PR?

@hmgaudecker
Copy link
Collaborator

Excellent! Apologies I didn't manage to look much into it today. Just had a brief look at the failures, many indeed seem to be because we ares using some constructs that Jax cannot handle. Will improve with #908 / #914, though I could use a hint on what to do to make lookup work (altersgrenze_gestaffelt[geburtsjahr] with altersgrenze_gestaffelt a dict[int, float] and geburtsjahr an array of ints.

This as an aside, some test failures do seem related. This one, for example:

FAILED src/_gettsim_tests/test_policy.py::test_policy[vorrangpr\xfcfungen/2009-01-01/hh_id_3.yaml] - KeyError: 'hh_id_num_segments'

@mj023
Copy link
Collaborator Author

mj023 commented May 21, 2025

In this specific case you might be able to convert the dictionary to an array, use a base year e.g. 1900 and then use geburtsjahr like an index by subtracting the base year. Otherwise you probably have to loop through all dictionary entries, or if you can somehow get a sorted array you can use jax.numpy.searchsorted(). Quite annoying, that there are no lookup tables in JAX.

I now updated execute_test to actually jit when Jax is installed. The missing num_segments was only because jit was always set to false. But most parts of GETTSIM are not yet Jit compatible, so the test are failing.

@hmgaudecker
Copy link
Collaborator

hmgaudecker commented May 22, 2025

Looks great and I think it can be merged once the comments in the code are addressed. We'll probably want to postpone this mechanism here:

I think there are two ways out:

  • creating internal [group]_id columns, which basically call reorder_ids on the existing ones and sticking with one global num_segments=len(p_id)

I think the first solution is good. Might make it easier for people to write their own ID functions that depend on other ID's, it's much easier if you know they are consecutively numbered starting at 0. It's also probably just better if num_segments does not get too big.

until we are in the process of redesigning the interface (early June).

However, please revert the changes to the test cases that changed id's to 0-based consecutive numbers, so that we don't forget. Please also open an issue outlining the solution strategy.

@hmgaudecker hmgaudecker changed the base branch from collect-unify-parsing-of-params to move-gettsim-params-files May 22, 2025 10:36
@mj023
Copy link
Collaborator Author

mj023 commented May 22, 2025

I think I addressed all your comments and I also reverted the test cases, that test the ID's. Was this correct, or was I supposed to revert the reordering and the changes to the test execution too?

@hmgaudecker
Copy link
Collaborator

Perfect!

We just want to support these cases in Jax, so we should keep them around. Test execution is perfect as is.

@hmgaudecker hmgaudecker merged commit c7d5cf5 into move-gettsim-params-files May 22, 2025
6 of 7 checks passed
@hmgaudecker hmgaudecker deleted the groupings_jax branch May 22, 2025 20:38
hmgaudecker added a commit that referenced this pull request May 22, 2025
@hmgaudecker hmgaudecker mentioned this pull request Jun 12, 2025
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants