Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Black format pytorch_lightning/core/hooks.py #3575

Merged
merged 1 commit into from
Sep 21, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union, List
from typing import Any, List, Union

import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn

try:
from apex import amp
Expand All @@ -28,7 +28,6 @@


class ModelHooks:

def setup(self, stage: str):
"""
Called at the beginning of fit and test.
Expand Down Expand Up @@ -113,7 +112,9 @@ def on_pretrain_routine_end(self) -> None:
"""
# do something at the end of the pretrain routine

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop before anything happens for that batch.

Expand All @@ -126,7 +127,9 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int)
"""
# do something when the batch starts

def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop after the batch.

Expand All @@ -137,7 +140,9 @@ def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) ->
"""
# do something when the batch ends

def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop before anything happens for that batch.

Expand All @@ -148,7 +153,9 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
"""
# do something when the batch starts

def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop after the batch.

Expand All @@ -159,7 +166,9 @@ def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: in
"""
# do something when the batch ends

def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop before anything happens for that batch.

Expand All @@ -170,7 +179,9 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -
"""
# do something when the batch starts

def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop after the batch.

Expand Down Expand Up @@ -288,7 +299,9 @@ def on_after_backward(self):

"""

def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
def backward(
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
) -> None:
"""
Override backward with your own implementation if you need to.

Expand All @@ -311,7 +324,13 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
"""
loss.backward()

def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
def amp_scale_loss(
self,
unscaled_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
amp_backend: AMPType,
):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
Expand All @@ -321,7 +340,6 @@ def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: A


class DataHooks:

def prepare_data(self) -> None:
"""
Use this to download and prepare data.
Expand Down Expand Up @@ -412,7 +430,9 @@ def train_dataloader(self):
return loader

"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
rank_zero_warn(
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
)

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Expand Down