From 516228c799d0a9f05586339dc2686be5f2e2d6a8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 27 Jun 2024 16:23:11 -0300 Subject: [PATCH 01/11] Fallback on `_embedding_bag_backward`. --- codegen/xla_native_functions.yaml | 1 + torch_xla/csrc/aten_xla_type.cpp | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index de5500a0c5b..4d086645973 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -182,6 +182,7 @@ supported: - dot - elu_backward - embedding_dense_backward + - _embedding_bag_backward - empty.memory_format - empty_strided - expand_copy diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5e233d051e3..30e9bf4d55f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1489,6 +1489,24 @@ XLANativeFunctions::_embedding_bag_forward_only( bridge::AtenFromXlaTensor(std::get<3>(result))); } +at::Tensor XLANativeFunctions::_embedding_bag_backward( + const at::Tensor& grad, const at::Tensor& indices_, + const at::Tensor& offsets_, const at::Tensor& offset2bag, + const at::Tensor& bag_size_, const at::Tensor& max_indices_, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const std::optional& per_sample_weights_opt, + int64_t padding_idx) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + TORCH_WARN( + "XLA does not support EmbeddingBag sparse backward function. " + "Falling back to the dense function."); + return at::native::call_fallback_fn<&xla_cpu_fallback, + ATEN_OP(_embedding_bag_backward)>:: + call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); +} + at::Tensor XLANativeFunctions::empty_symint( at::SymIntArrayRef sym_size, std::optional dtype, std::optional layout, std::optional device, From 5ffaadd91e66d2bf8afe31eff9b584478e7987e6 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 27 Jun 2024 16:45:48 -0300 Subject: [PATCH 02/11] Add torch_pin. --- .torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 .torch_pin diff --git a/.torch_pin b/.torch_pin new file mode 100644 index 00000000000..1ab02ede7a2 --- /dev/null +++ b/.torch_pin @@ -0,0 +1 @@ +#129691 \ No newline at end of file From dc6ab01bb463fb4162da7a49020be1c9d14756e8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 28 Jun 2024 16:27:24 -0300 Subject: [PATCH 03/11] Add test. --- test/test_operations.py | 76 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index e62cd517825..22074010301 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2198,6 +2198,82 @@ def foo(x, is_xla=False): self.assertEqual(out, Xout.cpu()) + def test_embedding_bag_backward_fallback(self): + # Tests whether EmbeddingBag backward function works and computes the expected results. + # + # EmbeddingBag has a 'sparse' flag which dictates what will be the layout of the grad + # returned by its backward function. Unfortunately, PyTorch/XLA doesn't support sparse + # tensors, yet. Therefore, as a work-around, we fallback to the dense backward function. + # + # This test tests whether we correctly compute the backward for sparse=True and + # sparse=False, making sure that we did not introduce any regressions. + + # Run EmbeddingBag forward and backwards. + # Return the forward result + the computed weight grad. + def fn(indices, weight, **kwargs): + out = F.embedding_bag(indices, weight, **kwargs) + out.sum().backward() + return out, weight.grad + + # Clone a tensor, and maybe move it to a different device. + def clone_and_maybe_move(tensor, device=None): + fresh = tensor + # Maybe move to the specified device. + if device is not None: + fresh = fresh.to(device) + # Clone if not cloned already by the previous device move. + if fresh.data_ptr() == tensor.data_ptr(): + fresh = tensor.clone() + # Make this tensor a leaf tensor by detaching and reseting its + # requires_grad property. + fresh = fresh.detach() + fresh.requires_grad_(tensor.requires_grad) + return fresh + + EMBEDDINGS = 10 + VECTOR_DIM = 5 + N = 5 + + kwargs = { + "indices": torch.randint(0, EMBEDDINGS, (N,)), + "weight": torch.randn((EMBEDDINGS, VECTOR_DIM), requires_grad=True), + "offsets": torch.tensor([0, 3], dtype=torch.long), + } + + # Test all combinations of sparse + mode. + for sparse, mode in itertools.product((False, True), + ("sum", "mean", "max")): + # According to nn.functional.embedding_bag PyTorch documentation, not supported. + if sparse and mode == "max": + continue + + extra_kwargs = { + "mode": mode, + "sparse": sparse, + } + + with self.subTest(sparse=sparse, mode=mode): + kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()} + xla_kwargs = { + k: clone_and_maybe_move(v, device=xm.xla_device()) + for k, v in kwargs.items() + } + + expected_out, expected_grad = fn(**kwargs_, **extra_kwargs) + actual_out, actual_grad = fn(**xla_kwargs, **extra_kwargs) + + # PyTorch/XLA doesn't support sparse tensors. + # We explicitly fallback to the dense backward function whenever sparse=True. + # Therefore, we have to convert the expected grad to dense, so that we can + # compare the actual numbers. + if sparse: + self.assertTrue(expected_grad.is_sparse) + self.assertFalse(actual_grad.is_sparse) + expected_grad = expected_grad.to_dense() + + self.assertEqual(actual_out, expected_out) + self.assertEqual(actual_grad, expected_grad) + class MNISTComparator(nn.Module): From b59e5b0d1aa70ac0e1ec5e5741fb153a362c8b60 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 1 Jul 2024 09:52:48 -0300 Subject: [PATCH 04/11] Warn only if sparse. --- torch_xla/csrc/aten_xla_type.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 30e9bf4d55f..f09da96894f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1497,9 +1497,11 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( const std::optional& per_sample_weights_opt, int64_t padding_idx) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - TORCH_WARN( - "XLA does not support EmbeddingBag sparse backward function. " - "Falling back to the dense function."); + if (sparse) { + TORCH_WARN( + "XLA does not support EmbeddingBag sparse backward function. " + "Falling back to the dense function."); + } return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(_embedding_bag_backward)>:: call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, From 46371e53c9a5665bfe2c76c24781d24c00295dd9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Jul 2024 19:01:16 -0300 Subject: [PATCH 05/11] Remove torch_pin. --- .torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .torch_pin diff --git a/.torch_pin b/.torch_pin deleted file mode 100644 index 1ab02ede7a2..00000000000 --- a/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#129691 \ No newline at end of file From a35b53a108df61584961a18eeb5a8e7b23627408 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Jul 2024 20:10:45 -0300 Subject: [PATCH 06/11] Propagate fallback function name change. --- torch_xla/csrc/aten_xla_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index f09da96894f..9a8655e625a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1502,7 +1502,7 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( "XLA does not support EmbeddingBag sparse backward function. " "Falling back to the dense function."); } - return at::native::call_fallback_fn<&xla_cpu_fallback, + return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>:: call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, /*sparse=*/false, From 6c3f77885397d1f5a2798751fa6c5d0978ba10e9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Jul 2024 20:17:53 -0300 Subject: [PATCH 07/11] Fix lint issues. --- torch_xla/csrc/aten_xla_type.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9a8655e625a..fad0f735df7 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1502,11 +1502,11 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( "XLA does not support EmbeddingBag sparse backward function. " "Falling back to the dense function."); } - return at::native::call_fallback_fn<&xla_fallback, - ATEN_OP(_embedding_bag_backward)>:: - call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, - num_weights, scale_grad_by_freq, mode, /*sparse=*/false, - per_sample_weights_opt, padding_idx); + return at::native:: + call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>::call( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); } at::Tensor XLANativeFunctions::empty_symint( From 9f05200019a94be4a7fa013b31cee73ecb3d6457 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 4 Jul 2024 14:13:37 -0300 Subject: [PATCH 08/11] Add path for `XLA_DISABLE_FUNCTIONALIZATION=1`. --- test/test_operations.py | 2 +- torch_xla/csrc/aten_xla_type.cpp | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 22074010301..eff040ba36a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2222,7 +2222,7 @@ def clone_and_maybe_move(tensor, device=None): if device is not None: fresh = fresh.to(device) # Clone if not cloned already by the previous device move. - if fresh.data_ptr() == tensor.data_ptr(): + if fresh.device == tensor.device and fresh.data_ptr() == tensor.data_ptr(): fresh = tensor.clone() # Make this tensor a leaf tensor by detaching and reseting its # requires_grad property. diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index fad0f735df7..139a3760c39 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -1502,11 +1503,17 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( "XLA does not support EmbeddingBag sparse backward function. " "Falling back to the dense function."); } - return at::native:: - call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>::call( - grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, - num_weights, scale_grad_by_freq, mode, /*sparse=*/false, - per_sample_weights_opt, padding_idx); + if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + return at::native::_embedding_bag_backward_symint( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); + } + return at::native::call_fallback_fn<&xla_cpu_fallback, + ATEN_OP(_embedding_bag_backward)>:: + call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); } at::Tensor XLANativeFunctions::empty_symint( From e85be64a24252f7dcf8cbf9d89f4444df90880a0 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 4 Jul 2024 15:19:01 -0300 Subject: [PATCH 09/11] Fix fallback function rename. --- torch_xla/csrc/aten_xla_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 139a3760c39..ed2b91826ae 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1509,7 +1509,7 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( num_weights, scale_grad_by_freq, mode, /*sparse=*/false, per_sample_weights_opt, padding_idx); } - return at::native::call_fallback_fn<&xla_cpu_fallback, + return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>:: call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, /*sparse=*/false, From c87e76a9c6a32143a79bfaf25b790c8ce76f3f00 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 4 Jul 2024 15:20:30 -0300 Subject: [PATCH 10/11] Fix lint issues. --- test/test_operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index eff040ba36a..affcf2b6039 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2222,7 +2222,8 @@ def clone_and_maybe_move(tensor, device=None): if device is not None: fresh = fresh.to(device) # Clone if not cloned already by the previous device move. - if fresh.device == tensor.device and fresh.data_ptr() == tensor.data_ptr(): + if fresh.device == tensor.device and fresh.data_ptr() == tensor.data_ptr( + ): fresh = tensor.clone() # Make this tensor a leaf tensor by detaching and reseting its # requires_grad property. From 8d61cc1117f648343c5b04aab2a1b1509b387d12 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 4 Jul 2024 15:22:54 -0300 Subject: [PATCH 11/11] Fix lint issues. --- torch_xla/csrc/aten_xla_type.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index ed2b91826ae..0428bad6578 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1509,11 +1509,11 @@ at::Tensor XLANativeFunctions::_embedding_bag_backward( num_weights, scale_grad_by_freq, mode, /*sparse=*/false, per_sample_weights_opt, padding_idx); } - return at::native::call_fallback_fn<&xla_fallback, - ATEN_OP(_embedding_bag_backward)>:: - call(grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, - num_weights, scale_grad_by_freq, mode, /*sparse=*/false, - per_sample_weights_opt, padding_idx); + return at::native:: + call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>::call( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); } at::Tensor XLANativeFunctions::empty_symint(