-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add initial deepspeed changes * Address code review * Move static method outside of function * Fixes * Add missing annotation * Remove seed setting * Doc changes * Doc changes, add address reviews * Fix docs * Try fixing issue by moving to torch adam * Clean up check * Changes, better APIs! * Add wrapper, swap to git install revision * Add special test * Add warning * Address review * Add better disclaimer * Turn off ZeRO for testing due to compilation * Add description on modifying parameters via the plugin * Doc strings clear * Small doc fixes * Fix hash, reduce test * Added CI change * Move to azure pipeline * Fix test name * Add missing flag * Remove sudo... * Try conda instead * Swap to conda base * Try suggested install * Apply suggestions from code review * Apply suggestions from code review * Revert "Apply suggestions from code review" This reverts commit 41cca05 * Revert "Apply suggestions from code review" This reverts commit e06ec29 * Remove setter * Address most review * Move out function, remove DeepSpeed from requirements * Install deepspeed/mpi4py within container * Use special tests, move to master commit for deepspeed * Export path * Force compile to happen first * Remove! * Debugging ninja * Fix error in optimizer step logic * Attempt to fix symbolic link * Reverse to aid debugging * Export path again * Clean up mess * var * Revert "var" This reverts commit 3450eac * Address review, add todo * Add note about unsupported functionality Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: tchaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
- Loading branch information
1 parent
6a409c7
commit 7189d67
Showing
16 changed files
with
877 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
pytorch_lightning/plugins/precision/deepspeed_precision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Callable, Union | ||
|
||
import torch | ||
from torch.optim import Optimizer | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
from pytorch_lightning.utilities.model_helpers import is_overridden | ||
from pytorch_lightning.utilities.warnings import WarningCache | ||
|
||
warning_cache = WarningCache() | ||
|
||
|
||
class DeepSpeedPrecisionPlugin(PrecisionPlugin): | ||
|
||
def __init__(self, precision): | ||
super().__init__() | ||
self.precision = precision | ||
|
||
def pre_optimizer_step( | ||
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs | ||
) -> bool: | ||
deepspeed_engine = pl_module.trainer.model | ||
# DeepSpeed not support closures. | ||
lambda_closure() | ||
|
||
if not pl_module.automatic_optimization: | ||
pl_module.trainer.call_hook("on_after_backward") | ||
|
||
deepspeed_engine.step() | ||
|
||
return False | ||
|
||
def backward( | ||
self, | ||
lightning_module: LightningModule, | ||
closure_loss: torch.Tensor, | ||
optimizer: torch.optim.Optimizer, | ||
opt_idx: int, | ||
should_accumulate: bool, | ||
*args, | ||
**kwargs, | ||
): | ||
if is_overridden('backward', lightning_module): | ||
warning_cache.warn( | ||
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" | ||
"backward logic outside of the LightningModule" | ||
) | ||
# todo: hack around for deepspeed engine to call backward | ||
deepspeed_engine = lightning_module.trainer.model | ||
deepspeed_engine.backward(closure_loss, **kwargs) | ||
# once backward has been applied, release graph | ||
closure_loss = closure_loss.detach() | ||
|
||
return closure_loss | ||
|
||
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): | ||
""" | ||
DeepSpeed handles clipping gradients via the training type plugin. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.