Skip to content

Commit a96a634

Browse files
nshazeerkpe
authored andcommitted
Code for Multiquery attention paper
PiperOrigin-RevId: 231286854
1 parent 871476a commit a96a634

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from tensor2tensor.models.research import glow
5353
from tensor2tensor.models.research import lm_experiments
5454
from tensor2tensor.models.research import moe_experiments
55+
from tensor2tensor.models.research import multiquery_paper
5556
from tensor2tensor.models.research import rl
5657
from tensor2tensor.models.research import similarity_transformer
5758
from tensor2tensor.models.research import super_lm
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Experiments for Multiquery-Attention Paper.
17+
"""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
from tensor2tensor.models import mtf_transformer2
24+
from tensor2tensor.utils import registry
25+
26+
27+
@registry.register_hparams
28+
def mqp_ende_base():
29+
# params=211M
30+
hparams = mtf_transformer2.mtr_tr_dense_0()
31+
hparams.learning_rate_decay_steps = 20000
32+
hparams.shared_embedding_and_softmax_weights = True
33+
hparams.layer_prepostprocess_dropout = 0.2
34+
return hparams
35+
36+
37+
@registry.register_hparams
38+
def mqp_ende_local():
39+
hparams = mqp_ende_base()
40+
hparams.decoder_local_attention_radius = 32
41+
return hparams
42+
43+
44+
@registry.register_hparams
45+
def mqp_ende_mq8():
46+
# params=178M
47+
hparams = mqp_ende_base()
48+
hparams.decoder_num_heads = 8
49+
hparams.decoder_num_memory_heads = 1
50+
hparams.encoder_num_heads = 8
51+
hparams.encoder_num_memory_heads = 1
52+
return hparams
53+
54+
55+
@registry.register_hparams
56+
def mqp_ende_mq8_ff5440():
57+
# params=211M
58+
hparams = mqp_ende_mq8()
59+
hparams.d_ff = 5440
60+
return hparams
61+
62+
63+
@registry.register_hparams
64+
def mqp_ende_mq8_ff5440_local():
65+
hparams = mqp_ende_mq8_ff5440()
66+
hparams.decoder_local_attention_radius = 32
67+
return hparams
68+
69+
70+
@registry.register_hparams
71+
def mqp_ende_h4_kv256():
72+
hparams = mqp_ende_base()
73+
hparams.decoder_num_heads = 4
74+
hparams.encoder_num_heads = 4
75+
hparams.d_kv = 256
76+
return hparams
77+
78+
79+
@registry.register_hparams
80+
def mqp_ende_h2_kv512():
81+
hparams = mqp_ende_base()
82+
hparams.decoder_num_heads = 2
83+
hparams.encoder_num_heads = 2
84+
hparams.d_kv = 512
85+
return hparams
86+
87+
88+
@registry.register_hparams
89+
def mqp_ende_h1_kv1024():
90+
hparams = mqp_ende_base()
91+
hparams.decoder_num_heads = 1
92+
hparams.encoder_num_heads = 1
93+
hparams.d_kv = 1024
94+
return hparams
95+
96+
97+
@registry.register_hparams
98+
def mqp_ende_h4_ff5632():
99+
hparams = mqp_ende_base()
100+
hparams.decoder_num_heads = 4
101+
hparams.encoder_num_heads = 4
102+
hparams.d_ff = 5632
103+
return hparams
104+
105+
106+
@registry.register_hparams
107+
def mqp_ende_h2_ff6400():
108+
hparams = mqp_ende_base()
109+
hparams.decoder_num_heads = 2
110+
hparams.encoder_num_heads = 2
111+
hparams.d_ff = 6400
112+
return hparams
113+
114+
115+
@registry.register_hparams
116+
def mqp_ende_h1_ff6784():
117+
hparams = mqp_ende_base()
118+
hparams.decoder_num_heads = 1
119+
hparams.encoder_num_heads = 1
120+
hparams.d_ff = 6784
121+
return hparams
122+
123+
124+
@registry.register_hparams
125+
def mqp_ende_h2_kv64_ff6784():
126+
hparams = mqp_ende_base()
127+
hparams.decoder_num_heads = 2
128+
hparams.encoder_num_heads = 2
129+
hparams.d_kv = 64
130+
hparams.d_ff = 6784
131+
return hparams
132+
133+
134+
@registry.register_hparams
135+
def mqp_ende_h4_kv32_ff6784():
136+
hparams = mqp_ende_base()
137+
hparams.decoder_num_heads = 4
138+
hparams.encoder_num_heads = 4
139+
hparams.d_kv = 32
140+
hparams.d_ff = 6784
141+
return hparams
142+
143+
144+
@registry.register_hparams
145+
def mqp_ende_h8_kv16_ff6784():
146+
hparams = mqp_ende_base()
147+
hparams.decoder_num_heads = 8
148+
hparams.encoder_num_heads = 8
149+
hparams.d_kv = 16
150+
return hparams
151+
152+
153+
@registry.register_hparams
154+
def mqp_lm1b_base():
155+
"""Series of architectures for language modeling."""
156+
hparams = mtf_transformer2.mtf_unitransformer_base()
157+
hparams.d_model = 1024
158+
hparams.max_length = 256
159+
hparams.batch_size = 256
160+
# Parameters for my_layer_stack()
161+
hparams.num_hidden_layers = 6
162+
hparams.d_ff = 8192
163+
hparams.d_kv = 128
164+
hparams.num_heads = 8
165+
hparams.learning_rate_decay_steps = 13600
166+
hparams.layout = "batch:batch;vocab:model;d_ff:model;heads:model"
167+
hparams.mesh_shape = "batch:32"
168+
return hparams
169+
170+
171+
@registry.register_hparams
172+
def mqp_lm1b_mq8():
173+
hparams = mqp_lm1b_base()
174+
hparams.num_heads = 8
175+
hparams.num_memory_heads = 1
176+
return hparams
177+
178+
179+
@registry.register_hparams
180+
def mqp_lm1b_mq8_ff9088():
181+
hparams = mqp_lm1b_mq8()
182+
hparams.d_ff = 9088
183+
return hparams
184+
185+
186+
@registry.register_hparams
187+
def mqp_lm1b_h1_ff9984():
188+
hparams = mqp_lm1b_base()
189+
hparams.num_heads = 1
190+
hparams.d_ff = 9984
191+
return hparams
192+
193+
194+
@registry.register_hparams
195+
def mqp_lm1b_h2_kv64_ff9984():
196+
hparams = mqp_lm1b_base()
197+
hparams.num_heads = 2
198+
hparams.d_kv = 64
199+
hparams.d_ff = 9984
200+
return hparams
201+
202+
203+
@registry.register_hparams
204+
def mqp_lm1b_h4_kv32_ff9984():
205+
hparams = mqp_lm1b_base()
206+
hparams.num_heads = 4
207+
hparams.d_kv = 32
208+
hparams.d_ff = 9984
209+
return hparams
210+
211+
212+
@registry.register_hparams
213+
def mqp_lm1b_h8_kv16_ff9984():
214+
hparams = mqp_lm1b_base()
215+
hparams.num_heads = 8
216+
hparams.d_kv = 16
217+
hparams.d_ff = 9984
218+
return hparams

0 commit comments

Comments
 (0)