Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Hints for Lightning Core #946

Merged
merged 38 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
661405d
first pass for LightningModule typehints
Feb 23, 2020
9b5e784
fix return types
Feb 23, 2020
dc89b90
add missing types
Feb 25, 2020
d3460cb
add type annotations to grads.py
Feb 25, 2020
cdfd1d1
add type annotations to hooks.py
Feb 25, 2020
3526a77
add type annotation to memory.py
Feb 25, 2020
614256c
proper docstring quotation marks
Feb 25, 2020
e36f42a
add type annotations to saving.py
Feb 25, 2020
0c86afe
fix cyclic import problem
Feb 25, 2020
280a0d0
fix cyclic import problem
Feb 25, 2020
e433490
add missing whitespace
Feb 25, 2020
4d24fdf
finish type hints for load_from_ methods
Feb 26, 2020
c2e1cec
docs: prepare_data does not return anything
Feb 26, 2020
a878adb
fix auto types in docs
Feb 26, 2020
35eb67c
revert typehint for trainer in hook
Feb 26, 2020
0272309
remove unnecessary return docs
Feb 26, 2020
34f7c7d
some fixes for memory docs
Feb 26, 2020
2fa525f
revert typing for args kwargs
Feb 26, 2020
1187ef7
added all missing None return types
Feb 26, 2020
fe19508
remove unused import
Feb 26, 2020
02e50c3
add more details to dict/list return types
Feb 26, 2020
d11f747
fix line too long
Feb 27, 2020
bee44d0
optimize imports
Mar 3, 2020
09b8c16
Merge branch 'master' into module-typehints
Borda Mar 5, 2020
8555961
linted
Borda Mar 5, 2020
7482cb8
Revert "linted"
Borda Mar 5, 2020
cad32d8
remove whitespace
Mar 5, 2020
ced9435
Merge branch 'master' into module-typehints
Mar 6, 2020
b3f0ba5
update
Mar 7, 2020
aef5138
Merge branch 'master' into module-typehints
Mar 7, 2020
15516de
update
Mar 7, 2020
9c026ce
update
Mar 7, 2020
2915e6b
update
Mar 7, 2020
4bd0437
update
Mar 7, 2020
5c4b27f
Merge branch 'master' into module-typehints
Mar 10, 2020
4f68fc1
changelog
Mar 10, 2020
4a517c3
Merge branch 'master' into module-typehints
Mar 11, 2020
1909bad
Merge branch 'master' into module-typehints
williamFalcon Mar 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))

### Changed
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
Module to describe gradients
"""
from typing import Dict

from torch import nn


class GradInformation(nn.Module):

def grad_norm(self, norm_type):
def grad_norm(self, norm_type: float) -> Dict[str, int]:
results = {}
total_norm = 0
for name, p in self.named_parameters():
Expand Down
39 changes: 16 additions & 23 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.

"""

from typing import Any

import torch

from torch import Tensor
from torch.optim.optimizer import Optimizer

try:
from apex import amp
Expand All @@ -36,48 +37,45 @@ def on_sanity_check_start(self):
:return:
"""

def on_train_start(self):
def on_train_start(self) -> None:
"""Called at the beginning of training before sanity check
:return:
"""
# do something at the start of training

def on_train_end(self):
def on_train_end(self) -> None:
"""
Called at the end of training before logger experiment is closed
:return:
"""
# do something at the end of training

def on_batch_start(self, batch):
def on_batch_start(self, batch: Any) -> None:
"""Called in the training loop before anything happens for that batch.

:param batch:
:return:
"""
# do something when the batch starts

def on_batch_end(self):
def on_batch_end(self) -> None:
"""Called in the training loop after the batch."""
# do something when the batch ends

def on_epoch_start(self):
def on_epoch_start(self) -> None:
"""Called in the training loop at the very beginning of the epoch."""
# do something when the epoch starts

def on_epoch_end(self):
def on_epoch_end(self) -> None:
"""Called in the training loop at the very end of the epoch."""
# do something when the epoch ends

def on_pre_performance_check(self):
def on_pre_performance_check(self) -> None:
"""Called at the very beginning of the validation loop."""
# do something before validation starts

def on_post_performance_check(self):
def on_post_performance_check(self) -> None:
"""Called at the very end of the validation loop."""
# do something before validation end

def on_before_zero_grad(self, optimizer):
def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""Called after optimizer.step() and before optimizer.zero_grad()

Called in the training loop after taking an optimizer step and before zeroing grads.
Expand All @@ -89,17 +87,13 @@ def on_before_zero_grad(self, optimizer):
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad

:param optimizer:
:return:
:param optimizer: The optimizer for which grads should be zeroed.
"""
# do something with the optimizer or inspect it.

def on_after_backward(self):
"""Called after loss.backward() and before optimizers do anything.

:return:
def on_after_backward(self) -> None:
"""Called in the training loop after loss.backward() and before optimizers do anything.

Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information

.. code-block:: python
Expand All @@ -116,14 +110,13 @@ def on_after_backward(self):

"""

def backward(self, trainer, loss, optimizer, optimizer_idx):
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Override backward with your own implementation if you need to

:param trainer: Pointer to the trainer
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
:return:

Called to perform backward step.
Feel free to override as needed.
Expand Down
Loading