@@ -410,7 +410,8 @@ def input_fn():
410410 input_gen = _decode_batch_input_fn (
411411 num_decode_batches , sorted_inputs ,
412412 inputs_vocab , decode_hp .batch_size ,
413- decode_hp .max_input_size , task_id = decode_hp .multiproblem_task_id )
413+ decode_hp .max_input_size ,
414+ task_id = decode_hp .multiproblem_task_id , has_input = has_input )
414415 gen_fn = make_input_fn_from_generator (input_gen )
415416 example = gen_fn ()
416417 return _decode_input_tensor_to_features_dict (example , hparams )
@@ -616,7 +617,8 @@ def input_fn():
616617
617618
618619def _decode_batch_input_fn (num_decode_batches , sorted_inputs , vocabulary ,
619- batch_size , max_input_size , task_id = - 1 ):
620+ batch_size , max_input_size ,
621+ task_id = - 1 , has_input = True ):
620622 """Generator to produce batches of inputs."""
621623 tf .logging .info (" batch %d" % num_decode_batches )
622624 for b in range (num_decode_batches ):
@@ -628,8 +630,9 @@ def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
628630 if max_input_size > 0 :
629631 # Subtract 1 for the EOS_ID.
630632 input_ids = input_ids [:max_input_size - 1 ]
631- final_id = text_encoder .EOS_ID if task_id < 0 else task_id
632- input_ids .append (final_id )
633+ if has_input or task_id > - 1 : # Do not append EOS for pure LM tasks.
634+ final_id = text_encoder .EOS_ID if task_id < 0 else task_id
635+ input_ids .append (final_id )
633636 batch_inputs .append (input_ids )
634637 if len (input_ids ) > batch_length :
635638 batch_length = len (input_ids )
0 commit comments