Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU with sharding: grpc initialization failure #24821

Closed
knyazer opened this issue Nov 10, 2024 · 2 comments
Closed

TPU with sharding: grpc initialization failure #24821

knyazer opened this issue Nov 10, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@knyazer
Copy link

knyazer commented Nov 10, 2024

Description

I'm running a decently large project, and I've just attempted to run it on TPUs (v4-32 node), and I'm getting an interesting error, which seems to be pretty internal:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:2519) instructions.size() == 2 channel 30 is used for multiple host send/recv instructions

When enabling internal logs, I get (on one of the machines):

RuntimeError: Unable to initialize backend 'cpu': ALREADY_EXISTS: Config key cpu:local_topology/cpu/3 already exists.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/InsertKeyValue:
:{"created":"@1731217318.560649460","description":"Error received from peer ipv4:some ip:8476","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key cpu:local_topology/cpu/3 already exists.","grpc_status":6} (set JAX_PLATFORMS='' to automatically choose an available backend)

The main issue with this bug is that I cannot provide a script to reproduce it: the codebase is very large, and my attempts at reproducing with a small example just don't lead to the error. And, I cannot even give you a link to the repo, because it's confidential.

Here is how I do the partitioning (I think these days you should use jax.device_put, but I'm still using the old API cuz I'm used to it):

import equinox as eqx

dynamic, static = eqx.partition(image_batch, eqx.is_inexact_array)
dynamic = jax.lax.with_sharding_constraint(dynamic, sharding)
image_sharded = eqx.combine(dynamic, static)

Equinox partition/combine just split PyTree into two/combine two PyTrees into one, such that one of the PyTrees contains all the leaves that satisfy a particular condition. This is convenient when trying to pass a partially static PyTree into a vmap, for example. I am 0.8 confident that the issue is not with Equinox.

I would appreciate any help :)

System info (python version, jaxlib version, accelerator, etc.)

# python3.11 -m pip freeze
jax==0.4.35
jax-smi==1.0.4
jaxlib==0.4.35
jaxtyping==0.2.34

# python3.11 --version
3.11.10

# neofetch
OS: Ubuntu 20.04.4 LTS x86_64
Host: Google Compute Engine
Kernel: 5.13.0-1023-gcp
CPU: AMD EPYC 7B12 (240) @ 2.249GHz
@knyazer knyazer added the bug Something isn't working label Nov 10, 2024
@knyazer knyazer changed the title TPU with sharding: grpc initiazliation fails TPU with sharding: grpc initialization failure Nov 10, 2024
@knyazer
Copy link
Author

knyazer commented Nov 10, 2024

Update: downgrading libtpu does not help, however, if I set export JAX_PLATFORMS='' I get a different internal error:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /home/arst/.local/lib/python3.11/site-packages/equinox/_jit.py:55:14: error: All components of the offset index in a gather op must either be a offset dimension or explicitly collapsed or explicitly batched; got len(slice_sizes)=4, output_slice_sizes=2, collapsed_slice_dims=1,2, operand_batching_dims=.: 
...
   @     0x7fb4ce901fe4  (unknown)
    @     0x7fb67c825f8d  xla::InitializeArgsAndCompile()
    @     0x7fb67c8266f6  xla::PjRtCApiClient::Compile()
    @     0x7fb68255566c  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7fb682550a51  xla::ifrt::PjRtCompiler::Compile()
    @     0x7fb681ce452e  xla::PyClient::CompileIfrtProgram()
    @     0x7fb681ce532e  xla::PyClient::Compile()

...
 error: 'mhlo.while' op can't be translated to XLA HLO
....

And an insanely log traceback which I'm not going to attach to not dilute the point.

@knyazer
Copy link
Author

knyazer commented Nov 10, 2024

Okay, after a while it seems that the issue is indeed with Equinox; I will duplicate issue to Patrick

@knyazer knyazer closed this as not planned Won't fix, can't repro, duplicate, stale Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant