Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 061488b

Browse files
mbzCopybara-Service
authored andcommitted
fixing internal_loss for recurrent models with L2 loss.
PiperOrigin-RevId: 219540943
1 parent 33583af commit 061488b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensor2tensor/models/video/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ def __process(self, all_frames, all_actions, all_rewards, all_raw_frames):
445445
# using the default (a bit strange) video modality - we should change that.
446446

447447
hparams = self.hparams
448+
all_frames_copy = [tf.identity(frame) for frame in all_frames]
448449
orig_frame_shape = common_layers.shape_list(all_frames[0])
449450
batch_size = orig_frame_shape[0]
450451
ss_func = self.get_scheduled_sample_func(batch_size)
@@ -506,7 +507,7 @@ def __process(self, all_frames, all_actions, all_rewards, all_raw_frames):
506507
has_input_predictions = hparams.video_num_input_frames > 1
507508
if self.is_training and hparams.internal_loss and has_input_predictions:
508509
# add the loss for input frames as well.
509-
extra_gts = all_frames[1:hparams.video_num_input_frames]
510+
extra_gts = all_frames_copy[1:hparams.video_num_input_frames]
510511
extra_raw_gts = all_raw_frames[1:hparams.video_num_input_frames]
511512
extra_pds = res_frames[:hparams.video_num_input_frames-1]
512513
recon_loss = self.get_extra_internal_loss(

0 commit comments

Comments
 (0)