Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid Parallelism #58776

Merged
merged 7 commits into from
Nov 14, 2023

Conversation

JZ-LIANG
Copy link
Contributor

@JZ-LIANG JZ-LIANG commented Nov 7, 2023

PR types

Function optimization

PR changes

Others

Description

Pcard-76459

The Current LayerNorm Implement would flatten the broadcast axes (axis before "begin_norm_axis") of LayerNorm, which would hinder broadcast axes to be sharded by different Mesh dimensions.

In Sequence-Data Hybrid Parallelism, we need to shard both batch and sequence axes (broadcast axes) of LayerNorm input to get the best Performance.

Therefore we plan to remove the "flatten logic" in LayerNorm Op.

Before PR:
all axes before 『begin_norm_axis』will be flatten into one axis for var and std of LN.

After PR:
the shape of mean and var will keep the same as the shape of input 『begin_norm_axis』.

# pseudocode
in = paddle.randn(shape = [8, 1024, 768])
var, mean, out = layer_norm(in, begin_norm_axis=2)

# before PR
out.shape == [8, 1024, 768]
var.shape == [8192]
mean.shape == [8192]

# after PR
out.shape == [8, 1024, 768]
var.shape == [8, 1024]
mean.shape == [8, 1024]

The Logic modified in this PR:

  • Layernorm InferMeta
  • Layernorm InferSpmd
  • Layernorm Prim composite rule
  • Layernorm CPU kernel
  • Layernorm XPU & MLKDNN Unitest

Copy link
Contributor

@pkuzyc pkuzyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there some corresponding cases in unit test? If so, unit test should also be revised.

@JZ-LIANG JZ-LIANG changed the title Revise Infermeta of LayerNorm [AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid Parallelism Nov 9, 2023
@JZ-LIANG JZ-LIANG requested review from zhiboniu and zhiqiu November 9, 2023 06:36
@JZ-LIANG JZ-LIANG closed this Nov 9, 2023
@JZ-LIANG JZ-LIANG reopened this Nov 9, 2023
Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for composite rules

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@pkuzyc pkuzyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for spmd rules

@JZ-LIANG JZ-LIANG merged commit db105fd into PaddlePaddle:develop Nov 14, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…Parallelism (PaddlePaddle#58776)

* modify infermate

* bugfix for kernel and spmd

* fix prim

* update unitest
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
…Parallelism (PaddlePaddle#58776)

* modify infermate

* bugfix for kernel and spmd

* fix prim

* update unitest
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants