From f9009cc238480d79f8f37219626446f760322f49 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Thu, 12 Sep 2024 17:09:11 +0800 Subject: [PATCH] Add fast_ln spmd rules (#68148) * Add fast_ln spmd rules * Update code --- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 27 +++++++++++++++++++ paddle/phi/infermeta/spmd_rules/layer_norm.h | 12 +++++++++ 2 files changed, 39 insertions(+) diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 0137af15d5e832..f5ede839c988d6 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -445,5 +445,32 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, {x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}); } +SpmdInfo FastLnInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + float epsilon) { + int begin_norm_axis = x.dims().size() - 1; + VLOG(4) << "FastLnInferSpmd call LayerNormInferSpmd with begin_norm_axis=" + << begin_norm_axis; + return LayerNormInferSpmd(x, scale, bias, epsilon, begin_norm_axis); +} + +SpmdInfo FastLnGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& mean, + const DistMetaTensor& invvar, + const DistMetaTensor& y_grad, + float epsilon) { + int begin_norm_axis = x.dims().size() - 1; + const DistMetaTensor& bias(scale); // bias is not used in FastLnGrad + VLOG(4) + << "FastLnGradInferSpmd call LayerNormGradInferSpmd with begin_norm_axis=" + << begin_norm_axis << ", the input 'bias' will be ignored."; + SpmdInfo spmd_info = LayerNormGradInferSpmd( + x, scale, bias, mean, invvar, y_grad, epsilon, begin_norm_axis); + spmd_info.first.erase(spmd_info.first.begin() + 2); // remove bias_dist_attr + return spmd_info; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.h b/paddle/phi/infermeta/spmd_rules/layer_norm.h index 195618168cefe3..17e07fdff6a82a 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.h +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.h @@ -44,5 +44,17 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, float epsilon, int begin_norm_axis); +SpmdInfo FastLnInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + float epsilon); + +SpmdInfo FastLnGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& mean, + const DistMetaTensor& invvar, + const DistMetaTensor& y_grad, + float epsilon); + } // namespace distributed } // namespace phi