Skip to content

Commit

Permalink
Black format pytorch_lightning/core/hooks.py (#3575)
Browse files Browse the repository at this point in the history
Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
  • Loading branch information
ananthsub authored Sep 21, 2020
1 parent cf1b946 commit 3442b97
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union, List
from typing import Any, List, Union

import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn

try:
from apex import amp
Expand All @@ -28,7 +28,6 @@


class ModelHooks:

def setup(self, stage: str):
"""
Called at the beginning of fit and test.
Expand Down Expand Up @@ -113,7 +112,9 @@ def on_pretrain_routine_end(self) -> None:
"""
# do something at the end of the pretrain routine

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop before anything happens for that batch.
Expand All @@ -126,7 +127,9 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
"""
# do something when the batch starts

def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop after the batch.
Expand All @@ -137,7 +140,9 @@ def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) ->
"""
# do something when the batch ends

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop before anything happens for that batch.
Expand All @@ -148,7 +153,9 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
"""
# do something when the batch starts

def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop after the batch.
Expand All @@ -159,7 +166,9 @@ def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: in
"""
# do something when the batch ends

def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop before anything happens for that batch.
Expand All @@ -170,7 +179,9 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -
"""
# do something when the batch starts

def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop after the batch.
Expand Down Expand Up @@ -288,7 +299,9 @@ def on_after_backward(self):
"""

def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
def backward(
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
) -> None:
"""
Override backward with your own implementation if you need to.
Expand All @@ -311,7 +324,13 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
"""
loss.backward()

def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
def amp_scale_loss(
self,
unscaled_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
amp_backend: AMPType,
):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
Expand All @@ -321,7 +340,6 @@ def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: A


class DataHooks:

def prepare_data(self) -> None:
"""
Use this to download and prepare data.
Expand Down Expand Up @@ -412,7 +430,9 @@ def train_dataloader(self):
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
rank_zero_warn(
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
)

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Expand Down

0 comments on commit 3442b97

Please sign in to comment.