Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsafe cast u32 -> i32 in threefry_seed #27282

Closed
botev-openai opened this issue Mar 20, 2025 · 3 comments
Closed

Unsafe cast u32 -> i32 in threefry_seed #27282

botev-openai opened this issue Mar 20, 2025 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@botev-openai
Copy link

Description

I'm not sure why currently the treefry_seed code casts the seed to int32 (unless x64 is enabled, but almost everyone uses it without it). This leads to an unsafe cast from uint32 -> int32 here.

I've never seen in my career so far anyone to use negative seeds, so was just wondering is there a good reason for this, or is it just a really case that doesn't effect anyone

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.8 Jax version 0.5.1

@botev-openai botev-openai added the bug Something isn't working label Mar 20, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 20, 2025

Hi - I think this is working as intended. Casting from uint32 to int32 in this context is not unsafe: all the bits are preserved through the cast. Seeds can be specified as signed or unsigned 32-bit integers, and in either case there are $2^{32}$ possible unique seeds: the negative range of int32 maps to the range of uint32 larger than $2^{31}$. For example:

import jax
print(jax.random.key(-1))
print(jax.random.key(2 ** 32 - 1))
Array((), dtype=key<fry>) overlaying:
[         0 4294967295]
Array((), dtype=key<fry>) overlaying:
[         0 4294967295]

@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Mar 20, 2025
@botev-openai
Copy link
Author

Ok, the issue I had was when you call jax.random.key under jit and I now realize that this is just an issue with how Jax assume to cast a python int to int32 even if it overflows, e.g. passing 2 ** 32 - 1 as a raw int leads to:

OverflowError: An overflow was encountered while parsing an argument to a jitted computation, whose argument path is b.

And this was also happening when passing large seeds. I guess the fix is to manually cast this to int32.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 20, 2025

yes - Python ints are cast to int32 by default, and we raise an OverflowError if they're out of range. At one point we had value-dependent default dtype semantics (i.e. integers in the range 2 ** 31 <= i < 2 ** 32 would result in uint32 rather than int32) but this was removed because it caused subtle bugs and unexpected recompilations.

You could pass np.uint32(2 ** 32 - 1) if you want to use this actual value; for the sake of jax.random.key, np.int32(2 ** 32 - 1) will result in equivalent behavior despite the value being wrapped to negative, but this is not the case generally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants