Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 32440ab

Browse files
committed
Using partial targets at inference time.
1 parent 20a0116 commit 32440ab

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensor2tensor/models/transformer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,10 @@ def _fast_decode(self,
699699
features=features)
700700
encoder_output = encoder_output[0]
701701
encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
702-
partial_targets = None
702+
if 'partial_targets' in features:
703+
partial_targets = features['partial_targets']
704+
else:
705+
partial_targets = None
703706
else:
704707
# The problem has no inputs.
705708
encoder_output = None
@@ -712,6 +715,8 @@ def _fast_decode(self,
712715
if partial_targets is None:
713716
partial_targets = features["targets"]
714717
assert partial_targets is not None
718+
719+
if partial_targets is not None:
715720
partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
716721
partial_targets = tf.to_int64(partial_targets)
717722
partial_targets_shape = common_layers.shape_list(partial_targets)

0 commit comments

Comments
 (0)