@@ -59,6 +59,7 @@ class MultiHeadAttention(Layer):
59
59
activity_regularizer: Regularizer for dense layer activity.
60
60
kernel_constraint: Constraint for dense layer kernels.
61
61
bias_constraint: Constraint for dense layer kernels.
62
+ seed: Optional integer to seed the dropout layer.
62
63
63
64
Call arguments:
64
65
query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
@@ -110,6 +111,7 @@ def __init__(
110
111
activity_regularizer = None ,
111
112
kernel_constraint = None ,
112
113
bias_constraint = None ,
114
+ seed = None ,
113
115
** kwargs ,
114
116
):
115
117
super ().__init__ (** kwargs )
@@ -137,6 +139,7 @@ def __init__(
137
139
f"Received: attention_axes={ attention_axes } "
138
140
)
139
141
self ._attention_axes = attention_axes
142
+ self .seed = seed
140
143
141
144
@property
142
145
def num_heads (self ):
@@ -189,6 +192,7 @@ def get_config(self):
189
192
),
190
193
"kernel_constraint" : constraints .serialize (self ._kernel_constraint ),
191
194
"bias_constraint" : constraints .serialize (self ._bias_constraint ),
195
+ "seed" : self .seed ,
192
196
}
193
197
return {** base_config , ** config }
194
198
@@ -359,7 +363,7 @@ def _build_attention(self, rank):
359
363
)
360
364
self ._softmax = Softmax (axis = norm_axes , dtype = self .dtype_policy )
361
365
self ._dropout_layer = Dropout (
362
- rate = self ._dropout , dtype = self .dtype_policy
366
+ rate = self ._dropout , dtype = self .dtype_policy , seed = self . seed
363
367
)
364
368
self ._inverse_sqrt_key_dim = 1.0 / math .sqrt (float (self ._key_dim ))
365
369
0 commit comments