Large scale 1M x 1M covariance into Kronecker factorization #24677
-
Hello, I’m working with a very large covariance matrix S of size 1 million x 1 million, which I want to approximate as a Kronecker product, S ≈ U⊗V. My goal is to factorize S into smaller matrices U and V in this Kronecker structure. Currently, I have an implementation in C++ based on an Alternating Least Squares approach for Kronecker approximation, but I’m considering switching to JAX for scalable reason. This is the first time using JAX. From what I understand, JAX could offer several advantages, such as GPU acceleration for parallelizing large matrix operations and automatic differentiation and batch processing. My question is would JAX be able to handle this matrix size (1M x 1M) or even larger with GPU acceleration, and would it offer noticeable advantages over C++ in my case? Any advice on managing such large-scale Kronecker factorizations (e.g. parallelizing large matrix operations, automatic differentiation, and batch processing) with JAX would be greatly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The new FFI interface would allow you to call your c++ code Alternatively, have you considered using pyKeops to handle such large matrices in a smart way? https://www.kernel-operations.io/keops/index.html Although it states that bindings are only for Numpy and PyTorch, you can easily use https://github.com/rdyro/torch2jax as the go-between. This all supports auto-diff, vmap, etc. And the extra overhead is very minimal, in fact, due to the hyper-efficient way Keops works it can be faster, even with going through an intermediate step. Finally there is Pallas for JAX, which allows you to write Trition Kernel (Triton is a language that sits above CUDA), to call the procedure. |
Beta Was this translation helpful? Give feedback.
The new FFI interface would allow you to call your c++ code
https://jax.readthedocs.io/en/latest/ffi.html
Alternatively, have you considered using pyKeops to handle such large matrices in a smart way?
https://www.kernel-operations.io/keops/index.html
Although it states that bindings are only for Numpy and PyTorch, you can easily use https://github.com/rdyro/torch2jax as the go-between. This all supports auto-diff, vmap, etc. And the extra overhead is very minimal, in fact, due to the hyper-efficient way Keops works it can be faster, even with going through an intermediate step.
Finally there is Pallas for JAX, which allows you to write Trition Kernel (Triton is a language that sits above …