-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clearer disable validation logic #650
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,6 +213,7 @@ def __init__( | |
# training state | ||
self.model = None | ||
self.testing = False | ||
self.disable_validation = False | ||
self.lr_schedulers = [] | ||
self.optimizers = None | ||
self.global_step = 0 | ||
|
@@ -486,11 +487,16 @@ def run_pretrain_routine(self, model): | |
self.run_evaluation(test=True) | ||
return | ||
|
||
# check if we should run validation during training | ||
self.disable_validation = ((self.num_val_batches == 0 or | ||
not self.is_overriden('validation_step')) | ||
and not self.fast_dev_run) | ||
|
||
# run tiny validation (if validation defined) | ||
# to make sure program won't crash during val | ||
ref_model.on_sanity_check_start() | ||
ref_model.on_train_start() | ||
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0: | ||
if not self.disable_validation and self.num_sanity_val_steps > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider user _ so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like _ prefix is not commonly used across the code so I believe it would be more consistent to leave it as it is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know, that I left it to your consideration... you introduced this variable so you should know if it is exposed (without _ ) or internal (with _ ) one... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I understand. And I just mean that I would prefer to stick to the current codebase style, where there seems to be no distinction between exposed and internal variables :) |
||
# init progress bars for validation sanity check | ||
pbar = tqdm.tqdm(desc='Validation sanity check', | ||
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,6 +184,7 @@ def __init__(self): | |
self.num_training_batches = None | ||
self.val_check_batch = None | ||
self.num_val_batches = None | ||
self.disable_validation = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. earlier you have it as True/False There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but it is the definition inside
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, but be consistent in your addition, use bool or obejct/None everywhere... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm... I define this only in 2 places: in |
||
self.fast_dev_run = None | ||
self.is_iterable_train_dataloader = None | ||
self.main_progress_bar = None | ||
|
@@ -294,14 +295,16 @@ def train(self): | |
model.current_epoch = epoch | ||
self.current_epoch = epoch | ||
|
||
# val can be checked multiple times in epoch | ||
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 | ||
val_checks_per_epoch = self.num_training_batches // self.val_check_batch | ||
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 | ||
total_val_batches = 0 | ||
if not self.disable_validation: | ||
# val can be checked multiple times in epoch | ||
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 | ||
val_checks_per_epoch = self.num_training_batches // self.val_check_batch | ||
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 | ||
total_val_batches = self.num_val_batches * val_checks_per_epoch | ||
|
||
# total batches includes multiple val checks | ||
self.total_batches = (self.num_training_batches + | ||
self.num_val_batches * val_checks_per_epoch) | ||
self.total_batches = self.num_training_batches + total_val_batches | ||
self.batch_loss_value = 0 # accumulated grads | ||
|
||
if self.fast_dev_run: | ||
|
@@ -386,7 +389,8 @@ def run_training_epoch(self): | |
# --------------- | ||
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 | ||
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 | ||
should_check_val = ((is_val_check_batch or early_stop_epoch) and can_check_epoch) | ||
should_check_val = (not self.disable_validation and can_check_epoch and | ||
(is_val_check_batch or early_stop_epoch)) | ||
|
||
# fast_dev_run always forces val checking after train batch | ||
if self.fast_dev_run or should_check_val: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls add ` around functions and variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done