Skip to content

Commit

Permalink
Add fast_ln spmd rules (#68148)
Browse files Browse the repository at this point in the history
* Add fast_ln spmd rules

* Update code
  • Loading branch information
From00 authored Sep 12, 2024
1 parent d1f6bf4 commit f9009cc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,5 +445,32 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
{x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr});
}

SpmdInfo FastLnInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
float epsilon) {
int begin_norm_axis = x.dims().size() - 1;
VLOG(4) << "FastLnInferSpmd call LayerNormInferSpmd with begin_norm_axis="
<< begin_norm_axis;
return LayerNormInferSpmd(x, scale, bias, epsilon, begin_norm_axis);
}

SpmdInfo FastLnGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& mean,
const DistMetaTensor& invvar,
const DistMetaTensor& y_grad,
float epsilon) {
int begin_norm_axis = x.dims().size() - 1;
const DistMetaTensor& bias(scale); // bias is not used in FastLnGrad
VLOG(4)
<< "FastLnGradInferSpmd call LayerNormGradInferSpmd with begin_norm_axis="
<< begin_norm_axis << ", the input 'bias' will be ignored.";
SpmdInfo spmd_info = LayerNormGradInferSpmd(
x, scale, bias, mean, invvar, y_grad, epsilon, begin_norm_axis);
spmd_info.first.erase(spmd_info.first.begin() + 2); // remove bias_dist_attr
return spmd_info;
}

} // namespace distributed
} // namespace phi
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/spmd_rules/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,17 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x,
float epsilon,
int begin_norm_axis);

SpmdInfo FastLnInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
float epsilon);

SpmdInfo FastLnGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& mean,
const DistMetaTensor& invvar,
const DistMetaTensor& y_grad,
float epsilon);

} // namespace distributed
} // namespace phi

0 comments on commit f9009cc

Please sign in to comment.