diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index aa66df9b54a8b..ea649a9b65236 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845)) +- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897)) + + ## [1.6.5] - 2022-07-13 diff --git a/src/pytorch_lightning/plugins/precision/apex_amp.py b/src/pytorch_lightning/plugins/precision/apex_amp.py index e18f82dc27f6e..2077f2072ab95 100644 --- a/src/pytorch_lightning/plugins/precision/apex_amp.py +++ b/src/pytorch_lightning/plugins/precision/apex_amp.py @@ -35,7 +35,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): def __init__(self, amp_level: str = "O2") -> None: if not _APEX_AVAILABLE: raise MisconfigurationException( - "You have asked for Apex AMP but you have not installed it." + "You have asked for Apex AMP but `apex` is not installed." " Install `apex` using this guide: https://github.com/NVIDIA/apex" ) super().__init__() diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index fa948520e1fd6..791a08a87d107 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -20,9 +20,9 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType -from pytorch_lightning.utilities.enums import PrecisionType +from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RequirementAvailable +from pytorch_lightning.utilities.imports import _APEX_AVAILABLE, _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -51,6 +51,15 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): """ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: + if amp_type == AMPType.APEX: + if not _APEX_AVAILABLE: + raise MisconfigurationException( + "You have asked for Apex AMP but `apex` is not installed." + " Install `apex` using this guide: https://github.com/NVIDIA/apex" + ) + + amp_level = amp_level or "O2" + supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED) if precision not in supported_precision: raise ValueError( diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index a4698e7c19c97..c1f7979ea8482 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -11,11 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_invalid_precision_with_deepspeed_precision(): with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): DeepSpeedPrecisionPlugin(precision=64, amp_type="native") + + +def test_deepspeed_precision_apex_not_installed(monkeypatch): + import pytorch_lightning.plugins.precision.deepspeed as deepspeed_apex + + monkeypatch.setattr(deepspeed_apex, "_APEX_AVAILABLE", False) + with pytest.raises(MisconfigurationException, match="You have asked for Apex AMP but `apex` is not installed."): + DeepSpeedPrecisionPlugin(precision=16, amp_type="apex") + + +@mock.patch("pytorch_lightning.plugins.precision.deepspeed._APEX_AVAILABLE", return_value=True) +def test_deepspeed_precision_apex_default_level(_): + precision_plugin = DeepSpeedPrecisionPlugin(precision=16, amp_type="apex") + assert isinstance(precision_plugin, DeepSpeedPrecisionPlugin) + assert precision_plugin.amp_level == "O2" diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index b02e3e29e9539..974964e5b9101 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -289,5 +289,5 @@ def test_precision_selection_raises(monkeypatch): monkeypatch.setattr(apex, "_APEX_AVAILABLE", False) with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch( "pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True - ), pytest.raises(MisconfigurationException, match="asked for Apex AMP but you have not installed it"): + ), pytest.raises(MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"): Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1)