How to jit function with conditional inside of for loop, and is it possible to chain functions before using jax.grad()? #19650
-
Hello! I'm having trouble figuring out how to # Calculate first principal direction
@partial(jit, static_argnums = (1,2,3))
def find_pd(mat, nfeat, max_iter, tol, key):
# Split the PRNG key
key = rand.split(key, 1)[0]
# Initialize random vector
cur_v = rand.uniform(key, shape = (nfeat,))
# Find principal direction
for _ in jnp.arange(0, max_iter):
old_v = cur_v.copy()
# Update current vector with gradient
cur_v = cur_v + gradient(mat, cur_v)
cur_v = cur_v / norm(cur_v)
# Check if tolerance has been reached
pred = norm(cur_v - old_v) < tol
if pred: # <-- ERROR POPS UP HERE
break
# Return principal direction
return( cur_v ) I get this error:
I don't know how to rewrite the if-statement such that I can break out of the for-loop when tolerance is reached, I thought the My second question is about this set of functions: # Project a set of (column) vectors onto another vector
@jit
def orth_proj(mat, vec):
return( jnp.dot(mat, vec) )
# Variance of projected (column) vectors
@jit
def op_var(mat, vec):
return( orth_proj(mat, vec).var() )
# Gradient of projected variance w.r.t the input vector
gradient = jit(grad(op_var, argnums = 1)) Is there any way to do something like: # Gradient of projected variance w.r.t the input vector
gradient = jit(grad(jnp.var(orth_proj), argnums = 1)) where instead of making the intermediate function |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Question 1: if statements For your first question: the So having an (If you've any experience with metaprogramming in other languages, then that's basically what's going on here: we use Python at the metaprogramming level; the computation graph of JAX arrays is then the program itself. If you're not familiar with metaprogramming then don't worry about this analogy though!) You might also like points 1 and 2 in this guide for a little more on this. (And I think the core JAX docs also have something about this somewhere.) Question 2: breaking Now, how to break a for loop in JAX? The answer is to write things as a while loop -- In particular note that we're using Question 3: avoiding the wrapper function Unfortunately there's no way to avoid the wrapper function in general. One thing you can do is to use a little |
Beta Was this translation helpful? Give feedback.
Question 1: if statements
For your first question: the
if
statement causes an error is because in JAX, we use Python to define the computation graph, not to run the computation graph. Here, a "computation graph" is simply some graph of operations that happen on JAX arrays. (And it is this graph which is JIT-compiled for speed.)So having an
if
statement corresponds to having two different kinds of computation graph that we might want to define. Meanwhile ajax.lax.cond
corresponds to actually doing branching within the graph itself.(If you've any experience with metaprogramming in other languages, then that's basically what's going on here: we use Python at the metaprogramming level; the…