From cc22ddc716820fe4fcc297aefbee644acd0285e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:25:56 +0100 Subject: [PATCH] Remove duplicate no_grad context managers (#16773) --- src/lightning/pytorch/loops/epoch/training_epoch_loop.py | 5 +---- src/lightning/pytorch/trainer/trainer.py | 3 +-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index cf7f38707ca46..1f52128f06954 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -15,8 +15,6 @@ from collections import OrderedDict from typing import Any, Dict, Optional, Union -import torch - import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -284,8 +282,7 @@ def _run_validation(self) -> None: # reload dataloaders self.val_loop._reload_evaluation_dataloaders() - with torch.no_grad(): - self.val_loop.run() + self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2926982ecc94c..35beb7c57d662 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -956,8 +956,7 @@ def _run_sanity_check(self) -> None: ] # run eval step - with torch.no_grad(): - val_loop.run() + val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end")