-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid Parallelism #58776
[AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid Parallelism #58776
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there some corresponding cases in unit test? If so, unit test should also be revised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for composite rules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for spmd rules
…Parallelism (PaddlePaddle#58776) * modify infermate * bugfix for kernel and spmd * fix prim * update unitest
…Parallelism (PaddlePaddle#58776) * modify infermate * bugfix for kernel and spmd * fix prim * update unitest
PR types
Function optimization
PR changes
Others
Description
Pcard-76459
The Current LayerNorm Implement would flatten the broadcast axes (axis before "begin_norm_axis") of LayerNorm, which would hinder broadcast axes to be sharded by different Mesh dimensions.
In Sequence-Data Hybrid Parallelism, we need to shard both batch and sequence axes (broadcast axes) of LayerNorm input to get the best Performance.
Therefore we plan to remove the "flatten logic" in LayerNorm Op.
Before PR:
all axes before 『begin_norm_axis』will be flatten into one axis for var and std of LN.
After PR:
the shape of mean and var will keep the same as the shape of input 『begin_norm_axis』.
The Logic modified in this PR: