Skip to content

Commit

Permalink
add reduce max infermeta
Browse files Browse the repository at this point in the history
  • Loading branch information
MingMingShangTian committed Mar 8, 2022
1 parent cbaa7c3 commit 4424722
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker {
};

DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::MeanRawInferMeta));
PD_INFER_META(phi::MaxRawInferMeta));

REGISTER_OPERATOR(
reduce_max, ops::ReduceOp, ReduceMaxOpMaker,
Expand Down
40 changes: 34 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,10 @@ void SumInferMeta(const MetaTensor& x,
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out);
}

void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out) {
DDim ReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();

std::vector<int64_t> formated_axis = axis;
Expand Down Expand Up @@ -462,6 +460,17 @@ void ReduceInferMetaBase(const MetaTensor& x,
}
DDim out_dim = phi::make_ddim(out_dim_vector);

return out_dim;
}

void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out) {
DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all);

DataType out_dtype;
if (dtype != DataType::UNDEFINED) {
out_dtype = dtype;
Expand All @@ -479,6 +488,25 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}

void MaxRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out) {
DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void MaxInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
bool reduce_all = false;
MaxRawInferMeta(x, axis, keep_dim, reduce_all, out);
}

void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ void MeanRawInferMeta(const MetaTensor& x,
bool reduce_all,
MetaTensor* out);

void MaxRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out);

void MaxInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);

void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
Expand Down

0 comments on commit 4424722

Please sign in to comment.