@@ -34,12 +34,14 @@ def __init__(
3434 rope_scaling_factor = 1.0 ,
3535 kernel_initializer = "glorot_uniform" ,
3636 sliding_window = 512 ,
37+ dropout = 0 ,
3738 ** kwargs ,
3839 ):
3940 super ().__init__ (** kwargs )
4041 self ._num_query_heads = num_query_heads
4142 self ._num_key_value_heads = num_key_value_heads
4243 self ._sliding_window = sliding_window
44+ self ._dropout = dropout
4345
4446 self ._num_key_value_groups = num_query_heads // num_key_value_heads
4547 self ._rope_max_wavelength = rope_max_wavelength
@@ -51,24 +53,32 @@ def __init__(
5153 self ._rope_scaling_factor = rope_scaling_factor
5254
5355 def build (self , inputs_shape ):
56+ # Einsum variables:
57+ # b = batch size
58+ # q = query length
59+ # k = key/value length
60+ # m = model dim
61+ # u = num query heads
62+ # v = num key/value heads
63+ # h = head dim
5464 self ._hidden_dim = inputs_shape [- 1 ]
55- self ._attn_head_size = self ._hidden_dim // self ._num_query_heads
65+ self ._head_dim = self ._hidden_dim // self ._num_query_heads
5666
5767 self ._query_dense = keras .layers .EinsumDense (
58- equation = "abc,cde->abde " ,
59- output_shape = (None , self ._num_query_heads , self ._attn_head_size ),
68+ equation = "bqm,muh->bquh " ,
69+ output_shape = (None , self ._num_query_heads , self ._head_dim ),
6070 kernel_initializer = self ._kernel_initializer ,
6171 dtype = self .compute_dtype ,
6272 name = "query" ,
6373 )
6474 self ._query_dense .build (inputs_shape )
6575
6676 self ._key_dense = keras .layers .EinsumDense (
67- equation = "abc,cde->abde " ,
77+ equation = "bkm,mvh->bkvh " ,
6878 output_shape = (
6979 None ,
7080 self ._num_key_value_heads ,
71- self ._attn_head_size ,
81+ self ._head_dim ,
7282 ),
7383 kernel_initializer = self ._kernel_initializer ,
7484 dtype = self .compute_dtype ,
@@ -77,11 +87,11 @@ def build(self, inputs_shape):
7787 self ._key_dense .build (inputs_shape )
7888
7989 self ._value_dense = keras .layers .EinsumDense (
80- equation = "abc,cde->abde " ,
90+ equation = "bkm,mvh->bkvh " ,
8191 output_shape = (
8292 None ,
8393 self ._num_key_value_heads ,
84- self ._attn_head_size ,
94+ self ._head_dim ,
8595 ),
8696 kernel_initializer = self ._kernel_initializer ,
8797 dtype = self .compute_dtype ,
@@ -91,14 +101,20 @@ def build(self, inputs_shape):
91101
92102 self ._softmax = keras .layers .Softmax (axis = - 1 , name = "attention_softmax" )
93103
104+ self ._dropout_layer = keras .layers .Dropout (
105+ rate = self ._dropout , dtype = self .compute_dtype
106+ )
107+
94108 self ._output_dense = keras .layers .EinsumDense (
95- equation = "abc,cd->abd " ,
109+ equation = "bquh,uhm->bqm " ,
96110 output_shape = (None , self ._hidden_dim ),
97111 kernel_initializer = self ._kernel_initializer ,
98112 dtype = self .compute_dtype ,
99113 name = "attention_output" ,
100114 )
101- self ._output_dense .build (inputs_shape )
115+ self ._output_dense .build (
116+ (None , None , self ._num_query_heads , self ._head_dim )
117+ )
102118
103119 self .rotary_embedding_layer = RotaryEmbedding (
104120 max_wavelength = self ._rope_max_wavelength ,
@@ -114,6 +130,7 @@ def call(
114130 attention_mask = None ,
115131 cache = None ,
116132 cache_update_index = None ,
133+ training = None ,
117134 ):
118135 seq_len = ops .shape (hidden_states )[1 ]
119136 start_index = (
@@ -221,14 +238,8 @@ def _compute_key_value(x):
221238 query , key , value , attention_mask
222239 )
223240
224- attention_output_shape = ops .shape (attention_output )
225- attention_output = ops .reshape (
226- attention_output ,
227- [
228- attention_output_shape [0 ], # batch_shape
229- attention_output_shape [1 ], # seq_len
230- self ._hidden_dim ,
231- ],
241+ attention_output = self ._dropout_layer (
242+ attention_output , training = training
232243 )
233244
234245 attention_output = self ._output_dense (attention_output )
@@ -247,9 +258,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
247258 def _compute_attention (self , query , key , value , attention_mask = None ):
248259 attention_scores = ops .einsum ("aecd,abcd->acbe" , key , query )
249260
250- norm_factor = ops .sqrt (
251- ops .cast (self ._attn_head_size , self .compute_dtype )
252- )
261+ norm_factor = ops .sqrt (ops .cast (self ._head_dim , self .compute_dtype ))
253262
254263 attention_scores = attention_scores / norm_factor
255264
@@ -274,6 +283,7 @@ def get_config(self):
274283 self ._kernel_initializer
275284 ),
276285 "sliding_window" : self ._sliding_window ,
286+ "dropout" : self ._dropout ,
277287 }
278288 )
279289 return config
0 commit comments