Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 79 additions & 47 deletions tensor2tensor/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensor2tensor.utils import mlperf_log
from tensor2tensor.utils import multistep_optimizer
from tensor2tensor.utils import yellowfin
from tensor2tensor.utils import registry

import tensorflow as tf

Expand Down Expand Up @@ -93,6 +94,83 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
return train_op


@registry.register_optimizer
def adam(learning_rate, hparams):
# We change the default epsilon for Adam.
# Using LazyAdam as it's much faster for large vocabulary embeddings.
return tf.contrib.opt.LazyAdamOptimizer(
learning_rate,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)


@registry.register_optimizer
def multistep_adam(learning_rate, hparams):
return multistep_optimizer.MultistepAdamOptimizer(
learning_rate,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon,
n=hparams.optimizer_multistep_accumulate_steps)


@registry.register_optimizer
def momentum(learning_rate, hparams):
return tf.train.MomentumOptimizer(
learning_rate,
momentum=hparams.optimizer_momentum_momentum,
use_nesterov=hparams.optimizer_momentum_nesterov)


@registry.register_optimizer
def yellow_fin(learning_rate, hparams):
return yellowfin.YellowFinOptimizer(
learning_rate=learning_rate,
momentum=hparams.optimizer_momentum_momentum)


@registry.register_optimizer
def true_adam(learning_rate, hparams):
return tf.train.AdamOptimizer(
learning_rate,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)


@registry.register_optimizer
def adam_w(learning_rate, hparams):
# Openai gpt used weight decay.
# Given the internals of AdamW, weight decay dependent on the
# learning rate is chosen to match the openai implementation.
# The weight decay update to each parameter is applied before the adam
# gradients computation, which is different from that described
# in the paper and in the openai implementation:
# https://arxiv.org/pdf/1711.05101.pdf
return tf.contrib.opt.AdamWOptimizer(
0.01*learning_rate,
learning_rate,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)


@registry.register_optimizer("Adafactor")
def register_adafactor(learning_rate, hparams):
return adafactor.adafactor_optimizer_from_hparams(hparams, learning_rate)


def _register_base_optimizer(key, fn):
registry.register_optimizer(key)(
lambda learning_rate, hparams: fn(learning_rate))


for k in tf.contrib.layers.OPTIMIZER_CLS_NAMES:
if k not in registry._OPTIMIZERS:
_register_base_optimizer(k, tf.contrib.layers.OPTIMIZER_CLS_NAMES[k])


class ConditionalOptimizer(tf.train.Optimizer):
"""Conditional optimizer."""

Expand All @@ -113,53 +191,7 @@ def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disab
value=hparams.optimizer_adam_epsilon,
hparams=hparams)

if optimizer_name == "Adam":
# We change the default epsilon for Adam.
# Using LazyAdam as it's much faster for large vocabulary embeddings.
self._opt = tf.contrib.opt.LazyAdamOptimizer(
lr,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)
elif optimizer_name == "MultistepAdam":
self._opt = multistep_optimizer.MultistepAdamOptimizer(
lr,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon,
n=hparams.optimizer_multistep_accumulate_steps)
elif optimizer_name == "Momentum":
self._opt = tf.train.MomentumOptimizer(
lr,
momentum=hparams.optimizer_momentum_momentum,
use_nesterov=hparams.optimizer_momentum_nesterov)
elif optimizer_name == "YellowFin":
self._opt = yellowfin.YellowFinOptimizer(
learning_rate=lr, momentum=hparams.optimizer_momentum_momentum)
elif optimizer_name == "TrueAdam":
self._opt = tf.train.AdamOptimizer(
lr,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)
elif optimizer_name == "AdamW":
# Openai gpt used weight decay.
# Given the internals of AdamW, weight decay dependent on the
# learning rate is chosen to match the openai implementation.
# The weight decay update to each parameter is applied before the adam
# gradients computation, which is different from that described
# in the paper and in the openai implementation:
# https://arxiv.org/pdf/1711.05101.pdf
self._opt = tf.contrib.opt.AdamWOptimizer(
0.01*lr,
lr,
beta1=hparams.optimizer_adam_beta1,
beta2=hparams.optimizer_adam_beta2,
epsilon=hparams.optimizer_adam_epsilon)
elif optimizer_name == "Adafactor":
self._opt = adafactor.adafactor_optimizer_from_hparams(hparams, lr)
else:
self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr)
self._opt = registry.optimizer(optimizer_name)(lr, hparams)
if _mixed_precision_is_enabled(hparams):
if not hparams.mixed_precision_optimizer_loss_scaler:
tf.logging.warning("Using mixed precision without a loss scaler will "
Expand Down
45 changes: 45 additions & 0 deletions tensor2tensor/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,51 @@ def list_models():
return list(sorted(_MODELS))


_OPTIMIZERS = {}


def register_optimizer(name=None):
"""Register an optimizer. name defaults to upper camel case of fn name."""

def default_opt_name(opt_fn):
return misc_utils.snakecase_to_camelcase(default_name(opt_fn))

def decorator(opt_fn, registration_name):
"""Registers and returns optimizer_fn with registration_name or default."""
if registration_name is None:
registration_name = default_opt_name(opt_fn)

if registration_name in _OPTIMIZERS and not tf.executing_eagerly():
raise LookupError("Optimizer %s already registered." % registration_name)
args, varargs, keywords, _ = inspect.getargspec(opt_fn)

if len(args) != 2 or varargs is not None or keywords is not None:
raise ValueError("Optimizer registration function must take two "
"arguments: learning_rate (float) and "
"hparams (HParams).")
_OPTIMIZERS[registration_name] = opt_fn
return opt_fn

if callable(name):
opt_fn = name
registration_name = default_opt_name(opt_fn)
return decorator(opt_fn, registration_name=registration_name)

return lambda opt_fn: decorator(opt_fn, name)


def optimizer(name):
if name not in _OPTIMIZERS:
raise LookupError("Optimizer %s never registered. "
"Available optimizers:\n %s"
% (name, "\n".join(list_optimizers())))
return _OPTIMIZERS[name]


def list_optimizers():
return list(sorted(_OPTIMIZERS))


def register_hparams(name=None):
"""Register an HParams set. name defaults to function name snake-cased."""

Expand Down