Skip to content

Commit

Permalink
Fix deepspeed default precision plugin amp_level to O2 (#13897)
Browse files Browse the repository at this point in the history
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
rohitgr7 and akihironitta authored Jul 29, 2022
1 parent aefb9ab commit 0f6caff
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))


- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))



## [1.6.5] - 2022-07-13

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
def __init__(self, amp_level: str = "O2") -> None:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but you have not installed it."
"You have asked for Apex AMP but `apex` is not installed."
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
)
super().__init__()
Expand Down
13 changes: 11 additions & 2 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.imports import _APEX_AVAILABLE, _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -51,6 +51,15 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
if amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but `apex` is not installed."
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
)

amp_level = amp_level or "O2"

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
if precision not in supported_precision:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")


def test_deepspeed_precision_apex_not_installed(monkeypatch):
import pytorch_lightning.plugins.precision.deepspeed as deepspeed_apex

monkeypatch.setattr(deepspeed_apex, "_APEX_AVAILABLE", False)
with pytest.raises(MisconfigurationException, match="You have asked for Apex AMP but `apex` is not installed."):
DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")


@mock.patch("pytorch_lightning.plugins.precision.deepspeed._APEX_AVAILABLE", return_value=True)
def test_deepspeed_precision_apex_default_level(_):
precision_plugin = DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")
assert isinstance(precision_plugin, DeepSpeedPrecisionPlugin)
assert precision_plugin.amp_level == "O2"
2 changes: 1 addition & 1 deletion tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,5 @@ def test_precision_selection_raises(monkeypatch):
monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch(
"pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but you have not installed it"):
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"):
Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1)

0 comments on commit 0f6caff

Please sign in to comment.