-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0 #1775
Comments
Can you try this? os.environ['XLA_USE_32BIT_LONG'] = '1' |
After i set this parameter, script produce this output: 2020-03-18 14:19:04.490633: I 6407 torch_xla/csrc/tensor_util.cpp:36] Using 32bit integers for kLong values and stuck |
Stuck ... wait. It is probably compiling. |
I think, it still stuck. Script don't produce output more then half hour. |
Ohh, I remember that model: #1657 Are you trying nightly? |
This was the fix: #1770 |
I update my conda environment on google cloud using update_nightly_torch_wheels.sh, but get the same error |
torch version: torch-1.5.0a0+a3de359 |
And when I setting XLA_USE_32BIT_LONG=1 looking like script still just sticking |
Model compiling more then hour, and I still don't receive output |
Can you please select "nightly" version (and "pytorch-nightly" on the TPU VM)? Also, IIRC, the user at #1657 did some changes to avoid the model to be issuing hundreds of tensor->str conversions. |
Yes, I alredy set pytorch-nightly on tpu instanse. Reformer model, in general compiles and run, but only if i set really small vocabulary and sequence length. If I set normal parametrs model compeling forever. |
I am not sure what the other user did, but can you try replacing this line: With: print(output.cpu()) If that is still slow, maybe try posting a debug-run tarbal: Run and let run until complete, or CTRL^C after 5 minutes or so. |
Even the model with this config import os
os.environ['XLA_USE_32BIT_LONG'] = '1'
# imports pytorch
import torch
# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import reformer_pytorch
from dotdict import dotdict
from tqdm import tqdm
config = dotdict(
num_tokens = 100, # vocab size
dim = 128, # embed dim
depth = 3, # layers
max_seq_len = 256, # ctx len
heads = 4,
lsh_dropout = 0.1,
emb_dim = 128, # embedding factorization for further memory savings
causal = False, # auto-regressive or not
bucket_size = 4, # average size of qk per bucket, 64 was recommended in paper
n_hashes = 4, # 4 is permissible per author, 8 is the best but slower
ff_chunks = 8, # number of chunks for feedforward layer, make higher if there are memory issues
weight_tie = False, # tie parameters of each layer for no memory per additional depth
attn_chunks = 2, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
num_mem_kv = 8, # persistent learned memory key values, from all-attention paper
twin_attention = False, # both branches of the reversible network will be attention
use_full_attn = False, # use full self attention, for comparison
full_attn_thres = 1, # use full attention if context length is less than set value
use_scale_norm = True, # use scale norm from 'Transformers without tears' paper
axial_position_shape = (4, 64),
axial_position_dims = (64, 64)
)
batch_size = 2
model = reformer_pytorch.ReformerLM(**config).to(dev)
opt = optim.Adam(model.parameters(), lr=3e-5)
for i in range(3):
opt.zero_grad()
x = torch.randint(0, 100, (1, 256)).to(dev)
print(x)
y = F.log_softmax(model(x), dim=-1)
print(y)
loss = F.nll_loss(y.reshape(y.shape[0] * y.shape[1], -1)[:-1], x.reshape(x.shape[0] * x.shape[1])[1:])
loss.backward()
print('LOSS ======================', loss)
xm.optimizer_step(opt, barrier=True) pass one step in 3 minutes |
print(output.cpu()) doesn't work |
Can you try to remove print(x) and print(y) from the small model above, and run 10 steps? |
output.cpu() did not work, because model hang earlier, then print() occur |
|
I honestly don't know at this point. |
Another problem - very long graph compilation of a graph with a normal size model |
Honestly, I have never waited for the end of compilation. My patience lasted for an hour at most |
I was able to repro this. export TRIM_GRAPH_SIZE=1000000 But the issue is, the model is huge in number of ops, and memory.
Here a test Colab I used: https://colab.research.google.com/drive/1yAygNnajc2cskFzrL5PIkXtegdtnIgJ9 |
It is really strange, on GPU model graph only takes about 3GB |
Ok, one more problem with regular size model. Train sticking on xm.optimizer_step(opt, barrier=True) code for reproducing import os
os.environ['XLA_USE_32BIT_LONG'] = '1'
os.environ['TRIM_GRAPH_SIZE'] = '1000000'
# imports pytorch
import torch
# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import reformer_pytorch
from dotdict import dotdict
from tqdm import tqdm
config = dotdict(
num_tokens=30522,
dim=768,
depth=12,
max_seq_len=40960,
heads=12,
lsh_dropout=0.1,
emb_dim=256, # embedding factorization for further memory savings
causal=False, # auto-regressive or not
bucket_size=64, # average size of qk per bucket, 64 was recommended in paper
n_hashes=4, # 4 is permissible per author, 8 is the best but slower
ff_chunks=200, # number of chunks for feedforward layer, make higher if there are memory issues
weight_tie=False, # tie parameters of each layer for no memory per additional depth
attn_chunks=8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
num_mem_kv=128, # persistent learned memory key values, from all-attention paper
twin_attention=False, # both branches of the reversible network will be attention
use_full_attn=False, # use full self attention, for comparison
full_attn_thres=128, # use full attention if context length is less than set value
use_scale_norm=True, # use scale norm from 'Transformers without tears' paper
axial_position_emb=True,
axial_position_shape=(640, 64),
axial_position_dims=(128, 128)
)
model = reformer_pytorch.ReformerLM(**config).to(dev)
opt = optim.Adam(model.parameters(), lr=3e-5)
for i in range(3):
opt.zero_grad()
x = torch.randint(0, 30522, (96, 40960)).to(dev)
y = F.log_softmax(model(x), dim=-1)
print("output calc")
loss = F.nll_loss(y.reshape(y.shape[0] * y.shape[1], -1)[:-1], x.reshape(x.shape[0] * x.shape[1])[1:])
loss.backward()
print("loss backward")
xm.optimizer_step(opt, barrier=True)
print("optim step") |
I think either the model makes PyTorch/GPU assumptions about how the code executes, or it is just too big. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
🐛 Bug
I trying to train reformer-pytorch on TPU, and after start training, i get error
To Reproduce
Steps to reproduce the behavior:
2020-03-18 13:34:08.627671: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] StackTrace:
2020-03-18 13:34:08.627680: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** Begin stack trace ***
2020-03-18 13:34:08.627687: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] tensorflow::CurrentStackTraceabi:cxx11
2020-03-18 13:34:08.627695: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] xla::util::ReportComputationError(tensorflow::Status const&, absl::Span<xla::XlaComputation const* const>, absl::Span<xla::Shape const* const>)
2020-03-18 13:34:08.627703: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] xla::XrtComputationClient::CheckCompileStatus(tensorflow::Status const&, std::vector<xla::ComputationClient::CompileInstance, std::allocatorxla::ComputationClient::CompileInstance > const&, xla::XrtComputationClient::SessionWork const&)
2020-03-18 13:34:08.627713: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627721: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627728: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627734: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627743: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627750: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] clone
2020-03-18 13:34:08.627758: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** End stack trace ***
2020-03-18 13:34:08.627769: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-03-18 13:34:08.627778: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Status: Unimplemented: From /job:tpu_worker/replica:0/task:0:
2020-03-18 13:34:08.627786: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %sort.18225 = (s64[60,164352]{1,0}, s32[60,164352]{1,0}) sort(s64[60,164352]{1,0} %add.18217, s32[60,164352]{1,0} %iota.18218), dimensions={1}, to_apply=%compare-less-than.18219
2020-03-18 13:34:08.627810: E 5797 tensorflow/compiler/xla/xla_client/xla_util.cc:76] [[{{node XRTCompile}}]]
Traceback (most recent call last):
File "simple_test.py", line 47, in
output = model(inputs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 671, in forward
x = self.reformer(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 634, in forward
x = self.layers(x, arg_route = arg_route, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 151, in forward
return _ReversibleFunction.apply(x, blocks, block_kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 112, in forward
x = block(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 53, in forward
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reversible.py", line 27, in forward
return self.net(*args, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 106, in forward
return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 106, in
return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 95, in forward
return self.fn(x, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 534, in forward
return self.net(x)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
input = module(input)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/nn/functional.py", line 1593, in linear
output = input.matmul(weight.t())
RuntimeError: Unimplemented: From /job:tpu_worker/replica:0/task:0:
While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %sort.18225 = (s64[60,164352]{1,0}, s32[60,164352]{1,0}) sort(s64[60,164352]{1,0} %add.18217, s32[60,164352]{1,0} %iota.18218), dimensions={1}, to_apply=%compare-less-than.18219
[[{{node XRTCompile}}]]
Environment
v3-8, pytorch-nightly
torch xla nightly(0.8+c77641f)
The text was updated successfully, but these errors were encountered: