@@ -37,7 +37,7 @@ def symbols_to_logits(_):
3737      # Just return random logits 
3838      return  tf .random_uniform ((batch_size  *  beam_size , vocab_size ))
3939
40-     final_ids , final_probs  =  beam_search .beam_search (
40+     final_ids , final_probs ,  _  =  beam_search .beam_search (
4141        symbols_to_logits , initial_ids , beam_size , decode_length , vocab_size ,
4242        0. )
4343
@@ -113,7 +113,7 @@ def symbols_to_logits(ids):
113113      logits  =  tf .to_float (tf .log (probabilities [pos  -  1 , :]))
114114      return  logits 
115115
116-     final_ids , final_probs  =  beam_search .beam_search (
116+     final_ids , final_probs ,  _  =  beam_search .beam_search (
117117        symbols_to_logits ,
118118        initial_ids ,
119119        beam_size ,
@@ -144,7 +144,7 @@ def symbols_to_logits(ids):
144144      logits  =  tf .to_float (tf .log (probabilities [pos  -  1 , :]))
145145      return  logits 
146146
147-     final_ids , final_probs  =  beam_search .beam_search (
147+     final_ids , final_probs ,  _  =  beam_search .beam_search (
148148        symbols_to_logits ,
149149        initial_ids ,
150150        beam_size ,
@@ -173,7 +173,7 @@ def symbols_to_logits(ids):
173173      logits  =  tf .to_float (tf .log (probabilities [pos  -  1 , :]))
174174      return  logits 
175175
176-     final_ids , final_probs  =  beam_search .beam_search (
176+     final_ids , final_probs ,  _  =  beam_search .beam_search (
177177        symbols_to_logits ,
178178        initial_ids ,
179179        beam_size ,
@@ -213,7 +213,7 @@ def symbols_to_logits(ids):
213213      logits  =  tf .to_float (tf .log (probabilities [pos  -  1 , :]))
214214      return  logits 
215215
216-     final_ids , final_scores  =  beam_search .beam_search (
216+     final_ids , final_scores ,  _  =  beam_search .beam_search (
217217        symbols_to_logits ,
218218        initial_ids ,
219219        beam_size ,
@@ -256,7 +256,7 @@ def symbols_to_logits(ids):
256256      return  logits 
257257
258258    # Disable early stopping 
259-     final_ids , final_scores  =  beam_search .beam_search (
259+     final_ids , final_scores ,  _  =  beam_search .beam_search (
260260        symbols_to_logits ,
261261        initial_ids ,
262262        beam_size ,
@@ -302,7 +302,7 @@ def symbols_to_logits(ids, _, states):
302302    states ["state" ] =  tf .placeholder_with_default (
303303        states ["state" ], shape = (None , 1 ))
304304
305-     final_ids , _  =  beam_search .beam_search (
305+     final_ids , _ ,  _  =  beam_search .beam_search (
306306        symbols_to_logits ,
307307        initial_ids ,
308308        beam_size ,
@@ -319,6 +319,41 @@ def symbols_to_logits(ids, _, states):
319319      except  tf .errors .InvalidArgumentError  as  e :
320320        raise  AssertionError (e .message )
321321
322+   def  testStatesAfterLoop (self ):
323+     batch_size  =  1 
324+     beam_size  =  1 
325+     vocab_size  =  2 
326+     decode_length  =  3 
327+ 
328+     initial_ids  =  tf .constant ([0 ] *  batch_size )  # GO 
329+     probabilities  =  tf .constant ([[[0.7 , 0.3 ]], [[0.4 , 0.6 ]], [[0.5 , 0.5 ]]])
330+ 
331+     def  symbols_to_logits (ids , _ , states ):
332+       pos  =  tf .shape (ids )[1 ] -  1 
333+       logits  =  tf .to_float (tf .log (probabilities [pos , :]))
334+       states ["state" ] +=  1 
335+       return  logits , states 
336+ 
337+     states  =  {
338+         "state" : tf .zeros ((batch_size , 1 )),
339+     }
340+     states ["state" ] =  tf .placeholder_with_default (
341+         states ["state" ], shape = (None , 1 ))
342+ 
343+     _ , _ , final_states  =  beam_search .beam_search (
344+         symbols_to_logits ,
345+         initial_ids ,
346+         beam_size ,
347+         decode_length ,
348+         vocab_size ,
349+         0.0 ,
350+         eos_id = 1 ,
351+         states = states )
352+     
353+     with  self .test_session () as  sess :
354+       final_states  =  sess .run (final_states )
355+     self .assertAllEqual ([[1 ]], final_states ["state" ])
356+ 
322357  def  testStateBeamTwo (self ):
323358    batch_size  =  1 
324359    beam_size  =  2 
@@ -352,7 +387,7 @@ def symbols_to_logits(ids, _, states):
352387    states ["state" ] =  tf .placeholder_with_default (
353388        states ["state" ], shape = (None , 1 ))
354389
355-     final_ids , _  =  beam_search .beam_search (
390+     final_ids , _ ,  _  =  beam_search .beam_search (
356391        symbols_to_logits ,
357392        initial_ids ,
358393        beam_size ,
0 commit comments