From f9e3309feedadfdd2f4239669606375d91c3b594 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Fri, 10 Sep 2021 21:26:09 +0800 Subject: [PATCH] Add deframe op and stft/istft api. (#23) * Add frame api * Add deframe op and kernels. * Add stft and istft apis. * Add deframe api. Update stft and istft apis. * Fix bug in frame_from_librosa function when input dims >= 3 * Rename deframe to overlap_add. * Update istft. * Update after code review. --- paddle/fluid/operators/frame_op.cc | 11 +- paddle/fluid/operators/frame_op.cu | 3 +- paddle/fluid/operators/frame_op.h | 179 +--------- paddle/fluid/operators/math/seq2col.h | 186 ++++++++++ paddle/fluid/operators/overlap_add_op.cc | 183 ++++++++++ paddle/fluid/operators/overlap_add_op.cu | 43 +++ paddle/fluid/operators/overlap_add_op.h | 304 ++++++++++++++++ .../fluid/tests/unittests/test_frame_op.py | 57 ++- python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/signal.py | 325 ++++++++++++++++++ 10 files changed, 1088 insertions(+), 204 deletions(-) create mode 100644 paddle/fluid/operators/math/seq2col.h create mode 100644 paddle/fluid/operators/overlap_add_op.cc create mode 100644 paddle/fluid/operators/overlap_add_op.cu create mode 100644 paddle/fluid/operators/overlap_add_op.h create mode 100644 python/paddle/tensor/signal.py diff --git a/paddle/fluid/operators/frame_op.cc b/paddle/fluid/operators/frame_op.cc index 9120c62b3cd84..56420bde7cba5 100644 --- a/paddle/fluid/operators/frame_op.cc +++ b/paddle/fluid/operators/frame_op.cc @@ -32,6 +32,11 @@ class FrameOp : public framework::OperatorWithKernel { const auto x_dims = ctx->GetInputDim("X"); const int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE( + x_rank, 1, platform::errors::InvalidArgument( + "Input(X) of FrameOp should be a tensor which contains " + "at least 1 dimension, but got rank %s.", + x_rank)); PADDLE_ENFORCE_GT(hop_length, 0, platform::errors::InvalidArgument( "Attribute(hop_length) of FrameOp should be greater " @@ -111,7 +116,7 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Frame Operator. - Frame op slices frames from input sequence $X$. + Frame op convert time sequences into frames. )DOC"); } @@ -174,7 +179,9 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( - frame_grad, ops::FrameGradKernel, + frame_grad, ops::FrameGradKernel, + ops::FrameGradKernel, + ops::FrameGradKernel, ops::FrameGradKernel, ops::FrameGradKernel>, diff --git a/paddle/fluid/operators/frame_op.cu b/paddle/fluid/operators/frame_op.cu index 203cc757ce687..797e0aa0111d8 100644 --- a/paddle/fluid/operators/frame_op.cu +++ b/paddle/fluid/operators/frame_op.cu @@ -29,7 +29,8 @@ REGISTER_OP_CUDA_KERNEL( paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( - frame_grad, + frame_grad, ops::FrameGradKernel, + ops::FrameGradKernel, ops::FrameGradKernel, ops::FrameGradKernel, ops::FrameGradKernel -struct DataMappingFunctor { - DataMappingFunctor(const T* x, T* out, size_t seq_length, size_t frame_length, - size_t n_frames, size_t hop_length) - : x_(x), - out_(out), - seq_length_(seq_length), - frame_length_(frame_length), - n_frames_(n_frames), - hop_length_(hop_length) {} - - /* - Convert sequences to frames. - - 1. Dimension infomation: - - Sequences Frames - (N, seq_length) -> (N, frame_length, n_frames) - - 2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from: - - a. Notion - - `i` stands for the flattened index of a bunch of frames. - - `src_idx` and `trg_idx` are the 1D indices of seqs and frames - respectivly. - - b. Sample idx - ```cpp - sample_idx = i / (n_frames_ * frame_length_); - ``` - - c. Maps `i` to `f` and `n`. - ```cpp - f = i % (n_frames_ * frame_length_) / n_frames_; - n = i % (n_frames_ * frame_length_) % n_frames_; - ``` - - d. Replace `sample_idx`, `f` and `n` in the following eqations: - ```cpp - src_idx = sample_idx * seq_length_ + n * hop_length_ + f; - trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n; - out_[trg_idx] = x_[src_idx]; - ``` - - e. Result can be deduced shown in the function body below. - */ - HOSTDEVICE void operator()(size_t i) const { - size_t src_idx; - size_t trg_idx; - src_idx = i / (n_frames_ * frame_length_) * seq_length_ + - i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ + - i % (n_frames_ * frame_length_) / n_frames_; - trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ + - i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ + - i % (n_frames_ * frame_length_) % n_frames_; - out_[trg_idx] = x_[src_idx]; - } - - const T* x_; - T* out_; - size_t seq_length_; - size_t frame_length_; - size_t n_frames_; - size_t hop_length_; -}; - -template -struct DataMappingGradFunctor { - DataMappingGradFunctor(const T* d_out, T* d_x, size_t seq_length, - size_t frame_length, size_t n_frames, - size_t hop_length) - : d_out_(d_out), - d_x_(d_x), - seq_length_(seq_length), - frame_length_(frame_length), - n_frames_(n_frames), - hop_length_(hop_length) {} - - /* - Accumulate output gradient d_out to d_x. - - 1. Dimension infomation: - - d_out d_x - (N, frame_length, n_frames) -> (N, seq_length) - - 2. Using a sliding window to find source indices from `d_out` according to - `i`: - - a. Notion - - `i` stands for the flattened index of `d_x`. - - `seq_i` stands for a relative index of a `d_x` sample. - - `left`: Starting index of a frame window. - - `right`: Ending index of a frame window. - - b. Sample idx - ```cpp - sample_idx = i / seq_length_; - ``` - - c. Slides a window with length of `frame_length` to find `f` and `n`. - - `n`: The idx of num_frames_, increases in each hop. - - `f`: The idx of frame_lengths_, relative idx from left of a sliding - window. - - d. Accumulate all grads from d_out. - ```cpp - d_x_[i] += - d_out_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; - ``` - */ - HOSTDEVICE void operator()(size_t i) const { - size_t sample_idx = i / seq_length_; - size_t seq_i = i % seq_length_; - - // Sliding window - d_x_[i] = 0; // Init d_x_[i] to 0, and sums up all - // grads from d_out_ in the while loop. - - size_t n = get_start_frame_idx(seq_i); - size_t f; - size_t left = n * hop_length_; - size_t right = left + frame_length_ - 1; - - while (left <= seq_i && right < seq_length_) { - f = seq_i - left; - d_x_[i] += - d_out_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; - // Next frame. - left += hop_length_; - right += hop_length_; - n += 1; - } - } - - /* - Calculate minimum value of frame index `n` to satisfy the inequality: - - seq_i <= right - ==> seq_i <= left + frame_length - 1 - ==> seq_i <= hop_length_ * n + frame_length_ - 1 - */ - HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const { - int64_t tmp = seq_i + 1 - frame_length_; - if (tmp > 0) { - size_t n = tmp / hop_length_; - if (tmp % hop_length_ == 0) { - return n; - } else { - return n + 1; - } - } else { - return 0; - } - } - - const T* d_out_; - T* d_x_; - size_t seq_length_; - size_t frame_length_; - size_t n_frames_; - size_t hop_length_; -}; - template struct FrameFunctor { void operator()(const DeviceContext& dev_ctx, const Tensor* input, @@ -203,12 +40,12 @@ struct FrameFunctor { platform::ForRange for_range(dev_ctx, numel); if (!is_grad) { - DataMappingFunctor functor(input_data, output_data, seq_length, - frame_length, n_frames, hop_length); + math::Seq2ColFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); for_range(functor); } else { - DataMappingGradFunctor functor(input_data, output_data, seq_length, - frame_length, n_frames, hop_length); + math::Col2SeqFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); for_range(functor); } } @@ -385,10 +222,8 @@ class FrameGradKernel : public framework::OpKernel { falls into Case 2. Finally, it restores the dims of `d_x` tensor. */ void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); d_x->mutable_data(ctx.GetPlace()); const size_t d_out_rank = d_out->dims().size(); const size_t d_x_rank = d_x->dims().size(); diff --git a/paddle/fluid/operators/math/seq2col.h b/paddle/fluid/operators/math/seq2col.h new file mode 100644 index 0000000000000..56134b6f0ea5c --- /dev/null +++ b/paddle/fluid/operators/math/seq2col.h @@ -0,0 +1,186 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace operators { +namespace math { + +template +struct Seq2ColFunctor { + Seq2ColFunctor(const T* seq, T* col, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length) + : seq_(seq), + col_(col), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Convert sequences to frames. + + 1. Dimension infomation: + + Sequences Frames + (N, seq_length) -> (N, frame_length, n_frames) + + 2. Mapping from `i` to `src_idx` and `trg_idx` can be derived from: + + a. Notion + - `i` stands for the flattened index of a bunch of frames. + - `src_idx` and `trg_idx` are the 1D indices of seqs and frames + respectivly. + + b. Sample idx + ```cpp + sample_idx = i / (n_frames_ * frame_length_); + ``` + + c. Maps `i` to `f` and `n`. + ```cpp + f = i % (n_frames_ * frame_length_) / n_frames_; + n = i % (n_frames_ * frame_length_) % n_frames_; + ``` + + d. Replace `sample_idx`, `f` and `n` in the following eqations: + ```cpp + src_idx = sample_idx * seq_length_ + n * hop_length_ + f; + trg_idx = sample_idx * n_frames_ * frame_length_ + f * n_frames_ + n; + col_[trg_idx] = seq_[src_idx]; + ``` + + e. Result can be deduced shown in the function body below. + */ + HOSTDEVICE void operator()(size_t i) const { + size_t src_idx; + size_t trg_idx; + src_idx = i / (n_frames_ * frame_length_) * seq_length_ + + i % (n_frames_ * frame_length_) % n_frames_ * hop_length_ + + i % (n_frames_ * frame_length_) / n_frames_; + trg_idx = i / (n_frames_ * frame_length_) * n_frames_ * frame_length_ + + i % (n_frames_ * frame_length_) / n_frames_ * n_frames_ + + i % (n_frames_ * frame_length_) % n_frames_; + col_[trg_idx] = seq_[src_idx]; + } + + const T* seq_; + T* col_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +template +struct Col2SeqFunctor { + Col2SeqFunctor(const T* col, T* seq, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length) + : col_(col), + seq_(seq), + seq_length_(seq_length), + frame_length_(frame_length), + n_frames_(n_frames), + hop_length_(hop_length) {} + + /* + Accumulate output gradient d_out to d_x. + + 1. Dimension infomation: + + d_out d_x + (N, frame_length, n_frames) -> (N, seq_length) + + 2. Using a sliding window to find source indices from `d_out` according to + `i`: + + a. Notion + - `i` stands for the flattened index of `d_x`. + - `seq_i` stands for a relative index of a `d_x` sample. + - `left`: Starting index of a frame window. + - `right`: Ending index of a frame window. + + b. Sample idx + ```cpp + sample_idx = i / seq_length_; + ``` + + c. Slides a window with length of `frame_length` to find `f` and `n`. + - `n`: The idx of num_frames_, increases in each hop. + - `f`: The idx of frame_lengths_, relative idx from left of a sliding + window. + + d. Accumulate all grads from d_out. + ```cpp + seq_[i] += + col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + ``` + */ + HOSTDEVICE void operator()(size_t i) const { + size_t sample_idx = i / seq_length_; + size_t seq_i = i % seq_length_; + + // Sliding window + seq_[i] = 0; // Init seq_[i] to 0, and sums up all + // grads from col_ in the while loop. + + size_t n = get_start_frame_idx(seq_i); + size_t f; + size_t left = n * hop_length_; + size_t right = left + frame_length_ - 1; + + while (left <= seq_i && right < seq_length_) { + f = seq_i - left; + seq_[i] += + col_[sample_idx * frame_length_ * n_frames_ + f * n_frames_ + n]; + // Next frame. + left += hop_length_; + right += hop_length_; + n += 1; + } + } + + /* + Calculate minimum value of frame index `n` to satisfy the inequality: + + seq_i <= right + ==> seq_i <= left + frame_length - 1 + ==> seq_i <= hop_length_ * n + frame_length_ - 1 + */ + HOSTDEVICE size_t get_start_frame_idx(size_t seq_i) const { + int64_t tmp = seq_i + 1 - frame_length_; + if (tmp > 0) { + size_t n = tmp / hop_length_; + if (tmp % hop_length_ == 0) { + return n; + } else { + return n + 1; + } + } else { + return 0; + } + } + + const T* col_; + T* seq_; + size_t seq_length_; + size_t frame_length_; + size_t n_frames_; + size_t hop_length_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/overlap_add_op.cc b/paddle/fluid/operators/overlap_add_op.cc new file mode 100644 index 0000000000000..f710bc9adcbe7 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.cc @@ -0,0 +1,183 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/overlap_add_op.h" + +namespace paddle { +namespace operators { + +class OverlapAddOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add"); + + const int hop_length = ctx->Attrs().Get("hop_length"); + const int axis = ctx->Attrs().Get("axis"); + + const auto x_dims = ctx->GetInputDim("X"); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GE( + x_rank, 2, + platform::errors::InvalidArgument( + "Input(X) of OverlapAddOp should be a tensor which contains " + "at least 2 dimensions, but got rank %s.", + x_rank)); + + PADDLE_ENFORCE_GT( + hop_length, 0, + platform::errors::InvalidArgument( + "Attribute(hop_length) of OverlapAddOp should be greater " + "than 0, but got %s.", + hop_length)); + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), true, + platform::errors::InvalidArgument( + "Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", + axis)); + + std::vector output_shape; + int n_frames; + int frame_length; + + int start_axis; + int end_axis; + if (axis == 0) { + n_frames = x_dims[0]; + frame_length = x_dims[1]; + start_axis = 2; + end_axis = x_rank - 1; + } else { + n_frames = x_dims[x_rank - 1]; + frame_length = x_dims[x_rank - 2]; + start_axis = 0; + end_axis = x_rank - 3; + } + + const int seq_length = (n_frames - 1) * hop_length + frame_length; + + // It won't go into for loop when x_rank == 2U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + output_shape.insert(output_shape.begin(), seq_length); + } else { + // (..., seq_length) + output_shape.push_back(seq_length); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of overlap_add op."); + AddOutput("Out", "(Tensor), The output tensor of overlap_add op."); + AddAttr("hop_length", + "Hop Length" + "Other doc of hop length arg..."); + AddAttr("axis", + "Axis" + "Other doc of axis arg...") + .SetDefault(-1); + AddComment(R"DOC( + OverlapAdd Operator. + + OverlapAdd op convert frames into time sequences. + + )DOC"); + } +}; + +class OverlapAddOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "overlap_add_grad"); + const auto x_dims = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto in_dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(in_dtype, ctx.GetPlace()); + } +}; + +template +class OverlapAddOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("overlap_add_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(overlap_add, ops::OverlapAddOp, ops::OverlapAddOpMaker, + ops::OverlapAddOpGradMaker, + ops::OverlapAddOpGradMaker); + +REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad); + +REGISTER_OP_CPU_KERNEL( + overlap_add, ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel>, + ops::OverlapAddKernel>); + +REGISTER_OP_CPU_KERNEL( + overlap_add_grad, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel>, + ops::OverlapAddGradKernel>); diff --git a/paddle/fluid/operators/overlap_add_op.cu b/paddle/fluid/operators/overlap_add_op.cu new file mode 100644 index 0000000000000..2b7935e0191b7 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.cu @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/overlap_add_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + overlap_add, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel, + ops::OverlapAddKernel>, + ops::OverlapAddKernel>); + +REGISTER_OP_CUDA_KERNEL( + overlap_add_grad, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel, + ops::OverlapAddGradKernel>, + ops::OverlapAddGradKernel>); diff --git a/paddle/fluid/operators/overlap_add_op.h b/paddle/fluid/operators/overlap_add_op.h new file mode 100644 index 0000000000000..865659ee942e4 --- /dev/null +++ b/paddle/fluid/operators/overlap_add_op.h @@ -0,0 +1,304 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/seq2col.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +struct OverlapAddFunctor { + void operator()(const DeviceContext& dev_ctx, const Tensor* input, + Tensor* output, size_t seq_length, size_t frame_length, + size_t n_frames, size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + const auto* input_data = input->data(); + auto* output_data = output->data(); + + platform::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + math::Col2SeqFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } else { + math::Seq2ColFunctor functor(input_data, output_data, seq_length, + frame_length, n_frames, hop_length); + for_range(functor); + } + } +}; + +template +class OverlapAddKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + const size_t x_rank = x->dims().size(); + const size_t out_rank = out->dims().size(); + + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; + const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2]; + const int seq_length = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + Tensor x_(x->type()); + x_ = *x; + + framework::DDim preserved_dims; + if (out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim x_resized_dims; + framework::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(out->dims(), 1, out_rank); + x_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + out_resized_dims = {seq_length, framework::product(preserved_dims)}; + } else { + preserved_dims = framework::slice_ddim(out->dims(), 0, out_rank - 1); + x_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + out_resized_dims = {framework::product(preserved_dims), seq_length}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + Tensor trans_x(x_.type()); + Tensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_rank == 1U) { + trans_out = *out; + + std::vector perm_x{1, 0}; + auto x_dims_vec = framework::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(framework::make_ddim(x_dims_vec)); + trans_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_x.size(), dev_ctx, x_, &trans_x, + perm_x); + } else { + std::vector perm_out{1, 0}; + auto out_dims_vec = framework::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(framework::make_ddim(out_dims_vec)); + trans_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_out.size(), dev_ctx, *out, + &trans_out, perm_out); + + std::vector perm_x{2, 1, 0}; + auto x_dims_vec = framework::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(framework::make_ddim(x_dims_vec)); + trans_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_x.size(), dev_ctx, x_, &trans_x, + perm_x); + } + } else { + trans_x = x_; + trans_out = *out; + } + + OverlapAddFunctor()(dev_ctx, &trans_x, &trans_out, + seq_length, frame_length, n_frames, + hop_length, /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0 && out_rank > 1U) { + std::vector perm_out{1, 0}; + TransCompute(perm_out.size(), dev_ctx, trans_out, out, + perm_out); + } + + // Restore output dims when the number of dims is larger than 2. + if (out_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_out_shape.push_back(seq_length); + } + + out->Resize(framework::make_ddim(restored_out_shape)); + } + } +}; + +template +class OverlapAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + const size_t d_out_rank = d_out->dims().size(); + const size_t d_x_rank = d_x->dims().size(); + + const int hop_length = ctx.Attr("hop_length"); + const int axis = ctx.Attr("axis"); + const int n_frames = + (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; + const int frame_length = + (axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2]; + const int seq_length = + (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; + + auto& dev_ctx = ctx.device_context(); + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + Tensor d_out_(d_out->type()); + d_out_ = *d_out; + + framework::DDim preserved_dims; + if (d_out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + framework::DDim d_x_resized_dims; + framework::DDim d_out_resized_dims; + if (axis == 0) { + preserved_dims = framework::slice_ddim(d_out_.dims(), 1, d_out_rank); + d_x_resized_dims = {n_frames, frame_length, + framework::product(preserved_dims)}; + d_out_resized_dims = {seq_length, framework::product(preserved_dims)}; + } else { + preserved_dims = + framework::slice_ddim(d_out_.dims(), 0, d_out_rank - 1); + d_x_resized_dims = {framework::product(preserved_dims), frame_length, + n_frames}; + d_out_resized_dims = {framework::product(preserved_dims), seq_length}; + } + d_x->Resize(d_x_resized_dims); + d_out_.Resize(d_out_resized_dims); + } + + Tensor trans_d_x(d_x->type()); + Tensor trans_d_out(d_out_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (d_out_rank == 1U) { + trans_d_out = d_out_; + + std::vector perm_d_x{1, 0}; + auto d_x_dims_vec = framework::vectorize(d_x->dims()); + for (int i = 0; i < d_x->dims().size(); ++i) { + d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; + } + trans_d_x.Resize(framework::make_ddim(d_x_dims_vec)); + trans_d_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_x.size(), dev_ctx, *d_x, + &trans_d_x, perm_d_x); + } else { + std::vector perm_d_out{1, 0}; + auto d_out_dims_vec = framework::vectorize(d_out_.dims()); + for (int i = 0; i < d_out_.dims().size(); ++i) { + d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; + } + trans_d_out.Resize(framework::make_ddim(d_out_dims_vec)); + trans_d_out.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_out.size(), dev_ctx, d_out_, + &trans_d_out, perm_d_out); + + std::vector perm_d_x{2, 1, 0}; + auto d_x_dims_vec = framework::vectorize(d_x->dims()); + for (int i = 0; i < d_x->dims().size(); ++i) { + d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; + } + trans_d_x.Resize(framework::make_ddim(d_x_dims_vec)); + trans_d_x.mutable_data(ctx.GetPlace()); + TransCompute(perm_d_x.size(), dev_ctx, *d_x, + &trans_d_x, perm_d_x); + } + } else { + trans_d_x = *d_x; + trans_d_out = d_out_; + } + + OverlapAddFunctor()(dev_ctx, &trans_d_out, &trans_d_x, + seq_length, frame_length, n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (d_out_rank == 1U) { + std::vector perm_d_x{1, 0}; + TransCompute(perm_d_x.size(), dev_ctx, trans_d_x, d_x, + perm_d_x); + } else { + std::vector perm_d_x{2, 1, 0}; + TransCompute(perm_d_x.size(), dev_ctx, trans_d_x, d_x, + perm_d_x); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (d_out_rank > 2) { + std::vector restored_d_x_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_d_x_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length); + restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_d_x_shape.push_back(frame_length); + restored_d_x_shape.push_back(n_frames); + } + + d_x->Resize(framework::make_ddim(restored_d_x_shape)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_frame_op.py b/python/paddle/fluid/tests/unittests/test_frame_op.py index ec478c9c834fc..f26662dcf4f26 100644 --- a/python/paddle/fluid/tests/unittests/test_frame_op.py +++ b/python/paddle/fluid/tests/unittests/test_frame_op.py @@ -21,23 +21,21 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1): - if axis == -1 and not x.flags["F_CONTIGUOUS"]: - x = np.asfortranarray(x) - elif axis == 0 and not x.flags["C_CONTIGUOUS"]: + if axis == -1 and not x.flags["C_CONTIGUOUS"]: x = np.ascontiguousarray(x) + elif axis == 0 and not x.flags["F_CONTIGUOUS"]: + x = np.asfortranarray(x) n_frames = 1 + (x.shape[axis] - frame_length) // hop_length strides = np.asarray(x.strides) - new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize - if axis == -1: shape = list(x.shape)[:-1] + [frame_length, n_frames] - strides = list(strides) + [hop_length * new_stride] + strides = list(strides) + [hop_length * x.itemsize] elif axis == 0: shape = [n_frames, frame_length] + list(x.shape)[1:] - strides = [hop_length * new_stride] + list(strides) + strides = [hop_length * x.itemsize] + list(strides) else: raise ValueError("Frame axis={} must be either 0 or -1".format(axis)) @@ -114,28 +112,29 @@ def initTestCase(self): return input_shape, input_type, attrs -# FIXME(chenxiaojie06): There are bugs when input dims >= 3 in librosa. -# class TestCase3(TestFrameOp): -# def initTestCase(self): -# input_shape = (4, 2, 150) -# input_type = 'int32' -# attrs = { -# 'frame_length': 50, -# 'hop_length': 15, -# 'axis': -1, -# } -# return input_shape, input_type, attrs - -# class TestCase4(TestFrameOp): -# def initTestCase(self): -# input_shape = (150, 4, 2) -# input_type = 'int32' -# attrs = { -# 'frame_length': 50, -# 'hop_length': 15, -# 'axis': 0, -# } -# return input_shape, input_type, attrs +class TestCase4(TestFrameOp): + def initTestCase(self): + input_shape = (4, 2, 150) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': -1, + } + return input_shape, input_type, attrs + + +class TestCase5(TestFrameOp): + def initTestCase(self): + input_shape = (150, 4, 2) + input_type = 'float64' + attrs = { + 'frame_length': 50, + 'hop_length': 15, + 'axis': 0, + } + return input_shape, input_type, attrs + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f41ae38e0d411..631e3aa737fdb 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -207,6 +207,7 @@ from .array import create_array # noqa: F401 from . import fft +from . import signal #this list used in math_op_patch.py for _binary_creator_ diff --git a/python/paddle/tensor/signal.py b/python/paddle/tensor/signal.py new file mode 100644 index 0000000000000..0b10e301c6174 --- /dev/null +++ b/python/paddle/tensor/signal.py @@ -0,0 +1,325 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import paddle + +from .attribute import is_complex, is_floating_point +from .fft import fft_r2c, fft_c2r, fft_c2c +from ..fluid.data_feeder import check_variable_and_dtype +from ..fluid.framework import in_dygraph_mode +from ..fluid.layer_helper import LayerHelper +from .. import _C_ops + +__all__ = [ + 'frame', + 'overlap_add', + 'stft', + 'istft', +] + + +def frame(x, frame_length, hop_length, axis=-1, name=None): + ''' + TODO(chenxiaojie06): Doc of frame. + ''' + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(frame_length, int) or frame_length < 0: + raise ValueError( + f'Unexpected frame_length: {frame_length}. It should be an positive integer.' + ) + + if not isinstance(hop_length, int) or hop_length < 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + if frame_length > x.shape[axis]: + raise ValueError( + f'Attribute frame_length should be less equal than sequence length, ' + f'but got ({frame_length}) > ({x.shape[axis]}).') + + op_type = 'frame' + + if in_dygraph_mode(): + attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', + axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={ + 'frame_length': frame_length, + 'hop_length': hop_length, + 'axis': axis + }, + outputs={'Out': out}) + return out + + +def overlap_add(x, hop_length, axis=-1, name=None): + ''' + TODO(chenxiaojie06): Doc of overlap_add. + ''' + if axis not in [0, -1]: + raise ValueError(f'Unexpected axis: {axis}. It should be 0 or -1.') + + if not isinstance(hop_length, int) or hop_length < 0: + raise ValueError( + f'Unexpected hop_length: {hop_length}. It should be an positive integer.' + ) + + op_type = 'overlap_add' + + if in_dygraph_mode(): + attrs = ('hop_length', hop_length, 'axis', axis) + op = getattr(_C_ops, op_type) + out = op(x, *attrs) + else: + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', + 'float64'], op_type) + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type=op_type, + inputs={'X': x}, + attrs={'hop_length': hop_length, + 'axis': axis}, + outputs={'Out': out}) + return out + + +def stft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + name=None): + ''' + TODO(chenxiaojie06): Doc of stft. + ''' + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'stft') + + x_rank = len(x.shape) + assert x_rank in [1, 2], \ + f'x should be a 1D or 2D real tensor, but got rank of x is {x_rank}' + + if x_rank == 1: # (batch, seq_length) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + assert hop_length > 0, \ + f'hop_length should be > 0, but got {hop_length}.' + + if win_length is None: + win_length = n_fft + + assert 0 < n_fft <= x.shape[-1], \ + f'n_fft should be in (0, seq_length({x.shape[-1]})], but got {n_fft}.' + + assert 0 < win_length <= n_fft, \ + f'win_length should be in (0, n_fft({n_fft})], but got {win_length}.' + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + f'expected a 1D window tensor of size equal to win_length({win_length}), but got window with shape {window.shape}.' + else: + window = paddle.ones(shape=(win_length, ), dtype=x.dtype) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + if center: + assert pad_mode in ['constant', 'reflect'], \ + 'pad_mode should be "reflect" or "constant", but got "{}".'.format(pad_mode) + + pad_length = n_fft // 2 + x = paddle.nn.functional.pad(x.unsqueeze(-1), + pad=[pad_length, pad_length], + mode=pad_mode, + data_format="NLC").squeeze(-1) + + x_frames = frame(x=x, frame_length=n_fft, hop_length=hop_length, axis=-1) + x_frames = x_frames.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + x_frames = x_frames * window + + norm = 'ortho' if normalized else 'backward' + if is_complex(x_frames): + assert not onesided, \ + 'onesided should be False when input or window is a complex Tensor.' + + if not is_complex(x): + out = fft_r2c( + x=x_frames, + n=None, + axis=-1, + norm=norm, + forward=True, + onesided=onesided, + name=name) + else: + out = fft_c2c( + x=x_frames, n=None, axis=-1, norm=norm, forward=True, name=name) + + out = out.transpose(perm=[0, 2, 1]) # (batch, n_fft, num_frames) + + if x_rank == 1: + out.squeeze_(0) + + return out + + +def istft(x, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, + name=None): + ''' + TODO(chenxiaojie06): Doc of istft. + ''' + check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'istft') + + x_rank = len(x.shape) + assert x_rank in [2, 3], \ + 'x should be a 2D or 3D complex tensor, but got rank of x is {}'.format(x_rank) + + if x_rank == 2: # (batch, n_fft, n_frames) + x = x.unsqueeze(0) + + if hop_length is None: + hop_length = int(n_fft // 4) + + if win_length is None: + win_length = n_fft + + # Assure no gaps between frames. + assert 0 < hop_length <= win_length, \ + 'hop_length should be in (0, win_length({})], but got {}.'.format(win_length, hop_length) + + assert 0 < win_length <= n_fft, \ + 'win_length should be in (0, n_fft({})], but got {}.'.format(n_fft, win_length) + + n_frames = x.shape[-1] + fft_size = x.shape[-2] + + if onesided: + assert (fft_size == n_fft // 2 + 1), \ + 'fft_size should be equal to n_fft // 2 + 1({}) when onesided is True, but got {}.'.format(n_fft // 2 + 1, fft_size) + else: + assert (fft_size == n_fft), \ + 'fft_size should be equal to n_fft({}) when onesided is False, but got {}.'.format(n_fft, fft_size) + + if window is not None: + assert len(window.shape) == 1 and len(window) == win_length, \ + 'expected a 1D window tensor of size equal to win_length({}), but got window with shape {}.'.format(win_length, window.shape) + else: + window = paddle.ones(shape=(win_length, )) + + if win_length < n_fft: + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + window = paddle.nn.functional.pad(window, + pad=[pad_left, pad_right], + mode='constant') + + x = x.transpose( + perm=[0, 2, + 1]) # switch n_fft to last dim, egs: (batch, num_frames, n_fft) + norm = 'ortho' if normalized else 'backward' + + if return_complex: + assert not onesided, \ + 'onesided should be False when input(output of istft) or window is a complex Tensor.' + + out = fft_c2c(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + else: + assert not is_complex(window), \ + 'Data type of window should not be complex when return_complex is False.' + + if onesided is False: + x = x[:, :, :n_fft // 2 + 1] + out = fft_c2r(x=x, n=None, axis=-1, norm=norm, forward=False, name=None) + + out = overlap_add( + x=(out * window).transpose( + perm=[0, 2, 1]), # (batch, n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (batch, seq_length) + + # FIXME: Use paddle.square when it supports complex tensor. + window_envelop = overlap_add( + x=paddle.tile( + x=window * window, repeat_times=[n_frames, 1]).transpose( + perm=[1, 0]), # (n_fft, num_frames) + hop_length=hop_length, + axis=-1) # (seq_length, ) + + if length is None: + if center: + out = out[:, (n_fft // 2):-(n_fft // 2)] + window_envelop = window_envelop[(n_fft // 2):-(n_fft // 2)] + else: + if center: + start = n_fft // 2 + else: + start = 0 + + out = out[:, start:start + length] + window_envelop = window_envelop[start:start + length] + + # Check whether the Nonzero Overlap Add (NOLA) constraint is met. + if window_envelop.abs().min().item() < 1e-11: + raise ValueError( + 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' + ) + + out = out / window_envelop + + if x_rank == 2: + out.squeeze_(0) + + return out