diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index cf7f38707ca46..1f52128f06954 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -15,8 +15,6 @@ from collections import OrderedDict from typing import Any, Dict, Optional, Union -import torch - import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -284,8 +282,7 @@ def _run_validation(self) -> None: # reload dataloaders self.val_loop._reload_evaluation_dataloaders() - with torch.no_grad(): - self.val_loop.run() + self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2926982ecc94c..35beb7c57d662 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -956,8 +956,7 @@ def _run_sanity_check(self) -> None: ] # run eval step - with torch.no_grad(): - val_loop.run() + val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end")