You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running a decently large project, and I've just attempted to run it on TPUs (v4-32 node), and I'm getting an interesting error, which seems to be pretty internal:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:2519) instructions.size() == 2 channel 30 is used for multiple host send/recv instructions
When enabling internal logs, I get (on one of the machines):
RuntimeError: Unable to initialize backend 'cpu': ALREADY_EXISTS: Config key cpu:local_topology/cpu/3 already exists.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/InsertKeyValue:
:{"created":"@1731217318.560649460","description":"Error received from peer ipv4:some ip:8476","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key cpu:local_topology/cpu/3 already exists.","grpc_status":6} (set JAX_PLATFORMS='' to automatically choose an available backend)
The main issue with this bug is that I cannot provide a script to reproduce it: the codebase is very large, and my attempts at reproducing with a small example just don't lead to the error. And, I cannot even give you a link to the repo, because it's confidential.
Here is how I do the partitioning (I think these days you should use jax.device_put, but I'm still using the old API cuz I'm used to it):
Equinox partition/combine just split PyTree into two/combine two PyTrees into one, such that one of the PyTrees contains all the leaves that satisfy a particular condition. This is convenient when trying to pass a partially static PyTree into a vmap, for example. I am 0.8 confident that the issue is not with Equinox.
I would appreciate any help :)
System info (python version, jaxlib version, accelerator, etc.)
Update: downgrading libtpu does not help, however, if I set export JAX_PLATFORMS='' I get a different internal error:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /home/arst/.local/lib/python3.11/site-packages/equinox/_jit.py:55:14: error: All components of the offset index in a gather op must either be a offset dimension or explicitly collapsed or explicitly batched; got len(slice_sizes)=4, output_slice_sizes=2, collapsed_slice_dims=1,2, operand_batching_dims=.:
...
@ 0x7fb4ce901fe4 (unknown)
@ 0x7fb67c825f8d xla::InitializeArgsAndCompile()
@ 0x7fb67c8266f6 xla::PjRtCApiClient::Compile()
@ 0x7fb68255566c xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7fb682550a51 xla::ifrt::PjRtCompiler::Compile()
@ 0x7fb681ce452e xla::PyClient::CompileIfrtProgram()
@ 0x7fb681ce532e xla::PyClient::Compile()
...
error: 'mhlo.while' op can't be translated to XLA HLO
....
And an insanely log traceback which I'm not going to attach to not dilute the point.
Description
I'm running a decently large project, and I've just attempted to run it on TPUs (v4-32 node), and I'm getting an interesting error, which seems to be pretty internal:
When enabling internal logs, I get (on one of the machines):
The main issue with this bug is that I cannot provide a script to reproduce it: the codebase is very large, and my attempts at reproducing with a small example just don't lead to the error. And, I cannot even give you a link to the repo, because it's confidential.
Here is how I do the partitioning (I think these days you should use jax.device_put, but I'm still using the old API cuz I'm used to it):
Equinox
partition
/combine
just split PyTree into two/combine two PyTrees into one, such that one of the PyTrees contains all the leaves that satisfy a particular condition. This is convenient when trying to pass a partially static PyTree into a vmap, for example. I am 0.8 confident that the issue is not with Equinox.I would appreciate any help :)
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: