Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info

# PyPI distribution artificats
build/
dist/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.6',
version='1.0.7',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io
import os
import tarfile
import urllib

# Dependency imports

Expand Down
7 changes: 1 addition & 6 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
# Dependency imports

import numpy as np
from six.moves import cPickle
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
from six.moves import cPickle

from tensor2tensor.data_generators import generator_utils

import tensorflow as tf
Expand Down Expand Up @@ -201,10 +200,6 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
])
labels = data["labels"]
all_labels.extend([labels[j] for j in xrange(num_images)])
# Shuffle the data to make sure classes are well distributed.
data = zip(all_images, all_labels)
random.shuffle(data)
all_images, all_labels = zip(*data)
return image_generator(all_images[start_from:start_from + how_many],
all_labels[start_from:start_from + how_many])

Expand Down
8 changes: 5 additions & 3 deletions tensor2tensor/data_generators/text_encoder.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from __future__ import division
from __future__ import print_function

from collections import defaultdict

# Dependency imports

import six
from six.moves import xrange # pylint: disable=redefined-builtin
from collections import defaultdict
from tensor2tensor.data_generators import tokenizer

import tensorflow as tf
Expand All @@ -41,6 +42,7 @@
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]


class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""

Expand Down Expand Up @@ -95,7 +97,7 @@ def encode(self, s):
if six.PY2:
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]
return [c + numres for c in s.encode('utf-8')]

def decode(self, ids):
numres = self._num_reserved_ids
Expand All @@ -109,7 +111,7 @@ def decode(self, ids):
if six.PY2:
return ''.join(decoded_ids)
# Python3: join byte arrays and then decode string
return b''.join(decoded_ids).decode("utf-8")
return b''.join(decoded_ids).decode('utf-8')

@property
def vocab_size(self):
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/tokenizer.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@
from __future__ import division
from __future__ import print_function

from collections import defaultdict
import string

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin
from collections import defaultdict


class Tokenizer(object):
"""Vocab for breaking words into wordpieces.
Expand Down
150 changes: 150 additions & 0 deletions tensor2tensor/models/bluenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BlueNet: and out of the blue network to experiment with shake-shake."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin

from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


def residual_module(x, hparams, train, n, sep):
"""A stack of convolution blocks with residual connection."""
k = (hparams.kernel_height, hparams.kernel_width)
dilations_and_kernels = [((1, 1), k) for _ in xrange(n)]
with tf.variable_scope("residual_module%d_sep%d" % (n, sep)):
y = common_layers.subseparable_conv_block(
x,
hparams.hidden_size,
dilations_and_kernels,
padding="SAME",
separability=sep,
name="block")
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))


def residual_module1(x, hparams, train):
return residual_module(x, hparams, train, 1, 1)


def residual_module1_sep(x, hparams, train):
return residual_module(x, hparams, train, 1, 0)


def residual_module2(x, hparams, train):
return residual_module(x, hparams, train, 2, 1)


def residual_module2_sep(x, hparams, train):
return residual_module(x, hparams, train, 2, 0)


def residual_module3(x, hparams, train):
return residual_module(x, hparams, train, 3, 1)


def residual_module3_sep(x, hparams, train):
return residual_module(x, hparams, train, 3, 0)


def norm_module(x, hparams, train):
del train # Unused.
return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module")


def identity_module(x, hparams, train):
del hparams, train # Unused.
return x


def run_modules(blocks, cur, hparams, train, dp):
"""Run blocks in parallel using dp as data_parallelism."""
assert len(blocks) % dp.n == 0
res = []
for i in xrange(len(blocks) // dp.n):
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train))
return res


@registry.register_model
class BlueNet(t2t_model.T2TModel):

def model_fn_body_sharded(self, sharded_features, train):
dp = self._data_parallelism
dp._reuse = False # pylint:disable=protected-access
hparams = self._hparams
blocks = [identity_module, norm_module,
residual_module1, residual_module1_sep,
residual_module2, residual_module2_sep,
residual_module3, residual_module3_sep]
inputs = sharded_features["inputs"]

cur = tf.concat(inputs, axis=0)
cur_shape = cur.get_shape()
for i in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % i):
processed = run_modules(blocks, cur, hparams, train, dp)
cur = common_layers.shakeshake(processed)
cur.set_shape(cur_shape)

return list(tf.split(cur, len(inputs), axis=0)), 0.0


@registry.register_hparams
def bluenet_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.batch_size = 4096
hparams.hidden_size = 768
hparams.dropout = 0.2
hparams.symbol_dropout = 0.2
hparams.label_smoothing = 0.1
hparams.clip_grad_norm = 2.0
hparams.num_hidden_layers = 8
hparams.kernel_height = 3
hparams.kernel_width = 3
hparams.learning_rate_decay_scheme = "exp50k"
hparams.learning_rate = 0.05
hparams.learning_rate_warmup_steps = 3000
hparams.initializer_gain = 1.0
hparams.weight_decay = 3.0
hparams.num_sampled_classes = 0
hparams.sampling_method = "argmax"
hparams.optimizer_adam_epsilon = 1e-6
hparams.optimizer_adam_beta1 = 0.85
hparams.optimizer_adam_beta2 = 0.997
hparams.add_hparam("imagenet_use_2d", True)
return hparams


@registry.register_hparams
def bluenet_tiny():
hparams = bluenet_base()
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 4
hparams.learning_rate_decay_scheme = "none"
return hparams
54 changes: 54 additions & 0 deletions tensor2tensor/models/bluenet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BlueNet tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import bluenet

import tensorflow as tf


class BlueNetTest(tf.test.TestCase):

def testBlueNet(self):
vocab_size = 9
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
hparams = bluenet.bluenet_tiny()
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
vocab_size)
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = bluenet.BlueNet(hparams, p_hparams)
sharded_logits, _, _ = model.model_fn(features, True)
logits = tf.concat(sharded_logits, 0)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))


if __name__ == "__main__":
tf.test.main()
46 changes: 46 additions & 0 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,52 @@ def inverse_exp_decay(max_step, min_value=0.01):
return inv_base**tf.maximum(float(max_step) - step, 0.0)


def shakeshake2_py(x, y, equal=False):
"""The shake-shake sum of 2 tensors, python version."""
alpha = 0.5 if equal else tf.random_uniform([])
return alpha * x + (1.0 - alpha) * y


@function.Defun()
def shakeshake2_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx


@function.Defun()
def shakeshake2_equal_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2, equal=True)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx


@function.Defun(grad_func=shakeshake2_grad)
def shakeshake2(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)


@function.Defun(grad_func=shakeshake2_equal_grad)
def shakeshake2_eqgrad(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)


def shakeshake(xs, equal_grad=False):
"""Multi-argument shake-shake, currently approximated by sums of 2."""
if len(xs) == 1:
return xs[0]
div = (len(xs) + 1) // 2
arg1 = shakeshake(xs[:div], equal_grad=equal_grad)
arg2 = shakeshake(xs[div:], equal_grad=equal_grad)
if equal_grad:
return shakeshake2_eqgrad(arg1, arg2)
return shakeshake2(arg1, arg2)


def standardize_images(x):
"""Image standardization on batches (tf.image.per_image_standardization)."""
with tf.name_scope("standardize_images", [x]):
Expand Down
9 changes: 9 additions & 0 deletions tensor2tensor/models/common_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def testEmbedding(self):
res = session.run(y)
self.assertEqual(res.shape, (3, 5, 16))

def testShakeShake(self):
x = np.random.rand(5, 7)
with self.test_session() as session:
x = tf.constant(x, dtype=tf.float32)
y = common_layers.shakeshake([x, x, x, x, x])
session.run(tf.global_variables_initializer())
inp, res = session.run([x, y])
self.assertAllClose(res, inp)

def testConv(self):
x = np.random.rand(5, 7, 1, 11)
with self.test_session() as session:
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tensor2tensor.models import attention_lm
from tensor2tensor.models import attention_lm_moe
from tensor2tensor.models import bluenet
from tensor2tensor.models import bytenet
from tensor2tensor.models import lstm
from tensor2tensor.models import modalities
Expand Down
10 changes: 10 additions & 0 deletions tensor2tensor/models/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,13 @@ def xception_base():
hparams.optimizer_adam_beta2 = 0.997
hparams.add_hparam("imagenet_use_2d", True)
return hparams


@registry.register_hparams
def xception_tiny():
hparams = xception_base()
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 4
hparams.learning_rate_decay_scheme = "none"
return hparams
Loading