Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clearer disable validation logic #650

Merged
merged 5 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 60 additions & 69 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,76 +266,67 @@ def evaluate(self, model, dataloaders, max_batches, test=False):

def run_evaluation(self, test=False):
# when testing make sure user defined a test step
can_run_test_step = False
if not (self.is_overriden('test_step') and self.is_overriden('test_end')):
m = '''You called .test() without defining a test step or test_end.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add ` around functions and variables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Please define and try again'''
raise MisconfigurationException(m)

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test:
can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
if not can_run_test_step:
m = '''You called .test() without defining a test step or test_end.
Please define and try again'''
raise MisconfigurationException(m)

# validate only if model has validation_step defined
# test only if test_step or validation_step are defined
run_val_step = self.is_overriden('validation_step')

if run_val_step or can_run_test_step:

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test:
dataloaders = self.get_test_dataloaders()
max_batches = self.num_test_batches
else:
# val
dataloaders = self.get_val_dataloaders()
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = 1

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# hook
model.on_post_performance_check()

# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
if not test:
self.main_progress_bar.set_postfix(**tqdm_metrics)

# close progress bar
if test:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()
dataloaders = self.get_test_dataloaders()
max_batches = self.num_test_batches
else:
# val
dataloaders = self.get_val_dataloaders()
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = 1

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# hook
model.on_post_performance_check()

# add model specific metrics
tqdm_metrics = self.training_tqdm_dict
if not test:
self.main_progress_bar.set_postfix(**tqdm_metrics)

# close progress bar
if test:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.global_step = 0
Expand Down Expand Up @@ -486,11 +487,16 @@ def run_pretrain_routine(self, model):
self.run_evaluation(test=True)
return

# check if we should run validation during training
self.disable_validation = ((self.num_val_batches == 0 or
not self.is_overriden('validation_step'))
and not self.fast_dev_run)

# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
if not self.disable_validation and self.num_sanity_val_steps > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider user _ so self._disable_validation but maybe it s not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like _ prefix is not commonly used across the code so I believe it would be more consistent to leave it as it is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, that I left it to your consideration... you introduced this variable so you should know if it is exposed (without _ ) or internal (with _ ) one...

Copy link
Contributor Author

@kuynzereb kuynzereb Dec 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand. And I just mean that I would prefer to stick to the current codebase style, where there seems to be no distinction between exposed and internal variables :)

# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
Expand Down
18 changes: 11 additions & 7 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(self):
self.num_training_batches = None
self.val_check_batch = None
self.num_val_batches = None
self.disable_validation = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

earlier you have it as True/False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but it is the definition inside TrainerTrainLoopMixin, where we have

# this is just a summary on variables used in this abstract class,
#  the proper values/initialisation should be done in child class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but be consistent in your addition, use bool or obejct/None everywhere...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... I define this only in 2 places: in TrainerTrainLoopMixin and in Trainer. Inside TrainerTrainLoopMixin all fields are set to None, and inside Trainer we assign the actual value. Maybe I don't unserstand something?

self.fast_dev_run = None
self.is_iterable_train_dataloader = None
self.main_progress_bar = None
Expand Down Expand Up @@ -294,14 +295,16 @@ def train(self):
model.current_epoch = epoch
self.current_epoch = epoch

# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = 0
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = (self.num_training_batches +
self.num_val_batches * val_checks_per_epoch)
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
Expand Down Expand Up @@ -386,7 +389,8 @@ def run_training_epoch(self):
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch)
should_check_val = (not self.disable_validation and can_check_epoch and
(is_val_check_batch or early_stop_epoch))

# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
Expand Down