Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3669442

Browse files
T2T TeamCopybara-Service
authored andcommitted
Allow Evolved Transformer number of decoder attention heads to exceed 16.
PiperOrigin-RevId: 239634336
1 parent ba35f3d commit 3669442

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensor2tensor/models/evolved_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,15 +287,15 @@ def evolved_transformer_decoder(decoder_input,
287287
residual_state = hidden_state
288288
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
289289

290-
# 16 head attention. Hard coding number of heads.
290+
# Attention with at least 16 heads.
291291
left_state = common_attention.multihead_attention(
292292
hidden_state,
293293
None,
294294
decoder_self_attention_bias,
295295
hparams.attention_key_channels or hparams.hidden_size,
296296
hparams.attention_value_channels or hparams.hidden_size,
297297
hparams.hidden_size,
298-
16, # Heads are hard coded to replicate paper.
298+
max(16, hparams.num_heads),
299299
hparams.attention_dropout,
300300
attention_type=hparams.self_attention_type,
301301
max_relative_position=hparams.max_relative_position,

0 commit comments

Comments
 (0)