From 44247226f3cb8c56821b2bb6ab8a1d9d08f2d236 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 8 Mar 2022 03:26:33 +0000 Subject: [PATCH] add reduce max infermeta --- .../operators/reduce_ops/reduce_max_op.cc | 2 +- paddle/phi/infermeta/unary.cc | 40 ++++++++++++++++--- paddle/phi/infermeta/unary.h | 11 +++++ 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index 666f313297fa7..a1e6efc2360f0 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -27,7 +27,7 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker { }; DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, ReduceMaxInferShapeFunctor, - PD_INFER_META(phi::MeanRawInferMeta)); + PD_INFER_META(phi::MaxRawInferMeta)); REGISTER_OPERATOR( reduce_max, ops::ReduceOp, ReduceMaxOpMaker, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 85db1547f16cc..52455dcf19941 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -394,12 +394,10 @@ void SumInferMeta(const MetaTensor& x, ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out); } -void ReduceInferMetaBase(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - DataType dtype, - MetaTensor* out) { +DDim ReduceInferDim(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all) { auto x_rank = x.dims().size(); std::vector formated_axis = axis; @@ -462,6 +460,17 @@ void ReduceInferMetaBase(const MetaTensor& x, } DDim out_dim = phi::make_ddim(out_dim_vector); + return out_dim; +} + +void ReduceInferMetaBase(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + DataType dtype, + MetaTensor* out) { + DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); + DataType out_dtype; if (dtype != DataType::UNDEFINED) { out_dtype = dtype; @@ -479,6 +488,25 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } +void MaxRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out) { + DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); + out->set_dims(out_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + +void MaxInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out) { + bool reduce_all = false; + MaxRawInferMeta(x, axis, keep_dim, reduce_all, out); +} + void MeanRawInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d4e21fbd8244b..71d0e837dbe53 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -98,6 +98,17 @@ void MeanRawInferMeta(const MetaTensor& x, bool reduce_all, MetaTensor* out); +void MaxRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out); + +void MaxInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out); + void MeanInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim,