Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Return tensors #894

Closed
see-- opened this issue Aug 2, 2019 · 19 comments
Closed

[Question] Return tensors #894

see-- opened this issue Aug 2, 2019 · 19 comments

Comments

@see--
Copy link

see-- commented Aug 2, 2019

What is the best pattern for metrics that can't be accumulated like accuracy? E.g. roc_auc_score. Is this pattern ok: return tensors?

  def test_loop_fn(model, loader, device, context):
    y_trues = torch.zeros(len(loader._loader._loader) * config.batch_size)
    y_preds = torch.zeros(len(loader._loader._loader) * config.batch_size)
    model.eval()
    for k, (X, y_true) in loader:
      output = model(X)
      y_pred = output.max(1)[1]
      y_preds[k * config.batch_size: (k + 1) * config.batch_size] = y_pred
      y_trues[k * config.batch_size: (k + 1) * config.batch_size] = y_true

    return y_trues, y_preds

I am asking because the validation takes 2x the time that my training loop takes for the same number of samples. Maybe related: Shouldn't the code be wrapped with torch.no_grad()?

@see--
Copy link
Author

see-- commented Aug 2, 2019

I found two problems: The tensors should live on the device and the correct size is num_val_samples_per_device=num_val_samples // len(devices). It's still not really working. If I remove the print statements I just get zeros. I guess some kind of syncing is missing, but I didn't find it in the guide/examples. This is the modified version:

  def test_loop_fn(model, loader, device, context):
    y_trues = torch.zeros(num_val_samples_per_device, device=device)
    y_preds = torch.zeros(num_val_samples_per_device, device=device)
    model.eval()
    for k, (X, y_true) in loader:
      output = model(X)
      y_pred = output.argmax(-1)
      y_preds[k * config.batch_size: (k + 1) * config.batch_size] = y_pred
      y_trues[k * config.batch_size: (k + 1) * config.batch_size] = y_true

    print(y_trues.max().item(), y_preds.max().item())
    y_trues = y_trues.cpu()
    y_preds = y_preds.cpu()
    print(y_trues.max().item(), y_preds.max().item())
    return y_trues, y_preds

@taylanbil
Copy link
Collaborator

I think the first version needs to initialize tensors on the device.

I'm trying a variant of this on the test_train_mnist.py file and not having any issues, it seems to be working? What are the issues you're seeing? It's not slow for me with the following version.

diff:

--- a/test/test_train_mnist.py
+++ b/test/test_train_mnist.py
@@ -113,24 +113,39 @@ def train_mnist():
                                                         tracker.rate()))

   def test_loop_fn(model, loader, device, context):
-    total_samples = 0
-    correct = 0
+    y_trues = torch.zeros(1152, device=device)
+    y_preds = torch.zeros(1152, device=device)
     model.eval()
-    for x, (data, target) in loader:
-      output = model(data)
-      pred = output.max(1, keepdim=True)[1]
-      correct += pred.eq(target.view_as(pred)).sum().item()
-      total_samples += data.size()[0]
-
-    print('[{}] Accuracy={:.2f}%'.format(device,
-                                         100.0 * correct / total_samples))
-    return correct / total_samples
+    #size=0
+    for k, (X, y_true) in loader:
+      #size += X.shape[0]  # 1152
+      #continue
+      output = model(X)
+      y_pred = output.argmax(-1)
+      #print(torch_xla._XLAC._xla_metrics_report())
+      y_preds[k * 128: (k + 1) * 128] = y_pred
+      y_trues[k * 128: (k + 1) * 128] = y_true
+      #print(y_preds.sum())  @ this changes every step as one would expect
+      #print('-'*80)
+      #print(torch_xla._XLAC._xla_metrics_report())
+      #import pdb; pdb.set_trace()
+
+    #print(size)
+    #return
+    #print(y_trues.max().item(), y_preds.max().item())
+    #y_trues = y_trues.cpu()
+    #y_preds = y_preds.cpu()
+    #print(y_trues.max().item(), y_preds.max().item())
+    return y_trues, y_preds

-  accuracy = 0.0
   for epoch in range(1, FLAGS.num_epochs + 1):
-    model_parallel(train_loop_fn, train_loader)
-    accuracies = model_parallel(test_loop_fn, test_loader)
-    accuracy = sum(accuracies) / len(accuracies)
+    #model_parallel(train_loop_fn, train_loader)
+    out = model_parallel(test_loop_fn, test_loader)
+    print('device 0, ytrues ypreds :')
+    print('{}'.format(out[0]))
+    print('device 1, ytrues ypreds :')
+    print('{}'.format(out[1]))
+    import pdb; pdb.set_trace()
     if FLAGS.metrics_debug:
       print(torch_xla._XLAC._xla_metrics_report())

output:

device_coordinates: 1
device_coordinates: 1

device 0, ytrues ypreds :
(tensor([8., 7., 3.,  ..., 1., 2., 6.], device='xla:1'), tensor([1., 1., 1.,  ..., 1., 1., 1.], device='xla:1'))
device 1, ytrues ypreds :
(tensor([7., 9., 6.,  ..., 8., 7., 8.], device='xla:2'), tensor([1., 1., 1.,  ..., 1., 1., 1.], device='xla:2'))
> /usr/share/torch-xla-nightly/pytorch/xla/test/tttt.py(149)train_mnist()
-> if FLAGS.metrics_debug:
(Pdb) c
device 0, ytrues ypreds :
(tensor([8., 1., 2.,  ..., 6., 6., 5.], device='xla:1'), tensor([1., 1., 1.,  ..., 1., 1., 1.], device='xla:1'))
device 1, ytrues ypreds :
(tensor([9., 9., 4.,  ..., 2., 2., 4.], device='xla:2'), tensor([1., 1., 1.,  ..., 1., 1., 1.], device='xla:2'))
> /usr/share/torch-xla-nightly/pytorch/xla/test/tttt.py(148)train_mnist()
-> import pdb; pdb.set_trace()
(Pdb) q

btw, If you wanted to compute roc (assuming you have a binary clf task), you may want the scores out instead of argmax.

So here, I commented out training, and my predictions are mostly 1 but not always.

@see--
Copy link
Author

see-- commented Aug 2, 2019

Ok thanks a lot for helping me here. Accuracy or roc is just an example, but I think the pattern is quite common. I am using the gcr.io/tpu-pytorch/xla:nightly container.

1.) I get Segmentation fault (core dumped) if I use these changes:
diff /pytorch/xla/test/test_train_mnist.py /pytorch/xla/test/tpu_metrics.py

<       if x % FLAGS.log_steps == 0:
<         print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
<                                                         tracker.rate()))
116,117c113,114
<     total_samples = 0
<     correct = 0
---
>     y_trues = torch.zeros(1152, device=device)
>     y_preds = torch.zeros(1152, device=device)
119,127c116,121
<     for x, (data, target) in loader:
<       output = model(data)
<       pred = output.max(1, keepdim=True)[1]
<       correct += pred.eq(target.view_as(pred)).sum().item()
<       total_samples += data.size()[0]
< 
<     print('[{}] Accuracy={:.2f}%'.format(device,
<                                          100.0 * correct / total_samples))
<     return correct / total_samples
---
>     for k, (X, y_true) in loader:
>       output = model(X)
>       y_pred = output.argmax(-1)
>       y_preds[k * 128: (k + 1) * 128] = y_pred
>       y_trues[k * 128: (k + 1) * 128] = y_true
>     return y_trues, y_preds
132,133c126,129
<     accuracies = model_parallel(test_loop_fn, test_loader)
<     accuracy = sum(accuracies) / len(accuracies)
---
>     ret = model_parallel(test_loop_fn, test_loader)
>     y_trues = torch.cat([r[0] for r in ret]).cpu().numpy()
>     y_preds = torch.cat([r[1] for r in ret]).cpu().numpy()
>     accuracy = (y_trues == y_preds).mean()

output:

2019-08-02 20:29:39.633587: E tensorflow/compiler/xla/xla_client/tf_logging.cc:11] Check failed: device == xrt_data.device() (TPU:0 vs. TPU:7)
*** Begin stack trace ***
	tensorflow::CurrentStackTrace[abi:cxx11]()
	xla::XrtComputationClient::GetArgumentsInputs(absl::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
	xla::XrtComputationClient::CreateExecuteOps(std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, xla::XrtSessionCache::Ref, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, xla::XrtSessionCache::Ref> > >*, xla::XrtComputationClient::XrtComputation const&, std::vector<std::vector<std::shared_ptr<xla::ComputationClient::Data>, std::allocator<std::shared_ptr<xla::ComputationClient::Data> > >, std::allocator<std::vector<std::shared_ptr<xla::ComputationClient::Data>, std::allocator<std::shared_ptr<xla::ComputationClient::Data> > > > > const&, bool, absl::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, std::unordered_map<tensorflow::Output, tensorflow::Input::Initializer, tensorflow::OutputHash, std::equal_to<tensorflow::Output>, std::allocator<std::pair<tensorflow::Output const, tensorflow::Input::Initializer> > >*)
	xla::XrtComputationClient::ExecuteComputation(xla::ComputationClient::Computation const&, absl::Span<std::shared_ptr<xla::ComputationClient::Data> const>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, xla::ComputationClient::ExecuteComputationOptions const&)
	
	
	
	
	
	clone
*** End stack trace ***

Segmentation fault (core dumped)

2.) This works: I move the cpu() inside the test_loop_fn. Really weird ... The max() are needed.

diff /pytorch/xla/test/tpu_metrics.py /pytorch/xla/test/tpu_metrics_v2.py

>     print(y_trues.max().item(), y_preds.max().item())
>     y_trues = y_trues.cpu()
>     y_preds = y_preds.cpu()
129a134
>     print("ep-%03d: %.4f" % (epoch, accuracy))
(pytorch) root@36dade8548fd:/aptos# diff /pytorch/xla/test/tpu_metrics.py /pytorch/xla/test/tpu_metrics_v2.py 
120a121,124
> 
>     print(y_trues.max().item(), y_preds.max().item())
>     y_trues = y_trues.cpu()
>     y_preds = y_preds.cpu()
129a134
>     print("ep-%03d: %.4f" % (epoch, accuracy))

output:

9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
ep-001: 0.9627
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
9.0 9.0
ep-002: 0.9798

3.) If I remove print(y_trues.max().item(), y_preds.max().item()) the accuracy is 1.0 as y_trues and y_preds are all just 0.
diff /pytorch/xla/test/tpu_metrics_v2.py /pytorch/xla/test/tpu_metrics_v3.py

122c122
<     print(y_trues.max().item(), y_preds.max().item())
---
>     # print(y_trues.max().item(), y_preds.max().item())
133a134
>     print(y_trues.max(), y_preds.max())

output:

0.0 0.0
ep-001: 1.0000
0.0 0.0
ep-002: 1.0000
0.0 0.0
ep-003: 1.0000
0.0 0.0
ep-004: 1.0000

4.) Regarding the speed. The validation now takes as long as the training. This is still surprising. Not computing the backward pass should be much faster? I can try to create an example with fake data. I am seeing this with real data.

5.) Unrelated: I often get:

packet_write_wait: Connection to 34.66.216.166 port 22: Broken pipe
ERROR: (gcloud.beta.compute.ssh) [/usr/bin/ssh] exited with return code [255].

And I am disconnected from my instance. Any help is appreciated.

@taylanbil
Copy link
Collaborator

5/ you can use screen or tmux and reattach to your working session when disconnected.

when was the last time you docker pulled? Validation steps were recently modified, you'll see different behavior if you docker pulled more than a week ago or so.

@see--
Copy link
Author

see-- commented Aug 2, 2019

I ran docker pull gcr.io/tpu-pytorch/xla:nightly ~5 minutes ago on a new instance.

@taylanbil
Copy link
Collaborator

so I think in 1/, you have tensors that you're concatenating, which live on different devices. That's not possible. Moving .cpu inside takes care of this issue. Alternatively, you can move to cpu outside test_loop_fn first, then concat.

@taylanbil
Copy link
Collaborator

I'm able to repro.

4/ my validation speed is not that slow.

begin train 2019-08-02 21:32:47.475127
end train, begin val 2019-08-02 21:32:52.244583
end val 2019-08-02 21:32:54.040544
begin train 2019-08-02 21:32:54.041433
end train, begin val 2019-08-02 21:32:57.109598
end val 2019-08-02 21:32:57.767177
begin train 2019-08-02 21:32:57.767633
end train, begin val 2019-08-02 21:33:00.966285
end val 2019-08-02 21:33:01.585395
begin train 2019-08-02 21:33:01.585923
end train, begin val 2019-08-02 21:33:04.680764
end val 2019-08-02 21:33:05.271830

@taylanbil
Copy link
Collaborator

I think you are exposing a behavior here where we don't update the data of cpu counterpart of a tpu tensor, unless some item call happens.

To compute accuracies, we don't need this, as you can compute on device accuracies and average them appropriately to get total accuracy, however to compute roc_auc or a similar complicated metric, we would need to go back to cpu w/o loss of data.

@taylanbil
Copy link
Collaborator

yeah I initialize ypreds = torch.ones instead of zeros, commented out the .max and accuracy is now 0%

@taylanbil
Copy link
Collaborator

@dlibenzi @asuhan @ailzhang do we have a .item() analogue that returns the tensor back to cpu? .values() seems to not be implemented, I'm getting runtime error.

@see-- I guess for now you can use the workaround that is .max().item(). And I'm not sure why you're experiencing slowness in validation.

@see--
Copy link
Author

see-- commented Aug 2, 2019

Hm, I think I might be bound by the data loading. To be clear, it's not related to these mnist examples. They can't be used to benchmark. I will investigate this more.

@dlibenzi
Copy link
Collaborator

dlibenzi commented Aug 5, 2019

Why do you need an item() call equivalent?
Just do operations in PyTorch tensor domain w/out dropping into Python scalar one.
Also, is this nightly or the original 0.1 version?

@taylanbil
Copy link
Collaborator

Some metrics such as the roc curve, precision recall curve, auc etc require all the predictions in hand before starting to compute (you sort w.r.to prediction score, and then count etc.) (as opposed to something like accuracy, where you can reduce to counters first and then combine).

I was running on nightly.

@dlibenzi
Copy link
Collaborator

dlibenzi commented Aug 5, 2019

What is the issue with storing all the tensors in a list?
If you are going to call item() on them, it means they are already scalar, so likely taking little space.
How many of those values will you have? 100s? 1000s? 10000s?

@taylanbil
Copy link
Collaborator

  • validation set size is usually in 10s of thousands.
  • storing all tensors in a list causes the following:
  • we have a list of 1dim tensors on different devices
  • we cannot concat them, because they are on different devices, it immediately errors.
  • we need to sort their values in order to compute/plot the roc-curve
  • when we send those tensors to cpu first, the cpu counterparts aren't updated, unless we call .max.item() on those tensors.

So we don't normally call item on those tensors, it's just a weird workaround that @see-- found.

@dlibenzi
Copy link
Collaborator

dlibenzi commented Aug 5, 2019

The training loop is per-device, so you can have per-device list.
Then at the end get them to CPU. But one at a time will be super slow.

@taylanbil
Copy link
Collaborator

taylanbil commented Aug 5, 2019

hmm, @see-- did you try this? first send to cpu, then concat, all outside the test_loop_fn?

>     y_trues = torch.cat([r[0].cpu() for r in ret]).numpy()
>     y_preds = torch.cat([r[1].cpu() for r in ret]).numpy()

@Borda
Copy link

Borda commented Jul 21, 2020

it seems I have the same problem https://github.com/PyTorchLightning/pytorch-lightning/pull/2632/checks?check_run_id=891921103

File "/content/pytorch-lightning/tests/base/model_valid_steps.py", line 25, in validation_step
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:991 : Check failed: device == xrt_data.device() (TPU:0 vs. CPU:0)

but when I print their devices

>>> model.device
xla:0
>>> y.device
xla:1
>>> labels_hat.device
xla:1

@taylanbil
Copy link
Collaborator

@Borda could you open a new issue with (small) repro instructions? There's a lot that's changed since this issue was reported/closed and it's likely to be a different issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants