Skip to content

Commit

Permalink
Symbolic shape tracing on jagged op (#1758)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1758

fbgemm's jagged op doesn't have symint support in the meta function, resulting in all the shapes specialized (materialized int baked in the model instead of symbolic shape). Fixing them.

Reviewed By: q10

Differential Revision:
D44736488

Privacy Context Container: L1156430

fbshipit-source-id: af969188c914aae32651beccc491f6c76bd23536
  • Loading branch information
xw285cornell authored and facebook-github-bot committed May 18, 2023
1 parent f46904e commit 0a3380f
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 68 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> jagged_dense_elementwise_mul(
std::tuple<at::Tensor, std::vector<at::Tensor>> dense_to_jagged(
const at::Tensor& dense,
const std::vector<at::Tensor>& offsets,
const c10::optional<int64_t>& total_L);
const c10::optional<at::SymInt>& total_L);

std::tuple<at::Tensor, std::vector<at::Tensor>>
jagged_dense_elementwise_add_jagged_output(
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ namespace fbgemm_gpu {
Tensor dense_to_jagged_forward(
const Tensor& dense,
const std::vector<Tensor>& offsets,
const c10::optional<int64_t>& total_L) {
const c10::optional<at::SymInt>& total_L) {
// D is the embedding dimension
auto D = dense.size(-1);

// If total_L is not given then compute it
int64_t total_L_computed;
at::SymInt total_L_computed;
if (total_L.has_value()) {
total_L_computed = total_L.value();
} else {
total_L_computed = (int64_t)offsets.back().max().item<int64_t>();
}
auto values = at::empty({total_L_computed, D}, dense.options());
auto values = at::empty_symint({total_L_computed, D}, dense.options());
auto output = at::empty_like(values);

at::cuda::OptionalCUDAGuard device_guard;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace fbgemm_gpu {
at::Tensor jagged_to_padded_dense_backward(
const Tensor& grad_output,
const std::vector<Tensor>& offsets,
const int64_t total_L) {
const at::SymInt& total_L) {
auto grad_padded_values = grad_output;
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());
Expand All @@ -29,7 +29,8 @@ at::Tensor jagged_to_padded_dense_backward(

// Initialize with zeros so output will be zero for the portion truncated
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());
auto grad_values =
at::zeros_symint({total_L, D}, grad_padded_values.options());

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
Expand Down
19 changes: 12 additions & 7 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace fbgemm_gpu {
at::Tensor jagged_to_padded_dense_forward(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const at::ArrayRef<at::SymInt>& max_lengths,
const double padding_value) {
const size_t num_jagged_dim = offsets.size();
TORCH_CHECK(
Expand All @@ -40,7 +40,7 @@ at::Tensor jagged_to_padded_dense_forward(
values.sizes().end(),
1,
std::multiplies<size_t>())});
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)});
padded_values_shape.insert(
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());

Expand All @@ -50,7 +50,8 @@ at::Tensor jagged_to_padded_dense_forward(
if (!D_folded) {
padded_values_shape.push_back(values.size(-1));
}
Tensor padded_values = at::empty(padded_values_shape, values.options());
Tensor padded_values =
at::empty_symint(padded_values_shape, values.options());
Tensor padded_values_view =
D_folded ? padded_values.unsqueeze(-1) : padded_values;

Expand Down Expand Up @@ -121,7 +122,7 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
padded_values_per_key.push_back(jagged_to_padded_dense_forward(
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
{offsets},
{max_L},
at::ArrayRef<at::SymInt>({max_L}),
padding_value));
}
return padded_values_per_key;
Expand Down Expand Up @@ -179,7 +180,7 @@ stacked_jagged_2d_to_dense_forward_cuda(
padded_values_per_key.push_back(jagged_to_padded_dense_forward(
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
{offsets},
{max_L},
at::ArrayRef<at::SymInt>({max_L}),
padding_value));
}

Expand Down Expand Up @@ -301,7 +302,10 @@ Tensor jagged_2d_to_dense_gpu_forward(
Tensor offsets,
int64_t max_sequence_length) {
return jagged_to_padded_dense_forward(
values, {offsets}, {max_sequence_length}, /*padding_value=*/0);
values,
{offsets},
c10::ArrayRef<c10::SymInt>({max_sequence_length}),
/*padding_value=*/0);
}

namespace {
Expand Down Expand Up @@ -369,7 +373,8 @@ class JaggedDenseAddJaggedOutputGPUOp
Tensor dense_values_grad = jagged_to_padded_dense_forward(
grad_outputs[0],
offsets,
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
c10::fromIntArrayRefKnownNonNegative(std::vector<int64_t>(
dense_shape.begin() + 1, dense_shape.end() - 1)),
/*padding_value=*/0);
TORCH_CHECK(dense_values_grad.sizes() == dense_shape);

Expand Down
57 changes: 45 additions & 12 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>
#include <torch/torch.h>

#include "ATen/TensorUtils.h"
#include "fbgemm_gpu/sparse_ops.h"
Expand All @@ -35,17 +36,18 @@ class JaggedToPaddedDenseOp
const std::vector<int64_t>& max_lengths,
const double padding_value) {
ctx->save_for_backward(offsets);
ctx->saved_data["total_L"] = values.size(0);
ctx->saved_data["total_L"] = values.sym_size(0);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "")
.typed<at::Tensor(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const at::ArrayRef<at::SymInt>& max_lengths,
const double padding_value)>();
Tensor padded_values = op.call(values, offsets, max_lengths, padding_value);
Tensor padded_values = op.call(
values, offsets, c10::fromIntArrayRefSlow(max_lengths), padding_value);

return {padded_values};
}
Expand All @@ -54,7 +56,7 @@ class JaggedToPaddedDenseOp
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
int32_t total_L = ctx->saved_data["total_L"].toInt();
at::SymInt total_L = ctx->saved_data["total_L"].toSymInt();
TORCH_CHECK(grad_outputs.size() == 1);

TORCH_CHECK(total_L >= 0);
Expand All @@ -64,7 +66,7 @@ class JaggedToPaddedDenseOp
.typed<at::Tensor(
const Tensor& grad_output,
const std::vector<Tensor>& offsets,
const int64_t total_L)>();
const at::SymInt& total_L)>();
auto grad_values = op.call(grad_outputs[0], {offsets}, total_L);

return {
Expand All @@ -86,7 +88,15 @@ class JaggedDenseDenseAddJaggedOutputOp
const Tensor& dense_0,
const Tensor& dense_1) {
ctx->save_for_backward(offsets);
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
// toSymIntVector support is from a recent PR
// https://github.com/pytorch/pytorch/pull/101056,
// so protect it under a version guard for compatibility
ctx->saved_data["dense_shape"] = dense_0.sym_sizes();
#else
ctx->saved_data["dense_shape"] = dense_0.sizes();
#endif

static auto op =
c10::Dispatcher::singleton()
Expand All @@ -107,7 +117,12 @@ class JaggedDenseDenseAddJaggedOutputOp
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
auto dense_shape = ctx->saved_data["dense_shape"].toSymIntVector();
#else
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
#endif
TORCH_CHECK(grad_outputs.size() == 1);

static auto op =
Expand All @@ -116,12 +131,12 @@ class JaggedDenseDenseAddJaggedOutputOp
.typed<at::Tensor(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const at::ArrayRef<at::SymInt>& max_lengths,
const double padding_value)>();
Tensor dense_values_grad_0 = op.call(
grad_outputs[0],
offsets,
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
std::vector<at::SymInt>(dense_shape.begin() + 1, dense_shape.end() - 1),
/*padding_value=*/0);
Tensor dense_values_grad_1 = dense_values_grad_0;

Expand Down Expand Up @@ -249,19 +264,27 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
torch::autograd::AutogradContext* ctx,
const Tensor& dense,
const std::vector<Tensor>& offsets,
const c10::optional<int64_t>& total_L) {
const c10::optional<at::SymInt>& total_L) {
ctx->save_for_backward(offsets);

// dims of dense tensor: <batch, [maxlen0, maxlen1, ...], embedding_dim>
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
// toSymIntVector support is from a recent PR
// https://github.com/pytorch/pytorch/pull/101056,
// so protect it under a version guard for compatibility
ctx->saved_data["dense_shape"] = dense.sym_sizes();
#else
ctx->saved_data["dense_shape"] = dense.sizes();
#endif

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "")
.typed<Tensor(
const Tensor& dense,
const std::vector<Tensor>& offsets,
const c10::optional<int64_t>& total_L)>();
const c10::optional<at::SymInt>& total_L)>();
auto output = op.call(dense, offsets, total_L);

return {output};
Expand All @@ -271,7 +294,12 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
auto dense_shape = ctx->saved_data["dense_shape"].toSymIntVector();
#else
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
#endif
TORCH_CHECK(grad_outputs.size() == 1);

static auto op =
Expand All @@ -280,15 +308,20 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
.typed<Tensor(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const at::ArrayRef<at::SymInt>& max_lengths,
const double padding_value)>();
auto dense_values_grad = op.call(
grad_outputs[0],
offsets,
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
std::vector<at::SymInt>(dense_shape.begin() + 1, dense_shape.end() - 1),
/*padding_value=*/0);

#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
TORCH_CHECK(dense_values_grad.sym_sizes() == dense_shape);
#else
TORCH_CHECK(dense_values_grad.sizes() == dense_shape);
#endif

return {
dense_values_grad,
Expand Down Expand Up @@ -730,7 +763,7 @@ Tensor batched_dense_vec_jagged_2d_mul(
std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged(
const Tensor& dense,
const std::vector<Tensor>& offsets,
const c10::optional<int64_t>& total_L) {
const c10::optional<at::SymInt>& total_L) {
return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets};
}

Expand Down
Loading

0 comments on commit 0a3380f

Please sign in to comment.