Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0 #1775

Closed
blizda opened this issue Mar 18, 2020 · 26 comments
Closed

RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0 #1775

blizda opened this issue Mar 18, 2020 · 26 comments
Labels
stale Has not had recent activity

Comments

@blizda
Copy link

blizda commented Mar 18, 2020

🐛 Bug

I trying to train reformer-pytorch on TPU, and after start training, i get error

To Reproduce

Steps to reproduce the behavior:

  1. Clone reformer-pytorch from https://github.com/blizda/reformer-pytorch
  2. Install reformer-pytorch
  3. Install transformers from pip
  4. run simple_test.py from https://github.com/blizda/reformer-pytorch

2020-03-18 13:34:08.627671: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] StackTrace:
2020-03-18 13:34:08.627680: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** Begin stack trace ***
2020-03-18 13:34:08.627687: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] tensorflow::CurrentStackTraceabi:cxx11
2020-03-18 13:34:08.627695: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] xla::util::ReportComputationError(tensorflow::Status const&, absl::Span<xla::XlaComputation const* const>, absl::Span<xla::Shape const* const>)
2020-03-18 13:34:08.627703: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] xla::XrtComputationClient::CheckCompileStatus(tensorflow::Status const&, std::vector<xla::ComputationClient::CompileInstance, std::allocatorxla::ComputationClient::CompileInstance > const&, xla::XrtComputationClient::SessionWork const&)
2020-03-18 13:34:08.627713: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627721: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627728: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627734: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627743: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627750: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] clone
2020-03-18 13:34:08.627758: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** End stack trace ***
2020-03-18 13:34:08.627769: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627778: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Status: Unimplemented: From /job:tpu_worker/replica:0/task:0:
2020-03-18 13:34:08.627786: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %sort.18225 = (s64[60,164352]{1,0}, s32[60,164352]{1,0}) sort(s64[60,164352]{1,0} %add.18217, s32[60,164352]{1,0} %iota.18218), dimensions={1}, to_apply=%compare-less-than.18219
2020-03-18 13:34:08.627810: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] [[{{node XRTCompile}}]]
Traceback (most recent call last):
File "simple_test.py", line 47, in
output = model(inputs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 671, in forward
x = self.reformer(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 634, in forward
x = self.layers(x, arg_route = arg_route, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 151, in forward
return _ReversibleFunction.apply(x, blocks, block_kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 112, in forward
x = block(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 53, in forward
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 27, in forward
return self.net(*args, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 106, in forward
return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 106, in
return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 95, in forward
return self.fn(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 534, in forward
return self.net(x)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/functional.py", line 1593, in linear
output = input.matmul(weight.t())
RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0:
While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %sort.18225 = (s64[60,164352]{1,0}, s32[60,164352]{1,0}) sort(s64[60,164352]{1,0} %add.18217, s32[60,164352]{1,0} %iota.18218), dimensions={1}, to_apply=%compare-less-than.18219
[[{{node XRTCompile}}]]

Environment

  • reproducible on XLA backend [CPU/TPU]:
    v3-8, pytorch-nightly
  • torch_xla version:
    torch xla nightly(0.8+c77641f)
@dlibenzi
Copy link
Collaborator

Can you try this?

os.environ['XLA_USE_32BIT_LONG'] = '1'

@blizda
Copy link
Author

blizda commented Mar 18, 2020

Can you try this?

os.environ['XLA_USE_32BIT_LONG'] = '1'

After i set this parameter, script produce this output:

2020-03-18 14:19:04.490633: I 6407 torch_xla/csrc/tensor_util.cpp:36] Using 32bit integers for kLong values

and stuck

@dlibenzi
Copy link
Collaborator

Stuck ... wait. It is probably compiling.

@blizda
Copy link
Author

blizda commented Mar 18, 2020

Stuck ... wait. It is probably compiling.

I think, it still stuck. Script don't produce output more then half hour.
When this flag wasn't set script showing output during compilation.

@dlibenzi
Copy link
Collaborator

Ohh, I remember that model: #1657

Are you trying nightly?
Just yesterday the fix for continuous compile went in.
Also, I think that user did other changes to avoid a lot of tensor->str conversions happening.

@dlibenzi
Copy link
Collaborator

This was the fix: #1770

@blizda
Copy link
Author

blizda commented Mar 18, 2020

This was the fix: #1770

I update my conda environment on google cloud using update_nightly_torch_wheels.sh, but get the same error

@blizda
Copy link
Author

blizda commented Mar 18, 2020

torch version: torch-1.5.0a0+a3de359
torch-xla version: torch-xla 0.8+2821d27

@blizda
Copy link
Author

blizda commented Mar 18, 2020

And when I setting XLA_USE_32BIT_LONG=1 looking like script still just sticking

@blizda
Copy link
Author

blizda commented Mar 18, 2020

Model compiling more then hour, and I still don't receive output

@dlibenzi
Copy link
Collaborator

Can you please select "nightly" version (and "pytorch-nightly" on the TPU VM)?
If you are using Colab, select "nightly" as VERSION in the version switcher.

Also, IIRC, the user at #1657 did some changes to avoid the model to be issuing hundreds of tensor->str conversions.

@blizda
Copy link
Author

blizda commented Mar 18, 2020

Can you please select "nightly" version (and "pytorch-nightly" on the TPU VM)?
If you are using Colab, select "nightly" as VERSION in the version switcher.

Also, IIRC, the user at #1657 did some changes to avoid the model to be issuing hundreds of tensor->str conversions.

Yes, I alredy set pytorch-nightly on tpu instanse. Reformer model, in general compiles and run, but only if i set really small vocabulary and sequence length. If I set normal parametrs model compeling forever.

@dlibenzi
Copy link
Collaborator

I am not sure what the other user did, but can you try replacing this line:

https://github.com/blizda/reformer-pytorch/blob/f7187d887c3522124d265dd11e4bb42b2f2906c6/simple_test.py#L48

With:

print(output.cpu())

If that is still slow, maybe try posting a debug-run tarbal:

https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#using-debug_runpy-to-collect-debug-information

Run and let run until complete, or CTRL^C after 5 minutes or so.

@blizda
Copy link
Author

blizda commented Mar 19, 2020

Even the model with this config

import os
os.environ['XLA_USE_32BIT_LONG'] = '1'

# imports pytorch
import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

import torch
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import reformer_pytorch
from dotdict import dotdict
from tqdm import tqdm

config = dotdict(
    num_tokens = 100, # vocab size
    dim = 128,  # embed dim
    depth = 3, # layers
    max_seq_len = 256, # ctx len
    heads = 4,
    lsh_dropout = 0.1,
    emb_dim = 128,        # embedding factorization for further memory savings
    causal = False,        # auto-regressive or not
    bucket_size = 4,     # average size of qk per bucket, 64 was recommended in paper
    n_hashes = 4,         # 4 is permissible per author, 8 is the best but slower
    ff_chunks = 8,      # number of chunks for feedforward layer, make higher if there are memory issues
    weight_tie = False,   # tie parameters of each layer for no memory per additional depth
    attn_chunks = 2,        # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    num_mem_kv = 8,       # persistent learned memory key values, from all-attention paper
    twin_attention = False, # both branches of the reversible network will be attention
    use_full_attn = False,  # use full self attention, for comparison
    full_attn_thres = 1, # use full attention if context length is less than set value
    use_scale_norm = True,  # use scale norm from 'Transformers without tears' paper
    axial_position_shape = (4, 64),
    axial_position_dims = (64, 64)
)
batch_size = 2
model = reformer_pytorch.ReformerLM(**config).to(dev)
opt = optim.Adam(model.parameters(), lr=3e-5)

for i in range(3):
  opt.zero_grad()
  x = torch.randint(0, 100, (1, 256)).to(dev)
  print(x)
  y = F.log_softmax(model(x), dim=-1)
  print(y)

  loss = F.nll_loss(y.reshape(y.shape[0] * y.shape[1], -1)[:-1], x.reshape(x.shape[0] * x.shape[1])[1:])
  loss.backward()
  print('LOSS ======================', loss)

  xm.optimizer_step(opt, barrier=True)

pass one step in 3 minutes

@blizda
Copy link
Author

blizda commented Mar 19, 2020

I am not sure what the other user did, but can you try replacing this line:

https://github.com/blizda/reformer-pytorch/blob/f7187d887c3522124d265dd11e4bb42b2f2906c6/simple_test.py#L48

With:

print(output.cpu())

If that is still slow, maybe try posting a debug-run tarbal:

https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#using-debug_runpy-to-collect-debug-information

Run and let run until complete, or CTRL^C after 5 minutes or so.

print(output.cpu())

doesn't work
I try scripts with debug-run. By link, you can find run with model with model post above.
Regular model run with debug by link. This model was compiling about 40 minutes, after which I interrupted it.

@dlibenzi
Copy link
Collaborator

Can you try to remove print(x) and print(y) from the small model above, and run 10 steps?
To see how it changes.
When printing big tensors, pytorch issues A LOT of ops.
For 3 steps you have done, there have been 900+ local-scalar-dense (item() calls) and hundreds of compiles.
Why print(output.cpu()) did not work?

@blizda
Copy link
Author

blizda commented Mar 19, 2020

Can you try to remove print(x) and print(y) from the small model above, and run 10 steps?
To see how it changes.
When printing big tensors, pytorch issues A LOT of ops.
For 3 steps you have done, there have been 900+ local-scalar-dense (item() calls) and hundreds of compiles.
Why print(output.cpu()) did not work?

output.cpu() did not work, because model hang earlier, then print() occur

@blizda
Copy link
Author

blizda commented Mar 19, 2020

Can you try to remove print(x) and print(y) from the small model above, and run 10 steps?
To see how it changes.
When printing big tensors, pytorch issues A LOT of ops.
For 3 steps you have done, there have been 900+ local-scalar-dense (item() calls) and hundreds of compiles.
Why print(output.cpu()) did not work?

Small model 10 steps run without print

@dlibenzi
Copy link
Collaborator

I honestly don't know at this point.
Looking at the graph report, it seems there is quite a bit of randomness, which seems to be creating different computation graphs at every step.
You can try contacting the user at the issue I linked above, as IIRC he seemed to have gotten further.

@blizda
Copy link
Author

blizda commented Mar 19, 2020

I honestly don't know at this point.
Looking at the graph report, it seems there is quite a bit of randomness, which seems to be creating different computation graphs at every step.
You can try contacting the user at the issue I linked above, as IIRC he seemed to have gotten further.

Another problem - very long graph compilation of a graph with a normal size model

@blizda
Copy link
Author

blizda commented Mar 19, 2020

Honestly, I have never waited for the end of compilation. My patience lasted for an hour at most

@dlibenzi
Copy link
Collaborator

I was able to repro this.
The issue is that the model is so big graph trimming happens.
You can:

export TRIM_GRAPH_SIZE=1000000

But the issue is, the model is huge in number of ops, and memory.
Only the root tuple (the graph output) is 100GB:

ROOT %tuple.588596 = (f32[10,40960,30522]{2,1,0}) tuple(f32[10,40960,30522]{2,1,0} %reshape.588595)

Here a test Colab I used:

https://colab.research.google.com/drive/1yAygNnajc2cskFzrL5PIkXtegdtnIgJ9

@blizda
Copy link
Author

blizda commented Mar 23, 2020

I was able to repro this.
The issue is that the model is so big graph trimming happens.
You can:

export TRIM_GRAPH_SIZE=1000000

But the issue is, the model is huge in number of ops, and memory.
Only the root tuple (the graph output) is 100GB:

ROOT %tuple.588596 = (f32[10,40960,30522]{2,1,0}) tuple(f32[10,40960,30522]{2,1,0} %reshape.588595)

Here a test Colab I used:

https://colab.research.google.com/drive/1yAygNnajc2cskFzrL5PIkXtegdtnIgJ9

It is really strange, on GPU model graph only takes about 3GB

@blizda
Copy link
Author

blizda commented Mar 23, 2020

Ok, one more problem with regular size model. Train sticking on

xm.optimizer_step(opt, barrier=True)

code for reproducing

import os
os.environ['XLA_USE_32BIT_LONG'] = '1'
os.environ['TRIM_GRAPH_SIZE'] = '1000000'

# imports pytorch
import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

import torch
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import reformer_pytorch
from dotdict import dotdict
from tqdm import tqdm

config = dotdict(
    num_tokens=30522,
    dim=768,
    depth=12,
    max_seq_len=40960,
    heads=12,
    lsh_dropout=0.1,
    emb_dim=256,  # embedding factorization for further memory savings
    causal=False,  # auto-regressive or not
    bucket_size=64,  # average size of qk per bucket, 64 was recommended in paper
    n_hashes=4,  # 4 is permissible per author, 8 is the best but slower
    ff_chunks=200,  # number of chunks for feedforward layer, make higher if there are memory issues
    weight_tie=False,  # tie parameters of each layer for no memory per additional depth
    attn_chunks=8,  # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    num_mem_kv=128,  # persistent learned memory key values, from all-attention paper
    twin_attention=False,  # both branches of the reversible network will be attention
    use_full_attn=False,  # use full self attention, for comparison
    full_attn_thres=128,  # use full attention if context length is less than set value
    use_scale_norm=True,  # use scale norm from 'Transformers without tears' paper
    axial_position_emb=True,
    axial_position_shape=(640, 64),
    axial_position_dims=(128, 128)
)
model = reformer_pytorch.ReformerLM(**config).to(dev)
opt = optim.Adam(model.parameters(), lr=3e-5)

for i in range(3):
  opt.zero_grad()
  x = torch.randint(0, 30522, (96, 40960)).to(dev)
  y = F.log_softmax(model(x), dim=-1)
  print("output calc")

  loss = F.nll_loss(y.reshape(y.shape[0] * y.shape[1], -1)[:-1], x.reshape(x.shape[0] * x.shape[1])[1:])
  loss.backward()
  print("loss backward")

  xm.optimizer_step(opt, barrier=True)
  print("optim step") 

@dlibenzi
Copy link
Collaborator

I think either the model makes PyTorch/GPU assumptions about how the code executes, or it is just too big.
Within the model code, tensors with shape f32[10,40960,30522] are created, each of which is 50GB in size.
So either somehow the code uses PyTorch Sparse tensors, which we do not support (and maybe PyTorch map to Dense tensors for XLA), or the code is wrong.

@stale
Copy link

stale bot commented Apr 27, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Apr 27, 2020
@stale stale bot closed this as completed May 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

2 participants