Skip to content

Commit 45b03a5

Browse files
Address review comments
1 parent 92b168f commit 45b03a5

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

keras_nlp/models/mistral/mistral_attention.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ def __init__(
3434
rope_scaling_factor=1.0,
3535
kernel_initializer="glorot_uniform",
3636
sliding_window=512,
37+
dropout=0,
3738
**kwargs,
3839
):
3940
super().__init__(**kwargs)
4041
self._num_query_heads = num_query_heads
4142
self._num_key_value_heads = num_key_value_heads
4243
self._sliding_window = sliding_window
44+
self._dropout = dropout
4345

4446
self._num_key_value_groups = num_query_heads // num_key_value_heads
4547
self._rope_max_wavelength = rope_max_wavelength
@@ -51,24 +53,32 @@ def __init__(
5153
self._rope_scaling_factor = rope_scaling_factor
5254

5355
def build(self, inputs_shape):
56+
# Einsum variables:
57+
# b = batch size
58+
# q = query length
59+
# k = key/value length
60+
# m = model dim
61+
# u = num query heads
62+
# v = num key/value heads
63+
# h = head dim
5464
self._hidden_dim = inputs_shape[-1]
55-
self._attn_head_size = self._hidden_dim // self._num_query_heads
65+
self._head_dim = self._hidden_dim // self._num_query_heads
5666

5767
self._query_dense = keras.layers.EinsumDense(
58-
equation="abc,cde->abde",
59-
output_shape=(None, self._num_query_heads, self._attn_head_size),
68+
equation="bqm,muh->bquh",
69+
output_shape=(None, self._num_query_heads, self._head_dim),
6070
kernel_initializer=self._kernel_initializer,
6171
dtype=self.compute_dtype,
6272
name="query",
6373
)
6474
self._query_dense.build(inputs_shape)
6575

6676
self._key_dense = keras.layers.EinsumDense(
67-
equation="abc,cde->abde",
77+
equation="bkm,mvh->bkvh",
6878
output_shape=(
6979
None,
7080
self._num_key_value_heads,
71-
self._attn_head_size,
81+
self._head_dim,
7282
),
7383
kernel_initializer=self._kernel_initializer,
7484
dtype=self.compute_dtype,
@@ -77,11 +87,11 @@ def build(self, inputs_shape):
7787
self._key_dense.build(inputs_shape)
7888

7989
self._value_dense = keras.layers.EinsumDense(
80-
equation="abc,cde->abde",
90+
equation="bkm,mvh->bkvh",
8191
output_shape=(
8292
None,
8393
self._num_key_value_heads,
84-
self._attn_head_size,
94+
self._head_dim,
8595
),
8696
kernel_initializer=self._kernel_initializer,
8797
dtype=self.compute_dtype,
@@ -91,14 +101,20 @@ def build(self, inputs_shape):
91101

92102
self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")
93103

104+
self._dropout_layer = keras.layers.Dropout(
105+
rate=self._dropout, dtype=self.compute_dtype
106+
)
107+
94108
self._output_dense = keras.layers.EinsumDense(
95-
equation="abc,cd->abd",
109+
equation="bquh,uhm->bqm",
96110
output_shape=(None, self._hidden_dim),
97111
kernel_initializer=self._kernel_initializer,
98112
dtype=self.compute_dtype,
99113
name="attention_output",
100114
)
101-
self._output_dense.build(inputs_shape)
115+
self._output_dense.build(
116+
(None, None, self._num_query_heads, self._head_dim)
117+
)
102118

103119
self.rotary_embedding_layer = RotaryEmbedding(
104120
max_wavelength=self._rope_max_wavelength,
@@ -114,6 +130,7 @@ def call(
114130
attention_mask=None,
115131
cache=None,
116132
cache_update_index=None,
133+
training=None,
117134
):
118135
seq_len = ops.shape(hidden_states)[1]
119136
start_index = (
@@ -221,14 +238,8 @@ def _compute_key_value(x):
221238
query, key, value, attention_mask
222239
)
223240

224-
attention_output_shape = ops.shape(attention_output)
225-
attention_output = ops.reshape(
226-
attention_output,
227-
[
228-
attention_output_shape[0], # batch_shape
229-
attention_output_shape[1], # seq_len
230-
self._hidden_dim,
231-
],
241+
attention_output = self._dropout_layer(
242+
attention_output, training=training
232243
)
233244

234245
attention_output = self._output_dense(attention_output)
@@ -247,9 +258,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
247258
def _compute_attention(self, query, key, value, attention_mask=None):
248259
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)
249260

250-
norm_factor = ops.sqrt(
251-
ops.cast(self._attn_head_size, self.compute_dtype)
252-
)
261+
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
253262

254263
attention_scores = attention_scores / norm_factor
255264

@@ -274,6 +283,7 @@ def get_config(self):
274283
self._kernel_initializer
275284
),
276285
"sliding_window": self._sliding_window,
286+
"dropout": self._dropout,
277287
}
278288
)
279289
return config

keras_nlp/models/mistral/mistral_transformer_decoder.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def __init__(
3939
layer_norm_epsilon=1e-5,
4040
kernel_initializer="glorot_uniform",
4141
sliding_window=512,
42+
dropout=0,
4243
**kwargs,
4344
):
44-
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
45-
4645
super().__init__(**kwargs)
4746
self.intermediate_dim = intermediate_dim
4847
self.num_query_heads = num_query_heads
@@ -51,16 +50,14 @@ def __init__(
5150
self.rope_max_wavelength = rope_max_wavelength
5251
self.rope_scaling_factor = rope_scaling_factor
5352

53+
self.dropout = dropout
54+
5455
self.sliding_window = sliding_window
5556
self.activation = keras.activations.get(activation)
5657
self.layer_norm_epsilon = layer_norm_epsilon
5758
self.kernel_initializer = keras.initializers.get(kernel_initializer)
5859

5960
self.supports_masking = True
60-
self._decoder_sequence_shape = None
61-
62-
if decoder_sequence_shape:
63-
self.build(decoder_sequence_shape)
6461

6562
def build(self, decoder_sequence_shape):
6663
self._decoder_sequence_shape = decoder_sequence_shape
@@ -74,6 +71,7 @@ def build(self, decoder_sequence_shape):
7471
rope_scaling_factor=self.rope_scaling_factor,
7572
sliding_window=self.sliding_window,
7673
kernel_initializer=clone_initializer(self.kernel_initializer),
74+
dropout=self.dropout,
7775
dtype=self.compute_dtype,
7876
name="self_attention",
7977
)
@@ -85,6 +83,11 @@ def build(self, decoder_sequence_shape):
8583
dtype=self.compute_dtype,
8684
)
8785
self._self_attention_layernorm.build(decoder_sequence_shape)
86+
self._self_attention_dropout = keras.layers.Dropout(
87+
rate=self.dropout,
88+
dtype=self.compute_dtype,
89+
name="self_attention_dropout",
90+
)
8891

8992
# Feedforward layers.
9093
self._feedforward_intermediate_dense = keras.layers.Dense(
@@ -135,6 +138,7 @@ def call(
135138
decoder_attention_mask=None,
136139
self_attention_cache=None,
137140
self_attention_cache_update_index=None,
141+
training=None,
138142
):
139143
self_attention_mask = self._compute_self_attention_mask(
140144
decoder_sequence=decoder_sequence,
@@ -156,6 +160,8 @@ def call(
156160
if self_attention_cache is not None:
157161
x, self_attention_cache = x
158162

163+
x = self._self_attention_dropout(x, training=training)
164+
159165
x = x + residual
160166
residual = x
161167

@@ -220,7 +226,7 @@ def get_config(self):
220226
"kernel_initializer": keras.initializers.serialize(
221227
self.kernel_initializer
222228
),
223-
"decoder_sequence_shape": self._decoder_sequence_shape,
229+
"dropout": self.dropout,
224230
}
225231
)
226232
return config

0 commit comments

Comments
 (0)