@@ -407,18 +407,13 @@ def input_fn(params):
407407 return dataset
408408 else :
409409 def input_fn ():
410- if has_input :
411- input_gen = _decode_batch_input_fn (
412- num_decode_batches , sorted_inputs ,
413- inputs_vocab , decode_hp .batch_size ,
414- decode_hp .max_input_size , task_id = decode_hp .multiproblem_task_id )
415- else :
416- input_gen = _decode_batch_input_fn_no_padding (sorted_inputs = sorted_inputs ,max_batch_size = decode_hp .batch_size ,
417- vocabulary = inputs_vocab ,max_input_size = decode_hp .max_input_size ,
418- decode_hp = decode_hp )
419- gen_fn = make_input_fn_from_generator (input_gen )
420- example = gen_fn ()
421- return _decode_input_tensor_to_features_dict (example , hparams )
410+ input_gen = _decode_batch_input_fn (
411+ num_decode_batches , sorted_inputs ,
412+ inputs_vocab , decode_hp .batch_size ,
413+ decode_hp .max_input_size , task_id = decode_hp .multiproblem_task_id )
414+ gen_fn = make_input_fn_from_generator (input_gen )
415+ example = gen_fn ()
416+ return _decode_input_tensor_to_features_dict (example , hparams )
422417 decodes = []
423418 result_iter = estimator .predict (input_fn , checkpoint_path = checkpoint_path )
424419
@@ -648,73 +643,6 @@ def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
648643 "inputs" : np .array (final_batch_inputs ).astype (np .int32 ),
649644 }
650645
651- def _decode_batch_input_fn_no_padding (sorted_inputs , max_batch_size , vocabulary , max_input_size , decode_hp ):
652- """Generator to produce batches of same length inputs (batch size will be variable)."""
653-
654- # First reverse all the input sentences so that if you're going to get OOMs,
655- # you'll see it in the first batch
656- sorted_inputs .reverse ()
657-
658- #Get variable batch sizes
659- last_batch_length = None
660- batch_lengths , batch_indicies = [],[]
661- for batch_index ,elm in enumerate (sorted_inputs ):
662- #Exclude whitespace and empty strings from batch length.
663- this_batch_length = len (elm .split (' ' ))
664- if max_input_size > 0 :
665- if this_batch_length > max_input_size :
666- this_batch_length = max_input_size
667- if this_batch_length != last_batch_length :
668- batch_lengths .append (this_batch_length )
669- batch_indicies .append (batch_index )
670- last_batch_length = this_batch_length
671- batch_indicies .append (len (sorted_inputs ))
672-
673- #Ensure no batches exceed the maximum batch_size
674- batch_sizes = np .diff (batch_indicies )
675- final_batch_sizes = []
676- final_batch_lengths = []
677- for ii ,bs in enumerate (batch_sizes ):
678- if bs < max_batch_size :
679- final_batch_sizes .append (bs )
680- final_batch_lengths .append (batch_lengths [ii ])
681- else :
682- full_batches = bs // max_batch_size
683- partial_batch = bs % max_batch_size
684- for _ in range (full_batches ):
685- final_batch_sizes .append (max_batch_size )
686- final_batch_lengths .append (batch_lengths [ii ])
687- if partial_batch > 0 :
688- final_batch_sizes .append (partial_batch )
689- final_batch_lengths .append (batch_lengths [ii ])
690-
691- #Continue with now variable batch sizes, no need for padding.
692- last_index = 0
693- for b ,batch_size in enumerate (final_batch_sizes ):
694- tf .logging .info ("Decoding batch %d" % b )
695- # Batch length should be the same for the entire batch -- Add one additional term for <EOS> token insertion (opt)
696- batch_length = min (max_input_size ,final_batch_lengths [b ]) + 1
697- batch_inputs = []
698- for inputs in sorted_inputs [last_index :last_index + batch_size ]:
699- input_ids = vocabulary .encode (inputs )
700- if max_input_size > 0 :
701- #For language modeling problems, more recent inputs are often more important.
702- input_ids = input_ids [- max_input_size :]
703- #Padding and <EOS> removed -- for language modeling problems.
704- batch_inputs .append (input_ids )
705- last_index += batch_size
706-
707- final_batch_inputs = []
708- #Ensure consistent batch size
709- for in_ids in batch_inputs :
710- assert len (in_ids ) == batch_length
711- x = in_ids
712- final_batch_inputs .append (x )
713-
714- yield {
715- "inputs" : np .array (final_batch_inputs ).astype (np .int32 ),
716- }
717-
718646
719647def _interactive_input_fn (hparams , decode_hp ):
720648 """Generator that reads from the terminal and yields "interactive inputs".
0 commit comments