diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index a1e6efc2360f0..41df8e4a15f09 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -27,7 +27,7 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker { }; DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, ReduceMaxInferShapeFunctor, - PD_INFER_META(phi::MaxRawInferMeta)); + PD_INFER_META(phi::ReduceInferMetaBase)); REGISTER_OPERATOR( reduce_max, ops::ReduceOp, ReduceMaxOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index 894106883cb0a..4a18330913803 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -97,7 +97,7 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { }; DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor, - PD_INFER_META(phi::MeanRawInferMeta)); + PD_INFER_META(phi::ReduceInferMetaBase)); REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, ops::ReduceMeanOpGradMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 6559ed479c84c..6441d53239e95 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -103,7 +103,7 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { }; DECLARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor, - PD_INFER_META(phi::ReduceInferMetaBase)); + PD_INFER_META(phi::SumRawInferMeta)); REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ops::ReduceSumVarTypeInference, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 52455dcf19941..f2c9d9bd2fdfc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -382,7 +382,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } -/* Why not use ReduceInferMetaBase directly? +/* Why not use SumRawInferMeta directly? Because we need make InferMetaFunction's args follow the design of api.yaml */ void SumInferMeta(const MetaTensor& x, @@ -391,7 +391,7 @@ void SumInferMeta(const MetaTensor& x, bool keep_dim, MetaTensor* out) { bool reduce_all = false; - ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out); + SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out); } DDim ReduceInferDim(const MetaTensor& x, @@ -463,12 +463,12 @@ DDim ReduceInferDim(const MetaTensor& x, return out_dim; } -void ReduceInferMetaBase(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - DataType dtype, - MetaTensor* out) { +void SumRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + DataType dtype, + MetaTensor* out) { DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); DataType out_dtype; @@ -488,39 +488,23 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } -void MaxRawInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out) { +void ReduceInferMetaBase(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out) { DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); out->set_dims(out_dim); out->set_dtype(x.dtype()); out->set_layout(x.layout()); } -void MaxInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out) { - bool reduce_all = false; - MaxRawInferMeta(x, axis, keep_dim, reduce_all, out); -} - -void MeanRawInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out) { - ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); -} - -void MeanInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out) { +void ReduceInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out) { bool reduce_all = false; - ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out); } void TransferLayoutInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 71d0e837dbe53..8b111e53af53b 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -85,35 +85,24 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void SumRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + DataType dtype, + MetaTensor* out); + void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, bool keep_dim, bool reduce_all, - DataType dtype, MetaTensor* out); -void MeanRawInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - bool reduce_all, - MetaTensor* out); - -void MaxRawInferMeta(const MetaTensor& x, +void ReduceInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim, - bool reduce_all, MetaTensor* out); -void MaxInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out); - -void MeanInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out); - void SumInferMeta(const MetaTensor& x, const std::vector& axis, DataType dtype, diff --git a/paddle/phi/kernels/math_kernel.h b/paddle/phi/kernels/math_kernel.h index fe8f3b749cdd8..7569cbcff087d 100644 --- a/paddle/phi/kernels/math_kernel.h +++ b/paddle/phi/kernels/math_kernel.h @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { DenseTensor dense_out; MetaTensor meta_out(&dense_out); - ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out); + SumRawInferMeta(x, axis, keep_dim, false, x.dtype(), &meta_out); MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 2544a175b1018..36798abe4c11b 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in // InferShape, so we must return the "sum_raw" KernelSignature. - // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // And the InferMeta function(i.e. SumRawInferMeta) is accordance with // the "sum_raw" KernelSignature if (ctx.IsForInferShape() || reduce_all) { return KernelSignature("sum_raw", @@ -40,7 +40,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in // InferShape, so we must return the "mean_raw" KernelSignature. - // And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the // "mean_raw" KernelSignature if (ctx.IsForInferShape() || reduce_all) { return KernelSignature( @@ -61,7 +62,8 @@ KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) { bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in // InferShape, so we must return the "max_raw" KernelSignature. - // And the InferMeta function(i.e. MaxRawInferMeta) is accordance with the + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the // "max_raw" KernelSignature if (ctx.IsForInferShape() || reduce_all) { return KernelSignature( diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 699e42f23732a..5c374bbb35b0a 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -124,7 +124,7 @@ args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) output : Tensor infer_meta : - func : MeanInferMeta + func : ReduceInferMeta kernel : func : mean