Skip to content

Commit

Permalink
move index_sample op
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Feb 24, 2022
1 parent ab0835d commit 0f59dfa
Show file tree
Hide file tree
Showing 13 changed files with 743 additions and 482 deletions.
148 changes: 81 additions & 67 deletions paddle/fluid/operators/index_sample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/index_sample_op.h"
// #include "paddle/fluid/operators/index_sample_op.h"
#include <vector>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -42,44 +46,48 @@ class IndexSampleOpMaker : public framework::OpProtoAndCheckerMaker {
class IndexSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Inputs(Input) of FindByIndex should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Inputs(Index) of FindByIndex should not be null."));

auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 2,
platform::errors::InvalidArgument(
"Inputs(X) shape of IndexSample op should be 2-D, but "
"got X's shape = [%s], please check X shape.",
input_dims));

auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE_EQ(
input_dims.size(), 2,
platform::errors::InvalidArgument(
"Inputs(Index) shape of IndexSample op should be 2-D, but "
"got Index's shape [%s] , please check index shape.",
input_dims));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0],
platform::errors::InvalidArgument(
"Inputs(X)'s value of dimension 0 must same with "
"Inputs(Index)'s value of dimension 0, but "
"got %d of Inputs(X), and got %d of Inputs(Index), "
"please check Inputs shape.",
input_dims[0], index_dims[0]));
}
ctx->SetOutputDim("Out", index_dims);
auto type = ctx->GetInputsVarType("Index")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Index", /*->*/ "Out");
}
}
// void InferShape(framework::InferShapeContext* ctx) const override {
// PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
// platform::errors::InvalidArgument(
// "Inputs(Input) of FindByIndex should not be
// null."));
// PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
// platform::errors::InvalidArgument(
// "Inputs(Index) of FindByIndex should not be
// null."));

// auto input_dims = ctx->GetInputDim("X");
// PADDLE_ENFORCE_EQ(
// input_dims.size(), 2,
// platform::errors::InvalidArgument(
// "Inputs(X) shape of IndexSample op should be 2-D, but "
// "got X's shape = [%s], please check X shape.",
// input_dims));

// auto index_dims = ctx->GetInputDim("Index");
// PADDLE_ENFORCE_EQ(
// input_dims.size(), 2,
// platform::errors::InvalidArgument(
// "Inputs(Index) shape of IndexSample op should be 2-D, but "
// "got Index's shape [%s] , please check index shape.",
// input_dims));
// if (ctx->IsRuntime()) {
// PADDLE_ENFORCE_EQ(input_dims[0], index_dims[0],
// platform::errors::InvalidArgument(
// "Inputs(X)'s value of dimension 0 must same with
// "
// "Inputs(Index)'s value of dimension 0, but "
// "got %d of Inputs(X), and got %d of
// Inputs(Index), "
// "please check Inputs shape.",
// input_dims[0], index_dims[0]));
// }
// ctx->SetOutputDim("Out", index_dims);
// auto type = ctx->GetInputsVarType("Index")[0];
// if (type == framework::proto::VarType::LOD_TENSOR) {
// ctx->ShareLoD("Index", /*->*/ "Out");
// }
//}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand All @@ -93,19 +101,20 @@ class IndexSampleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"), true,
platform::errors::InvalidArgument("Input(Index) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));

ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
// void InferShape(framework::InferShapeContext* ctx) const override {
// PADDLE_ENFORCE_EQ(
// ctx->HasInput("Index"), true,
// platform::errors::InvalidArgument("Input(Index) should be not
// null."));
// PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
// platform::errors::InvalidArgument(
// "Input(Out@GRAD) should be not null."));
// PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
// platform::errors::InvalidArgument(
// "Output(X@GRAD) should be not null."));

// ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
// }

protected:
framework::OpKernelType GetExpectedKernelType(
Expand Down Expand Up @@ -136,20 +145,25 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSampleGradNoNeedBufferVarInferer, "X");
} // namespace paddle

namespace ops = paddle::operators;
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(index_sample, IndexSampleInferShapeFunctor,
PT_INFER_META(phi::IndexSampleInferMeta));
REGISTER_OPERATOR(index_sample, ops::IndexSampleOp, ops::IndexSampleOpMaker,
ops::IndexSampleGradMaker<paddle::framework::OpDesc>,
ops::IndexSampleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp,
ops::IndexSampleGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
index_sample,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_sample_grad,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::IndexSampleGradMaker<paddle::imperative::OpBase>,
IndexSampleInferShapeFunctor);
REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp);
// REGISTER_OPERATOR(index_sample_grad, ops::IndexSampleGradOp,
// ops::IndexSampleGradNoNeedBufferVarInferer);
// REGISTER_OP_CPU_KERNEL(
// index_sample,
// ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, float>,
// ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, double>,
// ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int>,
// ops::IndexSampleKernel<paddle::platform::CPUDeviceContext, int64_t>);
// REGISTER_OP_CPU_KERNEL(
// index_sample_grad,
// ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, float>,
// ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, double>,
// ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int>,
// ops::IndexSampleGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
215 changes: 0 additions & 215 deletions paddle/fluid/operators/index_sample_op.cu

This file was deleted.

Loading

1 comment on commit 0f59dfa

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #39905 Commit ID: 0f59dfa contains failed CI.

🔹 Failed: PR-CI-iScan-Python

Unknown Failed
Unknown Failed

Please sign in to comment.