From 412f7cca1c643778be276056a525b77e126bbe81 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 10 Sep 2020 07:38:40 -0400 Subject: [PATCH 1/4] ref: organize args 2/n --- .../accelerators/accelerator_connector.py | 24 ++++- pytorch_lightning/trainer/callback_config.py | 99 ------------------- .../trainer/callback_connector.py | 94 ++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 52 +++------- 4 files changed, 131 insertions(+), 138 deletions(-) delete mode 100644 pytorch_lightning/trainer/callback_config.py create mode 100644 pytorch_lightning/trainer/callback_connector.py diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index b23aefeb01924..206951e6085da 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -10,7 +10,29 @@ class AcceleratorConnector: def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus): + def on_trainer_init( + self, + num_processes, + tpu_cores, + distributed_backend, + auto_select_gpus, + gpus, + num_nodes, + log_gpu_memory, + sync_batchnorm, + benchmark + ): + # benchmarking + self.trainer.benchmark = benchmark + torch.backends.cudnn.benchmark = self.trainer.benchmark + + # Transfer params + self.trainer.num_nodes = num_nodes + self.trainer.log_gpu_memory = log_gpu_memory + + # sync-bn backend + self.trainer.sync_batchnorm = sync_batchnorm + self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) self.trainer.on_tpu = self.trainer.tpu_cores is not None diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py deleted file mode 100644 index 5e64f2e886bfc..0000000000000 --- a/pytorch_lightning/trainer/callback_config.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import List, Optional - -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.core.lightning import LightningModule - - -class TrainerCallbackConfigMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - callbacks: List[Callback] - default_root_dir: str - logger: LightningLoggerBase - weights_save_path: Optional[str] - ckpt_path: str - checkpoint_callback: Optional[ModelCheckpoint] - - @property - @abstractmethod - def slurm_job_id(self) -> int: - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def save_checkpoint(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" - - def configure_checkpoint_callback(self, checkpoint_callback): - if checkpoint_callback is True: - # when no val step is defined, use 'loss' otherwise 'val_loss' - train_step_only = not is_overridden('validation_step', self.get_model()) - monitor_key = 'loss' if train_step_only else 'val_loss' - checkpoint_callback = ModelCheckpoint( - filepath=None, - monitor=monitor_key - ) - elif checkpoint_callback is False: - checkpoint_callback = None - - if checkpoint_callback: - checkpoint_callback.save_function = self.save_checkpoint - - return checkpoint_callback - - def configure_early_stopping(self, early_stop_callback): - if early_stop_callback is True or None: - early_stop_callback = EarlyStopping( - monitor='val_loss', - patience=3, - strict=True, - verbose=True, - mode='min' - ) - elif not early_stop_callback: - early_stop_callback = None - else: - early_stop_callback = early_stop_callback - return early_stop_callback - - def configure_progress_bar(self, refresh_rate=1, process_position=0): - progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] - if len(progress_bars) > 1: - raise MisconfigurationException( - 'You added multiple progress bar callbacks to the Trainer, but currently only one' - ' progress bar is supported.' - ) - elif len(progress_bars) == 1: - progress_bar_callback = progress_bars[0] - elif refresh_rate > 0: - progress_bar_callback = ProgressBar( - refresh_rate=refresh_rate, - process_position=process_position, - ) - self.callbacks.append(progress_bar_callback) - else: - progress_bar_callback = None - - return progress_bar_callback diff --git a/pytorch_lightning/trainer/callback_connector.py b/pytorch_lightning/trainer/callback_connector.py new file mode 100644 index 0000000000000..7071a73678446 --- /dev/null +++ b/pytorch_lightning/trainer/callback_connector.py @@ -0,0 +1,94 @@ +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_utils import is_overridden + + +class CallbackConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init( + self, + callbacks, + early_stop_callback, + checkpoint_callback, + progress_bar_refresh_rate, + process_position, + ): + # init callbacks + self.trainer.callbacks = callbacks or [] + + # configure early stop callback + # creates a default one if none passed in + early_stop_callback = self.trainer.configure_early_stopping(early_stop_callback) + if early_stop_callback: + self.trainer.callbacks.append(early_stop_callback) + + # configure checkpoint callback + # it is important that this is the last callback to run + # pass through the required args to figure out defaults + checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) + if checkpoint_callback: + self.trainer.callbacks.append(checkpoint_callback) + + # TODO refactor codebase (tests) to not directly reach into these callbacks + self.trainer.checkpoint_callback = checkpoint_callback + self.trainer.early_stop_callback = early_stop_callback + + # init progress bar + self.trainer._progress_bar_callback = self.configure_progress_bar( + progress_bar_refresh_rate, process_position + ) + + def configure_checkpoint_callback(self, checkpoint_callback): + if checkpoint_callback is True: + # when no val step is defined, use 'loss' otherwise 'val_loss' + train_step_only = not is_overridden('validation_step', self.trainer.get_model()) + monitor_key = 'loss' if train_step_only else 'val_loss' + checkpoint_callback = ModelCheckpoint( + filepath=None, + monitor=monitor_key + ) + elif checkpoint_callback is False: + checkpoint_callback = None + + if checkpoint_callback: + checkpoint_callback.save_function = self.trainer.save_checkpoint + + return checkpoint_callback + + def configure_early_stopping(self, early_stop_callback): + if early_stop_callback is True or None: + early_stop_callback = EarlyStopping( + monitor='val_loss', + patience=3, + strict=True, + verbose=True, + mode='min' + ) + elif not early_stop_callback: + early_stop_callback = None + else: + early_stop_callback = early_stop_callback + return early_stop_callback + + def configure_progress_bar(self, refresh_rate=1, process_position=0): + progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] + if len(progress_bars) > 1: + raise MisconfigurationException( + 'You added multiple progress bar callbacks to the Trainer, but currently only one' + ' progress bar is supported.' + ) + elif len(progress_bars) == 1: + progress_bar_callback = progress_bars[0] + elif refresh_rate > 0: + progress_bar_callback = ProgressBar( + refresh_rate=refresh_rate, + process_position=process_position, + ) + self.trainer.callbacks.append(progress_bar_callback) + else: + progress_bar_callback = None + + return progress_bar_callback diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6ccdaa050a36e..ef24759ce8403 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,13 +27,11 @@ from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler -from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10 from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin -from pytorch_lightning.utilities import device_parser from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin @@ -49,6 +47,7 @@ from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector +from pytorch_lightning.trainer.callback_connector import CallbackConnector from pytorch_lightning.trainer.model_connector import ModelConnector from pytorch_lightning import _logger as log from pytorch_lightning.tuner.tuning import Tuner @@ -94,7 +93,6 @@ class Trainer( TrainerLoggingMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, - TrainerCallbackConfigMixin, TrainerDeprecatedAPITillVer0_10, ): def __init__( @@ -176,6 +174,7 @@ def __init__( self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.initializer = Initializer(self) + self.callback_connector = CallbackConnector(self) self.tuner = Tuner(self) self.accelerator_backend = None @@ -218,42 +217,17 @@ def __init__( self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir - # ------------------------------- - # CALLBACK INITS - # ------------------------------- # init callbacks - self.callbacks = callbacks or [] - - # configure early stop callback - # creates a default one if none passed in - early_stop_callback = self.configure_early_stopping(early_stop_callback) - if early_stop_callback: - self.callbacks.append(early_stop_callback) - - # configure checkpoint callback - # it is important that this is the last callback to run - # pass through the required args to figure out defaults - checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) - if checkpoint_callback: - self.callbacks.append(checkpoint_callback) - - # TODO refactor codebase (tests) to not directly reach into these callbacks - self.checkpoint_callback = checkpoint_callback - self.early_stop_callback = early_stop_callback + self.callback_connector.on_trainer_init( + callbacks, + early_stop_callback, + checkpoint_callback, + progress_bar_refresh_rate, + process_position + ) self.on_init_start() - # benchmarking - self.benchmark = benchmark - torch.backends.cudnn.benchmark = self.benchmark - - # Transfer params - self.num_nodes = num_nodes - self.log_gpu_memory = log_gpu_memory - - # sync-bn backend - self.sync_batchnorm = sync_batchnorm - self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch @@ -267,7 +241,11 @@ def __init__( tpu_cores, distributed_backend, auto_select_gpus, - gpus + gpus, + num_nodes, + log_gpu_memory, + sync_batchnorm, + benchmark ) # ------------------- @@ -317,8 +295,6 @@ def __init__( self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) - self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) - # logging self.configure_logger(logger) self.log_save_interval = log_save_interval From 874dd2743ff42b3f1dfbc56c01c2effc638edeb3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 10 Sep 2020 07:49:54 -0400 Subject: [PATCH 2/4] ref: organize args 2/n --- pytorch_lightning/trainer/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_connector.py b/pytorch_lightning/trainer/callback_connector.py index 7071a73678446..b064068e7d5f0 100644 --- a/pytorch_lightning/trainer/callback_connector.py +++ b/pytorch_lightning/trainer/callback_connector.py @@ -21,7 +21,7 @@ def on_trainer_init( # configure early stop callback # creates a default one if none passed in - early_stop_callback = self.trainer.configure_early_stopping(early_stop_callback) + early_stop_callback = self.configure_early_stopping(early_stop_callback) if early_stop_callback: self.trainer.callbacks.append(early_stop_callback) From c7f060cc21462e252e63881e88b71ad66b94cab2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 10 Sep 2020 08:29:43 -0400 Subject: [PATCH 3/4] ref: organize args 2/n --- .../accelerators/accelerator_connector.py | 2 + pytorch_lightning/trainer/data_loading.py | 7 -- .../trainer/debugging_connector.py | 109 +++++++++++++++++ ...{initializer.py => precision_connector.py} | 14 ++- pytorch_lightning/trainer/trainer.py | 111 ++++-------------- pytorch_lightning/trainer/training_loop.py | 11 ++ 6 files changed, 156 insertions(+), 98 deletions(-) create mode 100644 pytorch_lightning/trainer/debugging_connector.py rename pytorch_lightning/trainer/{initializer.py => precision_connector.py} (84%) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 206951e6085da..af4a6be4c2431 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -82,6 +82,8 @@ def on_trainer_init( # NVIDIA setup self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids) + self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 898fbc92cbd21..f7c53c1cbe8fe 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -345,10 +345,3 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: hvd.join() return dataloader - - def determine_data_use_amount(self, overfit_batches: float) -> None: - """Use less data for debugging purposes""" - if overfit_batches > 0: - self.limit_train_batches = overfit_batches - self.limit_val_batches = overfit_batches - self.limit_test_batches = overfit_batches diff --git a/pytorch_lightning/trainer/debugging_connector.py b/pytorch_lightning/trainer/debugging_connector.py new file mode 100644 index 0000000000000..309c6ae9a98b2 --- /dev/null +++ b/pytorch_lightning/trainer/debugging_connector.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from typing import Union +from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info + + +class DebuggingConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_init_start( + self, + overfit_pct, + val_percent_check, + test_percent_check, + train_percent_check, + limit_train_batches, + limit_val_batches, + limit_test_batches, + val_check_interval, + overfit_batches, + fast_dev_run + ): + self.trainer.fast_dev_run = fast_dev_run + if self.trainer.fast_dev_run: + limit_train_batches = 1 + limit_val_batches = 1 + limit_test_batches = 1 + self.trainer.num_sanity_val_steps = 0 + self.trainer.max_epochs = 1 + rank_zero_info( + 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' + ) + + # how much of the data to use + # TODO: remove in 0.10.0 + if overfit_pct is not None: + rank_zero_warn( + "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + overfit_batches = overfit_pct + + # TODO: remove in 0.10.0 + if val_percent_check is not None: + rank_zero_warn( + "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_val_batches = val_percent_check + + # TODO: remove in 0.10.0 + if test_percent_check is not None: + rank_zero_warn( + "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_test_batches = test_percent_check + + # TODO: remove in 0.10.0 + if train_percent_check is not None: + rank_zero_warn( + "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) + limit_train_batches = train_percent_check + + self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') + self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') + self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') + self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') + self.trainer.determine_data_use_amount(self.trainer.overfit_batches) + + def determine_data_use_amount(self, overfit_batches: float) -> None: + """Use less data for debugging purposes""" + if overfit_batches > 0: + self.trainer.limit_train_batches = overfit_batches + self.trainer.limit_val_batches = overfit_batches + self.trainer.limit_test_batches = overfit_batches + + +def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: + if 0 <= batches <= 1: + return batches + elif batches > 1 and batches % 1.0 == 0: + return int(batches) + else: + raise MisconfigurationException( + f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' + ) diff --git a/pytorch_lightning/trainer/initializer.py b/pytorch_lightning/trainer/precision_connector.py similarity index 84% rename from pytorch_lightning/trainer/initializer.py rename to pytorch_lightning/trainer/precision_connector.py index b2a39056e152f..55fb945caf09e 100644 --- a/pytorch_lightning/trainer/initializer.py +++ b/pytorch_lightning/trainer/precision_connector.py @@ -11,16 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pytorch_lightning import _logger as log from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType -class Initializer: +class PrecisionConnector: def __init__(self, trainer): self.trainer = trainer + def on_trainer_init(self, precision, amp_level, amp_backend): + # AMP init + # These are the only lines needed after v0.8.0 + # we wrap the user's forward with autocast and give it back at the end of fit + self.trainer.autocast_original_forward = None + self.trainer.precision = precision + self.trainer.scaler = None + + self.trainer.amp_level = amp_level + self.init_amp(amp_backend) + def init_amp(self, amp_type: str): assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' self.trainer.amp_backend = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ef24759ce8403..55adee55e9499 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -43,15 +43,16 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.training_loop import TrainLoop -from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector from pytorch_lightning.trainer.callback_connector import CallbackConnector from pytorch_lightning.trainer.model_connector import ModelConnector +from pytorch_lightning.trainer.debugging_connector import DebuggingConnector from pytorch_lightning import _logger as log from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.trainer.initializer import Initializer +from pytorch_lightning.trainer.precision_connector import PrecisionConnector +from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer.properties import TrainerProperties @@ -173,8 +174,9 @@ def __init__( self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) - self.initializer = Initializer(self) + self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) + self.debugging_connector = DebuggingConnector(self) self.tuner = Tuner(self) self.accelerator_backend = None @@ -253,15 +255,8 @@ def __init__( # ------------------- self.weights_summary = weights_summary - self.max_epochs = max_epochs - self.min_epochs = min_epochs - self.max_steps = max_steps - self.min_steps = min_steps - - if num_sanity_val_steps == -1: - self.num_sanity_val_steps = float('inf') - else: - self.num_sanity_val_steps = num_sanity_val_steps + # init train loop related flags + self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch @@ -275,17 +270,6 @@ def __init__( self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() - self.fast_dev_run = fast_dev_run - if self.fast_dev_run: - limit_train_batches = 1 - limit_val_batches = 1 - limit_test_batches = 1 - self.num_sanity_val_steps = 0 - self.max_epochs = 1 - rank_zero_info( - 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' - ) - # configure profiler if profiler is True: profiler = SimpleProfiler() @@ -300,61 +284,22 @@ def __init__( self.log_save_interval = log_save_interval self.row_log_interval = row_log_interval - # how much of the data to use - # TODO: remove in 0.10.0 - if overfit_pct is not None: - rank_zero_warn( - "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - overfit_batches = overfit_pct - - # TODO: remove in 0.10.0 - if val_percent_check is not None: - rank_zero_warn( - "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_val_batches = val_percent_check - - # TODO: remove in 0.10.0 - if test_percent_check is not None: - rank_zero_warn( - "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_test_batches = test_percent_check - - # TODO: remove in 0.10.0 - if train_percent_check is not None: - rank_zero_warn( - "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", - DeprecationWarning, - ) - limit_train_batches = train_percent_check - - self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') - self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') - self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') - self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') - self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') - self.determine_data_use_amount(self.overfit_batches) - - # AMP init - # These are the only lines needed after v0.8.0 - # we wrap the user's forward with autocast and give it back at the end of fit - self.autocast_original_forward = None - self.precision = precision - self.scaler = None - - self.amp_level = amp_level - self.initializer.init_amp(amp_backend) + # init debugging flags + self.debugging_connector.on_init_start( + overfit_pct, + val_percent_check, + test_percent_check, + train_percent_check, + limit_train_batches, + limit_val_batches, + limit_test_batches, + val_check_interval, + overfit_batches, + fast_dev_run + ) - self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + # set precision + self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # Callback system self.on_init_end() @@ -862,18 +807,6 @@ def call_hook(self, hook_name, *args, **kwargs): return output - -def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: - if 0 <= batches <= 1: - return batches - elif batches > 1 and batches % 1.0 == 0: - return int(batches) - else: - raise MisconfigurationException( - f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' - ) - - # add docstrings Trainer.__init__.__doc__ = docstrings.trainer.init Trainer.fit.__doc__ = docstrings.trainer.fit diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ef0d2074b2382..f62e91c11b833 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -38,6 +38,17 @@ def __init__(self, trainer): self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) + def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps): + self.trainer.max_epochs = max_epochs + self.trainer.min_epochs = min_epochs + self.trainer.max_steps = max_steps + self.trainer.min_steps = min_steps + + if num_sanity_val_steps == -1: + self.trainer.num_sanity_val_steps = float('inf') + else: + self.trainer.num_sanity_val_steps = num_sanity_val_steps + @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) From ada886fe7a67e33d8c4192c417243c0bcb21f01b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 10 Sep 2020 08:35:15 -0400 Subject: [PATCH 4/4] ref: organize args 2/n --- pytorch_lightning/trainer/debugging_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/debugging_connector.py b/pytorch_lightning/trainer/debugging_connector.py index 309c6ae9a98b2..49b32b0903da8 100644 --- a/pytorch_lightning/trainer/debugging_connector.py +++ b/pytorch_lightning/trainer/debugging_connector.py @@ -88,7 +88,7 @@ def on_init_start( self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') - self.trainer.determine_data_use_amount(self.trainer.overfit_batches) + self.determine_data_use_amount(self.trainer.overfit_batches) def determine_data_use_amount(self, overfit_batches: float) -> None: """Use less data for debugging purposes"""