-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unsafe cast u32 -> i32 in threefry_seed #27282
Comments
Hi - I think this is working as intended. Casting from import jax
print(jax.random.key(-1))
print(jax.random.key(2 ** 32 - 1))
|
Ok, the issue I had was when you call jax.random.key under
And this was also happening when passing large seeds. I guess the fix is to manually cast this to int32. |
yes - Python ints are cast to You could pass |
Description
I'm not sure why currently the treefry_seed code casts the seed to
int32
(unless x64 is enabled, but almost everyone uses it without it). This leads to an unsafe cast fromuint32 -> int32
here.I've never seen in my career so far anyone to use negative seeds, so was just wondering is there a good reason for this, or is it just a really case that doesn't effect anyone
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.8 Jax version 0.5.1
The text was updated successfully, but these errors were encountered: