@@ -438,21 +438,14 @@ def update_hparams_for_universal_transformer(hparams):
438438
439439@registry .register_hparams
440440def universal_transformer_base ():
441- hparams = transformer .transformer_base ()
442- # To have a similar capacity to the transformer_base with 6 layers,
443- # we need to increase the size of the UT's layer
444- # since, in fact, UT has a single layer repeating multiple times.
445- hparams .hidden_size = 1024
446- hparams .filter_size = 4096
447- hparams .num_heads = 16
448- hparams .layer_prepostprocess_dropout = 0.3
441+ hparams = transformer .transformer_big ()
449442 hparams = update_hparams_for_universal_transformer (hparams )
450443 return hparams
451444
452445
453446@registry .register_hparams
454447def universal_transformer_base_tpu ():
455- hparams = universal_transformer_base ()
448+ hparams = transformer . transformer_big ()
456449 hparams = update_hparams_for_universal_transformer (hparams )
457450 transformer .update_hparams_for_tpu (hparams )
458451 hparams .add_step_timing_signal = False
@@ -461,7 +454,7 @@ def universal_transformer_base_tpu():
461454
462455@registry .register_hparams
463456def universal_transformer_big ():
464- hparams = universal_transformer_base ()
457+ hparams = transformer . transformer_big ()
465458 hparams = update_hparams_for_universal_transformer (hparams )
466459 hparams .hidden_size = 2048
467460 hparams .filter_size = 8192
0 commit comments