From 02910d528080604c0f04df9efd12fb6e456753b2 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 11 Sep 2024 15:29:25 +0800 Subject: [PATCH] Add fast_ln spmd rules --- .../gpt-3/external_ops/fast_ln/ln_api.cpp | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/legacy/model_zoo/gpt-3/external_ops/fast_ln/ln_api.cpp b/legacy/model_zoo/gpt-3/external_ops/fast_ln/ln_api.cpp index c5eb9b3c000f..e6bc0649c3ba 100644 --- a/legacy/model_zoo/gpt-3/external_ops/fast_ln/ln_api.cpp +++ b/legacy/model_zoo/gpt-3/external_ops/fast_ln/ln_api.cpp @@ -19,9 +19,13 @@ * with minor changes. */ #include "paddle/extension.h" - #include "ln.h" // NOLINT +#ifdef CUSTOM_OP_WITH_SPMD +#include "paddle/phi/api/ext/spmd_infer.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + /* Supported Type combinations: @@ -197,12 +201,10 @@ std::vector LnFwd(const paddle::Tensor &x, auto sizes = x.shape(); PD_CHECK(sizes.size() >= 2); - int rows = 1; - for (size_t i = 0; i + 1 < sizes.size(); ++i) { - rows *= sizes[i]; - } - + std::vector row_sizes(sizes.begin(), sizes.begin() + sizes.size() - 1); + const int cols = sizes[sizes.size() - 1]; + const int rows = x.numel() / cols; auto hidden_size = scale.numel(); PD_CHECK(scale.shape() == bias.shape()); @@ -214,8 +216,8 @@ std::vector LnFwd(const paddle::Tensor &x, auto y = paddle::empty(sizes, output_type, place); - auto mean = paddle::empty({rows}, compute_type, place); - auto invvar = paddle::empty({rows}, compute_type, place); + auto mean = paddle::empty({row_sizes}, compute_type, place); + auto invvar = paddle::empty({row_sizes}, compute_type, place); LaunchNormFwd(x.stream(), place, @@ -481,11 +483,8 @@ std::vector> LnFwdInferShape( std::vector scale_shape, std::vector bias_shape, float epsilon) { - int64_t rows = 1; - for (size_t i = 0; i + 1 < x_shape.size(); ++i) { - rows *= x_shape[i]; - } - return {x_shape, {rows}, {rows}}; + std::vector row_shape(x_shape.begin(), x_shape.begin() + x_shape.size() - 1); + return {x_shape, row_shape, row_shape}; } std::vector> RMSLnFwdInferShape( @@ -543,7 +542,11 @@ PD_BUILD_OP(fast_ln) .Attrs({"epsilon: float"}) .SetKernelFn(PD_KERNEL(LnFwd)) .SetInferShapeFn(PD_INFER_SHAPE(LnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype)); + .SetInferDtypeFn(PD_INFER_DTYPE(LnFwdInferDtype)) +#ifdef CUSTOM_OP_WITH_SPMD + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnInferSpmd)) +#endif +; PD_BUILD_GRAD_OP(fast_ln) .Inputs({"x", "scale", "mean", "invvar", paddle::Grad("y")}) @@ -551,7 +554,11 @@ PD_BUILD_GRAD_OP(fast_ln) .Attrs({"epsilon: float"}) .SetKernelFn(PD_KERNEL(LnBwd)) .SetInferShapeFn(PD_INFER_SHAPE(LnBwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(LnBwdInferDtype)); + .SetInferDtypeFn(PD_INFER_DTYPE(LnBwdInferDtype)) +#ifdef CUSTOM_OP_WITH_SPMD + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::FastLnGradInferSpmd)) +#endif +; PD_BUILD_OP(fast_rms_norm) .Inputs({"x", "scale"})