@@ -217,6 +217,170 @@ def transformer_encoder(encoder_input,
217217 return common_layers .layer_preprocess (x , hparams )
218218
219219
220+ def evolved_transformer_encoder (encoder_input ,
221+ encoder_self_attention_bias ,
222+ hparams ,
223+ name = "encoder" ,
224+ nonpadding = None ,
225+ save_weights_to = None ,
226+ make_image_summary = True ,
227+ losses = None ,
228+ attn_bias_for_padding = None ):
229+ """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.
230+
231+ Note: Pad remover is not supported.
232+
233+ Args:
234+ encoder_input: a Tensor.
235+ encoder_self_attention_bias: bias Tensor for self-attention (see
236+ common_attention.attention_bias()).
237+ hparams: hyperparameters for model.
238+ name: a string.
239+ nonpadding: optional Tensor with shape [batch_size, encoder_length]
240+ indicating what positions are not padding. This must either be passed in,
241+ which we do for "packed" datasets, or inferred from
242+ encoder_self_attention_bias. The knowledge about padding is used for
243+ pad_remover(efficiency) and to mask out padding in convolutional layers.
244+ save_weights_to: an optional dictionary to capture attention weights for
245+ visualization; the weights tensor will be appended there under a string
246+ key created from the variable scope (including name).
247+ make_image_summary: Whether to make an attention image summary.
248+ losses: Not used.
249+ attn_bias_for_padding: Padded attention bias in case a unidirectional
250+ encoder is being used where future attention is masked.
251+
252+ Returns:
253+ Tensor encoder output.
254+ """
255+ del losses
256+
257+ hidden_state = encoder_input
258+ attention_dropout_broadcast_dims = (
259+ common_layers .comma_separated_string_to_integer_list (
260+ getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
261+
262+ with tf .variable_scope (name ):
263+ if nonpadding is not None :
264+ padding = 1.0 - nonpadding
265+ else :
266+ attention_bias = encoder_self_attention_bias
267+ if attn_bias_for_padding is not None :
268+ attention_bias = attn_bias_for_padding
269+ padding = common_attention .attention_bias_to_padding (attention_bias )
270+ nonpadding = 1.0 - padding
271+
272+ for layer in range (hparams .num_encoder_layers or hparams .num_hidden_layers ):
273+ with tf .variable_scope ("layer_%d" % layer ):
274+
275+ with tf .variable_scope ("gated_linear_unit" ):
276+
277+ residual_state = hidden_state
278+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
279+
280+ values = tf .layers .dense (hidden_state , hparams .hidden_size )
281+ gates = tf .layers .dense (
282+ hidden_state , hparams .hidden_size , activation = tf .nn .sigmoid )
283+ hidden_state = values * gates
284+
285+ hidden_state = common_layers .layer_postprocess (
286+ residual_state , hidden_state , hparams )
287+
288+ with tf .variable_scope ("conv_branches" ):
289+
290+ residual_state = hidden_state
291+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
292+ # Mask padding from conv layers.
293+ mask = tf .tile (
294+ tf .expand_dims (nonpadding , 2 ), [1 , 1 , hparams .hidden_size ])
295+ hidden_state *= mask
296+
297+ left_output_dim = int (hparams .hidden_size * 4 )
298+ left_state = tf .layers .dense (
299+ hidden_state , left_output_dim , activation = tf .nn .relu )
300+ left_state = tf .nn .dropout (left_state ,
301+ 1 - hparams .layer_prepostprocess_dropout )
302+
303+ right_output_dim = int (hparams .hidden_size / 2 )
304+ right_state = tf .layers .conv1d (
305+ hidden_state ,
306+ right_output_dim ,
307+ 3 ,
308+ padding = "SAME" ,
309+ name = "standard_conv_3x1" ,
310+ activation = tf .nn .relu )
311+ right_state = tf .nn .dropout (right_state ,
312+ 1 - hparams .layer_prepostprocess_dropout )
313+
314+ right_state = tf .pad (
315+ right_state ,
316+ [[0 , 0 ], [0 , 0 ], [0 , left_output_dim - right_output_dim ]],
317+ constant_values = 0 )
318+ hidden_state = left_state + right_state
319+
320+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
321+ # Mask padding from conv layer.
322+ mask = tf .tile (tf .expand_dims (nonpadding , 2 ), [1 , 1 , left_output_dim ])
323+ hidden_state *= mask
324+
325+ separable_conv_9x1 = tf .layers .SeparableConv1D (
326+ right_output_dim , 9 , padding = "SAME" , name = "separable_conv_9x1" )
327+ hidden_state = separable_conv_9x1 .apply (hidden_state )
328+ hidden_state = tf .pad (
329+ hidden_state ,
330+ [[0 , 0 ], [0 , 0 ], [0 , hparams .hidden_size - right_output_dim ]],
331+ constant_values = 0 )
332+
333+ hidden_state = common_layers .layer_postprocess (
334+ residual_state , hidden_state , hparams )
335+
336+ with tf .variable_scope ("self_attention" ):
337+ residual_state = hidden_state
338+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
339+
340+ hidden_state = common_attention .multihead_attention (
341+ hidden_state ,
342+ None ,
343+ encoder_self_attention_bias ,
344+ hparams .attention_key_channels or hparams .hidden_size ,
345+ hparams .attention_value_channels or hparams .hidden_size ,
346+ hparams .hidden_size ,
347+ hparams .num_heads ,
348+ hparams .attention_dropout ,
349+ attention_type = hparams .self_attention_type ,
350+ max_relative_position = hparams .max_relative_position ,
351+ heads_share_relative_embedding = (
352+ hparams .heads_share_relative_embedding ),
353+ add_relative_to_values = hparams .add_relative_to_values ,
354+ save_weights_to = save_weights_to ,
355+ make_image_summary = make_image_summary ,
356+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
357+ max_length = hparams .get ("max_length" ),
358+ vars_3d = hparams .get ("attention_variables_3d" ),
359+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
360+ weight_dtype = hparams .get ("weight_dtype" , "float32" ))
361+
362+ hidden_state = common_layers .layer_postprocess (
363+ residual_state , hidden_state , hparams )
364+
365+ with tf .variable_scope ("dense_layers" ):
366+ residual_state = hidden_state
367+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
368+
369+ hidden_state = tf .layers .dense (
370+ hidden_state , int (hparams .hidden_size * 4 ), activation = tf .nn .relu )
371+ hidden_state = tf .nn .dropout (hidden_state ,
372+ 1 - hparams .layer_prepostprocess_dropout )
373+
374+ hidden_state = tf .layers .dense (hidden_state , hparams .hidden_size )
375+ hidden_state = common_layers .layer_postprocess (
376+ residual_state , hidden_state , hparams )
377+
378+ # If normalization is done in layer_preprocess, then it should also be done
379+ # on the output, since the output can grow very large, being the sum of
380+ # a whole stack of unnormalized layer outputs.
381+ return common_layers .layer_preprocess (hidden_state , hparams )
382+
383+
220384def transformer_ffn_layer (x ,
221385 hparams ,
222386 pad_remover = None ,
0 commit comments