Skip to content

Commit

Permalink
Wrap model output in list if the return is a single tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Jul 9, 2024
1 parent a8072f8 commit 2e75daf
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,9 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs):
if self.use_torch_amp:
es.enter_context(torch.autocast(device_type=self.device.type, dtype=torch.float16))
preds = self.model(*data)
# If preds is a single tensor, wrap it in a list
if isinstance(preds, torch.Tensor):
preds = [preds]
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels)
return loss_buffer, total_loss_tensor, preds, labels

Expand Down Expand Up @@ -1177,6 +1180,9 @@ def load_data_and_get_loss(self, data, loss_buffer):
# elements of data are supposed to be 2 different augmentations.
data, _ = self.process_data_loader_yield(data)
preds = self.model(*data)
# If preds is a single tensor, wrap it in a list
if isinstance(preds, torch.Tensor):
preds = [preds]
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds)
return loss_buffer, total_loss_tensor, preds, None

Expand Down Expand Up @@ -1414,6 +1420,9 @@ def load_data_and_get_loss(lightning_model, batch, batch_idx, *args, **kwargs):
n_losses = len(lightning_model.gtrainer.configs.pred_names_and_types) + 1
data, labels = lightning_model.gtrainer.process_data_loader_yield(batch)
preds = lightning_model(*data)
# If preds is a single tensor, wrap it in a list
if isinstance(preds, torch.Tensor):
preds = [preds]
_, total_loss_tensor = lightning_model.gtrainer.compute_losses([0.0] * n_losses, preds, labels)
return total_loss_tensor

Expand Down

0 comments on commit 2e75daf

Please sign in to comment.