Skip to content

Commit

Permalink
Add complex type compatibility for stft api and stft op. (#40113)
Browse files Browse the repository at this point in the history
* Add stft_op.

* Add stft_grad_op.

* Add stft_op unittest.

* [DLTP-45176] Add complex compatibility in static mode for stft api.

* [DLTP-45176] Add complex compatibility in static mode for stft api.

* Add doc.

* Update unitests of stft op.

* Update spectral helper.

* fix coding style.
  • Loading branch information
KPatr1ck authored Mar 23, 2022
1 parent 3d0be93 commit 319f95d
Show file tree
Hide file tree
Showing 20 changed files with 1,901 additions and 1,419 deletions.
20 changes: 14 additions & 6 deletions paddle/fluid/operators/frame_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,26 @@ class FrameOp : public framework::OperatorWithKernel {
end_axis = x_rank - 2;
}

PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length, seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length, seq_length));
}

// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}

n_frames = 1 + (seq_length - frame_length) / hop_length;
if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}

if (axis == 0) {
// (n_frames, frame_length, ...)
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,17 @@ REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
11 changes: 9 additions & 2 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
paddle::platform::complex<float>>,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
23 changes: 16 additions & 7 deletions paddle/fluid/operators/overlap_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int seq_length;

int start_axis;
int end_axis;
Expand All @@ -69,14 +70,22 @@ class OverlapAddOp : public framework::OperatorWithKernel {
end_axis = x_rank - 3;
}

PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(
hop_length, frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length, frame_length));
}

const int seq_length = (n_frames - 1) * hop_length + frame_length;
if (n_frames == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}

// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
Expand Down
Loading

0 comments on commit 319f95d

Please sign in to comment.