Skip to content

Commit

Permalink
Move frame kernel to phi (#44615)
Browse files Browse the repository at this point in the history
* Move frame OP to phi、add frame OP yaml config and supplement single test

* add Header file of in_dygraph_mode

* Modify variable name and FrameGradInferMeta multiplex UnchangedInferMeta

* move seq2col to phi
  • Loading branch information
Charles-hit authored Jul 28, 2022
1 parent 511a2c1 commit 28b4b2f
Show file tree
Hide file tree
Showing 22 changed files with 721 additions and 561 deletions.
134 changes: 18 additions & 116 deletions paddle/fluid/operators/frame_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/frame_op.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,89 +27,6 @@ class FrameOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");

const int frame_length = ctx->Attrs().Get<int>("frame_length");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");

const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();

PADDLE_ENFORCE_GE(
x_rank,
1,
platform::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length,
0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
platform::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));

std::vector<int64_t> output_shape;
int seq_length;
int n_frames;

int start_axis;
int end_axis;

if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}

bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length,
seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length,
seq_length));
}

// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}

if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}

if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}

ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -136,17 +59,6 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
class FrameOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"frame_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand All @@ -160,7 +72,6 @@ template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X"));
Expand All @@ -175,28 +86,19 @@ class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(frame,
FrameInferShapeFunctor,
PD_INFER_META(phi::FrameInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(frame_grad,
FrameGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));

REGISTER_OPERATOR(frame,
ops::FrameOp,
ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad);

REGISTER_OP_CPU_KERNEL(
frame,
ops::FrameKernel<phi::CPUContext, int>,
ops::FrameKernel<phi::CPUContext, int64_t>,
ops::FrameKernel<phi::CPUContext, float>,
ops::FrameKernel<phi::CPUContext, double>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<double>>);
ops::FrameOpGradMaker<paddle::imperative::OpBase>,
FrameInferShapeFunctor);

REGISTER_OP_CPU_KERNEL(
frame_grad,
ops::FrameGradKernel<phi::CPUContext, int>,
ops::FrameGradKernel<phi::CPUContext, int64_t>,
ops::FrameGradKernel<phi::CPUContext, float>,
ops::FrameGradKernel<phi::CPUContext, double>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<double>>);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor);
43 changes: 0 additions & 43 deletions paddle/fluid/operators/frame_op.cu

This file was deleted.

Loading

0 comments on commit 28b4b2f

Please sign in to comment.