-
Notifications
You must be signed in to change notification settings - Fork 3.7k
FP16 optimizer automatically detect DeepSpeed compatibility #18084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
400c304
code compare dymanically
pengwa e430d96
fix
pengwa dd2e6a5
refine
pengwa 17a30c0
comment
pengwa f12aeba
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa 2cd979c
minor
pengwa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
orttraining/orttraining/python/training/optim/_ds_code_store.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # | ||
| # Copyright 2020 The Microsoft DeepSpeed Team | ||
| # | ||
| # !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version, | ||
| # It is used to compare with the source code of current installed DeepSpeed during runtime. | ||
| # Please don't modify it or do any code formatting for it. | ||
| # 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention. | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| # Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed. | ||
|
|
||
| class Stage1And2_DeepSpeedZeroOptimizer_0_9_2: | ||
|
|
||
| def has_overflow_serial(self, params, is_grad_list=False): | ||
| for p in params: | ||
| if p.grad is not None and self._has_inf_or_nan(p.grad.data): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def get_grad_norm_direct(self, gradients, params, norm_type=2): | ||
| """Clips gradient norm of an iterable of parameters. | ||
|
|
||
| This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and | ||
| added functionality to handle model parallel parameters. Note that | ||
| the gradients are modified in place. | ||
|
|
||
| Arguments: | ||
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | ||
| single Tensor that will have gradients normalized | ||
| max_norm (float or int): max norm of the gradients | ||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | ||
| infinity norm. | ||
|
|
||
| Returns: | ||
| Total norm of the parameters (viewed as a single vector). | ||
| """ | ||
| norm_type = float(norm_type) | ||
| if norm_type == inf: | ||
| total_norm = max(g.data.abs().max() for g in gradients) | ||
| total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) | ||
| dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) | ||
|
|
||
| # Take max across all GPUs. | ||
| self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) | ||
| total_norm = total_norm_cuda[0].item() | ||
| else: | ||
| total_norm = 0.0 | ||
| # if dist.get_rank() == 0: | ||
| # logger.info(f"Total Norm beginning {total_norm}") | ||
| for g, p in zip(gradients, params): | ||
| # Pipeline parallelism may replicate parameters. Avoid multi-counting. | ||
| if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: | ||
| continue | ||
| if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): | ||
| param_norm = g.data.double().norm(2) | ||
| total_norm += param_norm.item()**2 | ||
| # Sum across all model parallel GPUs. | ||
| total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) | ||
| dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) | ||
|
|
||
| self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) | ||
|
|
||
| total_norm = total_norm_cuda[0].item()**(1. / norm_type) | ||
|
|
||
| if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: | ||
Check warningCode scanning / CodeQL Comparison of identical values
Comparison of identical values; use cmath.isnan() if testing for not-a-number.
|
||
| total_norm = -1 | ||
|
|
||
| return total_norm | ||
|
|
||
|
|
||
| def has_overflow_partitioned_grads_serial(self): | ||
| for i in range(len(self.bit16_groups)): | ||
| for j, grad in enumerate(self.averaged_gradients[i]): | ||
| if grad is not None and self._has_inf_or_nan(grad.data, j): | ||
| return True | ||
| return False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Commented-out code