-
Notifications
You must be signed in to change notification settings - Fork 255
/
moe.py
2059 lines (1799 loc) · 81 KB
/
moe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2023 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mixture-of-experts code.
Interfaces and algorithms are under development and subject to rapid change
without notice.
TODO(noam): Remove the other copy of this code from tensor2tensor.
TODO(noam): Write a new, simpler, cleaner version of this code.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gin
import mesh_tensorflow as mtf
from mesh_tensorflow.transformer import transformer
import tensorflow.compat.v1 as tf
@gin.configurable
class MoE1D(transformer.TransformerLayer):
"""Mixture of Experts Layer."""
def __init__(self,
num_experts=16,
loss_coef=1e-2,
hidden_size=4096,
group_size=1024,
capacity_factor_train=1.25,
capacity_factor_eval=2.0,
use_second_place_loss=False,
second_policy_train="random",
second_policy_eval="random",
second_threshold_train=0.2,
second_threshold_eval=0.2,
dropout_rate=0.0,
activation="relu",
moe_gating="top_2",
min_expert_capacity=4,
switch_policy_train="input_jitter",
switch_policy_eval="input_jitter",
switch_dropout=0.1,
switch_temperature=1.0,
switch_jitter=1e-2,
ntlb_top_k=4,
output_dim=None,
use_experts_attention=False,
z_loss=None,
word_embed_mode=None,
use_second_place_expert_prob=None,
use_second_place_expert_prob_temp=None,
top_n_num_experts_per_token=3):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
moe_loss_coef=loss_coef,
moe_hidden_size=hidden_size,
moe_group_size=group_size,
moe_min_expert_capacity=min_expert_capacity,
moe_capacity_factor_train=capacity_factor_train,
moe_capacity_factor_eval=capacity_factor_eval,
moe_use_second_place_loss=use_second_place_loss,
moe_second_policy_train=second_policy_train,
moe_second_policy_eval=second_policy_eval,
moe_second_threshold_train=second_threshold_train,
moe_second_threshold_eval=second_threshold_eval,
moe_dropout_rate=dropout_rate,
moe_switch_policy_train=switch_policy_train,
moe_switch_policy_eval=switch_policy_eval,
moe_switch_dropout=switch_dropout,
moe_switch_temperature=switch_temperature,
moe_switch_jitter=switch_jitter,
moe_output_dim=output_dim,
moe_ntlb_top_k=ntlb_top_k,
moe_use_experts_attention=use_experts_attention,
moe_z_loss=z_loss,
moe_word_embed_mode=word_embed_mode,
moe_use_second_place_expert_prob=(
use_second_place_expert_prob),
moe_use_second_place_expert_prob_temp=(
use_second_place_expert_prob_temp),
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
self._activation = activation
def call(self, context, x, losses=None):
"""Call the layer."""
if context.model.ensemble_dim:
raise NotImplementedError("MoE not yet implemented with ensembles")
has_length_dim = context.length_dim in x.shape.dims
if not has_length_dim:
x_shape = x.shape
shape_with_length = mtf.Shape(
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
+ x_shape.dims[-1:])
x = mtf.reshape(x, shape_with_length)
# Extract the MoE output dimension
if self._hparams.moe_output_dim is not None:
output_dim = self._hparams.moe_output_dim
else:
output_dim = context.model.model_dim
y, loss = transformer_moe_layer_v1(
x,
output_dim,
self._hparams,
context.train,
context.variable_dtype,
layout=context.model.layout,
mesh_shape=context.model.mesh_shape,
nonpadding=context.nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
if self._hparams.moe_use_experts_attention:
y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y]
y = y_reshape
else:
y = mtf.reshape(y, x_shape)
return y
class MoE2D(transformer.TransformerLayer):
"""Mixture of Experts Layer."""
def __init__(self,
expert_x=8,
expert_y=8,
loss_coef=1e-2,
hidden_size=4096,
group_size=1024,
capacity_factor_train=1.25,
capacity_factor_eval=2.0,
capacity_factor_second_level=1.0,
use_second_place_loss=False,
second_policy_train="random",
second_policy_eval="random",
second_threshold_train=0.2,
second_threshold_eval=0.2):
self._hparams = HParams(
moe_gating="top_2",
moe_num_experts=[expert_x, expert_y],
moe_loss_coef=loss_coef,
moe_hidden_size=hidden_size,
moe_group_size=group_size,
moe_capacity_factor_train=capacity_factor_train,
moe_capacity_factor_eval=capacity_factor_eval,
moe_capacity_factor_second_level=capacity_factor_second_level,
moe_use_second_place_loss=use_second_place_loss,
moe_second_policy_train=second_policy_train,
moe_second_policy_eval=second_policy_eval,
moe_second_threshold_train=second_threshold_train,
moe_second_threshold_eval=second_threshold_eval)
def call(self, context, x, losses=None):
"""Call the layer."""
if context.model.ensemble_dim:
raise NotImplementedError("MoE not yet implemented with ensembles")
has_length_dim = context.length_dim in x.shape.dims
if not has_length_dim:
x_shape = x.shape
shape_with_length = mtf.Shape(
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
+ x_shape.dims[-1:])
x = mtf.reshape(x, shape_with_length)
y, loss = transformer_moe_layer_v2(
x,
context.model.model_dim,
self._hparams,
context.train,
context.variable_dtype,
layout=context.model.layout,
mesh_shape=context.model.mesh_shape,
nonpadding=context.nonpadding,
num_microbatches=context.num_microbatches)
if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
y = mtf.reshape(y, x_shape)
return y
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
"""Local mixture of experts that works well on TPU.
Adapted from the paper https://arxiv.org/abs/1701.06538
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
dictionary in order not to complicate the interface in mtf_transformer.py .
Once this code moves out of "research", we should pass the hyperparameters
separately.
Hyperparameters used:
hparams.moe_num_experts: number of experts
hparams.moe_hidden_size: size of hidden layer in each expert
hparams.moe_group_size: size of each "group" for gating purposes
hparams.moe_capacity_factor_train: a float
hparams.moe_capacity_factor_eval: a float
hparams.moe_gating: a string
+ all hyperparmeters used by _top_2_gating()
The number of parameters in the gating network is:
(input_dim.size * hparams.num_experts) +
The number of parameters in the experts themselves is:
(hparams.num_experts
* (input_dim.size + output_dim.size)
* hparams.moe_hidden_size)
The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
of the representations of all positions in a batch of sequences.
Each position of each sequence is sent to 0-2 experts. The expert
choices and the combination weights are determined by a learned gating
function.
This function returns a small auxiliary loss that should be added to the
training loss of the model. This loss helps to balance expert usage.
Without the loss, it is very likely that a few experts will be trained and
the rest will starve.
Several hacks are necessary to get around current TPU limitations:
- To ensure static shapes, we enforce (by truncation/padding)
that each sequence send the same number of elements to each expert.
It would make more sense to enforce this equality over the entire batch,
but due to our hacked-up gather-by-matmul implementation, we need to divide
the batch into "groups". For each group, the same number of elements
are sent to each expert.
TODO(noam): Factor this code better. We want to be able to substitute
different code for the experts themselves.
Dimensions cheat sheet:
B: batch dim(s)
L: original sequence length
M: input depth
N: output depth
G: number of groups
S: group size
E: number of experts
C: expert capacity
Args:
inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
output_dim: a mtf.Dimension (for Transformer, this is input_dim)
hparams: model hyperparameters
train: a boolean
variable_dtype: a mtf.VariableDType
layout: optional - an input to mtf.convert_to_layout_rules
mesh_shape: optional - an input to mtf.convert_to_shape
nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
and the same dtype as inputs, consisting of ones(nonpadding)
and zeros(padding).
activation: a function.
num_microbatches: number of microbatches.
token_embeddings: a mtf.Tensor with shape
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
loss: a mtf scalar
Raises:
ValueError: on unrecognized hparams.moe_gating
"""
# pylint: disable=line-too-long
#
# O outer_batch dimension can be used for expert replication, e.g.
# outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each
# expert.
#
# E.g. 16x16 basic example:
# moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024
# ---
# Below ` indicates common way of splitting along mesh dimension.
#
# orig_inputs OB`LM Tensor
# Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
# v (reshaped)
# inputs OG`SM
# Shape[outer_batch=1, batch=1024, group=1024, d_model=1024]
#
# combine_tensor,
# dispatch_tensor OG`SEC
# Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4]
#
# (dispatched inputs)
# expert_inputs OEG`CM
# Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
# v (re-split via ReshapeOperation)
# OE`GCM
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
#
# (hidden representation)
# h OE`GCH
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192]
#
# expert_output OE`GCM
# Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
# v (re-split via ReshapeOperation)
# OEG`CM
# Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
#
# (combined expert_output)
# output OG`SM
# Shape[outer_batch=1, batch=1024, group=1024, d_model=1024
# v (reshape)
# OB`LM
# Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
#
# pylint: enable=line-too-long
orig_inputs = inputs
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])
# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
#
# We then reqire num_groups to be a multiple of mesh_dim_size.
if orig_inputs.shape.dims[0].name == "outer_batch":
outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
else:
outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
orig_inputs.shape.dims[0])
# Number of MoE inputs (total number of position across batch_and_length_dims
# per replica.
n = 1
for d in batch_and_length_dims:
n *= d.size
n = n // outer_batch_dim.size
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
orig_batch_dim)
num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
mesh_dim_size)
group_size_dim = mtf.Dimension("group", group_size)
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
# OGSM Tensor
inputs = mtf.reshape(inputs, moe_input_dims)
# Token embeddings that can be optionally used in the router for determining
# where to send tokens.
if hparams.moe_word_embed_mode is not None:
token_embeddings = mtf.cast(
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)
# Each sequence sends expert_capacity positions to each expert.
if train:
capacity_factor = hparams.moe_capacity_factor_train
else:
capacity_factor = hparams.moe_capacity_factor_eval
expert_capacity = min(
group_size_dim.size,
int((group_size_dim.size * capacity_factor) / experts_dim.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
tf.logging.info("expert_capacity: %d" % expert_capacity)
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
if nonpadding is not None:
nonpadding = mtf.zeros(
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
if hparams.moe_gating == "top_2":
# combine_tensor,
# dispatch_tensor OG`SEC Tensors
# (G is generally split along mesh dim)
dispatch_tensor, combine_tensor, loss = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "top_n":
dispatch_tensor, combine_tensor, loss = _top_n_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch":
dispatch_tensor, combine_tensor, loss = _switch_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "switch_max":
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
elif hparams.moe_gating == "expert_selection":
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=experts_dim_unsplit,
group_size_dim=group_size_dim,
expert_capacity_dim=expert_capacity_dim,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=nonpadding,
name="expert_selection_gating",
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
expert_inputs = mtf.einsum([inputs, dispatch_tensor],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim, input_dim
]))
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
d_model_split_dim
]))
# Split over batch -> split over experts
expert_inputs = mtf.reshape(
expert_inputs,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
input_dim
]))
# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")
if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
def _compute_output(hidden, layer_name):
"""Compute the output of the attention layer from the hidden vector."""
expert_output = mtf.layers.dense(
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
name=layer_name)
# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
d_model_split_dim = mtf.Dimension(
"d_model_split", expert_output.shape[-1].size)
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))
# Split over experts -> split over batch
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output
if hparams.moe_use_experts_attention:
# We share k_h and v_h with no degradation in performance
q_h, k_h = h, h
outputs = []
q = _compute_output(q_h, layer_name="q_wo")
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
def transformer_moe_layer_v2(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, num_microbatches=None):
"""2-level mixture of experts.
Adapted from the paper https://arxiv.org/abs/1701.06538
Note: until the algorithm and inferface solidify, we pass in a hyperparameters
dictionary in order not to complicate the interface in mtf_transformer.py .
Once this code moves out of "research", we should pass the hyperparameters
separately.
Hyperparameters used:
hparams.moe_num_experts: number of experts
hparams.moe_hidden_size: size of hidden layer in each expert
hparams.moe_group_size: size of each "group" for gating purposes
hparams.moe_capacity_factor_train: a float
hparams.moe_capacity_factor_eval: a float
hparams.moe_capacity_factor_second_level: a float
hparams.moe_gating: a string
+ all hyperparmeters used by _top_2_gating()
One set of params for experts in first level and different of hparams
per expert in the second level.
The number of parameters in the gating network is:
(input_dim.size * (hparams.num_experts) +
(moe_hidden_size * hparams.num_experts) * hparams.num_experts
The number of parameters in the experts themselves is:
(hparams.num_experts
* (input_dim.size + output_dim.size)
* hparams.moe_hidden_size)
The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
of the representations of all positions in a batch of sequences.
Each position of each sequence is sent to 0-3 experts. The expert
choices and the combination weights are determined by a learned gating
function.
This function returns a small auxiliary loss that should be added to the
training loss of the model. This loss helps to balance expert usage.
Without the loss, it is very likely that a few experts will be trained and
the rest will starve.
Several hacks are necessary to get around current TPU limitations:
- To ensure static shapes, we enforce (by truncation/padding)
that each sequence send the same number of elements to each expert.
It would make more sense to enforce this equality over the entire batch,
but due to our hacked-up gather-by-matmul implementation, we need to divide
the batch into "groups". For each group, the same number of elements
are sent to each expert.
TODO(noam): Factor this code better. We want to be able to substitute
different code for the experts themselves.
Dimensions cheat sheet:
a, b: batch size
l: original sequence length
m: input depth
n: output depth
g, h: number of groups
s, t: group size
x, y: number of experts
c, d: expert capacity
input: [a0, b1, l, m]
input: [a0, g1, s, m]
dispatch_tensor_x: [a0, g1, s, x, c]
expert_input: [a0, g1, x, c, m]
alltoall: [a0, g, x1, c, m]
alltoall: [a0, g, x1, c, m]
transpose: [x1, a0, g, c, m]
reshape: [x1, h0, s, m]
assignment2: [x1, h0, t, y, d]
expert_input2: [x1, h0, y, d, m]
alltoall: [x1, h, y0, d, m]
...
reverse of that
gating params 0: [m, x]
gating params 1: [x1, m, y]
expert params:
[x1, y0, m, hidden]
[x1, y0, hidden, n]
Args:
inputs: a mtf.Tensor with shape [a, b, l, m]
output_dim: a mtf.Dimension (for Transformer, this is input_dim)
hparams: model hyperparameters
train: a boolean
variable_dtype: a mtf.VariableDType
layout: optional - an input to mtf.convert_to_layout_rules
mesh_shape: optional - an input to mtf.convert_to_shape
nonpadding: an optional mtf.Tensor with shape [a, b, l]
and the same dtype as inputs, consisting of ones(nonpadding)
and zeros(padding).
num_microbatches: number of microbatches.
Returns:
outputs: a Tensor with shape [a, b, l, n]
loss: a mtf scalar
Raises:
ValueError: on unrecognized hparams.moe_gating
"""
if nonpadding is not None:
nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1],
dtype=inputs.dtype) + nonpadding
insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
if insert_outer_batch_dim:
inputs = mtf.reshape(
inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims)
assert len(hparams.moe_num_experts) == 2
a0, b1, l, m = inputs.shape.dims
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
n = output_dim
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups (g.size) is a multiple of the mesh dimension
# over which those groups are split.
num_groups, group_size = _split_into_groups(
b1.size * l.size, hparams.moe_group_size,
mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
g1 = mtf.Dimension(b1.name, num_groups)
g = mtf.Dimension(b1.name + "_unsplit", g1.size)
s = mtf.Dimension("group_size_x", group_size)
# Each sequence sends (at most?) expert_capacity positions to each expert.
# Static expert_capacity dimension is needed for expert batch sizes
if train:
capacity_factor = hparams.moe_capacity_factor_train
else:
capacity_factor = hparams.moe_capacity_factor_eval
expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
c = mtf.Dimension("expert_capacity_x", expert_capacity)
# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups (h.size) is a multiple of the mesh dimension
# over which those groups are split.
num_groups, group_size = _split_into_groups(
a0.size * g.size * c.size,
hparams.moe_group_size,
mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
t = mtf.Dimension("group_size_y", group_size)
h0 = mtf.Dimension(a0.name, num_groups)
h = mtf.Dimension(a0.name + "_unsplit", h0.size)
expert_capacity = min(
t.size,
int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
d = mtf.Dimension("expert_capacity_y", expert_capacity)
# First level of expert routing
# Reshape the inner batch size to a multiple of group_dim g1 and
# group_size_dim s.
inputs = mtf.reshape(inputs, [a0, g1, s, m])
if nonpadding is not None:
nonpadding = mtf.reshape(nonpadding, [a0, g1, s])
# Get the assignments for the first level.
# dispatch_tensor_x has shape [a0, g1, s, x, c]
if hparams.moe_gating == "top_2":
dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
inputs=inputs,
outer_expert_dims=None,
experts_dim=x,
expert_capacity_dim=c,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
name="outer_gating",
importance=nonpadding,
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
# Now create expert_inputs based on the assignments.
# put num_experts dimension first to make split easier in alltoall
expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m])
# we construct an "importance" Tensor for the inputs to the second-level
# gating. The importance of an input is 1.0 if it represents the
# first-choice expert-group and 0.5 if it represents the second-choice expert
# group. This is used by the second-level gating.
importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
importance = 0.5 * (
mtf.to_float(mtf.greater(importance, 0.5)) +
mtf.to_float(mtf.greater(importance, 0.0)))
# First level, all to all. Here we change the split dimension from g1 to x1.
expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape(
[x1, a0, g, c, m]))
importance = mtf.reshape(importance, [x1, a0, g, c])
# Second level of expert routing
# Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
# and group_size_dim t.
inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
importance = mtf.reshape(importance, [x1, h0, t])
# Get the assignments for the second level.
# dispatch_tensor_y has shape [x1, h0, t, y, d]
if hparams.moe_gating == "top_2":
dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
inputs=inputs_y,
outer_expert_dims=[x1],
experts_dim=y,
expert_capacity_dim=d,
hparams=hparams,
train=train,
variable_dtype=variable_dtype,
importance=importance,
name="inner_gating",
num_microbatches=num_microbatches)
else:
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
# Now create expert_inputs based on the assignments.
# put num_experts dimension first to make split easier in alltoall
expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m])
# Second level, all to all. Here we change the split dimension from h0 to y0.
expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape(
[y0, x1, h, d, m]))
hidden_output = mtf.layers.dense(
expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
reduced_dims=expert_inputs_y.shape.dims[-1:],
activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype,
name="wi")
expert_output = mtf.layers.dense(
hidden_output, output_dim, expert_dims=[y0, x1],
reduced_dims=hidden_output.shape.dims[-1:],
use_bias=False, variable_dtype=variable_dtype,
name="wo")
# NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
# expert_output has shape [y0, x1, h, d, n]
# alltoall
expert_output = mtf.reshape(expert_output, mtf.Shape(
[y, x1, h0, d, n]))
# combine results from inner level
output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])
# Reshape the combined tensor from inner level to now contain outer_batch_dim
# a0 and group_dim g
output = mtf.reshape(output_y, [x1, a0, g, c, n])
# alltoall from expert_dim x to group_dim g1
expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))
# combine results from outer level
output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])
# Reshape the combined tensor to now contain inner_batch_dim
# b1 and the original sequence length
output = mtf.reshape(output_x, [a0, b1, l, n])
if insert_outer_batch_dim:
output = mtf.reshape(output, [b1, l, n])
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def _stochastically_use_non_top_expert(gate_logits, experts_dim, hparams):
"""With a specified probability use the second place or lower experts."""
# With the specified probability use the second place expert in place of the
# top expert.
tf.logging.info("Using second place expert with prob: {}".format(
hparams.moe_use_second_place_expert_prob))
_, top_expert_index = mtf.top_1(gate_logits, reduced_dim=experts_dim)
top_expert_mask = mtf.one_hot(
top_expert_index, experts_dim, dtype=gate_logits.dtype)
# With probability moe_expert_use_second_place_expert_prob send the token to
# the non-top expert.
use_second_place_expert = mtf.cast(
mtf.less(
mtf.random_uniform(gate_logits.mesh, gate_logits.shape[:-1]),
hparams.moe_use_second_place_expert_prob), gate_logits.dtype)
# Mask out the top logit.
second_place_gate_logits = -1e9 * top_expert_mask + gate_logits
# If a temperature is specified sample from the remaining N-1 experts.
if hparams.moe_use_second_place_expert_prob_temp is not None:
tf.logging.info("Expert second place temp: {}".format(
hparams.moe_use_second_place_expert_prob_temp))
# What expert should be used.
second_expert_index = mtf.sample_with_temperature(
second_place_gate_logits, experts_dim,
temperature=hparams.moe_use_second_place_expert_prob_temp)
second_expert_mask = mtf.one_hot(
second_expert_index, experts_dim, dtype=gate_logits.dtype)
# Set all logits to -inf that are not the sampled expert
second_place_gate_logits += (1 - second_expert_mask) * -1e9
gate_logits = (use_second_place_expert * second_place_gate_logits +
(1 - use_second_place_expert) * gate_logits)
return gate_logits
def _ntlb_gating(inputs,
outer_expert_dims,
experts_dim,
expert_capacity_dim,
hparams,
train,
variable_dtype,
importance=None,
name="ntlb_gating",
num_microbatches=None,
token_embeddings=None):
"""Compute Switch gating with no-token-left behind (NTLB) behavior."""
# SELECT EXPERT
if train:
policy = hparams.moe_switch_policy_train
else:
policy = hparams.moe_switch_policy_eval
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
gate_inputs = mtf.to_float(inputs)
# Input perturbations
if train and policy == "input_jitter":
gate_inputs = mtf.layers.multiplicative_jitter(
gate_inputs, hparams.moe_switch_jitter)
if hparams.moe_word_embed_mode is not None:
gate_inputs = _add_token_emb_to_gate_inputs(
gate_inputs, token_embeddings, hparams.moe_word_embed_mode)
gate_logits = mtf.layers.dense(
gate_inputs,
experts_dim,
use_bias=False,
expert_dims=outer_expert_dims,
variable_dtype=variable_dtype,
name=name)
if hparams.moe_use_second_place_expert_prob is not None and train:
gate_logits = _stochastically_use_non_top_expert(
gate_logits, experts_dim, hparams)
raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)
# The internals of this function run in float32.
# bfloat16 seems to reduce quality.
raw_gates = mtf.to_float(raw_gates)
# Top-k operation
k_dim = mtf.Dimension("k", hparams.moe_ntlb_top_k)
expert_gate, expert_index = mtf.top_k(
raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
expert_mask = mtf.one_hot(expert_index, experts_dim)
# LOAD BALANCING LOSS
outer_batch_dim = inputs.shape[0]
batch_dim = inputs.shape[1]
group_size_dim = inputs.shape[-2]
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
if importance is not None:
expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
density_1_proxy *= mtf.cast(
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
loss = (
mtf.reduce_mean(density_1_proxy * density_1) *
float(experts_dim.size * experts_dim.size))
if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches
# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
tf.logging.info("Using z_loss: {}".format(hparams.moe_z_loss))
z_loss = _router_z_loss(gate_logits, experts_dim, num_microbatches,
importance)
mtf.scalar_summary(name + "/z_loss", z_loss)
loss += (hparams.moe_z_loss * z_loss)
# Logging
if train:
entropy = mtf.reduce_sum(
-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(mask_count_experts)
expert_fraction = mtf.to_float(mask_count_experts / total_routed)
split_fractions = mtf.split(
expert_fraction,
split_dim=experts_dim,
num_or_size_splits=experts_dim.size)
for fraction in split_fractions:
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
# COMPUTE ASSIGNMENT TO EXPERT
# Iteratively route tokens (no-token-left-behind). The idea is to route as
# many tokens as possible to top-i before then trying top-(i+1).
top_k_masks = mtf.split(
expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
top_k_gates = mtf.split(
expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
top_k_indices = mtf.split(
expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)
# Tensors cumulative values over the iterative process.
combine_tensor = mtf.constant(
inputs.mesh,
value=0,
shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
cum_tokens = mtf.constant(
inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
tokens_left_to_route = mtf.constant(
inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])