Skip to content

Commit

Permalink
Add typings for evaluation_loop.py and remove some dead code (#7015)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Apr 15, 2021
1 parent 5bd3cd5 commit f645df5
Show file tree
Hide file tree
Showing 22 changed files with 114 additions and 161 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_pkg-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ jobs:
pip install dist/*.whl
cd ..
python -c "import pytorch_lightning as pl ; print(pl.__version__)"
pip uninstall -y pytorch-lightning
pip uninstall -y pytorch-lightning
2 changes: 1 addition & 1 deletion .github/workflows/events-recurrent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
echo $jobs_to_delete
if [ ${#jobs_to_delete} -gt 1 ];
then kubectl delete job $(kubectl get job | awk 'match($4,/[0-9]+[dh]/) {print $1}');
fi
fi
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def run_optimizer_step(
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(
self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0,
self,
optimizer: Optimizer,
clip_val: Union[float, int],
norm_type: float = 2.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
) -> None:
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def on_train_epoch_start(self, trainer, pl_module) -> None:
self._snap_inter_step_time = None

@rank_zero_only
def on_train_batch_start(
self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_train_batch_start(self, trainer, pl_module, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,8 @@ def sanitize_parameters_to_prune(
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]

if parameters_to_prune is None:
parameters_to_prune = [
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
]
parameters_to_prune = [(m, p) for p in parameters for m in current_modules
if getattr(m, p, None) is not None]
elif (
isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0
and all(len(p) == 2 for p in parameters_to_prune)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
_MLFLOW_AVAILABLE = _module_available("mlflow")
try:
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.tracking import context
from mlflow.tracking import context, MlflowClient
# todo: there seems to be still some remaining import error with Conda env
except ImportError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None


# before v1.1.0
if hasattr(context, 'resolve_tags'):
from mlflow.tracking.context import resolve_tags


# since v1.1.0
elif hasattr(context, 'registry'):
from mlflow.tracking.context.registry import resolve_tags
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def __init__(self) -> None:
super().__init__()
self.scaler = ShardedGradScaler()

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0) -> None:
def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def on_trainer_init(

# gradient clipping
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
raise MisconfigurationException(
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
)
raise MisconfigurationException(f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}")
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import multiprocessing
from abc import ABC
from copy import deepcopy
from typing import Iterable, List, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -41,9 +41,9 @@ class TrainerDataLoadingMixin(ABC):
train_dataloader: DataLoader
num_training_batches: Union[int, float]
val_check_batch: float
val_dataloaders: List[DataLoader]
val_dataloaders: Optional[List[DataLoader]]
num_val_batches: List[Union[int, float]]
test_dataloaders: List[DataLoader]
test_dataloaders: Optional[List[DataLoader]]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
overfit_batches: Union[int, float]
Expand Down
Loading

0 comments on commit f645df5

Please sign in to comment.