Skip to content

Speed up computations with dynamic input shapes #14786

Answered by jakevdp
mar-muel asked this question in Q&A
Discussion options

You must be logged in to vote

One possibility: if you create a padded matrix with the smaller matrix in the upper left corner, the padded svd output will contain the smaller svd:

import jax.numpy as jnp
import numpy as np

x = jnp.array(np.random.rand(3, 4))
x_padded = jnp.zeros((10, 10)).at[:3, :4].set(x)

u, s, vt = jnp.linalg.svd(x)
u_padded, s_padded, vt_padded = jnp.linalg.svd(x_padded)

print(u)
print(u_padded[:3, :3])
print()
print(s)
print(s_padded[:3])
print()
print(vt)
print(jnp.vstack([vt_padded[:3, :4], vt_padded[-1:, :4]]))
[[-0.6057347  -0.5485058   0.57639146]
 [-0.6427847  -0.08961529 -0.7607872 ]
 [-0.46894974  0.8313307   0.29828787]]
[[-0.6057346  -0.54850584 -0.5763916 ]
 [-0.64278454 -0.0896153   …

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by mar-muel
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants