changes coming to JAX's RNG – auto-parallelizable by default #18480
froystig
announced in
Announcements
Replies: 2 comments 2 replies
-
There is a typo in what is happening: Thanks for the heads up! |
Beta Was this translation helpful? Give feedback.
1 reply
This comment was marked as off-topic.
This comment was marked as off-topic.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Partitionable Threefry RNG upgrade
You can stop reading now if:
jax_threefry_partitionable
toTrue
; orunsafe_rbg
JAX RNG algorithm.What is happening?
JAX's default RNG algorithm ("Threefry") is changing under the hood, to make random number generation efficiently auto-parallelizable ("partitionable"). This makes random numbers faster with multiple devices.
The current behavior corresponds to setting the config field
jax_threefry_partitionable
toFalse
, its current default value. The new behavior corresponds to settingjax_threefry_partitionable
toTrue
, its future default value.Some time soon, we will set out to change the default value of the configuration flag
jax_threefry_partitionable
fromFalse
toTrue
. This will cause a one-time change in the random values generated from a given RNG key. See code below.Want to ensure this goes smoothly? Try flipping
jax_threefry_partitionable
toTrue
today to detect any issues in your code ahead of the upgrade. See code below.What do we mean by a one-time change to random values? Here is today's default behavior:
And here is what will happen when the default setting changes soon:
Same key, different generated value when
JAX_THREEFRY_PARTITIONABLE=False
versusJAX_THREEFRY_PARTITIONABLE=True
.JAX's RNG will remain deterministic. We try to rarely change the output of JAX's pseudorandom functions. That said, our API policy promises stability in distribution, not in value. This particular change is broad-based, and is the first of its kind in a long time, so we're drawing extra attention to it.
This change will break code that depends on specific RNG keys generating specific RNG values. Common examples include reference tests, high-variance randomized tests, or a machine learning experiment that depends on random values (e.g. model initialization) that you want to reproduce exactly.
Non-partitionable Threefry will later be deprecated. That is, at some point after we've upgraded the default value of the
jax_threefry_partitionable
setting, we will deprecate the flag entirely.Who is affected?
JAX supports several RNG schemes. Its current three built-in modes are called
threefry2x32
,rbg
andunsafe_rbg
. These modes can be set using theimpl
argument tojax.random.PRNGKey
andjax.random.key
, or with the configuration flagjax_default_prng_impl
.You are affected if you use either
threefry2x32
orrbg
. Specifically:threefry2x32
.rbg
mode. That is, the keys generated from a specific key, usingjax.random.split
orjax.random.fold_in
, will be different than before. In turn, random values generated from such derived keys will change.Opting in early
To try the upgrade now, you can set the configuration flag
jax_threefry_partitionable
toTrue
in your code explicitly. This can be done with the environment variableJAX_THREEFRY_PARTITIONABLE=True
, the command-line flag--jax_threefry_partitionable=True
, or programmatically, usingjax.config.update
or thejax.threefry_partitionable
context manager. For example:Beta Was this translation helpful? Give feedback.
All reactions