-
Hi team! I know I'm not the first one to ask questions about dynamic input shapes. But maybe someone will have some further ideas how to optimize this (and maybe it will help someone else with a similar problem). I'm trying to run SVD on a bunch of matrices with different shapes. Here are the shapes and counts of my input data: Counter({(768, 768): 48, (1280, 1280): 48, (320, 320): 40, (640, 640): 40, (768, 3072): 12, (1280, 768): 12, (3072, 768): 12, (320, 768): 10, (640, 768): 10, (1280, 5120): 6, (10240, 1280): 6, (320, 1280): 5, (640, 2560): 5, (2560, 320): 5, (5120, 640): 5}) The naive approach of going through these matrices and computing the SVD with the As you know, the issue here is that for each shape jax re-compiles the SVD function. When using jit and saving the compilation cache using from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("jax_cache") I bring it down to 38s. When sorting/grouping the matrices by shape, I get a further speed-up and end up with 33s. I've further tried to use vmap to vectorize, but this leads to a slow-down (36s). I could use pmap, but that leads to issues with the number of items not being divisible by the number of devices. As of 2023, are there other ways to speed up this computation? E.g., by using Any help is welcome 🙏 ! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
One possibility: if you create a padded matrix with the smaller matrix in the upper left corner, the padded svd output will contain the smaller svd: import jax.numpy as jnp
import numpy as np
x = jnp.array(np.random.rand(3, 4))
x_padded = jnp.zeros((10, 10)).at[:3, :4].set(x)
u, s, vt = jnp.linalg.svd(x)
u_padded, s_padded, vt_padded = jnp.linalg.svd(x_padded)
print(u)
print(u_padded[:3, :3])
print()
print(s)
print(s_padded[:3])
print()
print(vt)
print(jnp.vstack([vt_padded[:3, :4], vt_padded[-1:, :4]]))
This would allow you to compute the outputs using only a single-sized svd call. |
Beta Was this translation helpful? Give feedback.
-
Hi Jake - That's great, I didn't know about this trick to pad matrices for SVD. Comparing the padded and unpadded results with However, no matter what padding strategy (or bucketing) I apply, I don't get a speedup as there's additional overhead from running the large padded matrices through SVD. Thanks for your help anyway! PS. For the third output I used a slightly different transformation than you. Note this is not based on any algebraic derivation (but seemed to work 🤷 ) x0, x1 = x.shape # shape before padding
pad = 10
x_padded = jnp.zeros((pad, pad)).at[:x0, :x1].set(x)
U, S, Vh = jnp.linalg.svd(x_padded)
U = U[:x0, :x0]
S = S[:x0]
Vh = Vh.at[x1-1, :x1].set(-1*Vh[x1-1, :x1])[:x1, :x1] |
Beta Was this translation helpful? Give feedback.
One possibility: if you create a padded matrix with the smaller matrix in the upper left corner, the padded svd output will contain the smaller svd: