Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer(gradient_clip_algorithm='value'|'norm') #6123

Merged
merged 59 commits into from
Apr 6, 2021
Merged
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
09ea112
add changelog
dhkim0225 Feb 22, 2021
c0e8064
add clip by value
dhkim0225 Feb 22, 2021
ca8e6fd
fix bug in training tricks.rst
dhkim0225 Feb 22, 2021
87f12c1
fix bug in trainer.rst
dhkim0225 Feb 22, 2021
8e43b8a
Update trainer.rst
dhkim0225 Feb 22, 2021
2bb9924
Update trainer.rst
dhkim0225 Feb 22, 2021
5b83f0d
Update CHANGELOG.md
dhkim0225 Feb 23, 2021
caafdf2
Update pytorch_lightning/plugins/precision/deepspeed_precision.py
dhkim0225 Feb 23, 2021
ca774b6
Update pytorch_lightning/utilities/enums.py
dhkim0225 Feb 23, 2021
5a741e2
yapf formatting
dhkim0225 Feb 23, 2021
0568e3a
update training tricks
dhkim0225 Feb 23, 2021
1a4e79e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Feb 26, 2021
4a813c1
Merge branch 'master' into feat/clip_grad_by_value
tchaton Feb 26, 2021
2f5cb3e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
f4275a2
update based on comment
dhkim0225 Mar 2, 2021
e92ec69
update based on comment
dhkim0225 Mar 2, 2021
ac701ce
Update pytorch_lightning/trainer/trainer.py
dhkim0225 Mar 2, 2021
bc20fa4
update based on comment
dhkim0225 Mar 2, 2021
b842210
Merge branch 'feat/clip_grad_by_value' of https://github.com/dhkim022…
dhkim0225 Mar 2, 2021
5ec2ebd
pep8
dhkim0225 Mar 2, 2021
d37fbbc
mypy
dhkim0225 Mar 2, 2021
952c778
mypy
dhkim0225 Mar 2, 2021
b8fdbe1
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
c4cccf0
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 3, 2021
6bd4793
Update docs/source/advanced/training_tricks.rst
dhkim0225 Mar 4, 2021
3aeba85
Update sharded_native_amp.py
dhkim0225 Mar 4, 2021
902a33c
Update test_sharded_parity.py
dhkim0225 Mar 4, 2021
7467616
update test codes
dhkim0225 Mar 4, 2021
5463830
Update test_tpu.py
dhkim0225 Mar 4, 2021
2e933d4
Update pytorch_lightning/trainer/connectors/training_trick_connector.py
dhkim0225 Mar 4, 2021
b1e26e6
Update test_trainer.py
dhkim0225 Mar 4, 2021
cedf5f6
Update enums.py
dhkim0225 Mar 4, 2021
f5bb45d
Update enums.py
dhkim0225 Mar 4, 2021
42fc5f6
Merge branch 'master' into feat/clip_grad_by_value
Borda Mar 4, 2021
e55b90c
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 5, 2021
308ce38
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 23, 2021
903f2e2
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
28c948a
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
177a1c9
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
fc23845
use subprocess check_call
dhkim0225 Mar 25, 2021
d99a650
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 25, 2021
f80aa8d
change order of horovod tests
dhkim0225 Mar 25, 2021
fb895b6
set max_epochs 2 in horovod test
dhkim0225 Mar 25, 2021
caa0bbf
remove clip_grad_val test from horovod-cpu
dhkim0225 Mar 25, 2021
f1f9015
remove "type: ignore"
dhkim0225 Mar 25, 2021
5dfe5ef
divide clip grad val test in horovod
dhkim0225 Mar 25, 2021
50a6c74
update based on comments
dhkim0225 Mar 25, 2021
c337b12
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
f7a4fda
bugfix
dhkim0225 Mar 25, 2021
48c3dd8
bugfix
dhkim0225 Mar 25, 2021
e7e3b47
revert some changes
dhkim0225 Mar 26, 2021
2997536
revert some changes
dhkim0225 Mar 26, 2021
fb34e84
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 26, 2021
8e665ec
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 27, 2021
9575774
Update tests/models/test_horovod.py
carmocca Mar 27, 2021
7c16f6a
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 29, 2021
fec189a
merge master
dhkim0225 Apr 6, 2021
1e80304
merge master
dhkim0225 Apr 6, 2021
4d5e05f
Delete signature test
carmocca Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add clip by value
dhkim0225 committed Feb 22, 2021
commit c0e80642609e68e3408a52109b10dc231850126a
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [UnReleased] - 2021-MM-DD

### Added
Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6121](https://github.com/PyTorchLightning/pytorch-lightning/pull/6121)).
Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved

### Changed

24 changes: 24 additions & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
@@ -115,6 +115,28 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
)


@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def test_ddp_sharded_plugin_clip_gradients(tmpdir, args=None):
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
)
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
gradient_clip_algorithm='value',
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@@ -245,6 +267,8 @@ def plugin_parity_test(
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1,
gradient_clip_val: float = 0,
gradient_clip_algorithm: str = 'norm',
):
"""
Ensures that the trained model is identical to the standard DDP implementation.
8 changes: 6 additions & 2 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
@@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If gradient_clip_algorithm option is set to 'value', which is 'norm' by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

@@ -39,6 +41,8 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

# clip gradients with value above 0.5
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')
Borda marked this conversation as resolved.
Show resolved Hide resolved
----------

Stochastic Weight Averaging
13 changes: 13 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
@@ -735,6 +735,19 @@ Gradient clipping value
# default used by the Trainer
trainer = Trainer(gradient_clip_val=0.0)

gradient_clip_val
^^^^^^^^^^^^^^^^^

Gradient clipping algorithm

- Clip gradients by norm or value.

.. testcode::

# default used by the Trainer
trainer = Trainer(gradient_clip_algorithm='norm')


limit_train_batches
^^^^^^^^^^^^^^^^^^^

11 changes: 8 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum


class Accelerator(object):
@@ -287,10 +287,15 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_gradients(
self,
Borda marked this conversation as resolved.
Show resolved Hide resolved
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: str = GradClipAlgorithmType.NORM,
) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm)

def on_train_epoch_end(self, outputs) -> None:
"""Hook to do something on the end of an training epoch
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

@@ -54,7 +55,13 @@ def backward(

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: str = GradClipAlgorithmType.NORM,
norm_type: float = float(2.0),
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
57 changes: 33 additions & 24 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType


class PrecisionPlugin(Plugin):
@@ -86,7 +87,13 @@ def pre_optimizer_step(
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: str = GradClipAlgorithmType.NORM,
norm_type: float = float(2.0),
) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
if clip_val is None:
@@ -98,26 +105,28 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm
return

parameters = list(self.master_params(optimizer))

max_norm = grad_clip_val

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

device = parameters[0].device

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

eps = self.EPSILON

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
torch.nn.utils.clip_grad_value_(parameters, clip_value=grad_clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
max_norm = grad_clip_val
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))

device = parameters[0].device

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

eps = self.EPSILON

clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
21 changes: 17 additions & 4 deletions pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,11 @@
# limitations under the License.
from typing import cast, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE, GradClipAlgorithmType

if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
@@ -31,6 +32,18 @@ def __init__(self):
super().__init__()
self.scaler = ShardedGradScaler()

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: str = GradClipAlgorithmType.NORM,
norm_type: float = float(2.0),
):
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
parameters = list(self.master_params(optimizer))
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


@@ -23,6 +24,7 @@ def __init__(self, trainer):
def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
@@ -32,7 +34,11 @@ def on_trainer_init(
self.trainer.terminate_on_nan = terminate_on_nan

# gradient clipping
if gradient_clip_algorithm not in [GradClipAlgorithmType.VALUE, GradClipAlgorithmType.NORM]:
raise MisconfigurationException(f"gradient_clip_algorithm should be "
f"'{GradClipAlgorithmType.VALUE}' or '{GradClipAlgorithmType.NORM}'")
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -90,6 +90,7 @@ def __init__(
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
gradient_clip_algorithm: str = 'norm',
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
@@ -201,6 +202,8 @@ def __init__(

gradient_clip_val: 0 means don't clip.

gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Defualt: 'norm'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Defualt: 'norm'
gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm'

Copy link
Contributor Author

@dhkim0225 dhkim0225 Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub All modifications have been completed based on comments. Can you check the last changes?
Thanks.


limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
@@ -355,7 +358,12 @@ def __init__(

# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
)

# init train loop related flags
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,13 @@
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
DeviceType,
DistributedType,
GradClipAlgorithmType,
LightningEnum,
)
from pytorch_lightning.utilities.imports import ( # noqa: F401
_APEX_AVAILABLE,
_BOLTS_AVAILABLE,
9 changes: 9 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
@@ -84,3 +84,12 @@ class DeviceType(LightningEnum):
CPU = 'CPU'
GPU = 'GPU'
TPU = 'TPU'


class GradClipAlgorithmType(LightningEnum):
""" Define gradient_clip_algorithm types - training-tricks.
>>> GradClipAlgorithmType.VALUE in ('value', 'norm')
True
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved
"""
VALUE = 'value'
NORM = 'norm'
13 changes: 13 additions & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import shlex
import subprocess
import sys
from copy import deepcopy

import numpy as np
import pytest
@@ -66,6 +67,13 @@ def _run_horovod(trainer_options, on_gpu=False):
assert exit_code == 0


def _run_horovod_clip_grad_by_value(trainer_options, on_gpu=False):
# clip_grad_by_value test
trainer_options_clip_grad_val = deepcopy(trainer_options)
trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'})
_run_horovod(trainer_options_clip_grad_val, on_gpu)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_cpu(tmpdir):
"""Test Horovod running multi-process on CPU."""
@@ -81,6 +89,7 @@ def test_horovod_cpu(tmpdir):
deterministic=True,
)
_run_horovod(trainer_options)
_run_horovod_clip_grad_by_value(trainer_options)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@@ -97,6 +106,7 @@ def test_horovod_cpu_implicit(tmpdir):
deterministic=True,
)
_run_horovod(trainer_options)
_run_horovod_clip_grad_by_value(trainer_options)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@@ -117,6 +127,7 @@ def test_horovod_multi_gpu(tmpdir):
accelerator='horovod',
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
@@ -141,6 +152,7 @@ def test_horovod_apex(tmpdir):
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp")
@@ -165,6 +177,7 @@ def test_horovod_amp(tmpdir):
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
Loading