diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index de5500a0c5b..4d086645973 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -182,6 +182,7 @@ supported: - dot - elu_backward - embedding_dense_backward + - _embedding_bag_backward - empty.memory_format - empty_strided - expand_copy diff --git a/test/test_operations.py b/test/test_operations.py index e62cd517825..affcf2b6039 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2198,6 +2198,83 @@ def foo(x, is_xla=False): self.assertEqual(out, Xout.cpu()) + def test_embedding_bag_backward_fallback(self): + # Tests whether EmbeddingBag backward function works and computes the expected results. + # + # EmbeddingBag has a 'sparse' flag which dictates what will be the layout of the grad + # returned by its backward function. Unfortunately, PyTorch/XLA doesn't support sparse + # tensors, yet. Therefore, as a work-around, we fallback to the dense backward function. + # + # This test tests whether we correctly compute the backward for sparse=True and + # sparse=False, making sure that we did not introduce any regressions. + + # Run EmbeddingBag forward and backwards. + # Return the forward result + the computed weight grad. + def fn(indices, weight, **kwargs): + out = F.embedding_bag(indices, weight, **kwargs) + out.sum().backward() + return out, weight.grad + + # Clone a tensor, and maybe move it to a different device. + def clone_and_maybe_move(tensor, device=None): + fresh = tensor + # Maybe move to the specified device. + if device is not None: + fresh = fresh.to(device) + # Clone if not cloned already by the previous device move. + if fresh.device == tensor.device and fresh.data_ptr() == tensor.data_ptr( + ): + fresh = tensor.clone() + # Make this tensor a leaf tensor by detaching and reseting its + # requires_grad property. + fresh = fresh.detach() + fresh.requires_grad_(tensor.requires_grad) + return fresh + + EMBEDDINGS = 10 + VECTOR_DIM = 5 + N = 5 + + kwargs = { + "indices": torch.randint(0, EMBEDDINGS, (N,)), + "weight": torch.randn((EMBEDDINGS, VECTOR_DIM), requires_grad=True), + "offsets": torch.tensor([0, 3], dtype=torch.long), + } + + # Test all combinations of sparse + mode. + for sparse, mode in itertools.product((False, True), + ("sum", "mean", "max")): + # According to nn.functional.embedding_bag PyTorch documentation, not supported. + if sparse and mode == "max": + continue + + extra_kwargs = { + "mode": mode, + "sparse": sparse, + } + + with self.subTest(sparse=sparse, mode=mode): + kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()} + xla_kwargs = { + k: clone_and_maybe_move(v, device=xm.xla_device()) + for k, v in kwargs.items() + } + + expected_out, expected_grad = fn(**kwargs_, **extra_kwargs) + actual_out, actual_grad = fn(**xla_kwargs, **extra_kwargs) + + # PyTorch/XLA doesn't support sparse tensors. + # We explicitly fallback to the dense backward function whenever sparse=True. + # Therefore, we have to convert the expected grad to dense, so that we can + # compare the actual numbers. + if sparse: + self.assertTrue(expected_grad.is_sparse) + self.assertFalse(actual_grad.is_sparse) + expected_grad = expected_grad.to_dense() + + self.assertEqual(actual_out, expected_out) + self.assertEqual(actual_grad, expected_grad) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5e233d051e3..0428bad6578 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -1489,6 +1490,32 @@ XLANativeFunctions::_embedding_bag_forward_only( bridge::AtenFromXlaTensor(std::get<3>(result))); } +at::Tensor XLANativeFunctions::_embedding_bag_backward( + const at::Tensor& grad, const at::Tensor& indices_, + const at::Tensor& offsets_, const at::Tensor& offset2bag, + const at::Tensor& bag_size_, const at::Tensor& max_indices_, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const std::optional& per_sample_weights_opt, + int64_t padding_idx) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if (sparse) { + TORCH_WARN( + "XLA does not support EmbeddingBag sparse backward function. " + "Falling back to the dense function."); + } + if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + return at::native::_embedding_bag_backward_symint( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); + } + return at::native:: + call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>::call( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, + num_weights, scale_grad_by_freq, mode, /*sparse=*/false, + per_sample_weights_opt, padding_idx); +} + at::Tensor XLANativeFunctions::empty_symint( at::SymIntArrayRef sym_size, std::optional dtype, std::optional layout, std::optional device,