Skip to content

Commit

Permalink
Allow running test data after each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed May 11, 2024
1 parent 9e6afe5 commit 4e6d566
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
5 changes: 5 additions & 0 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ class TrainingConfig(Config):
validation_dataset: Optional[Dataset] = None
"""The validation dataset. See the docstring of `training_dataset` for more details."""

test_dataset: Optional[Dataset] = None
"""
The test dataset. It has no influence on training, just providing a way to check test performance after each epoch.
"""

batch_size_per_process: int = 64
"""
The batch size per process. With this value denoted by `n_bspp`, the trainer behaves as the following:
Expand Down
54 changes: 54 additions & 0 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs):
self['loss'] = []
self['val_loss'] = []
self['best_val_loss'] = np.inf
self['test_loss'] = []
self['lrs'] = []
self['epoch_best_val_loss'] = 0
self.current_epoch = 0
Expand All @@ -80,8 +81,10 @@ def __init__(self, pred_names=('cs', 'eg', 'sg'), *args, **kwargs):
self['loss_{}'.format(pred_name)] = []
self['val_loss_{}'.format(pred_name)] = []
self['best_val_loss_{}'.format(pred_name)] = np.inf
self['test_loss_{}'.format(pred_name)] = []
self['train_acc_{}'.format(pred_name)] = []
self['val_acc_{}'.format(pred_name)] = []
self['test_acc_{}'.format(pred_name)] = []
self['classification_preds_{}'.format(pred_name)] = []
self['classification_labels_{}'.format(pred_name)] = []

Expand Down Expand Up @@ -264,6 +267,7 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces
self.dataset = self.configs.dataset
self.training_dataset = None
self.validation_dataset = None
self.test_dataset = None
self.validation_ratio = self.configs.validation_ratio
self.model = None
self.model_params = None
Expand All @@ -272,6 +276,8 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces
self.training_dataloader = None
self.validation_sampler = None
self.validation_dataloader = None
self.test_sampler = None
self.test_dataloader = None
self.num_local_devices = self.get_num_local_devices()
self.num_processes = num_processes
self.rank = rank
Expand Down Expand Up @@ -460,6 +466,19 @@ def build_dataloaders(self):
sampler=self.validation_sampler,
collate_fn=lambda x: x
)
if self.configs.test_dataset is not None:
self.test_sampler = torch.utils.data.distributed.DistributedSampler(
self.test_dataset,
num_replicas=self.num_processes,
rank=self.rank,
drop_last=False
)
self.test_dataloader = DistributedDataLoader(
self.test_dataset,
batch_size=self.configs.batch_size_per_process,
sampler=self.test_sampler,
collate_fn=lambda x: x
)
else:
# ALCF documentation mentions that there is a bug in Pytorch's multithreaded data loaders with
# distributed training across multiple nodes. Therefore, `num_workers` is set to 0. See also:
Expand All @@ -474,6 +493,11 @@ def build_dataloaders(self):
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)
self.test_dataloader = DataLoader(self.test_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)

def run_training(self):
for self.current_epoch in range(self.current_epoch, self.num_epochs):
Expand All @@ -487,6 +511,9 @@ def run_training(self):
self.model.eval()
self.run_validation()

if self.test_dataset is not None:
self.run_test()

if self.verbose and self.rank == 0:
self.loss_tracker.print_losses()
self.write_training_info()
Expand Down Expand Up @@ -694,6 +721,31 @@ def run_validation(self):
if self.configs.post_validation_epoch_hook is not None:
self.configs.post_validation_epoch_hook()

def run_test(self):
losses = self.get_epoch_loss_buffer()
n_batches = 0
if self.configs.task_type == 'classification':
self.loss_tracker.clear_classification_results_and_labels()
for j, data_and_labels in enumerate(self.test_dataloader):
losses, _, preds, labels = self.load_data_and_get_loss(data_and_labels, losses)
if self.configs.task_type == 'classification':
self.loss_tracker.update_classification_results_and_labels(preds, labels)
n_batches += 1
if n_batches == 0:
logging.warning('Test set might be too small that at least 1 rank did not get any test data.')
n_batches = np.max([n_batches, 1])

losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='test_loss')

if self.configs.task_type == 'classification':
self.loss_tracker.sync_classification_preds_and_labels_across_ranks()
acc_dict = self.loss_tracker.calculate_classification_accuracy()
self.loss_tracker.update_accuracy_history(acc_dict, 'test')

if self.configs.post_validation_epoch_hook is not None:
self.configs.post_validation_epoch_hook()

def run_model_update_step(self, loss_node):
self.optimizer.zero_grad()
self.grad_scaler.scale(loss_node).backward()
Expand All @@ -719,6 +771,8 @@ def build_split_datasets(self):
logging.info('Training set size = {}; validation set size = {}.'.format(
len(self.training_dataset), len(self.validation_dataset))
)
if self.configs.test_dataset is not None:
self.test_dataset = self.configs.test_dataset

def build_optimizer(self):
if self.configs.pretrained_model_path is not None and self.configs.load_pretrained_encoder_only:
Expand Down

0 comments on commit 4e6d566

Please sign in to comment.