@@ -166,17 +166,17 @@ def add_positional_embedding_nd(x, max_length, name):
166166
167167
168168def  embedding_to_padding (emb ):
169-   """Input  embeddings -> is_padding . 
169+   """Calculates the padding mask based on which  embeddings are all zero . 
170170
171171  We have hacked symbol_modality to return all-zero embeddings for padding. 
172172
173173  Args: 
174174    emb: a Tensor with shape [..., depth]. 
175175  Returns: 
176-     a boolean  Tensor with shape [...]. 
176+     a float  Tensor with shape [...]. 
177177  """ 
178178  emb_sum  =  tf .reduce_sum (tf .abs (emb ), axis = - 1 )
179-   return  tf .equal (emb_sum , 0.0 )
179+   return  tf .to_float ( tf . equal (emb_sum , 0.0 ) )
180180
181181
182182def  attention_bias_lower_triangle (length ):
@@ -197,13 +197,13 @@ def attention_bias_ignore_padding(memory_padding):
197197  """Create an bias tensor to be added to attention logits. 
198198
199199  Args: 
200-     memory_padding: a boolean  `Tensor` with shape [batch, memory_length]. 
200+     memory_padding: a float  `Tensor` with shape [batch, memory_length]. 
201201
202202  Returns: 
203203    a `Tensor` with shape [batch, 1, 1, memory_length]. 
204204  """ 
205-   ret  =  tf . to_float ( memory_padding )  *  - 1e9 
206-   return  tf .expand_dims (tf .expand_dims (ret , 1 ), 1 )
205+   ret  =  memory_padding  *  - 1e9 
206+   return  tf .expand_dims (tf .expand_dims (ret , axis = 1 ), axis = 1 )
207207
208208
209209def  attention_bias_proximal (length ):
@@ -523,8 +523,7 @@ def pad_l_and_r(x, pad_length):
523523    # [batch, heads, blocks, block_length, dim] 
524524    k_new  =  tf .transpose (k_new , [2 , 3 , 0 , 1 , 4 ])
525525
526-     attention_bias  =  tf .expand_dims (
527-         tf .to_float (embedding_to_padding (k_new )) *  - 1e9 , axis = - 2 )
526+     attention_bias  =  tf .expand_dims (embedding_to_padding (k_new ) *  - 1e9 , axis = - 2 )
528527
529528    v_t  =  tf .transpose (v , [2 , 0 , 1 , 3 ])
530529    v_new  =  tf .gather (v_t , gather_indices )
0 commit comments