diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index f3530b574..4f48fda56 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -338,9 +338,9 @@ def model_fn_sharded(self, sharded_features): sharded_logits[k] = dp(self.top, v, datashard_to_features) sharded_losses[k] = dp(self.loss, sharded_logits[k], datashard_to_features) - training_loss_dict = average_sharded_losses([{ + training_loss_dict = average_sharded_losses([({ "training": l - } for l in loss for loss in sharded_losses.values()]) + } for l in loss) for loss in sharded_losses.values()]) losses.update(training_loss_dict) else: sharded_logits = dp(self.top, body_out, datashard_to_features)