Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid] optimizer sharding support optimize cast #35878

Merged

Conversation

wangxicoding
Copy link
Contributor

@wangxicoding wangxicoding commented Sep 18, 2021

PR types

Performance optimization

PR changes

Others

Describe

optimizer sharding support optimize cast.

  1. 将前向反向的参数cast移到优化器阶段,减少cast个数,提升性能。
  2. 在optimizer_sharding中,只需存储自己所需的fp32参数,在dp_degree > 2 时可节约显存。

精度测试

Ernie3.0,base模型,单机8卡
baseline=2mp+2pp+2dp, optimize_cast=2mp+2pp+2opt_sharding+optimize_cast
image

速度测试

模型配置
hidden_size 3072
num_attention_heads 48
num_hidden_layers 39
num_sharding_layers 36
branch_hidden_size 256
branch_num_attention_heads 4
baseline(token/s) optimize_cast(token/s) 提升
6421 6868 6.96%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@wangxicoding wangxicoding changed the title optimizer sharding support optimize cast [hybrid] optimizer sharding support optimize cast Sep 24, 2021
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangxicoding wangxicoding merged commit eef0a94 into PaddlePaddle:develop Sep 28, 2021
@wangxicoding wangxicoding deleted the opt_sharding_optimize_cast branch September 28, 2021 02:45
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': broadcast_params},
outputs={'Out': broadcast_params},
inputs={'X': params_name},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the broadcast in launched into calc stream, there is not need to sync calc stream at the end of broadcasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I originally wanted to delete it in this PR, but there are too many unittest that need to be changed, so I kept it first... will remove in future.

# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use a global variable to record the 'cast_fp16' rule, otherwise if this pattern is change in AMP, we should change everywhere in sharding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get, good idea

offload_helper.cast_fp32param_in_optimize(main_block, startup_block)
offload_helper = OffloadHelper(ring_id=dp_ring_id)
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job, not only reduce the number of cast op from twice per param to once per param, but also reduce the frequency of cast call to 1/acc_step !

AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants