Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update warnings in TrainingTricksConnector #9595

Merged
merged 6 commits into from
Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,36 @@ def __init__(self, trainer):

def on_trainer_init(
self,
gradient_clip_val: float,
gradient_clip_val: Union[int, float],
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int]],
terminate_on_nan: bool,
):

self.trainer.terminate_on_nan = terminate_on_nan
if not isinstance(terminate_on_nan, bool):
raise MisconfigurationException(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# gradient clipping
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
raise MisconfigurationException(f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}")
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
if not isinstance(gradient_clip_val, (int, float)):
raise MisconfigurationException(
f"Gradient Clipping Value can be an int or a float, got {gradient_clip_val}."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)

if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
raise MisconfigurationException(
f"Gradient Clipping Algorithm {gradient_clip_algorithm} is invalid. "
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != "inf":
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
raise MisconfigurationException(
f"track_grad_norm can be an int, a float or 'inf' (infinity norm), got {track_grad_norm}."
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)

self.trainer.terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.track_grad_norm = float(track_grad_norm)

# accumulated grads
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
checkpoint_callback: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0.0,
gradient_clip_val: Union[int, float] = 0.0,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
gradient_clip_algorithm: str = "norm",
process_position: int = 0,
num_nodes: int = 1,
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class GradClipAlgorithmType(LightningEnum):
VALUE = "value"
NORM = "norm"

@staticmethod
def supported_type(val: str) -> bool:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return any(x.value == val for x in GradClipAlgorithmType)

@staticmethod
def supported_types() -> List[str]:
return [x.value for x in GradClipAlgorithmType]


class AutoRestartBatchKeys(LightningEnum):
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
Expand Down
15 changes: 15 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,11 @@ def training_step(self, batch, batch_idx):
assert torch.isfinite(param).all()


def test_invalid_terminate_on_nan(tmpdir):
with pytest.raises(MisconfigurationException, match="`terminate_on_nan` should be a bool"):
Trainer(default_root_dir=tmpdir, terminate_on_nan="False")


@mock.patch("torch.Tensor.backward")
def test_nan_params_detection(backward_mock, tmpdir):
class CurrentModel(BoringModel):
Expand Down Expand Up @@ -1012,6 +1017,16 @@ def backward(*args, **kwargs):
trainer.fit(model)


def test_invalid_gradient_clip_value(tmpdir):
with pytest.raises(MisconfigurationException, match="Gradient Clipping Value can be an int or a float"):
Trainer(default_root_dir=tmpdir, gradient_clip_val=(1, 2))


def test_invalid_gradient_clip_algo(tmpdir):
with pytest.raises(MisconfigurationException, match="Gradient Clipping Algorithm norm2 is invalid"):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
Trainer(default_root_dir=tmpdir, gradient_clip_algorithm="norm2")


def test_gpu_choice(tmpdir):
trainer_options = dict(default_root_dir=tmpdir)
# Only run if CUDA is available
Expand Down
9 changes: 8 additions & 1 deletion tests/utilities/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import pytest

from pytorch_lightning.utilities.enums import DeviceType, ModelSummaryMode, PrecisionType
from pytorch_lightning.utilities.enums import DeviceType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType


def test_consistency():
Expand Down Expand Up @@ -42,3 +42,10 @@ def test_model_summary_mode():

with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."):
ModelSummaryMode.get_max_depth("invalid")


def test_gradient_clip_algorithms():
assert GradClipAlgorithmType.supported_types() == ["value", "norm"]
assert GradClipAlgorithmType.supported_type("norm")
assert GradClipAlgorithmType.supported_type("value")
assert not GradClipAlgorithmType.supported_type("norm2")