Skip to content

Commit

Permalink
Update warnings in TrainingTricksConnector (#9595)
Browse files Browse the repository at this point in the history
* update warnings

* add tests

* comments

* Apply suggestions from code review

* Apply suggestions from code review
  • Loading branch information
rohitgr7 authored Sep 25, 2021
1 parent ddf6967 commit a4bc0ac
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 10 deletions.
26 changes: 18 additions & 8 deletions pytorch_lightning/trainer/connectors/training_trick_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,31 @@ def __init__(self, trainer):

def on_trainer_init(
self,
gradient_clip_val: float,
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
terminate_on_nan: bool,
):

self.trainer.terminate_on_nan = terminate_on_nan
if not isinstance(terminate_on_nan, bool):
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

# gradient clipping
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
raise MisconfigurationException(f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}")
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
if not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != "inf":
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
raise MisconfigurationException(
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
)

self.trainer.terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.track_grad_norm = float(track_grad_norm)
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
checkpoint_callback: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0.0,
gradient_clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: str = "norm",
process_position: int = 0,
num_nodes: int = 1,
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class GradClipAlgorithmType(LightningEnum):
VALUE = "value"
NORM = "norm"

@staticmethod
def supported_type(val: str) -> bool:
return any(x.value == val for x in GradClipAlgorithmType)

@staticmethod
def supported_types() -> List[str]:
return [x.value for x in GradClipAlgorithmType]


class AutoRestartBatchKeys(LightningEnum):
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
Expand Down
20 changes: 20 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,16 @@ def training_step(self, batch, batch_idx):
assert torch.isfinite(param).all()


def test_invalid_terminate_on_nan(tmpdir):
with pytest.raises(TypeError, match="`terminate_on_nan` should be a bool"):
Trainer(default_root_dir=tmpdir, terminate_on_nan="False")


def test_invalid_track_grad_norm(tmpdir):
with pytest.raises(MisconfigurationException, match="`track_grad_norm` should be an int, a float"):
Trainer(default_root_dir=tmpdir, track_grad_norm="nan")


@mock.patch("torch.Tensor.backward")
def test_nan_params_detection(backward_mock, tmpdir):
class CurrentModel(BoringModel):
Expand Down Expand Up @@ -1070,6 +1080,16 @@ def backward(*args, **kwargs):
trainer.fit(model)


def test_invalid_gradient_clip_value(tmpdir):
with pytest.raises(TypeError, match="`gradient_clip_val` should be an int or a float"):
Trainer(default_root_dir=tmpdir, gradient_clip_val=(1, 2))


def test_invalid_gradient_clip_algo(tmpdir):
with pytest.raises(MisconfigurationException, match="`gradient_clip_algorithm` norm2 is invalid"):
Trainer(default_root_dir=tmpdir, gradient_clip_algorithm="norm2")


def test_gpu_choice(tmpdir):
trainer_options = dict(default_root_dir=tmpdir)
# Only run if CUDA is available
Expand Down
9 changes: 8 additions & 1 deletion tests/utilities/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import pytest

from pytorch_lightning.utilities.enums import DeviceType, ModelSummaryMode, PrecisionType
from pytorch_lightning.utilities.enums import DeviceType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType


def test_consistency():
Expand Down Expand Up @@ -42,3 +42,10 @@ def test_model_summary_mode():

with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."):
ModelSummaryMode.get_max_depth("invalid")


def test_gradient_clip_algorithms():
assert GradClipAlgorithmType.supported_types() == ["value", "norm"]
assert GradClipAlgorithmType.supported_type("norm")
assert GradClipAlgorithmType.supported_type("value")
assert not GradClipAlgorithmType.supported_type("norm2")

0 comments on commit a4bc0ac

Please sign in to comment.