Skip to content

Which XlaOps are used for random number generation in JAX? #22708

Answered by jaro-sevcik
Ebanflo42 asked this question in Q&A
Discussion options

You must be logged in to vote

The random generators and distributions are implement completely in JAX, see the jax/_src/random.py file. For example, see here for the code for jax.random.normal.

If you are interested in seeing the XLA code for random number generation, you can dump the MLIR code resulting from JIT compilation of the particular random function. Example:

import jax

key = jax.random.key(0)
print(jax.jit(jax.random.normal).lower(key).as_text())

This will produce something like this (note that the body of the _normal_real function indeed corresponds to the JAX code for _normal_real):

module @jit_normal attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0:…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Ebanflo42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants