Skip to content

Large scale 1M x 1M covariance into Kronecker factorization #24677

Closed Answered by adam-hartshorne
seongwoohan asked this question in Q&A
Discussion options

You must be logged in to vote

The new FFI interface would allow you to call your c++ code
https://jax.readthedocs.io/en/latest/ffi.html

Alternatively, have you considered using pyKeops to handle such large matrices in a smart way?

https://www.kernel-operations.io/keops/index.html

Although it states that bindings are only for Numpy and PyTorch, you can easily use https://github.com/rdyro/torch2jax as the go-between. This all supports auto-diff, vmap, etc. And the extra overhead is very minimal, in fact, due to the hyper-efficient way Keops works it can be faster, even with going through an intermediate step.

Finally there is Pallas for JAX, which allows you to write Trition Kernel (Triton is a language that sits above …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by seongwoohan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants