Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c06cf3f

Browse files
Lukasz Kaisercopybara-github
authored andcommitted
Make Reformer config smaller and allow to use a single rng for all steps.
PiperOrigin-RevId: 268567430
1 parent 35639ad commit c06cf3f

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

tensor2tensor/trax/configs/transformer_revnet_imagenet64_8gb.gin

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import tensor2tensor.trax.trax
55

66
# Parameters for batch_fun:
77
# ==============================================================================
8-
batch_fun.batch_size_per_device = 8
9-
batch_fun.eval_batch_size = 128
8+
batch_fun.batch_size_per_device = 2
9+
batch_fun.eval_batch_size = 16
1010
batch_fun.max_eval_length = 12288 # 64 * 64 * 3
1111

1212
# Parameters for inputs:
@@ -41,24 +41,33 @@ DotProductCausalAttention.dropout = 0.0
4141
MemoryEfficientCausalAttention.dropout = 0.0
4242
MemoryEfficientCausalAttention.loop_stride = 512
4343

44-
# Parameters for DummyHashedAttention:
44+
# Parameters for MergedHashedCausalAttention:
4545
# ==============================================================================
46-
# DummyHashedAttention.dropout = 0.0
47-
# DummyHashedAttention.n_bins = 64
46+
MergedHashedCausalAttention.dropout = 0.0
47+
MergedHashedCausalAttention.n_bins = 32
48+
MergedHashedCausalAttention.bin_by_time = True
49+
MergedHashedCausalAttention.one_rng = False
50+
51+
# Parameters for MergedMultiHashedCausalAttention:
52+
# ==============================================================================
53+
MergedMultiHashedCausalAttention.dropout = 0.0
54+
MergedMultiHashedCausalAttention.n_bins = 64
55+
MergedMultiHashedCausalAttention.n_hashes = 4
56+
MergedMultiHashedCausalAttention.bin_by_time = False
57+
MergedHashedCausalAttention.one_rng = True
4858

4959
# Parameters for TransformerRevnetLM:
5060
# ==============================================================================
5161
TransformerRevnetLM.d_model = 1024
5262
TransformerRevnetLM.d_ff = 2048
53-
TransformerRevnetLM.d_attention_key = 32
54-
TransformerRevnetLM.d_attention_value = 32
63+
TransformerRevnetLM.d_attention_key = 64
64+
TransformerRevnetLM.d_attention_value = 64
5565
TransformerRevnetLM.dropout = 0.0
5666
TransformerRevnetLM.max_len = 12288 # 64 * 64 * 3
5767
TransformerRevnetLM.mode = 'train'
5868
TransformerRevnetLM.n_heads = 4
59-
TransformerRevnetLM.n_layers = 6
69+
TransformerRevnetLM.n_layers = 4
6070
TransformerRevnetLM.vocab_size = 256
6171
TransformerRevnetLM.n_chunks = 16
6272
TransformerRevnetLM.n_attention_chunks = 1
63-
TransformerRevnetLM.attention_type = @trax.layers.MemoryEfficientCausalAttention
64-
73+
TransformerRevnetLM.attention_type = @trax.layers.MergedMultiHashedCausalAttention

tensor2tensor/trax/layers/attention.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import random
2122
import jax
2223
import numpy as onp
2324

@@ -559,11 +560,17 @@ def body_fun(vals): # pylint: disable=invalid-name
559560
class MergedHashedCausalAttention(BaseCausalAttention):
560561
"""Hash-based causal attention."""
561562

562-
def __init__(self, dropout, mode, n_bins=64, bin_by_time=False):
563+
def __init__(self, dropout, mode, n_bins=64,
564+
bin_by_time=False, one_rng=False):
563565
del dropout, mode
564566
super(MergedHashedCausalAttention, self).__init__()
565567
self.n_bins = n_bins
566568
self.bin_by_time = bin_by_time
569+
seed = random.randint(0, 2**31 - 1)
570+
self._one_rng = one_rng
571+
self._prng = None
572+
if one_rng:
573+
self._prng = backend.random.get_prng(seed)
567574

568575
def call(self, inputs, params=(), state=(), **kwargs):
569576
del params
@@ -604,8 +611,12 @@ def hash_vectors(self, vecs, rng):
604611
# It's not clear whether sampling a different random rotation for each head
605612
# and batch element matters here, but see MergedMultiHashedCausalAttention.
606613
assert self.n_bins % 2 == 0
614+
rot_rng = rng
615+
if self._one_rng:
616+
rot_rng = jax.lax.tie_in(vecs, self._prng)
607617
random_rotation = jax.random.normal(
608-
rng, (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32')
618+
rot_rng,
619+
(vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32')
609620

610621
# TODO(kitaev): making the vectors unit-length here is probably redundant.
611622
vecs = self.make_unit_length(vecs)
@@ -735,12 +746,18 @@ def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name
735746
class MergedMultiHashedCausalAttention(BaseCausalAttention):
736747
"""Hash-based causal attention, with multiple hashes."""
737748

738-
def __init__(self, dropout, mode, n_bins=64, n_hashes=1, bin_by_time=False):
749+
def __init__(self, dropout, mode, n_bins=64, n_hashes=1,
750+
bin_by_time=False, one_rng=False):
739751
del dropout, mode
740752
super(MergedMultiHashedCausalAttention, self).__init__()
741753
self.n_bins = n_bins
742754
self.n_hashes = n_hashes
743755
self.bin_by_time = bin_by_time
756+
seed = random.randint(0, 2**31 - 1)
757+
self._one_rng = one_rng
758+
self._prng = None
759+
if one_rng:
760+
self._prng = backend.random.get_prng(seed)
744761

745762
def bin_vectors_by_time(self, vecs):
746763
seqlen = vecs.shape[-2]
@@ -770,8 +787,12 @@ def hash_vectors(self, vecs, rng):
770787
# of vecs. Applying multiple hashes to the same input is important because
771788
# it increases the probability of being in the same bin as relevant items.
772789
assert self.n_bins % 2 == 0
790+
rot_rng = rng
791+
if self._one_rng:
792+
rot_rng = jax.lax.tie_in(vecs, self._prng)
773793
random_rotation = jax.random.normal(
774-
rng, (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32')
794+
rot_rng,
795+
(vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32')
775796

776797
# TODO(kitaev): making the vectors unit-length here is probably redundant.
777798
vecs = self.make_unit_length(vecs)

0 commit comments

Comments
 (0)