Skip to content

Commit

Permalink
Handle no eval loss (#121)
Browse files Browse the repository at this point in the history
* Handle no eval loss

* Assert len(data_loader) > 0
  • Loading branch information
erogol authored Aug 14, 2023
1 parent 33bd187 commit 50f1d8a
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,10 @@ def _get_loader(
loader = model.get_data_loader(
config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus
)

assert (
len(loader) > 0
), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. "
return loader

def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader:
Expand Down Expand Up @@ -1210,11 +1214,9 @@ def optimize(
)

# skip the rest if not outputs from the model
if not outputs:
if loss_dict:
raise RuntimeError(" [!] Model must return outputs when losses are computed.")
if not loss_dict:
step_time = time.time() - step_start_time
return None, {}, step_time
return outputs, {}, step_time

grad_clip = self._set_grad_clip_per_optimizer(config=config, optimizer_idx=optimizer_idx)
# optimizer step
Expand Down Expand Up @@ -1758,9 +1760,9 @@ def _fit(self) -> None:
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
self.c_logger.print_epoch_end(self.epochs_done, self.keep_avg_eval.avg_values)
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
self.test_run()

self.c_logger.print_epoch_end(
epoch,
self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values,
Expand Down Expand Up @@ -1882,12 +1884,14 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""

# set the target loss to choose the best model
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
eval_loss = None
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
train_loss = self._pick_target_avg_loss(self.keep_avg_train)

# save the model and update the best_loss
self.best_loss = save_best_model(
target_loss_dict,
train_loss if eval_loss is None else eval_loss,
self.best_loss,
self.config,
self.model,
Expand All @@ -1904,7 +1908,11 @@ def save_best_model(self) -> None:
@rank_zero_only
def save_checkpoint(self) -> None:
"""Save the current model checkpoint."""
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
eval_loss = None
if self.keep_avg_eval and len(self.keep_avg_eval.avg_values.keys()) > 0:
eval_loss = self._pick_target_avg_loss(self.keep_avg_eval)
train_loss = self._pick_target_avg_loss(self.keep_avg_train)

save_checkpoint(
self.config,
self.model,
Expand All @@ -1913,7 +1921,7 @@ def save_checkpoint(self) -> None:
self.total_steps_done,
self.epochs_done,
self.output_path,
model_loss=target_avg_loss,
model_loss={"train_loss": train_loss, "eval_loss": eval_loss},
save_n_checkpoints=self.config.save_n_checkpoints,
save_func=self.dashboard_logger.save_model,
)
Expand Down Expand Up @@ -2094,7 +2102,6 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict:
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None

# return if target loss defined in the model config
# if not available in Dict use loss_1 as by default loss
if "target_loss" in self.config and self.config.target_loss:
Expand All @@ -2115,7 +2122,7 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
target_avg_loss = keep_avg_target.avg_values.get("avg_loss", 0)
return target_avg_loss

def _setup_logger_config(self, log_file: str) -> None:
Expand Down

0 comments on commit 50f1d8a

Please sign in to comment.