Skip to content

Commit 89b9b77

Browse files
jackdkpe
authored andcommitted
internal merge of PR tensorflow#1401
PiperOrigin-RevId: 230778721
1 parent 2ac65f9 commit 89b9b77

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tensor2tensor/utils/optimize.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from tensor2tensor.utils import adafactor
2424
from tensor2tensor.utils import mlperf_log
2525
from tensor2tensor.utils import multistep_optimizer
26-
from tensor2tensor.utils import yellowfin
2726
from tensor2tensor.utils import registry
27+
from tensor2tensor.utils import yellowfin
2828

2929
import tensorflow as tf
3030

@@ -94,7 +94,7 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
9494
return train_op
9595

9696

97-
@registry.register_optimizer
97+
@registry.register_optimizer("adam")
9898
def adam(learning_rate, hparams):
9999
# We change the default epsilon for Adam.
100100
# Using LazyAdam as it's much faster for large vocabulary embeddings.
@@ -105,7 +105,7 @@ def adam(learning_rate, hparams):
105105
epsilon=hparams.optimizer_adam_epsilon)
106106

107107

108-
@registry.register_optimizer
108+
@registry.register_optimizer("multistep_adam")
109109
def multistep_adam(learning_rate, hparams):
110110
return multistep_optimizer.MultistepAdamOptimizer(
111111
learning_rate,
@@ -115,22 +115,22 @@ def multistep_adam(learning_rate, hparams):
115115
n=hparams.optimizer_multistep_accumulate_steps)
116116

117117

118-
@registry.register_optimizer
118+
@registry.register_optimizer("momentum")
119119
def momentum(learning_rate, hparams):
120120
return tf.train.MomentumOptimizer(
121121
learning_rate,
122122
momentum=hparams.optimizer_momentum_momentum,
123123
use_nesterov=hparams.optimizer_momentum_nesterov)
124124

125125

126-
@registry.register_optimizer
126+
@registry.register_optimizer("yellow_fin")
127127
def yellow_fin(learning_rate, hparams):
128128
return yellowfin.YellowFinOptimizer(
129129
learning_rate=learning_rate,
130130
momentum=hparams.optimizer_momentum_momentum)
131131

132132

133-
@registry.register_optimizer
133+
@registry.register_optimizer("true_adam")
134134
def true_adam(learning_rate, hparams):
135135
return tf.train.AdamOptimizer(
136136
learning_rate,
@@ -139,7 +139,7 @@ def true_adam(learning_rate, hparams):
139139
epsilon=hparams.optimizer_adam_epsilon)
140140

141141

142-
@registry.register_optimizer
142+
@registry.register_optimizer("adam_w")
143143
def adam_w(learning_rate, hparams):
144144
# Openai gpt used weight decay.
145145
# Given the internals of AdamW, weight decay dependent on the
@@ -161,13 +161,15 @@ def register_adafactor(learning_rate, hparams):
161161
return adafactor.adafactor_optimizer_from_hparams(hparams, learning_rate)
162162

163163

164+
165+
164166
def _register_base_optimizer(key, fn):
165167
registry.register_optimizer(key)(
166168
lambda learning_rate, hparams: fn(learning_rate))
167169

168170

169171
for k in tf.contrib.layers.OPTIMIZER_CLS_NAMES:
170-
if k not in registry._OPTIMIZERS:
172+
if k not in registry._OPTIMIZERS: # pylint: disable=protected-access
171173
_register_base_optimizer(k, tf.contrib.layers.OPTIMIZER_CLS_NAMES[k])
172174

173175

0 commit comments

Comments
 (0)