Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.12】 Fix test_partial_sum_op #62783

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,20 @@ struct FusedElemwiseAddActivationGradOpTranscriber
}
};

struct PartialSumOpTranscriber : public OpTranscriber {
cmcamdy marked this conversation as resolved.
Show resolved Hide resolved
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
std::string target_op_name = "pd_op.partial_sum";
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Op partial_sum should have corresponding OpInfo "
"pd_op.partial_sum"));
}
return op_info;
}
};

struct MatrixRankOpTranscriber : public OpTranscriber {
pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -3182,6 +3196,7 @@ OpTranslator::OpTranslator() {
special_handlers["slice"] = SliceOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["partial_sum"] = PartialSumOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber();
special_handlers["matmul"] = LegacyMatmulOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
'lars_momentum',
'lars_momentum_',
'max_pool2d_v2',
'partial_sum',
'random_routing',
'recv_v2',
'rnn_',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,16 @@
func : partial_recv
data_type : dtype

- op : partial_sum
args : (Tensor[] x, int start_index = 0, int length = -1)
output : Tensor(out)
infer_meta :
func : PartialSumInferMeta
kernel :
func : partial_sum
data_type : x
backward : partial_sum_grad

- op : pool2d
args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,16 @@
composite : pad_grad(x, out_grad, paddings, pad_value, x_grad)
backward : pad_double_grad

- backward_op : partial_sum_grad
forward : partial_sum (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : PartialSumGradInferMeta
param : [x]
kernel :
func : partial_sum_grad

- backward_op : pool2d_double_grad
forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ const std::unordered_set<std::string> LegacyOpList = {
MatchMatrixTensorGradOp::name(),
NceOp::name(),
NceGradOp::name(),
PartialSumOp::name(),
PartialSumGradOp::name(),
LrnOp::name(),
LrnGradOp::name(),
MovingAverageAbsMaxScaleOp::name(),
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,10 @@

- op : partial_sum
backward : partial_sum_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]

Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,16 @@ void NceGradInferMeta(const MetaTensor& input,
}
}

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads) {
auto input_num = xs.size();
for (size_t i = 0; i < input_num; i++) {
auto x_dims = xs[i]->dims();
x_grads[i]->set_dims(x_dims);
x_grads[i]->set_dtype(xs[i]->dtype());
}
}

void NllLossGradInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ void NanmedianGradInferMeta(const MetaTensor& x,
bool keep_dim,
MetaTensor* x_grad);

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
59 changes: 59 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4453,6 +4453,65 @@ void SumInferMeta(const MetaTensor& x,
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config);
}

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config) {
int64_t batch_size = -1;
int64_t input_len = -1;

auto inputs_num = xs.size();
PADDLE_ENFORCE_GT(inputs_num,
0,
phi::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"received inputs' length is 0."));

cmcamdy marked this conversation as resolved.
Show resolved Hide resolved
// Only support two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
for (size_t i = 0; i < inputs_num; i++) {
auto x_dim = xs[i]->dims();

PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument("Only support two dimensions input now."));

if (i == 0) {
batch_size = x_dim[0];
input_len = x_dim[1];
} else {
// each tensor's dim must eq
PADDLE_ENFORCE_EQ(x_dim[0],
batch_size,
phi::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(x_dim[1],
input_len,
phi::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}
PADDLE_ENFORCE_GT(
input_len,
start_index,
phi::errors::OutOfRange("start_index must be less than input len"));
if (length > 0) {
PADDLE_ENFORCE_GE(input_len,
start_index + length,
phi::errors::OutOfRange(
"start_index + length is larger than input length"));
}

std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
out_dims[1] = (length == -1) ? input_len - start_index : length;
DDim out_dim = common::make_ddim(out_dims);
out->set_dims(out_dim);
out->set_dtype(xs[0]->dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,12 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ test_one_hot_v2_op
test_one_hot_v2_op_static_build
test_overlap_add_op
test_pad3d_op
test_partial_sum_op
test_pass_quantization
test_pixel_shuffle_op
test_poisson_op
Expand Down