Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback _embedding_bag_backward and force sparse=false. #7584

Merged
merged 11 commits into from
Jul 4, 2024
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ supported:
- dot
- elu_backward
- embedding_dense_backward
- _embedding_bag_backward
- empty.memory_format
- empty_strided
- expand_copy
Expand Down
77 changes: 77 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,83 @@ def foo(x, is_xla=False):

self.assertEqual(out, Xout.cpu())

def test_embedding_bag_backward_fallback(self):
# Tests whether EmbeddingBag backward function works and computes the expected results.
#
# EmbeddingBag has a 'sparse' flag which dictates what will be the layout of the grad
# returned by its backward function. Unfortunately, PyTorch/XLA doesn't support sparse
# tensors, yet. Therefore, as a work-around, we fallback to the dense backward function.
#
# This test tests whether we correctly compute the backward for sparse=True and
# sparse=False, making sure that we did not introduce any regressions.

# Run EmbeddingBag forward and backwards.
# Return the forward result + the computed weight grad.
def fn(indices, weight, **kwargs):
out = F.embedding_bag(indices, weight, **kwargs)
out.sum().backward()
return out, weight.grad

# Clone a tensor, and maybe move it to a different device.
def clone_and_maybe_move(tensor, device=None):
fresh = tensor
# Maybe move to the specified device.
if device is not None:
fresh = fresh.to(device)
# Clone if not cloned already by the previous device move.
if fresh.device == tensor.device and fresh.data_ptr() == tensor.data_ptr(
):
fresh = tensor.clone()
# Make this tensor a leaf tensor by detaching and reseting its
# requires_grad property.
fresh = fresh.detach()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. I wonder why we need to make the tensor a leaf tensor. What would happen if we don't do it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so we can grab its tensor.grad and compare.

fresh.requires_grad_(tensor.requires_grad)
return fresh

EMBEDDINGS = 10
VECTOR_DIM = 5
N = 5

kwargs = {
"indices": torch.randint(0, EMBEDDINGS, (N,)),
"weight": torch.randn((EMBEDDINGS, VECTOR_DIM), requires_grad=True),
"offsets": torch.tensor([0, 3], dtype=torch.long),
}

# Test all combinations of sparse + mode.
for sparse, mode in itertools.product((False, True),
("sum", "mean", "max")):
# According to nn.functional.embedding_bag PyTorch documentation, not supported.
if sparse and mode == "max":
continue

extra_kwargs = {
"mode": mode,
"sparse": sparse,
}

with self.subTest(sparse=sparse, mode=mode):
kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()}
xla_kwargs = {
k: clone_and_maybe_move(v, device=xm.xla_device())
for k, v in kwargs.items()
}

expected_out, expected_grad = fn(**kwargs_, **extra_kwargs)
actual_out, actual_grad = fn(**xla_kwargs, **extra_kwargs)

# PyTorch/XLA doesn't support sparse tensors.
# We explicitly fallback to the dense backward function whenever sparse=True.
# Therefore, we have to convert the expected grad to dense, so that we can
# compare the actual numbers.
if sparse:
self.assertTrue(expected_grad.is_sparse)
self.assertFalse(actual_grad.is_sparse)
expected_grad = expected_grad.to_dense()

self.assertEqual(actual_out, expected_out)
self.assertEqual(actual_grad, expected_grad)


class MNISTComparator(nn.Module):

Expand Down
27 changes: 27 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/ops/_embedding_bag_backward_native.h>
#include <ATen/ops/expand_copy.h>
#include <c10/core/Contiguity.h>
#include <torch/csrc/lazy/core/shape_inference.h>
Expand Down Expand Up @@ -1489,6 +1490,32 @@ XLANativeFunctions::_embedding_bag_forward_only(
bridge::AtenFromXlaTensor(std::get<3>(result)));
}

at::Tensor XLANativeFunctions::_embedding_bag_backward(
const at::Tensor& grad, const at::Tensor& indices_,
const at::Tensor& offsets_, const at::Tensor& offset2bag,
const at::Tensor& bag_size_, const at::Tensor& max_indices_,
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
const std::optional<at::Tensor>& per_sample_weights_opt,
int64_t padding_idx) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (sparse) {
TORCH_WARN(
"XLA does not support EmbeddingBag sparse backward function. "
"Falling back to the dense function.");
}
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::_embedding_bag_backward_symint(
grad, indices_, offsets_, offset2bag, bag_size_, max_indices_,
num_weights, scale_grad_by_freq, mode, /*sparse=*/false,
per_sample_weights_opt, padding_idx);
}
return at::native::
call_fallback_fn<&xla_fallback, ATEN_OP(_embedding_bag_backward)>::call(
grad, indices_, offsets_, offset2bag, bag_size_, max_indices_,
num_weights, scale_grad_by_freq, mode, /*sparse=*/false,
per_sample_weights_opt, padding_idx);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test to test/test_operations.py, run the embedding bag fwd and bwd and make sure it doesn't crash? this way we can prevent it from regressing.


at::Tensor XLANativeFunctions::empty_symint(
at::SymIntArrayRef sym_size, std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout, std::optional<at::Device> device,
Expand Down
Loading