Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary copy in TensorSource #8849

Merged
merged 9 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ function run_xla_op_tests1 {
run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py"
run_pt_xla_debug_level1 "$CDIR/debug_tool/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_data_transfer.py"
run_test "$CDIR/test_hlo_metadata.py"
# TODO(https://github.com/pytorch/xla/issues/8796): Re-enable this test
# run_test "$CDIR/test_profiler.py"
Expand Down
40 changes: 40 additions & 0 deletions test/test_data_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import glob
import os
from absl.testing import absltest

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr


class TestDataTransfer(absltest.TestCase):

def setUp(self):
met.clear_all()

def test_h2d_tensor_no_copy(self):
t = torch.zeros(10, 10)
t = t.to('xla')
self.assertNotIn('AtenSourceTensorCopy', met.counter_names())

def test_h2d_tensor_copy(self):
# Non-contiguous tensor will trigger a copy.
t = torch.zeros(10, 10).transpose(0, 1)
t = t.to('xla')
self.assertIn('AtenSourceTensorCopy', met.counter_names())
self.assertEqual(met.counter_value('AtenSourceTensorCopy'), 1)

@absltest.skipUnless(xr.device_type() == 'CUDA',
"This test only runs on CUDA.")
def test_h2d_tensor_cuda(self):
# If a torch tensor is on cuda, now it will be copied to CPU
# before sending to GPU via PJRT.
t = torch.zeros(10, 10).to('cuda')
t = t.to('xla')
self.assertIn('AtenSourceTensorCopy', met.counter_names())
self.assertEqual(met.counter_value('AtenSourceTensorCopy'), 1)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"
python3 "$TEST_CDIR/test_data_transfer.py"
python3 "$TEST_CDIR/test_data_type.py"
python3 "$TEST_CDIR/test_compilation_cache_utils.py"

Expand Down
26 changes: 19 additions & 7 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,25 @@ class AtenSource : public TensorSource {
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
}
// TODO(ysiraichi): check, first, if tensor lives in a device that the
// current PjRt client has access. If so, we don't need to go through the
// CPU.
tensor_ = std::move(
tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type),
/*non_blocking=*/false,
/*copy=*/true, at::MemoryFormat::Contiguous));
// The purposes of copy are:
// 1. Ensure the memory is contiguous, which is expected by PJRT.
// 2. Move CUDA tensor to CPU since we cannot pass CUDA memory to PJRT now.
// 3. Cast data type.
// We can avoid if copy is not needed.
if (tensor.device() == at::kCPU && tensor.is_contiguous() &&
tensor.dtype() == target_torch_type) {
// Skip copying a CPU tensor to CPU.
tensor_ = std::move(tensor);
} else {
TORCH_LAZY_COUNTER("AtenSourceTensorCopy", 1);
// TODO(ysiraichi): check, first, if tensor lives in a device that the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: I personally prefer to have TODOs linked to issues, and then having those assigned to people. That way, things can be more easily followed-up if a contributor is no longer active

// current PjRt client has access. If so, we don't need to go through the
// CPU.
tensor_ = std::move(tensor.to(
at::TensorOptions().device(at::kCPU).dtype(target_torch_type),
/*non_blocking=*/false,
/*copy=*/true, at::MemoryFormat::Contiguous));
}
}

const void* data() const override { return tensor_.const_data_ptr(); }
Expand Down
Loading