You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed that the only XlaOp which seems to take a random state is XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape);, but for the other ops like RngUniform and RngNormal it is not clear to me where the seed could be set. I thought it would be initialized either in the builder or when the pjrt executable is invoked but, at least based on my search through the code, there doesn't seem to be a method for that? What is the expected way for a higher-level API to manually set an RNG seed?
The text was updated successfully, but these errors were encountered:
In my understanding ops like RngUniform are deprecated and are not used by JAX, and the ops which require explicit passing and generation of random seeds are the ones used.
Thanks for the info! I'm looking at the headers in the xla_extension and the only op which takes a seed looks like RngBitGenerator, but I still cannot find where the RandomAlgorithm class/enum/typedef is declared. I guess this argument is determining which kind of distribution is to be sampled, But not knowing how to instantiate the argument makes it difficult to understand how this op interacts with the rest of the computation.
I noticed that the only XlaOp which seems to take a random state is
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape);
, but for the other ops likeRngUniform
andRngNormal
it is not clear to me where the seed could be set. I thought it would be initialized either in the builder or when the pjrt executable is invoked but, at least based on my search through the code, there doesn't seem to be a method for that? What is the expected way for a higher-level API to manually set an RNG seed?The text was updated successfully, but these errors were encountered: