@@ -133,23 +133,27 @@ def __init__(
133133
134134 def forward (
135135 self ,
136- input_pos : torch .Tensor ,
136+ input_pos : Optional [ torch .Tensor ] ,
137137 q : torch .Tensor , # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
138138 k : torch .Tensor , # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
139139 v : torch .Tensor , # (bs, n_local_kv_heads, seqlen, head_dim)
140140 bsz ,
141141 seqlen ,
142142 mask : torch .Tensor ,
143143 ) -> torch .Tensor :
144- if self .enable_dynamic_shape :
145- start_pos = input_pos [- 1 ].item ()
146- torch ._check_is_size (start_pos )
147- torch ._check (start_pos < self .max_context_len )
148- seq_length = q .size (2 )
149- # pyre-ignore: Incompatible parameter type [6]
150- attn_mask = mask .narrow (0 , start_pos , seq_length )
144+ if input_pos is None :
145+ # No kv cache
146+ attn_mask = mask [:seqlen , :seqlen ]
151147 else :
152- attn_mask = mask [None , None , input_pos ]
148+ if self .enable_dynamic_shape :
149+ start_pos = input_pos [- 1 ].item ()
150+ torch ._check_is_size (start_pos )
151+ torch ._check (start_pos < self .max_context_len )
152+ seq_length = q .size (2 )
153+ # pyre-ignore: Incompatible parameter type [6]
154+ attn_mask = mask .narrow (0 , start_pos , seq_length )
155+ else :
156+ attn_mask = mask [None , None , input_pos ]
153157
154158 # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
155159 # can natively support GQA now. But needs enable_gqa=True
@@ -218,13 +222,13 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
218222 self .head_dim ,
219223 args .enable_dynamic_shape ,
220224 )
221- self .SDPA = SDPA (
222- dim = self .n_local_heads * self .head_dim ,
223- head_dim = self .head_dim ,
224- n_rep = self .n_rep ,
225- max_context_len = self .max_context_len ,
226- enable_dynamic_shape = args .enable_dynamic_shape ,
227- )
225+ self .SDPA = SDPA (
226+ dim = self .n_local_heads * self .head_dim ,
227+ head_dim = self .head_dim ,
228+ n_rep = self .n_rep ,
229+ max_context_len = self .max_context_len ,
230+ enable_dynamic_shape = args .enable_dynamic_shape ,
231+ )
228232
229233 def forward (
230234 self ,
@@ -257,21 +261,5 @@ def forward(
257261 if self .use_kv_cache :
258262 assert input_pos is not None
259263 k , v = self .kv_cache .update (input_pos , k , v )
260- output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
261- return self .wo (output ), None
262-
263- # grouped multiquery attention: expand out keys and values
264- k = k .repeat_interleave (self .n_rep , dim = 1 )
265- v = v .repeat_interleave (self .n_rep , dim = 1 )
266-
267- assert hasattr (self , "mask" )
268-
269- mask = self .mask [:seqlen , :seqlen ]
270-
271- output = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
272-
273- output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
274-
275- output = self .wo (output )
276-
277- return output , None
264+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
265+ return self .wo (output ), None
0 commit comments