Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move frame kernel to phi #44615

Merged
merged 7 commits into from
Jul 28, 2022
134 changes: 18 additions & 116 deletions paddle/fluid/operators/frame_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/frame_op.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,89 +27,6 @@ class FrameOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");

const int frame_length = ctx->Attrs().Get<int>("frame_length");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");

const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();

PADDLE_ENFORCE_GE(
x_rank,
1,
platform::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length,
0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
platform::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));

std::vector<int64_t> output_shape;
int seq_length;
int n_frames;

int start_axis;
int end_axis;

if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}

bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length,
seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length,
seq_length));
}

// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}

if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}

if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}

ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -136,17 +59,6 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
class FrameOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"frame_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand All @@ -160,7 +72,6 @@ template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X"));
Expand All @@ -175,28 +86,19 @@ class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(frame,
FrameInferShapeFunctor,
PD_INFER_META(phi::FrameInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(frame_grad,
FrameGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));

REGISTER_OPERATOR(frame,
ops::FrameOp,
ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad);

REGISTER_OP_CPU_KERNEL(
frame,
ops::FrameKernel<phi::CPUContext, int>,
ops::FrameKernel<phi::CPUContext, int64_t>,
ops::FrameKernel<phi::CPUContext, float>,
ops::FrameKernel<phi::CPUContext, double>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<double>>);
ops::FrameOpGradMaker<paddle::imperative::OpBase>,
FrameInferShapeFunctor);

REGISTER_OP_CPU_KERNEL(
frame_grad,
ops::FrameGradKernel<phi::CPUContext, int>,
ops::FrameGradKernel<phi::CPUContext, int64_t>,
ops::FrameGradKernel<phi::CPUContext, float>,
ops::FrameGradKernel<phi::CPUContext, double>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<double>>);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor);
43 changes: 0 additions & 43 deletions paddle/fluid/operators/frame_op.cu

This file was deleted.

Loading