2323from tensor2tensor .utils import adafactor
2424from tensor2tensor .utils import mlperf_log
2525from tensor2tensor .utils import multistep_optimizer
26- from tensor2tensor .utils import yellowfin
2726from tensor2tensor .utils import registry
27+ from tensor2tensor .utils import yellowfin
2828
2929import tensorflow as tf
3030
@@ -94,7 +94,7 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
9494 return train_op
9595
9696
97- @registry .register_optimizer
97+ @registry .register_optimizer ( "adam" )
9898def adam (learning_rate , hparams ):
9999 # We change the default epsilon for Adam.
100100 # Using LazyAdam as it's much faster for large vocabulary embeddings.
@@ -105,7 +105,7 @@ def adam(learning_rate, hparams):
105105 epsilon = hparams .optimizer_adam_epsilon )
106106
107107
108- @registry .register_optimizer
108+ @registry .register_optimizer ( "multistep_adam" )
109109def multistep_adam (learning_rate , hparams ):
110110 return multistep_optimizer .MultistepAdamOptimizer (
111111 learning_rate ,
@@ -115,22 +115,22 @@ def multistep_adam(learning_rate, hparams):
115115 n = hparams .optimizer_multistep_accumulate_steps )
116116
117117
118- @registry .register_optimizer
118+ @registry .register_optimizer ( "momentum" )
119119def momentum (learning_rate , hparams ):
120120 return tf .train .MomentumOptimizer (
121121 learning_rate ,
122122 momentum = hparams .optimizer_momentum_momentum ,
123123 use_nesterov = hparams .optimizer_momentum_nesterov )
124124
125125
126- @registry .register_optimizer
126+ @registry .register_optimizer ( "yellow_fin" )
127127def yellow_fin (learning_rate , hparams ):
128128 return yellowfin .YellowFinOptimizer (
129129 learning_rate = learning_rate ,
130130 momentum = hparams .optimizer_momentum_momentum )
131131
132132
133- @registry .register_optimizer
133+ @registry .register_optimizer ( "true_adam" )
134134def true_adam (learning_rate , hparams ):
135135 return tf .train .AdamOptimizer (
136136 learning_rate ,
@@ -139,7 +139,7 @@ def true_adam(learning_rate, hparams):
139139 epsilon = hparams .optimizer_adam_epsilon )
140140
141141
142- @registry .register_optimizer
142+ @registry .register_optimizer ( "adam_w" )
143143def adam_w (learning_rate , hparams ):
144144 # Openai gpt used weight decay.
145145 # Given the internals of AdamW, weight decay dependent on the
@@ -161,13 +161,15 @@ def register_adafactor(learning_rate, hparams):
161161 return adafactor .adafactor_optimizer_from_hparams (hparams , learning_rate )
162162
163163
164+
165+
164166def _register_base_optimizer (key , fn ):
165167 registry .register_optimizer (key )(
166168 lambda learning_rate , hparams : fn (learning_rate ))
167169
168170
169171for k in tf .contrib .layers .OPTIMIZER_CLS_NAMES :
170- if k not in registry ._OPTIMIZERS :
172+ if k not in registry ._OPTIMIZERS : # pylint: disable=protected-access
171173 _register_base_optimizer (k , tf .contrib .layers .OPTIMIZER_CLS_NAMES [k ])
172174
173175
0 commit comments