Skip to content

How to jit function with conditional inside of for loop, and is it possible to chain functions before using jax.grad()? #19650

Discussion options

You must be logged in to vote

Question 1: if statements

For your first question: the if statement causes an error is because in JAX, we use Python to define the computation graph, not to run the computation graph. Here, a "computation graph" is simply some graph of operations that happen on JAX arrays. (And it is this graph which is JIT-compiled for speed.)

So having an if statement corresponds to having two different kinds of computation graph that we might want to define. Meanwhile a jax.lax.cond corresponds to actually doing branching within the graph itself.

(If you've any experience with metaprogramming in other languages, then that's basically what's going on here: we use Python at the metaprogramming level; the…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@MyeloidSol
Comment options

@patrick-kidger
Comment options

@MyeloidSol
Comment options

Answer selected by MyeloidSol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants