Skip to content

Commit 563a410

Browse files
lgeigerkpe
authored andcommitted
internal merge of PR tensorflow#1350
PiperOrigin-RevId: 228806666
1 parent 87bc13f commit 563a410

File tree

1 file changed

+7
-79
lines changed

1 file changed

+7
-79
lines changed

tensor2tensor/utils/decoding.py

Lines changed: 7 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -407,18 +407,13 @@ def input_fn(params):
407407
return dataset
408408
else:
409409
def input_fn():
410-
if has_input:
411-
input_gen = _decode_batch_input_fn(
412-
num_decode_batches, sorted_inputs,
413-
inputs_vocab, decode_hp.batch_size,
414-
decode_hp.max_input_size, task_id=decode_hp.multiproblem_task_id)
415-
else:
416-
input_gen = _decode_batch_input_fn_no_padding(sorted_inputs=sorted_inputs,max_batch_size=decode_hp.batch_size,
417-
vocabulary=inputs_vocab,max_input_size=decode_hp.max_input_size,
418-
decode_hp=decode_hp)
419-
gen_fn = make_input_fn_from_generator(input_gen)
420-
example = gen_fn()
421-
return _decode_input_tensor_to_features_dict(example, hparams)
410+
input_gen = _decode_batch_input_fn(
411+
num_decode_batches, sorted_inputs,
412+
inputs_vocab, decode_hp.batch_size,
413+
decode_hp.max_input_size, task_id=decode_hp.multiproblem_task_id)
414+
gen_fn = make_input_fn_from_generator(input_gen)
415+
example = gen_fn()
416+
return _decode_input_tensor_to_features_dict(example, hparams)
422417
decodes = []
423418
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
424419

@@ -648,73 +643,6 @@ def _decode_batch_input_fn(num_decode_batches, sorted_inputs, vocabulary,
648643
"inputs": np.array(final_batch_inputs).astype(np.int32),
649644
}
650645

651-
def _decode_batch_input_fn_no_padding(sorted_inputs, max_batch_size, vocabulary, max_input_size, decode_hp):
652-
"""Generator to produce batches of same length inputs (batch size will be variable)."""
653-
654-
# First reverse all the input sentences so that if you're going to get OOMs,
655-
# you'll see it in the first batch
656-
sorted_inputs.reverse()
657-
658-
#Get variable batch sizes
659-
last_batch_length=None
660-
batch_lengths, batch_indicies = [],[]
661-
for batch_index,elm in enumerate(sorted_inputs):
662-
#Exclude whitespace and empty strings from batch length.
663-
this_batch_length=len(elm.split(' '))
664-
if max_input_size>0:
665-
if this_batch_length>max_input_size:
666-
this_batch_length=max_input_size
667-
if this_batch_length!=last_batch_length:
668-
batch_lengths.append(this_batch_length)
669-
batch_indicies.append(batch_index)
670-
last_batch_length = this_batch_length
671-
batch_indicies.append(len(sorted_inputs))
672-
673-
#Ensure no batches exceed the maximum batch_size
674-
batch_sizes = np.diff(batch_indicies)
675-
final_batch_sizes = []
676-
final_batch_lengths = []
677-
for ii,bs in enumerate(batch_sizes):
678-
if bs<max_batch_size:
679-
final_batch_sizes.append(bs)
680-
final_batch_lengths.append(batch_lengths[ii])
681-
else:
682-
full_batches = bs//max_batch_size
683-
partial_batch= bs%max_batch_size
684-
for _ in range(full_batches):
685-
final_batch_sizes.append(max_batch_size)
686-
final_batch_lengths.append(batch_lengths[ii])
687-
if partial_batch>0:
688-
final_batch_sizes.append(partial_batch)
689-
final_batch_lengths.append(batch_lengths[ii])
690-
691-
#Continue with now variable batch sizes, no need for padding.
692-
last_index=0
693-
for b,batch_size in enumerate(final_batch_sizes):
694-
tf.logging.info("Decoding batch %d" % b)
695-
# Batch length should be the same for the entire batch -- Add one additional term for <EOS> token insertion (opt)
696-
batch_length = min(max_input_size,final_batch_lengths[b]) + 1
697-
batch_inputs = []
698-
for inputs in sorted_inputs[last_index:last_index+batch_size]:
699-
input_ids = vocabulary.encode(inputs)
700-
if max_input_size>0:
701-
#For language modeling problems, more recent inputs are often more important.
702-
input_ids = input_ids[-max_input_size:]
703-
#Padding and <EOS> removed -- for language modeling problems.
704-
batch_inputs.append(input_ids)
705-
last_index+=batch_size
706-
707-
final_batch_inputs = []
708-
#Ensure consistent batch size
709-
for in_ids in batch_inputs:
710-
assert len(in_ids) == batch_length
711-
x=in_ids
712-
final_batch_inputs.append(x)
713-
714-
yield {
715-
"inputs": np.array(final_batch_inputs).astype(np.int32),
716-
}
717-
718646

719647
def _interactive_input_fn(hparams, decode_hp):
720648
"""Generator that reads from the terminal and yields "interactive inputs".

0 commit comments

Comments
 (0)