Skip to content

Commit

Permalink
Align CTC grad scale same with ESPNet (PaddlePaddle#34729)
Browse files Browse the repository at this point in the history
* dygraph support more ctc grad scale

* scale for 1.x

* fix unitest

* fix unitest

* format code

* fix unittest

* fix log info

* unittest cov

* fix format;notest,test=cpu,coverage

* skip ctc_loss egs;test=cpu

* warpctc grad cov;test=coverage

* add dygraph test;test=coverage

* format;test=cpu,coverage

* format;test=cpu

* add api compat;test=cpu

* add cpu test

* rename

* rename

* fix

* fix test

* format

* eigen cpu

* eigen gpu grad pass

* cuda gpu pass

* format

* fix ci
  • Loading branch information
zh794390558 authored Aug 17, 2021
1 parent 8046e33 commit 10f9644
Show file tree
Hide file tree
Showing 15 changed files with 602 additions and 104 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS execu

if (WITH_GPU OR WITH_ROCM)
if(WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu)
# warpctc_op needs cudnn 7 above
elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
Expand Down
31 changes: 27 additions & 4 deletions paddle/fluid/operators/math/sequence_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void CopyValidData(framework::Tensor* dst_tensor,
const framework::Tensor* src_tensor,
const framework::Vector<size_t>& seq_offsets,
int pad_seq_len, int step_width, bool norm_by_len,
CopyType type, PadLayout layout) {
bool norm_by_batchsize, bool norm_by_total_logits_len,
int total_logits_len, CopyType type, PadLayout layout) {
int seq_num = seq_offsets.size() - 1;
const T* src_data = src_tensor->data<T>();
T* dst_data = dst_tensor->data<T>();
Expand All @@ -54,7 +55,21 @@ void CopyValidData(framework::Tensor* dst_tensor,
int pad_data_offset = layout == kBatchLengthWidth
? seq_idx * pad_seq_len * step_width
: seq_idx * step_width;
float scale = 1.0f / static_cast<float>(valid_seq_len);

float scale = 1.0f;
if (norm_by_total_logits_len) {
scale = 1.0f / static_cast<float>(total_logits_len);
VLOG(3) << "[warpctc grad][norm_by_total_logits_len]: scale " << scale
<< "total_logits_len " << total_logits_len;
} else if (norm_by_batchsize) {
scale = 1.0f / static_cast<float>(seq_num);
VLOG(3) << "[warpctc grad][norm_by_batchsize]: scale " << scale << "B "
<< seq_num;
} else if (norm_by_len) {
scale = 1.0f / static_cast<float>(valid_seq_len);
VLOG(3) << "[warpctc grad][norm_by_len]: scale " << scale << "T "
<< valid_seq_len;
}

for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) {
const T* src =
Expand Down Expand Up @@ -97,6 +112,8 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
Expand Down Expand Up @@ -131,7 +148,8 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
}

CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kSeqToPad, layout);
step_width, norm_by_times, false, false, 0, kSeqToPad,
layout);
}
};

Expand All @@ -142,20 +160,25 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int total_logits_len = TotalSequenceLength(seq_offsets);
int step_width = seq_tensor->numel() / seq_tensor_dims[0];

CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);

CopyValidData<T>(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kPadToSeq, layout);
step_width, norm_by_times, norm_by_batchsize,
norm_by_total_logits_len, total_logits_len, kPadToSeq,
layout);
}
};

Expand Down
24 changes: 20 additions & 4 deletions paddle/fluid/operators/math/sequence_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ template <typename T, CopyType Type>
__global__ void SequencePaddingKernel(
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
const size_t step_width, bool norm_by_len, const PadLayout layout) {
const size_t step_width, bool norm_by_len, bool norm_by_batchsize,
bool norm_by_total_logits_len, int total_logits_len,
const PadLayout layout) {
size_t seq_idx = blockIdx.y;
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

Expand All @@ -38,7 +40,15 @@ __global__ void SequencePaddingKernel(
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

if (step_idx < seq_len) {
float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
float scale = 1.0f;
if (norm_by_total_logits_len) {
scale = 1.0f / static_cast<float>(total_logits_len);
} else if (norm_by_batchsize) {
scale = 1.0f / static_cast<float>(seq_num);
} else if (norm_by_len) {
scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
}

for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = scale * src_data[i];
}
Expand All @@ -57,6 +67,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
Expand Down Expand Up @@ -107,7 +119,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
step_width, norm_by_times, false, false, 0, layout);
}
};

Expand All @@ -118,6 +130,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
Expand All @@ -126,6 +140,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
int total_logits_len = TotalSequenceLength(seq_offsets);
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;

Expand Down Expand Up @@ -159,7 +174,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
seq_data, pad_data, nullptr, false,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
step_width, norm_by_times, norm_by_batchsize, norm_by_total_logits_len,
total_logits_len, layout);
}
};

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/math/sequence_padding.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class PaddingLoDTensorFunctor {
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth);
};

Expand All @@ -117,6 +119,8 @@ class UnpaddingLoDTensorFunctor {
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
bool norm_by_batchsize = false,
bool norm_by_total_logits_len = false,
const PadLayout layout = kBatchLengthWidth);
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/math/sequence_padding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ void TestSequencePadding(const DeviceContext &context,
}

paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
context, seq, &padding, pad_value, -1, 0, false,
context, seq, &padding, pad_value, -1, 0, false, false, false,
paddle::operators::math::kLengthBatchWidth);

seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
context, padding, &seq_back, -1, 0, false,
context, padding, &seq_back, -1, 0, false, false, false,
paddle::operators::math::kLengthBatchWidth);

if (paddle::platform::is_cpu_place(place)) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/sequence_ops/sequence_pad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SequencePadOpKernel : public framework::OpKernel<T> {

math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);

LoDTensor seq_len;
seq_len.Resize(len_t->dims());
Expand All @@ -72,7 +72,7 @@ class SequencePadGradOpKernel : public framework::OpKernel<T> {

math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);
}
}
};
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/sequence_ops/sequence_unpad_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {

int64_t padded_length = x_t->dims()[1];
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
dev_ctx, *x_t, out_t, padded_length, 0, false, false, false,
math::kBatchLengthWidth);
}
};

Expand All @@ -93,7 +94,7 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {

math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x, zero_pads,
padded_length, 0, false, math::kBatchLengthWidth);
padded_length, 0, false, false, false, math::kBatchLengthWidth);
}
}
};
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/operators/warpctc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
"normalize the gradients by the number of time-step, "
"which is also the sequence's length.")
.SetDefault(false);
AddAttr<bool>(
"norm_by_batchsize",
"(bool, default: false), normalize the loss by the batch size."
"If True, supersedes norm_by_times")
.SetDefault(false);
AddAttr<bool>(
"norm_by_total_logits_len",
"(bool, default: false), normalize the loss by the total number of "
"frames"
"in the batch. If True, supersedes norm_by_batchsize and norm_by_times")
.SetDefault(false);
AddComment(R"DOC(
An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
Expand Down Expand Up @@ -206,3 +217,21 @@ REGISTER_OP_CPU_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_VERSION(warpctc)
.AddCheckpoint(
R"ROC(
Upgrade warpctc add a new attribute [norm_by_batchsize] and [norm_by_total_logits_len])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr(
"norm_by_batchsize",
"(bool, default: false), normalize the loss by the batch size."
"If True, supersedes norm_by_times",
false)
.NewAttr("norm_by_total_logits_len",
"(bool, default: false), normalize the loss by the total "
"number of "
"frames"
"in the batch. If True, supersedes norm_by_batchsize and "
"norm_by_times",
false));
Loading

0 comments on commit 10f9644

Please sign in to comment.