diff --git a/CHANGELOG.md b/CHANGELOG.md index 21df3834bd368..4e17728f9424d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) ### Changed diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index d98a6efeb85ca..b5d2d5616a60f 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -1,13 +1,14 @@ """ Module to describe gradients """ +from typing import Dict from torch import nn class GradInformation(nn.Module): - def grad_norm(self, norm_type): + def grad_norm(self, norm_type: float) -> Dict[str, int]: results = {} total_norm = 0 for name, p in self.named_parameters(): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 500ba247f1414..d903de21d887a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -14,10 +14,11 @@ 3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called. """ - +from typing import Any import torch - +from torch import Tensor +from torch.optim.optimizer import Optimizer try: from apex import amp @@ -36,48 +37,45 @@ def on_sanity_check_start(self): :return: """ - def on_train_start(self): + def on_train_start(self) -> None: """Called at the beginning of training before sanity check - :return: """ # do something at the start of training - def on_train_end(self): + def on_train_end(self) -> None: """ Called at the end of training before logger experiment is closed - :return: """ # do something at the end of training - def on_batch_start(self, batch): + def on_batch_start(self, batch: Any) -> None: """Called in the training loop before anything happens for that batch. :param batch: - :return: """ # do something when the batch starts - def on_batch_end(self): + def on_batch_end(self) -> None: """Called in the training loop after the batch.""" # do something when the batch ends - def on_epoch_start(self): + def on_epoch_start(self) -> None: """Called in the training loop at the very beginning of the epoch.""" # do something when the epoch starts - def on_epoch_end(self): + def on_epoch_end(self) -> None: """Called in the training loop at the very end of the epoch.""" # do something when the epoch ends - def on_pre_performance_check(self): + def on_pre_performance_check(self) -> None: """Called at the very beginning of the validation loop.""" # do something before validation starts - def on_post_performance_check(self): + def on_post_performance_check(self) -> None: """Called at the very end of the validation loop.""" # do something before validation end - def on_before_zero_grad(self, optimizer): + def on_before_zero_grad(self, optimizer: Optimizer) -> None: """Called after optimizer.step() and before optimizer.zero_grad() Called in the training loop after taking an optimizer step and before zeroing grads. @@ -89,17 +87,13 @@ def on_before_zero_grad(self, optimizer): model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad - :param optimizer: - :return: + :param optimizer: The optimizer for which grads should be zeroed. """ # do something with the optimizer or inspect it. - def on_after_backward(self): - """Called after loss.backward() and before optimizers do anything. - - :return: + def on_after_backward(self) -> None: + """Called in the training loop after loss.backward() and before optimizers do anything. - Called in the training loop after model.backward() This is the ideal place to inspect or log gradient information .. code-block:: python @@ -116,14 +110,13 @@ def on_after_backward(self): """ - def backward(self, trainer, loss, optimizer, optimizer_idx): + def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: """Override backward with your own implementation if you need to :param trainer: Pointer to the trainer :param loss: Loss is already scaled by accumulated grads :param optimizer: Current optimizer being used :param optimizer_idx: Index of the current optimizer being used - :return: Called to perform backward step. Feel free to override as needed. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 68cfee6f88828..b531c1542ee6d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,11 +5,15 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel from torch.optim import Adam +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks @@ -69,7 +73,7 @@ def __init__(self, *args, **kwargs): self.hparams = None - def print(self, *args, **kwargs): + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once @@ -136,7 +140,9 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs): + def training_step(self, *args, **kwargs) -> Union[ + int, Dict[str, Union[Tensor, Dict[str, Tensor]]] + ]: r"""return loss, dict with metrics for tqdm Args: @@ -222,7 +228,9 @@ def training_end(self, *args, **kwargs): Deprecated in v0.7.0. use training_step_end instead """ - def training_step_end(self, *args, **kwargs): + def training_step_end(self, *args, **kwargs) -> Dict[ + str, Union[Tensor, Dict[str, Tensor]] + ]: """ Use this when training with dp or ddp2 because training_step will operate on only part of the batch. However, this is still optional @@ -283,7 +291,7 @@ def training_step_end(self, outputs): .. seealso:: see the `multi-gpu guide for more details `_. """ - def validation_step(self, *args, **kwargs): + def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" Operate on a single batch of data from the validation set In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -371,7 +379,7 @@ def validation_step(self, batch, batch_idx, dataset_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs): + def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ Use this when validating with dp or ddp2 because validation_step will operate on only part of the batch. However, this is still optional @@ -435,7 +443,10 @@ def validation_end(self, outputs): Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0 """ - def validation_epoch_end(self, outputs: list): + def validation_epoch_end( + self, + outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + ) -> Dict[str, Dict[str, Tensor]]: """ Called at end of validation epoch with the output of all validation_steps @@ -509,7 +520,7 @@ def validation_epoch_end(self, outputs): return results """ - def test_step(self, *args, **kwargs): + def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" Operate on a single batch of data from the test set In this step you'd normally generate examples or calculate anything of interest @@ -590,7 +601,7 @@ def test_step(self, batch, batch_idx, dataset_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs): + def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ Use this when testing with dp or ddp2 because test_step will operate on only part of the batch. However, this is still optional @@ -654,7 +665,10 @@ def test_end(self, outputs): Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 """ - def test_epoch_end(self, outputs): + def test_epoch_end( + self, + outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + ) -> Dict[str, Dict[str, Tensor]]: """ Called at end of test epoch with the output of all test_steps. @@ -669,7 +683,7 @@ def test_epoch_end(self, outputs): test_epoch_end(test_outs) Args: - outputs (list): List of outputs you defined in test_step, or if there are multiple + outputs: List of outputs you defined in test_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader Return: @@ -728,7 +742,11 @@ def test_epoch_end(self, outputs): return results """ - def configure_ddp(self, model, device_ids): + def configure_ddp( + self, + model: 'LightningModule', + device_ids: List[int] + ) -> DistributedDataParallel: r""" Override to init DDP in your own way or with your own wrapper. The only requirements are that: @@ -738,8 +756,8 @@ def configure_ddp(self, model, device_ids): 3. On a testing batch, the call goes to model.test_step Args: - model (:class:`.LightningModule`): the LightningModule currently being optimized - device_ids (list): the list of GPU ids + model: the LightningModule currently being optimized + device_ids: the list of GPU ids Return: DDP wrapped model @@ -765,7 +783,7 @@ def configure_ddp(self, model, device_ids): ) return model - def init_ddp_connection(self, proc_rank, world_size): + def init_ddp_connection(self, proc_rank: int, world_size: int) -> None: r""" Override to define your custom way of setting up a distributed environment. @@ -773,8 +791,8 @@ def init_ddp_connection(self, proc_rank, world_size): Lightning's implementation uses env:// init by default and sets the first node as root. Args: - proc_rank (int): The current process rank within the node. - world_size (int): Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes). + proc_rank: The current process rank within the node. + world_size: Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes). Examples: .. code-block:: python @@ -843,16 +861,22 @@ def init_ddp_connection(self): os.environ['MASTER_ADDR'] = root_node dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) - def configure_apex(self, amp, model, optimizers, amp_level): + def configure_apex( + self, + amp: object, + model: 'LightningModule', + optimizers: List[Optimizer], + amp_level: str + ) -> Tuple['LightningModule', List[Optimizer]]: r""" Override to init AMP your own way Must return a model and list of optimizers Args: - amp (object): pointer to amp library object - model (:class:`.LightningModule`): pointer to current lightningModule - optimizers (list): list of optimizers passed in configure_optimizers() - amp_level (str): AMP mode chosen ('O1', 'O2', etc...) + amp: pointer to amp library object + model: pointer to current lightningModule + optimizers: list of optimizers passed in configure_optimizers() + amp_level: AMP mode chosen ('O1', 'O2', etc...) Return: Apex wrapped model and optimizers @@ -874,7 +898,9 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers - def configure_optimizers(self): + def configure_optimizers(self) -> Union[ + Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List] + ]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -942,7 +968,14 @@ def configure_optimizers(self): """ return Adam(self.parameters(), lr=1e-3) - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + second_order_closure: Optional[Callable] = None, + ) -> None: r""" Override this method to adjust the default way the Trainer calls each optimizer. @@ -950,11 +983,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_orde once per optimizer. Args: - epoch (int): Current epoch - batch_idx (int): Index of current batch - optimizer (torch.nn.Optimizer): A PyTorch optimizer - optimizer_idx (int): If you used multiple optimizers this indexes into that list - second_order_closure (int): closure for second order methods + epoch: Current epoch + batch_idx: Index of current batch + optimizer: A PyTorch optimizer + optimizer_idx: If you used multiple optimizers this indexes into that list + second_order_closure: closure for second order methods Examples: .. code-block:: python @@ -1013,7 +1046,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, # clear gradients optimizer.zero_grad() - def tbptt_split_batch(self, batch, split_size): + def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" When using truncated backpropagation through time, each batch must be split along the @@ -1021,8 +1054,8 @@ def tbptt_split_batch(self, batch, split_size): this function. Args: - batch (torch.nn.Tensor): Current batch - split_size (int): How big the split is + batch: Current batch + split_size: How big the split is Return: list of batch splits. Each split will be passed to forward_step to enable truncated @@ -1075,7 +1108,7 @@ def tbptt_split_batch(self, batch, split_size): return splits - def prepare_data(self): + def prepare_data(self) -> None: """Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once @@ -1099,9 +1132,8 @@ def prepare_data(self): clean_imagenet() cache_imagenet() """ - return None - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """Implement a PyTorch DataLoader Return: @@ -1136,7 +1168,6 @@ def train_dataloader(self): return loader """ - return None def tng_dataloader(self): # todo: remove in v1.0.0 """Implement a PyTorch DataLoader. @@ -1149,7 +1180,7 @@ def tng_dataloader(self): # todo: remove in v1.0.0 " and this method will be removed in v1.0.0", DeprecationWarning) return output - def test_dataloader(self): + def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" Return a dataloader. It will not be called every epoch unless you set @@ -1168,7 +1199,7 @@ def test_dataloader(self): No need to set yourself. Return: - PyTorch DataLoader + Single or multiple PyTorch DataLoader Example: .. code-block:: python @@ -1190,9 +1221,8 @@ def test_dataloader(self): this method. """ - return None - def val_dataloader(self): + def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" Return a dataloader. It will not be called every epoch unless you set @@ -1210,7 +1240,7 @@ def val_dataloader(self): No need to set yourself. Return: - PyTorch DataLoader + Single or multiple PyTorch DataLoader Examples: .. code-block:: python @@ -1257,7 +1287,6 @@ def val_dataloader(self): .. note:: In the case where you return multiple `val_dataloaders`, the `validation_step` will have an argument `dataset_idx` which matches the order here. """ - return None @classmethod def load_from_metrics(cls, weights_path, tags_csv, map_location=None): @@ -1361,7 +1390,7 @@ def __init__(self, hparams): return model @classmethod - def _load_model_state(cls, checkpoint): + def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule': cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters ckpt_hparams = checkpoint.get('hparams') @@ -1395,11 +1424,11 @@ def _load_model_state(cls, checkpoint): return model - def summarize(self, mode): + def summarize(self, mode: str) -> None: model_summary = ModelSummary(self, mode=mode) log.info('\n' + model_summary.__str__()) - def freeze(self): + def freeze(self) -> None: r""" Freeze all params for inference @@ -1415,7 +1444,7 @@ def freeze(self): self.eval() - def unfreeze(self): + def unfreeze(self) -> None: """Unfreeze all params for training. .. code-block:: python @@ -1429,13 +1458,13 @@ def unfreeze(self): self.train() - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" Called by lightning to restore your model. If you saved something with **on_save_checkpoint** this is your chance to restore this. Args: - checkpoint (dict): Loaded checkpoint + checkpoint: Loaded checkpoint Example: @@ -1449,14 +1478,14 @@ def on_load_checkpoint(self, checkpoint): No need for you to restore anything regarding training. """ - def on_save_checkpoint(self, checkpoint): + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" Called by lightning when saving a checkpoint to give you a chance to store anything else you might want to save Args: - checkpoint (dic): Checkpoint to be saved + checkpoint: Checkpoint to be saved Example: .. code-block:: python @@ -1471,7 +1500,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self): + def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: r""" Additional items to be displayed in the progress bar. diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 967710d11d594..926786268d90f 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -1,23 +1,25 @@ -''' +""" Generates a summary of a model's layers and dimensionality -''' +""" import gc import logging as log import os import subprocess from subprocess import PIPE +from typing import Tuple, Dict, Union, List import numpy as np import torch +from torch.nn import Module + +import pytorch_lightning as pl class ModelSummary(object): - def __init__(self, model, mode='full'): - ''' - Generates summaries of model layers and dimensions. - ''' + def __init__(self, model: 'pl.LightningModule', mode: str = 'full'): + """ Generates summaries of model layers and dimensions. """ self.model = model self.mode = mode self.in_sizes = [] @@ -31,7 +33,7 @@ def __str__(self): def __repr__(self): return self.summary.__str__() - def named_modules(self): + def named_modules(self) -> List[Tuple[str, Module]]: if self.mode == 'full': mods = self.model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) @@ -42,8 +44,8 @@ def named_modules(self): mods = [] return list(mods) - def get_variable_sizes(self): - '''Run sample input through each layer to get output sizes''' + def get_variable_sizes(self) -> None: + """ Run sample input through each layer to get output sizes """ mods = self.named_modules() in_sizes = [] out_sizes = [] @@ -98,8 +100,8 @@ def get_variable_sizes(self): self.out_sizes = out_sizes assert len(in_sizes) == len(out_sizes) - def get_layer_names(self): - '''Collect Layer Names''' + def get_layer_names(self) -> None: + """ Collect Layer Names """ mods = self.named_modules() names = [] layers = [] @@ -112,8 +114,8 @@ def get_layer_names(self): self.layer_names = names self.layer_types = layer_types - def get_parameter_sizes(self): - '''Get sizes of all parameters in `model`''' + def get_parameter_sizes(self) -> None: + """ Get sizes of all parameters in `model` """ mods = self.named_modules() sizes = [] for _, m in mods: @@ -123,8 +125,8 @@ def get_parameter_sizes(self): self.param_sizes = sizes - def get_parameter_nums(self): - '''Get number of parameters in each layer''' + def get_parameter_nums(self) -> None: + """ Get number of parameters in each layer """ param_nums = [] for mod in self.param_sizes: all_params = 0 @@ -133,12 +135,12 @@ def get_parameter_nums(self): param_nums.append(all_params) self.param_nums = param_nums - def make_summary(self): - ''' + def make_summary(self) -> None: + """ Makes a summary listing with: Layer Name, Layer Type, Input Size, Output Size, Number of Parameters - ''' + """ arrays = [['Name', self.layer_names], ['Type', self.layer_types], ['Params', list(map(get_human_readable_count, self.param_nums))]] @@ -147,9 +149,8 @@ def make_summary(self): arrays.append(['Out sizes', self.out_sizes]) self.summary = _format_summary_table(*arrays) - return - def summarize(self): + def summarize(self) -> None: self.get_layer_names() self.get_parameter_sizes() self.get_parameter_nums() @@ -159,12 +160,12 @@ def summarize(self): self.make_summary() -def _format_summary_table(*cols): - ''' +def _format_summary_table(*cols) -> str: + """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted. - ''' + """ n_rows = len(cols[0][1]) n_cols = 1 + len(cols) @@ -204,7 +205,7 @@ def _format_summary_table(*cols): return summary -def print_mem_stack(): # pragma: no cover +def print_mem_stack() -> None: # pragma: no cover for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): @@ -213,7 +214,7 @@ def print_mem_stack(): # pragma: no cover pass -def count_mem_items(): # pragma: no cover +def count_mem_items() -> Tuple[int, int]: # pragma: no cover num_params = 0 num_tensors = 0 for obj in gc.get_objects(): @@ -230,11 +231,12 @@ def count_mem_items(): # pragma: no cover return num_params, num_tensors -def get_memory_profile(mode): - """ - 'all' means return memory for all gpus - 'min_max' means return memory for max and min - :param mode: +def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: + """ Get a profile of the current memory usage. + + :param mode: There are two modes: + - 'all' means return memory for all gpus + - 'min_max' means return memory for max and min :return: """ memory_map = get_gpu_memory_map() @@ -248,14 +250,12 @@ def get_memory_profile(mode): return memory_map -def get_gpu_memory_map(): +def get_gpu_memory_map() -> Dict[str, int]: """Get the current gpu usage. - Returns - ------- - usage: dict - Keys are device ids as integers. - Values are memory usage as integers in MB. + Return: + A dictionary in which the keys are device ids as integers and + values are memory usage as integers in MB. """ result = subprocess.run( [ @@ -273,7 +273,7 @@ def get_gpu_memory_map(): return gpu_memory_map -def get_human_readable_count(number): +def get_human_readable_count(number: int) -> str: """ Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 9668f7c4e75d7..5695d0189f1f8 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -2,11 +2,12 @@ import logging as log import os from argparse import Namespace +from typing import Union, Dict, Any class ModelIO(object): - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ Do something with the checkpoint Gives model a chance to load something before state_dict is restored @@ -14,7 +15,7 @@ def on_load_checkpoint(self, checkpoint): :return: """ - def on_save_checkpoint(self, checkpoint): + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ Give the model a chance to add something to the checkpoint. state_dict is already there @@ -23,20 +24,18 @@ def on_save_checkpoint(self, checkpoint): # ------------------------- # OPTIONAL HOOKS # ------------------------- - def on_hpc_save(self, checkpoint): + def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None: """ Hook to do whatever you need right before Slurm manager saves the model - :return: """ - def on_hpc_load(self, checkpoint): + def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ Hook to do whatever you need right before Slurm manager loads the model - :return: """ -def load_hparams_from_tags_csv(tags_csv) -> Namespace: +def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.') return Namespace() @@ -48,7 +47,7 @@ def load_hparams_from_tags_csv(tags_csv) -> Namespace: return ns -def convert(val): +def convert(val: str) -> Union[int, float, bool, str]: constructors = [int, float, str] if isinstance(val, str):