Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.12】 Fix test_partial_sum_op #62783

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,20 @@ struct FusedElemwiseAddActivationGradOpTranscriber
}
};

struct PartialSumOpTranscriber : public OpTranscriber {
cmcamdy marked this conversation as resolved.
Show resolved Hide resolved
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = "pd_op.partial_sum";
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意一下,我们在全面替换 IR_THROW(#62748),这里可以考虑改成 PADDLE_THROW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

行,这个文件我都改一下把

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要看看这个文件是否有人在 #62748 认领,如果没人认领可以认领一下再改,不要重复修改了,另外建议全量修改在单独 PR 做

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OKK,刚看有人领了,我就暂时把我自己的改了

"Op partial_sum should have corresponding OpInfo "
"pd_op.partial_sum");
}
return op_info;
}
};

struct MatrixRankOpTranscriber : public OpTranscriber {
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -3182,6 +3196,7 @@ OpTranslator::OpTranslator() {
special_handlers["slice"] = SliceOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["partial_sum"] = PartialSumOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber();
special_handlers["matmul"] = LegacyMatmulOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
'lars_momentum',
'lars_momentum_',
'max_pool2d_v2',
'partial_sum',
'random_routing',
'recv_v2',
'rnn_',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,16 @@
backward : sum_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : partial_sum
args : (Tensor[] x, int start_index = 0, int length = -1)
output : Tensor(out)
infer_meta :
func : PartialSumInferMeta
kernel :
func : partial_sum
data_type : x
backward : partial_sum_grad

- op : swish
args : (Tensor x)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,16 @@
no_need_buffer : x
backward : sum_double_grad

- backward_op : partial_sum_grad
forward : partial_sum (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : PartialSumGradInferMeta
param : [x]
kernel :
func : partial_sum_grad

- backward_op : swish_grad
forward : swish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ const std::unordered_set<std::string> LegacyOpList = {
MatchMatrixTensorGradOp::name(),
NceOp::name(),
NceGradOp::name(),
PartialSumOp::name(),
PartialSumGradOp::name(),
LrnOp::name(),
LrnGradOp::name(),
MovingAverageAbsMaxScaleOp::name(),
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,10 @@

- op : partial_sum
backward : partial_sum_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]

Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,16 @@ void NceGradInferMeta(const MetaTensor& input,
}
}

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads) {
auto input_num = xs.size();
for(size_t i=0; i< input_num; i++){
auto x_dims = xs[i]->dims();
x_grads[i]->set_dims(x_dims);
x_grads[i]->set_dtype(xs[i]->dtype());
}
}

void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ void NanmedianGradInferMeta(const MetaTensor& x,
bool keep_dim,
MetaTensor* x_grad);

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
60 changes: 60 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4453,6 +4453,66 @@ void SumInferMeta(const MetaTensor& x,
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config);
}

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config) {
int64_t batch_size = -1;
int64_t input_len = -1;

auto inputs_num = xs.size();
PADDLE_ENFORCE_GT(inputs_num,
0,
phi::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"received inputs' length is 0."));

cmcamdy marked this conversation as resolved.
Show resolved Hide resolved
// Only support two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
for (size_t i = 0; i < inputs_num; i++) {
auto x_dim = xs[i]->dims();
VLOG(1) << "inputs_dims:" << x_dim;

PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument("Only support two dimensions input now."));

if (i == 0) {
batch_size = x_dim[0];
input_len = x_dim[1];
} else {
// each tensor's dim must eq
PADDLE_ENFORCE_EQ(x_dim[0],
batch_size,
phi::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(x_dim[1],
input_len,
phi::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}
PADDLE_ENFORCE_GT(
input_len,
start_index,
phi::errors::OutOfRange("start_index must be less than input len"));
if (length > 0) {
PADDLE_ENFORCE_GE(input_len,
start_index + length,
phi::errors::OutOfRange(
"start_index + length is larger than input length"));
}

std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
out_dims[1] = (length == -1) ? input_len - start_index : length;
DDim out_dim = common::make_ddim(out_dims);
out->set_dims(out_dim);
out->set_dtype(xs[0]->dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,13 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config = MetaConfig());


void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down