Skip to content

Commit 110cda9

Browse files
committed
multihead_attention: use pure and elemental where necessary
1 parent 2731d63 commit 110cda9

File tree

5 files changed

+48
-48
lines changed

5 files changed

+48
-48
lines changed

src/nf/nf_cross_attention_layer.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module function cross_attention_layer_cons(n_heads) result(res)
3535
res % n_heads = n_heads
3636
end function cross_attention_layer_cons
3737

38-
module subroutine backward(self, input, gradient)
38+
pure module subroutine backward(self, input, gradient)
3939
!! Cross Attention Back propagation
4040
class(cross_attention_layer), intent(in out) :: self
4141
real, intent(in) :: input(:, :, :)
@@ -46,7 +46,7 @@ module subroutine backward(self, input, gradient)
4646
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
4747
end subroutine backward
4848

49-
module subroutine forward(self, input)
49+
pure module subroutine forward(self, input)
5050
!! Cross Attention Forward propagation
5151
!! Input Shape (kind, sequence_length, model_dimension)
5252
!! where kind is 1 for Query and 2 for Key-Value

src/nf/nf_multihead_attention.f90

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end function multihead_attention_layer_cons
6060

6161
interface
6262

63-
module subroutine common_backward(self, input, gradient)
63+
pure module subroutine common_backward(self, input, gradient)
6464
!! General backprop for MultiHead Attention mechanism
6565
!! Might be used for both Self and Cross Attention
6666
!! Self Attention: sum output gradients
@@ -70,7 +70,7 @@ module subroutine common_backward(self, input, gradient)
7070
real, intent(in) :: gradient(:, :)
7171
end subroutine common_backward
7272

73-
module subroutine common_forward(self, query, key, value)
73+
pure module subroutine common_forward(self, query, key, value)
7474
!! General forward propagation for MultiHead Attention Mechanism
7575
!! Might be used for both Self and Cross Attention
7676
!! Self Attention: pass the same value thrice
@@ -79,63 +79,63 @@ module subroutine common_forward(self, query, key, value)
7979
real, intent(in) :: query(:, :), key(:, :), value(:, :)
8080
end subroutine common_forward
8181

82-
module subroutine init(self, input_shape)
82+
pure module subroutine init(self, input_shape)
8383
!! Initialize the layer data structures.
8484
!!
8585
!! This is a deferred procedure from the `base_layer` abstract type.
8686
class(multihead_attention_layer), intent(in out) :: self
8787
integer, intent(in) :: input_shape(:)
8888
end subroutine init
8989

90-
module function split_heads(self, input) result(output)
90+
pure module function split_heads(self, input) result(output)
9191
!! Split inputs into heads
9292
!!
9393
!! Example with two heads:
9494
!! input (3, 4)
9595
!! output (3, 2, 2)
96-
class(multihead_attention_layer) :: self
97-
real :: input(:, :)
96+
class(multihead_attention_layer), intent(in) :: self
97+
real, intent(in) :: input(:, :)
9898
real :: output(self % sequence_length, self % head_size, self % n_heads)
9999
end function split_heads
100100

101-
module subroutine create_attention_matrix(self, query, key)
101+
pure module subroutine create_attention_matrix(self, query, key)
102102
!! Create attention matrix for query and key
103103
!! Output dimensions: sequence_length, sequence_length, n_heads
104-
class(multihead_attention_layer) :: self
105-
real :: query(:, :, :)
106-
real :: key(:, :, :)
104+
class(multihead_attention_layer), intent(in out) :: self
105+
real, intent(in) :: query(:, :, :)
106+
real, intent(in) :: key(:, :, :)
107107
integer :: head
108108
end subroutine create_attention_matrix
109109

110-
module subroutine normalize_attention_matrix(self, attention_mask)
110+
pure module subroutine normalize_attention_matrix(self, attention_mask)
111111
!! Create attention matrix for query and key
112112
!! Output dims: sequence_length, sequence_length, n_heads
113-
class(multihead_attention_layer) :: self
113+
class(multihead_attention_layer), intent(in out) :: self
114114
!! (sequence_length, sequence_length, n_heads)
115-
real, optional :: attention_mask(:, :, :)
115+
real, optional, intent(in) :: attention_mask(:, :, :)
116116
!! (sequence_length, sequence_length, n_heads)
117117
real, allocatable :: output(:, :, :)
118118
integer :: head, seq
119119
end subroutine normalize_attention_matrix
120120

121-
module subroutine scaled_dot_product_attention(self, value)
121+
pure module subroutine scaled_dot_product_attention(self, value)
122122
!! Create scaled dot product attention
123123
!! Output dims: sequence_length, head_size, n_heads
124-
class(multihead_attention_layer) :: self
125-
real :: value(:, :, :)
124+
class(multihead_attention_layer), intent(in out) :: self
125+
real, intent(in) :: value(:, :, :)
126126
integer :: head
127127
end subroutine scaled_dot_product_attention
128128

129-
module function combine_heads(self, input) result(output)
130-
class(multihead_attention_layer) :: self
131-
real :: input(:, :, :)
129+
pure module function combine_heads(self, input) result(output)
130+
class(multihead_attention_layer), intent(in) :: self
131+
real, intent(in) :: input(:, :, :)
132132
!! (sequence_length, head_size, n_heads)
133133
real :: output(self % sequence_length, self % model_dimension)
134134
integer :: seq
135135
end function combine_heads
136136

137-
module function get_num_params(self) result(num_params)
138-
class(multihead_attention_layer) :: self
137+
elemental module function get_num_params(self) result(num_params)
138+
class(multihead_attention_layer), intent(in) :: self
139139
integer :: num_params
140140
end function get_num_params
141141

src/nf/nf_multihead_attention_submodule.f90

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module function multihead_attention_layer_cons(n_heads) result(res)
1414
res % n_heads = n_heads
1515
end function multihead_attention_layer_cons
1616

17-
module subroutine common_backward(self, input, gradient)
17+
pure module subroutine common_backward(self, input, gradient)
1818
class(multihead_attention_layer), intent(in out) :: self
1919
real, intent(in) :: input(:, :)
2020
real, intent(in) :: gradient(:, :)
@@ -112,7 +112,7 @@ module subroutine common_backward(self, input, gradient)
112112
deallocate(dk)
113113
end subroutine common_backward
114114

115-
module subroutine common_forward(self, query, key, value)
115+
pure module subroutine common_forward(self, query, key, value)
116116
class(multihead_attention_layer), intent(in out) :: self
117117
real, intent(in) :: query(:, :), key(:, :), value(:, :)
118118

@@ -156,27 +156,27 @@ module subroutine common_forward(self, query, key, value)
156156
deallocate(v)
157157
end subroutine common_forward
158158

159-
module function split_heads(self, input) result(output)
160-
class(multihead_attention_layer) :: self
161-
real :: input(:, :)
159+
pure module function split_heads(self, input) result(output)
160+
class(multihead_attention_layer), intent(in) :: self
161+
real, intent(in) :: input(:, :)
162162
real :: output(self % sequence_length, self % head_size, self % n_heads)
163163
output = reshape(input, [self % sequence_length, self % head_size, self % n_heads])
164164
end function split_heads
165165

166-
module subroutine create_attention_matrix(self, query, key)
167-
class(multihead_attention_layer) :: self
168-
real :: query(:, :, :)
169-
real :: key(:, :, :)
166+
pure module subroutine create_attention_matrix(self, query, key)
167+
class(multihead_attention_layer), intent(in out) :: self
168+
real, intent(in) :: query(:, :, :)
169+
real, intent(in) :: key(:, :, :)
170170
integer :: head
171171
! create attention matrix for each sequence in each batch
172172
do concurrent(head = 1: self % n_heads)
173173
self % attention_matrix(:, :, head) = matmul(query(:, :, head), transpose(key(:, :, head)))
174174
end do
175175
end subroutine create_attention_matrix
176176

177-
module subroutine normalize_attention_matrix(self, attention_mask)
178-
class(multihead_attention_layer) :: self
179-
real, optional :: attention_mask(:, :, :)
177+
pure module subroutine normalize_attention_matrix(self, attention_mask)
178+
class(multihead_attention_layer), intent(in out) :: self
179+
real, optional, intent(in) :: attention_mask(:, :, :)
180180
real, allocatable :: output(:, :, :)
181181
integer :: head, seq
182182

@@ -198,19 +198,19 @@ module subroutine normalize_attention_matrix(self, attention_mask)
198198
deallocate(output)
199199
end subroutine normalize_attention_matrix
200200

201-
module subroutine scaled_dot_product_attention(self, value)
202-
class(multihead_attention_layer) :: self
203-
real :: value(:, :, :)
201+
pure module subroutine scaled_dot_product_attention(self, value)
202+
class(multihead_attention_layer), intent(in out) :: self
203+
real, intent(in) :: value(:, :, :)
204204
integer :: head
205205

206206
do concurrent(head = 1: self % n_heads)
207207
self % sdpa(:, :, head) = matmul(self % attention_matrix(:, :, head), value(:, :, head))
208208
end do
209209
end subroutine scaled_dot_product_attention
210210

211-
module function combine_heads(self, input) result(output)
212-
class(multihead_attention_layer) :: self
213-
real :: input(:, :, :)
211+
pure module function combine_heads(self, input) result(output)
212+
class(multihead_attention_layer), intent(in) :: self
213+
real, intent(in) :: input(:, :, :)
214214
real :: output(self % sequence_length, self % model_dimension)
215215
integer :: seq
216216

@@ -219,8 +219,8 @@ module function combine_heads(self, input) result(output)
219219
end do
220220
end function combine_heads
221221

222-
module function get_num_params(self) result(num_params)
223-
class(multihead_attention_layer) :: self
222+
elemental module function get_num_params(self) result(num_params)
223+
class(multihead_attention_layer), intent(in) :: self
224224
integer :: num_params
225225

226226
num_params = &

src/nf/nf_self_attention_layer.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module function self_attention_layer_cons(n_heads) result(res)
3535
res % n_heads = n_heads
3636
end function self_attention_layer_cons
3737

38-
module subroutine backward(self, input, gradient)
38+
pure module subroutine backward(self, input, gradient)
3939
!! Self Attention back propagation
4040
!! Returns sum of Query, Key and Value gradients
4141
class(self_attention_layer), intent(in out) :: self
@@ -49,7 +49,7 @@ module subroutine backward(self, input, gradient)
4949
+ self % value_layer % gradient
5050
end subroutine backward
5151

52-
module subroutine forward(self, input)
52+
pure module subroutine forward(self, input)
5353
!! Cross Attention forward propagation
5454
!! Passes input three times into MultiHead Attention
5555
!! Input Shape: (sequence_length, model_dimension)

test/test_multihead_attention_layer.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ subroutine test_multihead_attention_split_heads(attention, input, ok, output)
6868
end subroutine test_multihead_attention_split_heads
6969

7070
subroutine test_multihead_attention_create_attention_matrix(attention, input, ok)
71-
type(multihead_attention_layer), intent(in) :: attention
71+
type(multihead_attention_layer), intent(in out) :: attention
7272
real, intent(in) :: input(:, :, :)
7373
logical, intent(in out) :: ok
7474
real :: attention_matrix_shape(3)
@@ -95,7 +95,7 @@ subroutine test_multihead_attention_create_attention_matrix(attention, input, ok
9595
end subroutine test_multihead_attention_create_attention_matrix
9696

9797
subroutine test_multihead_attention_normalization(attention, ok)
98-
type(multihead_attention_layer), intent(in) :: attention
98+
type(multihead_attention_layer), intent(in out) :: attention
9999
logical, intent(in out) :: ok
100100
real :: output_flat(18)
101101
real :: expected_output_flat(18) = [&
@@ -114,7 +114,7 @@ subroutine test_multihead_attention_normalization(attention, ok)
114114
end subroutine test_multihead_attention_normalization
115115

116116
subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok)
117-
type(multihead_attention_layer), intent(in) :: attention
117+
type(multihead_attention_layer), intent(in out) :: attention
118118
real, intent(in) :: value(:, :, :)
119119
logical, intent(in out) :: ok
120120
real :: output_flat(12)

0 commit comments

Comments
 (0)