Skip to content

Commit da83683

Browse files
committed
Add seed arg to MHA layer.
1 parent e620cb4 commit da83683

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

keras/src/layers/attention/multi_head_attention.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class MultiHeadAttention(Layer):
5959
activity_regularizer: Regularizer for dense layer activity.
6060
kernel_constraint: Constraint for dense layer kernels.
6161
bias_constraint: Constraint for dense layer kernels.
62+
seed: Optional integer to seed the dropout layer.
6263
6364
Call arguments:
6465
query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
@@ -110,6 +111,7 @@ def __init__(
110111
activity_regularizer=None,
111112
kernel_constraint=None,
112113
bias_constraint=None,
114+
seed=None,
113115
**kwargs,
114116
):
115117
super().__init__(**kwargs)
@@ -137,6 +139,7 @@ def __init__(
137139
f"Received: attention_axes={attention_axes}"
138140
)
139141
self._attention_axes = attention_axes
142+
self.seed = seed
140143

141144
@property
142145
def num_heads(self):
@@ -189,6 +192,7 @@ def get_config(self):
189192
),
190193
"kernel_constraint": constraints.serialize(self._kernel_constraint),
191194
"bias_constraint": constraints.serialize(self._bias_constraint),
195+
"seed": self.seed,
192196
}
193197
return {**base_config, **config}
194198

@@ -359,7 +363,7 @@ def _build_attention(self, rank):
359363
)
360364
self._softmax = Softmax(axis=norm_axes, dtype=self.dtype_policy)
361365
self._dropout_layer = Dropout(
362-
rate=self._dropout, dtype=self.dtype_policy
366+
rate=self._dropout, dtype=self.dtype_policy, seed=self.seed
363367
)
364368
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))
365369

0 commit comments

Comments
 (0)