Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release overlap_comm & contiguous_gradients restrictions for ZeRO 1 #4887

Merged
merged 2 commits into from
Jan 5, 2024

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Dec 30, 2023

The overlap_comm and contiguous_gradients options have been ignored in ZeRO stage 1 since #1246. Back in that time, ZeRO 1 and 2 are separately implemented (see https://github.com/microsoft/DeepSpeed/tree/6ae756c03f12674f17aef90622e7664a8af9d2af/deepspeed/runtime/zero). ZeRO 1 does not have gradient hooks registered to overlap backward and gradient all-reduce, so it's fine to ignore overlap_comm and contiguous_gradients. However, in the current implementation, ZeRO 1 and 2 share almost the same implementation (stage_1_and_2.py). Features like overlap_comm and contiguous_gradients can also be enabled for ZeRO 1 (Please correct me if I made a mistake).

With this PR, turning on overlap_comm and contiguous_gradients for ZeRO 1 on the SFT task produces exactly the same training curve as the latest master.

image

I also see a ~1.05x e2e speedup by overlapping backward and gradient all-reduce. I can confirm by the trace that backward and all-reduce do overlap, and the separate gradients are indeed copied to a flat buffer. These options are also effective for ZeRO 1.

image

image

Related issue: #2295

@tjruwase
Copy link
Contributor

tjruwase commented Jan 5, 2024

@li-plus, thanks so much for this PR. Your analysis of the problem is accurate. Previously, we did not have bandwidth to evaluate the correctness of enabling those optimizations for ZeRO-1. This is a great contribution. Thanks so much.

@tjruwase tjruwase added this pull request to the merge queue Jan 5, 2024
Merged via the queue into microsoft:master with commit af03383 Jan 5, 2024
14 checks passed
@li-plus li-plus deleted the opt-zero1 branch January 6, 2024 02:20
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
…icrosoft#4887)

The `overlap_comm` and `contiguous_gradients` options have been ignored
in ZeRO stage 1 since microsoft#1246.
Back in that time, ZeRO 1 and 2 are separately implemented (see
https://github.com/microsoft/DeepSpeed/tree/6ae756c03f12674f17aef90622e7664a8af9d2af/deepspeed/runtime/zero).
ZeRO 1 does not have gradient hooks registered to overlap backward and
gradient all-reduce, so it's fine to ignore `overlap_comm` and
`contiguous_gradients`. However, in the current implementation, ZeRO 1
and 2 share almost the same implementation (`stage_1_and_2.py`).
Features like `overlap_comm` and `contiguous_gradients` can also be
enabled for ZeRO 1 (Please correct me if I made a mistake).

With this PR, turning on `overlap_comm` and `contiguous_gradients` for
ZeRO 1 on the [SFT
task](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning)
produces exactly the same training curve as the latest master.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/bda3be7b-c236-4e08-b687-b3cd01f5cc73)

I also see a ~1.05x e2e speedup by overlapping backward and gradient
all-reduce. I can confirm by the trace that backward and all-reduce do
overlap, and the separate gradients are indeed copied to a flat buffer.
These options are also effective for ZeRO 1.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/5f876296-e1b4-404b-8b33-03cee8e5e6b2)


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/9654f6be-5c7a-401a-b0bc-413ecd3f4e6b)

Related issue: microsoft#2295

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants