Skip to content

Commit

Permalink
unify function signature of jagged_xD_to_dense (pytorch#813)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#813

As title

Reviewed By: jiaqizhai, jianyuh

Differential Revision: D33066551

fbshipit-source-id: 9c186312ed67bc507bf06dba26aceb83ead0a0b2
  • Loading branch information
xing-liu authored and facebook-github-bot committed Dec 16, 2021
1 parent 06feb03 commit d334f65
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 97 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ at::Tensor batched_unary_embeddings_backward_cuda(
const at::Tensor& indices);

at::Tensor jagged_2d_to_dense_forward_cuda(
at::Tensor embeddings,
at::Tensor values,
at::Tensor offsets,
int32_t max_L);

at::Tensor jagged_2d_to_dense_backward_cuda(
at::Tensor grad_padded_embeddings,
at::Tensor grad_padded_values,
at::Tensor offsets,
int32_t total_L);

Expand Down
70 changes: 35 additions & 35 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,8 @@ __global__ void jagged_2d_to_dense_forward_kernel(
int32_t max_L,
int32_t D,
at::PackedTensorAccessor32<index_t, 1> offsets,
at::PackedTensorAccessor64<scalar_t, 2> embeddings,
at::PackedTensorAccessor64<scalar_t, 3> padded_embeddings) {
at::PackedTensorAccessor64<scalar_t, 2> values,
at::PackedTensorAccessor64<scalar_t, 3> padded_values) {
int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y;
int32_t l = b_l / B;
int32_t b = b_l % B;
Expand All @@ -1297,39 +1297,39 @@ __global__ void jagged_2d_to_dense_forward_kernel(
if (l < length) {
for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) {
if (d + threadIdx.x < D) {
padded_embeddings[b][l][d + threadIdx.x] =
embeddings[row_start + l][d + threadIdx.x];
padded_values[b][l][d + threadIdx.x] =
values[row_start + l][d + threadIdx.x];
}
}
} else {
for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) {
if (d + threadIdx.x < D) {
padded_embeddings[b][l][d + threadIdx.x] = 0.0;
padded_values[b][l][d + threadIdx.x] = 0.0;
}
}
}
}
Tensor jagged_2d_to_dense_forward_cuda(
Tensor embeddings,
Tensor values,
Tensor offsets,
int32_t max_L) {
TORCH_CHECK(embeddings.dim() == 2);
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(max_L > 0);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(embeddings.get_device());
device_guard.set_index(values.get_device());
int32_t D = embeddings.size(1);
int32_t D = values.size(1);
int32_t B = offsets.numel() - 1;
auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options());
const auto embeddings_contig = embeddings.contiguous();
auto padded_values = at::empty({B, max_L, D}, values.options());
const auto values_contig = values.contiguous();
const auto offsets_contig = offsets.contiguous();
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "jagged_2d_to_dense_forward_kernel_1", ([&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(),
values.scalar_type(),
"jagged_2d_to_dense_forward_kernel_2",
([&]() {
jagged_2d_to_dense_forward_kernel<index_t, scalar_t>
Expand All @@ -1345,12 +1345,12 @@ Tensor jagged_2d_to_dense_forward_cuda(
max_L,
D,
offsets_contig.packed_accessor32<index_t, 1>(),
embeddings_contig.packed_accessor64<scalar_t, 2>(),
padded_embeddings.packed_accessor64<scalar_t, 3>());
values_contig.packed_accessor64<scalar_t, 2>(),
padded_values.packed_accessor64<scalar_t, 3>());
}));
}));
return padded_embeddings;
return padded_values;
}
template <typename index_t, typename scalar_t>
Expand All @@ -1359,8 +1359,8 @@ __global__ void jagged_2d_to_dense_backward_kernel(
int32_t max_L,
int32_t D,
at::PackedTensorAccessor32<index_t, 1> offsets,
at::PackedTensorAccessor64<scalar_t, 3> grad_padded_embeddings,
at::PackedTensorAccessor64<scalar_t, 2> grad_embeddings) {
at::PackedTensorAccessor64<scalar_t, 3> grad_padded_values,
at::PackedTensorAccessor64<scalar_t, 2> grad_values) {
int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y;
int32_t l = b_l / B;
int32_t b = b_l % B;
Expand All @@ -1373,37 +1373,37 @@ __global__ void jagged_2d_to_dense_backward_kernel(
if (l < length) {
for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) {
if (d + threadIdx.x < D) {
grad_embeddings[row_start + l][d + threadIdx.x] =
grad_padded_embeddings[b][l][d + threadIdx.x];
grad_values[row_start + l][d + threadIdx.x] =
grad_padded_values[b][l][d + threadIdx.x];
}
}
}
}
Tensor jagged_2d_to_dense_backward_cuda(
Tensor grad_padded_embeddings,
Tensor grad_padded_values,
Tensor offsets,
int32_t total_L) {
TORCH_CHECK(grad_padded_embeddings.dim() == 3);
TORCH_CHECK(grad_padded_values.dim() == 3);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(total_L >= 0);
TORCH_CHECK(offsets.numel() == grad_padded_embeddings.size(0) + 1);
TORCH_CHECK(offsets.numel() == grad_padded_values.size(0) + 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_embeddings.get_device());
int32_t B = grad_padded_embeddings.size(0);
int32_t max_L = grad_padded_embeddings.size(1);
int32_t D = grad_padded_embeddings.size(2);
auto grad_embeddings =
at::zeros({total_L, D}, grad_padded_embeddings.options());
const auto grad_padded_embeddings_config =
grad_padded_embeddings.contiguous();
device_guard.set_index(grad_padded_values.get_device());
int32_t B = grad_padded_values.size(0);
int32_t max_L = grad_padded_values.size(1);
int32_t D = grad_padded_values.size(2);
auto grad_values =
at::zeros({total_L, D}, grad_padded_values.options());
const auto grad_padded_values_config =
grad_padded_values.contiguous();
const auto offsets_contig = offsets.contiguous();
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "jagged_2d_to_dense_backward_kernel_1", ([&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_embeddings.scalar_type(),
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel_2",
([&]() {
jagged_2d_to_dense_backward_kernel<index_t, scalar_t>
Expand All @@ -1419,13 +1419,13 @@ Tensor jagged_2d_to_dense_backward_cuda(
max_L,
D,
offsets_contig.packed_accessor32<index_t, 1>(),
grad_padded_embeddings_config
grad_padded_values_config
.packed_accessor64<scalar_t, 3>(),
grad_embeddings.packed_accessor64<scalar_t, 2>());
grad_values.packed_accessor64<scalar_t, 2>());
}));
}));
return grad_embeddings;
return grad_values;
}
template <typename index_t, typename data_t>
Expand Down
36 changes: 18 additions & 18 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ void jagged_2d_to_dense_forward_kernel(
int32_t max_L,
int32_t D,
const index_t* offsets,
const scalar_t* embeddings_data,
scalar_t* padded_embeddings_data) {
const scalar_t* values_data,
scalar_t* padded_values_data) {
const auto block_size = max_L * D;
const auto embedding_byte_size = D * sizeof(scalar_t);
for (auto b = 0; b < B; ++b) {
Expand All @@ -851,53 +851,53 @@ void jagged_2d_to_dense_forward_kernel(
}
auto padding_length = max_L - length;
memcpy(
&padded_embeddings_data[b * block_size],
&embeddings_data[start_idx * D],
&padded_values_data[b * block_size],
&values_data[start_idx * D],
length * embedding_byte_size);
memset(
&padded_embeddings_data[b * block_size + length * D],
&padded_values_data[b * block_size + length * D],
0,
padding_length * embedding_byte_size);
}
}

Tensor jagged_2d_to_dense_forward_cpu(
Tensor embeddings,
Tensor values,
Tensor offsets,
int64_t max_L) {
TORCH_CHECK(embeddings.dim() == 2);
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(max_L > 0);

const auto B = offsets.numel() - 1;
const auto D = embeddings.size(1);
const auto embeddings_contig = embeddings.expect_contiguous();
const auto D = values.size(1);
const auto values_contig = values.expect_contiguous();
const auto offsets_contig = offsets.expect_contiguous();

if (embeddings.size(0) == 0) {
return at::zeros({B, max_L, D}, embeddings.options());
if (values.size(0) == 0) {
return at::zeros({B, max_L, D}, values.options());
}

auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options());
auto padded_values = at::empty({B, max_L, D}, values.options());
AT_DISPATCH_INDEX_TYPES(
offsets_contig->scalar_type(),
"jagged_2d_to_dense_forward_by_offsets",
([&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings_contig->scalar_type(),
"jagged_2d_to_dense_forward_by_embeddings",
values_contig->scalar_type(),
"jagged_2d_to_dense_forward_by_values",
([&]() {
jagged_2d_to_dense_forward_kernel(
B,
max_L,
D,
offsets_contig->data_ptr<index_t>(),
embeddings_contig->data_ptr<scalar_t>(),
padded_embeddings.data_ptr<scalar_t>());
values_contig->data_ptr<scalar_t>(),
padded_values.data_ptr<scalar_t>());
}));
}));

return padded_embeddings;
return padded_values;
}

template <typename index_t, typename scalar_t>
Expand Down Expand Up @@ -1190,7 +1190,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor");
m.def(
"jagged_2d_to_dense(Tensor embeddings, Tensor offsets, int max_sequence_length) -> Tensor");
"jagged_2d_to_dense(Tensor values, Tensor offsets, int max_sequence_length) -> Tensor");
m.def(
"jagged_1d_to_dense(Tensor values, Tensor offsets, int max_sequence_length, int padding_value) -> Tensor");
m.def(
Expand Down
18 changes: 9 additions & 9 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ class Jagged2DToDenseGPUOp
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor embeddings,
Tensor values,
Tensor offsets,
int32_t max_sequence_length) {
int32_t total_L = embeddings.size(0);
int32_t total_L = values.size(0);
ctx->save_for_backward({offsets});
ctx->saved_data["total_L"] = total_L;

return {jagged_2d_to_dense_forward_cuda(
embeddings, offsets, max_sequence_length)};
values, offsets, max_sequence_length)};
}

static torch::autograd::variable_list backward(
Expand All @@ -82,23 +82,23 @@ class Jagged2DToDenseGPUOp
int32_t total_L = ctx->saved_data["total_L"].toInt();

using torch::autograd::Variable;
auto grad_padded_embeddings = grad_outputs[0];
auto grad_embeddings = jagged_2d_to_dense_backward_cuda(
grad_padded_embeddings, offsets, total_L);
auto grad_padded_values = grad_outputs[0];
auto grad_values = jagged_2d_to_dense_backward_cuda(
grad_padded_values, offsets, total_L);
return {
grad_embeddings,
grad_values,
Variable(), // offsets
Variable() // max_sequence_length
};
}
};

Tensor jagged_2d_to_dense_gpu(
Tensor embeddings,
Tensor values,
Tensor offsets,
int64_t max_sequence_length) {
return Jagged2DToDenseGPUOp::apply(
embeddings, offsets, static_cast<int32_t>(max_sequence_length))[0];
values, offsets, static_cast<int32_t>(max_sequence_length))[0];
}

} // namespace fbgemm_gpu
Expand Down
Loading

0 comments on commit d334f65

Please sign in to comment.