Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@

#include "common.cuh"

FBGEMM_OP_DISPATCH(CUDA, "dense_to_jagged", fbgemm_gpu::dense_to_jagged);
FBGEMM_OP_DISPATCH(
CUDA,
"jagged_to_padded_dense",
fbgemm_gpu::jagged_to_padded_dense);
FBGEMM_OP_DISPATCH(
CUDA,
"jagged_dense_elementwise_add",
Expand Down
42 changes: 40 additions & 2 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class JaggedToPaddedDenseOp
const std::vector<Tensor>& offsets,
at::ArrayRef<at::SymInt> max_lengths,
const double padding_value)>();

at::AutoDispatchBelowAutograd mode;
Tensor padded_values = op.call(values, offsets, max_lengths, padding_value);

return {padded_values};
Expand Down Expand Up @@ -286,6 +288,7 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
const Tensor& dense,
const std::vector<Tensor>& offsets,
std::optional<at::SymInt> total_L)>();
at::AutoDispatchBelowAutograd mode;
auto output = op.call(dense, offsets, total_L);

return {output};
Expand Down Expand Up @@ -785,14 +788,30 @@ class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
} // namespace

///@ingroup jagged-tensor-ops-cpu
Tensor jagged_to_padded_dense(
Tensor jagged_to_padded_dense_forward_autograd(
const Tensor& values,
const std::vector<Tensor>& offsets,
const c10::SymIntArrayRef max_lengths,
const double padding_value) {
return JaggedToPaddedDenseOp::apply(
values, offsets, max_lengths, padding_value)[0];
}
Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
const c10::SymIntArrayRef max_lengths,
const double padding_value) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "")
.typed<at::Tensor(
const Tensor& values,
const std::vector<Tensor>& offsets,
at::ArrayRef<at::SymInt> max_lengths,
const double padding_value)>();
Tensor output = op.call(values, offsets, max_lengths, padding_value);
return output;
}

///@ingroup jagged-tensor-ops-cpu
/// Output = x + y where x is jagged, y and output are dense
Expand Down Expand Up @@ -855,7 +874,20 @@ std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged(
const Tensor& dense,
const std::vector<Tensor>& offsets,
std::optional<at::SymInt> total_L) {
return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets};
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "")
.typed<Tensor(
const Tensor& dense,
const std::vector<Tensor>& offsets,
std::optional<at::SymInt> total_L)>();
auto output = op.call(dense, offsets, total_L);
return {output, offsets};
}
Tensor dense_to_jagged_forward_autograd(
const Tensor& dense,
const std::vector<Tensor>& offsets,
std::optional<at::SymInt> total_L) {
return DenseToJaggedOp::apply(dense, offsets, total_L)[0];
}

///@ingroup jagged-tensor-ops-cpu
Expand Down Expand Up @@ -973,6 +1005,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
m.impl(
"jagged_to_padded_dense_forward",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_autograd));
m.impl(
"dense_to_jagged_forward",
TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_autograd));
}

TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) {
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,13 +1818,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense);
DISPATCH_TO_CPU("dense_to_jagged", fbgemm_gpu::dense_to_jagged);
DISPATCH_TO_CPU(
"dense_to_jagged_forward", fbgemm_gpu::dense_to_jagged_forward);
DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
DISPATCH_TO_CPU(
"jagged_to_padded_dense_forward",
fbgemm_gpu::jagged_to_padded_dense_forward);
fbgemm_gpu::jagged_to_padded_dense_forward_cpu);
DISPATCH_TO_CPU(
"jagged_to_padded_dense_backward",
fbgemm_gpu::jagged_to_padded_dense_backward);
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,21 @@ Tensor jagged_to_padded_dense_meta(

Tensor jagged_to_padded_dense_backward_meta(
const at::Tensor& grad_output,
const std::vector<Tensor>& /*offsets*/,
const std::vector<Tensor>& offsets,
at::SymInt total_L) {
const auto& grad_padded_values = grad_output;

at::SymInt D = grad_padded_values.sym_size(-1);
const bool D_folded = grad_padded_values.dim() == offsets.size() + 1;
const auto& grad_padded_values_view =
D_folded ? grad_padded_values.unsqueeze(-1) : grad_padded_values;
at::SymInt D = grad_padded_values_view.sym_size(-1);
// Initialize with zeros so output will be zero for the portion truncated
// in forward.
auto grad_values =
at::zeros_symint({std::move(total_L), D}, grad_padded_values.options());

TORCH_CHECK(grad_values.is_meta());
return grad_values;
return D_folded ? grad_values.squeeze(-1) : grad_values;
}

Tensor jagged_dense_dense_elementwise_add_jagged_output_forward_meta(
Expand Down
11 changes: 1 addition & 10 deletions fbgemm_gpu/test/jagged/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import itertools
import sys
import unittest
from typing import Callable

import fbgemm_gpu
Expand Down Expand Up @@ -43,15 +42,7 @@
# Please avoid putting tests here, you should put operator-specific
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
additional_decorators: dict[str, list[Callable]] = {
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_to_padded_dense": [
unittest.expectedFailure,
],
}
additional_decorators: dict[str, list[Callable]] = {}


def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/test/jagged/dense_to_jagged_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def _test_dense_to_jagged(
jagged_values.backward(ref_output_values)
torch.testing.assert_close(dense.grad, ref_values)

torch.library.opcheck(
torch.ops.fbgemm.dense_to_jagged,
(dense.detach().requires_grad_(True), offsets),
)

@given(
num_jagged_dim=st.integers(1, 5),
outer_dense_size=st.integers(0, 5),
Expand Down
20 changes: 20 additions & 0 deletions fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ def test_jagged_index_select_2d(
rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None,
atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None,
)
if known_shape:
with torch.no_grad():
tmp_output, _ = torch.ops.fbgemm.jagged_index_select(
values, lengths, indices
)
num_dense_output_rows = tmp_output.shape[0]
torch.library.opcheck(
torch.ops.fbgemm.jagged_index_select.default,
(
values.detach().requires_grad_(),
lengths,
indices,
num_dense_output_rows,
),
)
else:
torch.library.opcheck(
torch.ops.fbgemm.jagged_index_select.default,
(values.detach().requires_grad_(), lengths, indices),
)

@given(
max_seq_length=st.integers(5, 10),
Expand Down
44 changes: 44 additions & 0 deletions fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,50 @@ def test_jagged_to_padded_dense(
rtol=1e-3,
)

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c, d):
return torch.ops.fbgemm.jagged_to_padded_dense(a, b, c, d)

with torch.inference_mode():
gm = torch.export.export(
Mod(),
(
x_values.float().requires_grad_(True),
x_offsets,
max_lengths.astype(int).tolist(),
padding_value,
),
).run_decompositions()
num_fw_ops = len(
[
x
for x in gm.graph.nodes
if x.target is torch.ops.fbgemm.jagged_to_padded_dense_forward.default
]
)
num_composite_ops = len(
[
x
for x in gm.graph.nodes
if x.target is torch.ops.fbgemm.jagged_to_padded_dense.default
]
)
self.assertEqual(num_fw_ops, 1)
self.assertEqual(num_composite_ops, 0)

torch.library.opcheck(
torch.ops.fbgemm.jagged_to_padded_dense,
(
x_values.float().requires_grad_(True),
x_offsets,
max_lengths,
padding_value,
),
)

@given(
num_jagged_dim=st.integers(1, 5),
outer_dense_size=st.integers(0, 5),
Expand Down
Loading