Skip to content

Commit

Permalink
Clearer disable validation logic (#650)
Browse files Browse the repository at this point in the history
* Clearer disable validation logic

* fix for fast_dev_run

* flake8 fix

* Test check fix

* update error message
  • Loading branch information
kuynzereb authored and williamFalcon committed Jan 14, 2020
1 parent 083dd6a commit 756c70a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 77 deletions.
129 changes: 60 additions & 69 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,76 +266,67 @@ def evaluate(self, model, dataloaders, max_batches, test=False):

def run_evaluation(self, test=False):
# when testing make sure user defined a test step
can_run_test_step = False
if test and not (self.is_overriden('test_step') and self.is_overriden('test_end')):
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
Please define and try again'''
raise MisconfigurationException(m)

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test:
can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
if not can_run_test_step:
m = '''You called .test() without defining a test step or test_end.
Please define and try again'''
raise MisconfigurationException(m)

# validate only if model has validation_step defined
# test only if test_step or validation_step are defined
run_val_step = self.is_overriden('validation_step')

if run_val_step or can_run_test_step:

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test:
dataloaders = self.get_test_dataloaders()
max_batches = self.num_test_batches
else:
# val
dataloaders = self.get_val_dataloaders()
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = 1

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# hook
model.on_post_performance_check()

# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
if not test:
self.main_progress_bar.set_postfix(**tqdm_metrics)

# close progress bar
if test:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()
dataloaders = self.get_test_dataloaders()
max_batches = self.num_test_batches
else:
# val
dataloaders = self.get_val_dataloaders()
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = 1

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# hook
model.on_post_performance_check()

# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
if not test:
self.main_progress_bar.set_postfix(**tqdm_metrics)

# close progress bar
if test:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.global_step = 0
Expand Down Expand Up @@ -486,11 +487,16 @@ def run_pretrain_routine(self, model):
self.run_evaluation(test=True)
return

# check if we should run validation during training
self.disable_validation = ((self.num_val_batches == 0 or
not self.is_overriden('validation_step')) and
not self.fast_dev_run)

# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
if not self.disable_validation and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
Expand Down
18 changes: 11 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(self):
self.num_training_batches = None
self.val_check_batch = None
self.num_val_batches = None
self.disable_validation = None
self.fast_dev_run = None
self.is_iterable_train_dataloader = None
self.main_progress_bar = None
Expand Down Expand Up @@ -294,14 +295,16 @@ def train(self):
model.current_epoch = epoch
self.current_epoch = epoch

# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = 0
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = (self.num_training_batches +
self.num_val_batches * val_checks_per_epoch)
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
Expand Down Expand Up @@ -390,7 +393,8 @@ def run_training_epoch(self):
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch)
should_check_val = (not self.disable_validation and can_check_epoch and
(is_val_check_batch or early_stop_epoch))

# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
Expand Down

0 comments on commit 756c70a

Please sign in to comment.