Best way to handle a lot of unstructured data in compiled function #14076
Replies: 1 comment
-
I suspect your best approach would probably be to encapsulate your unstructured data within custom a custom pytree (see Extending Pytrees for discussion and examples). This would let you use your structure directly within jit-compiled functions, and the pytree registration lets you specify how various attributes should be handled within JIT. As to your other specific questions:
No, JAX doesn't have any notion of on-disk memory-mapped arrays.
I'm not aware of any such libraries that work with JAX. Does a JAX-ic way to handle "variable shape" arrays exist? (I'm using list of arrays right now to deal with it)
|
Beta Was this translation helpful? Give feedback.
-
Hi,
This is a pretty broad question : I am in the process of writing a JAX based library for astrophysics, and in this field, there are many models which are built using tabulated data from old codes. Ofc these arrays are not structured on regulars grids, and I cannot find better ways than nested lists to handle all these various shapes. Building differentiable functions using these tabulated data require long compilation time and memory, and show poor performance when compared to other implementations.
Would you have some hints/tricks to handle these kinds of usage in a better way so that I can take advantage of the performance of JAX ? I would like to know, for instance :
I would be glad to go deeper in details if this is not clear enough or if it requires more specific approach,
And already thank you very much for reading this and helping me
Beta Was this translation helpful? Give feedback.
All reactions