Skip to content

Commit 05c8061

Browse files
MEZTech-LLCkpe
authored andcommitted
internal merge of PR tensorflow#1282
PiperOrigin-RevId: 228809213
1 parent 563a410 commit 05c8061

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tensor2tensor/utils/decoding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ def input_fn():
410410
input_gen = _decode_batch_input_fn(
411411
num_decode_batches, sorted_inputs,
412412
inputs_vocab, decode_hp.batch_size,
413-
decode_hp.max_input_size, task_id=decode_hp.multiproblem_task_id)
413+
decode_hp.max_input_size,
414+
task_id=decode_hp.multiproblem_task_id, has_input=has_input)
414415
gen_fn = make_input_fn_from_generator(input_gen)
415416
example = gen_fn()
416417
return _decode_input_tensor_to_features_dict(example, hparams)
@@ -616,7 +617,8 @@ def input_fn():
616617

617618

618619
def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
619-
batch_size, max_input_size, task_id=-1):
620+
batch_size, max_input_size,
621+
task_id=-1, has_input=True):
620622
"""Generator to produce batches of inputs."""
621623
tf.logging.info(" batch %d" % num_decode_batches)
622624
for b in range(num_decode_batches):
@@ -628,8 +630,9 @@ def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
628630
if max_input_size > 0:
629631
# Subtract 1 for the EOS_ID.
630632
input_ids = input_ids[:max_input_size - 1]
631-
final_id = text_encoder.EOS_ID if task_id < 0 else task_id
632-
input_ids.append(final_id)
633+
if has_input or task_id > -1: # Do not append EOS for pure LM tasks.
634+
final_id = text_encoder.EOS_ID if task_id < 0 else task_id
635+
input_ids.append(final_id)
633636
batch_inputs.append(input_ids)
634637
if len(input_ids) > batch_length:
635638
batch_length = len(input_ids)

0 commit comments

Comments
 (0)