-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bfloat16
support to DeepSpeedStrategy
#12508
Conversation
bfloat16
support to DeepSpeedStrategy
This PR misses implementation for @@ -512,11 +521,12 @@ def _initialize_deepspeed_train(self, model):
def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized
- dtype = (
- torch.float16
- if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
- else torch.float32
- )
+ if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
+ dtype = torch.float16
+ elif self.precision_plugin.precision == PrecisionType.BFLOAT:
+ dtype = torch.bfloat16
+ else:
+ dtype = torch.float32
model_parallel_context = deepspeed.zero.Init(
remote_device=self.remote_device, pin_memory=True, config_dict_or_path=self.config, dtype=dtype
) |
If you include patch, please also add a test for ZeRO3+bf16 with model_sharded_context. Otherwise please ping me and I can open a new PR after this one is merged. |
yep, I was reading about why it is done like that for stage3, that's why kept this draft. |
for more information, see https://pre-commit.ci
Co-authored-by: Carlos Mocholí <[email protected]>
Can i use bf16 in the latest master branch? |
@toriving , yes |
Thx. I can not use the latest release version (1.6.5) but can use it in master branch. |
yep... v1.6.5 is the bug fix branch and does not include any new features that were merged after v1.6 release. This is currently available only on master and will be available in v1.7 release. |
What does this PR do?
Add
bf16
support toDeepSpeedStrategy
.More info: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options
Note that
bf16
support was added indeepspeed==0.6.0
. Do we maintain a minimum version for deepspeed? If not, then if someone hasdeepspeed < 0.6.0
and tries to use bf16, it won't fail. I can add a check in that case and throw an exception if we supportdeepspeed < 0.6.0
.Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃
cc @Borda @SeanNaren @awaelchli @rohitgr7 @akihironitta