@@ -60,7 +60,7 @@ end function multihead_attention_layer_cons
6060
6161 interface
6262
63- module subroutine common_backward (self , input , gradient )
63+ pure module subroutine common_backward(self, input, gradient)
6464 ! ! General backprop for MultiHead Attention mechanism
6565 ! ! Might be used for both Self and Cross Attention
6666 ! ! Self Attention: sum output gradients
@@ -70,7 +70,7 @@ module subroutine common_backward(self, input, gradient)
7070 real , intent (in ) :: gradient(:, :)
7171 end subroutine common_backward
7272
73- module subroutine common_forward (self , query , key , value )
73+ pure module subroutine common_forward(self, query, key, value)
7474 ! ! General forward propagation for MultiHead Attention Mechanism
7575 ! ! Might be used for both Self and Cross Attention
7676 ! ! Self Attention: pass the same value thrice
@@ -79,63 +79,63 @@ module subroutine common_forward(self, query, key, value)
7979 real , intent (in ) :: query(:, :), key(:, :), value(:, :)
8080 end subroutine common_forward
8181
82- module subroutine init (self , input_shape )
82+ pure module subroutine init(self, input_shape)
8383 ! ! Initialize the layer data structures.
8484 ! !
8585 ! ! This is a deferred procedure from the `base_layer` abstract type.
8686 class(multihead_attention_layer), intent (in out ) :: self
8787 integer , intent (in ) :: input_shape(:)
8888 end subroutine init
8989
90- module function split_heads (self , input ) result(output)
90+ pure module function split_heads(self, input) result(output)
9191 ! ! Split inputs into heads
9292 ! !
9393 ! ! Example with two heads:
9494 ! ! input (3, 4)
9595 ! ! output (3, 2, 2)
96- class(multihead_attention_layer) :: self
97- real :: input(:, :)
96+ class(multihead_attention_layer), intent ( in ) :: self
97+ real , intent ( in ) :: input(:, :)
9898 real :: output(self % sequence_length, self % head_size, self % n_heads)
9999 end function split_heads
100100
101- module subroutine create_attention_matrix (self , query , key )
101+ pure module subroutine create_attention_matrix(self, query, key)
102102 ! ! Create attention matrix for query and key
103103 ! ! Output dimensions: sequence_length, sequence_length, n_heads
104- class(multihead_attention_layer) :: self
105- real :: query(:, :, :)
106- real :: key(:, :, :)
104+ class(multihead_attention_layer), intent ( in out ) :: self
105+ real , intent ( in ) :: query(:, :, :)
106+ real , intent ( in ) :: key(:, :, :)
107107 integer :: head
108108 end subroutine create_attention_matrix
109109
110- module subroutine normalize_attention_matrix (self , attention_mask )
110+ pure module subroutine normalize_attention_matrix(self, attention_mask)
111111 ! ! Create attention matrix for query and key
112112 ! ! Output dims: sequence_length, sequence_length, n_heads
113- class(multihead_attention_layer) :: self
113+ class(multihead_attention_layer), intent ( in out ) :: self
114114 ! ! (sequence_length, sequence_length, n_heads)
115- real , optional :: attention_mask(:, :, :)
115+ real , optional , intent ( in ) :: attention_mask(:, :, :)
116116 ! ! (sequence_length, sequence_length, n_heads)
117117 real , allocatable :: output(:, :, :)
118118 integer :: head, seq
119119 end subroutine normalize_attention_matrix
120120
121- module subroutine scaled_dot_product_attention (self , value )
121+ pure module subroutine scaled_dot_product_attention(self, value)
122122 ! ! Create scaled dot product attention
123123 ! ! Output dims: sequence_length, head_size, n_heads
124- class(multihead_attention_layer) :: self
125- real :: value(:, :, :)
124+ class(multihead_attention_layer), intent ( in out ) :: self
125+ real , intent ( in ) :: value(:, :, :)
126126 integer :: head
127127 end subroutine scaled_dot_product_attention
128128
129- module function combine_heads (self , input ) result(output)
130- class(multihead_attention_layer) :: self
131- real :: input(:, :, :)
129+ pure module function combine_heads(self, input) result(output)
130+ class(multihead_attention_layer), intent ( in ) :: self
131+ real , intent ( in ) :: input(:, :, :)
132132 ! ! (sequence_length, head_size, n_heads)
133133 real :: output(self % sequence_length, self % model_dimension)
134134 integer :: seq
135135 end function combine_heads
136136
137- module function get_num_params (self ) result(num_params)
138- class(multihead_attention_layer) :: self
137+ elemental module function get_num_params(self) result(num_params)
138+ class(multihead_attention_layer), intent ( in ) :: self
139139 integer :: num_params
140140 end function get_num_params
141141
0 commit comments