Skip to content

Commit

Permalink
ref: organize args 3/n (#3447)
Browse files Browse the repository at this point in the history
* ref: organize args 2/n

* ref: organize args 2/n

* ref: organize args 2/n

* ref: organize args 2/n
  • Loading branch information
williamFalcon authored Sep 10, 2020
1 parent deb82d9 commit 541c4ab
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 98 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def on_trainer_init(
# NVIDIA setup
self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)

self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,3 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
hvd.join()

return dataloader

def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches
109 changes: 109 additions & 0 deletions pytorch_lightning/trainer/debugging_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Union
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info


class DebuggingConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_init_start(
self,
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
):
self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)

# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct

# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check

# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check

# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check

self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.trainer.overfit_batches)

def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.trainer.limit_train_batches = overfit_batches
self.trainer.limit_val_batches = overfit_batches
self.trainer.limit_test_batches = overfit_batches


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType


class Initializer:
class PrecisionConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, precision, amp_level, amp_backend):
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.trainer.autocast_original_forward = None
self.trainer.precision = precision
self.trainer.scaler = None

self.trainer.amp_level = amp_level
self.init_amp(amp_backend)

def init_amp(self, amp_type: str):
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
self.trainer.amp_backend = None
Expand Down
112 changes: 23 additions & 89 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.initializer import Initializer
from pytorch_lightning.trainer.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
Expand Down Expand Up @@ -173,8 +174,10 @@ def __init__(
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.initializer = Initializer(self)
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)

self.tuner = Tuner(self)
self.accelerator_backend = None

Expand Down Expand Up @@ -253,15 +256,8 @@ def __init__(
# -------------------
self.weights_summary = weights_summary

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.max_steps = max_steps
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = num_sanity_val_steps
# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

Expand All @@ -275,17 +271,6 @@ def __init__(
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()

self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.num_sanity_val_steps = 0
self.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)

# configure profiler
if profiler is True:
profiler = SimpleProfiler()
Expand All @@ -300,61 +285,22 @@ def __init__(
self.log_save_interval = log_save_interval
self.row_log_interval = row_log_interval

# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct

# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check

# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check

# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check

self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.overfit_batches)

# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.precision = precision
self.scaler = None

self.amp_level = amp_level
self.initializer.init_amp(amp_backend)
# init debugging flags
self.debugging_connector.on_init_start(
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
)

self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)

# Callback system
self.on_init_end()
Expand Down Expand Up @@ -862,18 +808,6 @@ def call_hook(self, hook_name, *args, **kwargs):

return output


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)


# add docstrings
Trainer.__init__.__doc__ = docstrings.trainer.init
Trainer.fit.__doc__ = docstrings.trainer.fit
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def __init__(self, trainer):
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)

def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps

if num_sanity_val_steps == -1:
self.trainer.num_sanity_val_steps = float('inf')
else:
self.trainer.num_sanity_val_steps = num_sanity_val_steps

@property
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
Expand Down

0 comments on commit 541c4ab

Please sign in to comment.