@@ -41,7 +41,7 @@ def get_kv_cache_shape(
4141 num_kv_heads : int ,
4242 head_size : int ,
4343 ) -> tuple [int , ...]:
44- return (num_blocks , block_size , num_kv_heads * head_size )
44+ return (num_blocks , block_size , num_kv_heads * 2 , head_size )
4545
4646 @staticmethod
4747 def swap_blocks (
@@ -132,7 +132,7 @@ def forward(
132132 query : torch .Tensor ,
133133 key : torch .Tensor ,
134134 value : torch .Tensor ,
135- kv_cache : tuple [ torch .Tensor , torch . Tensor ] ,
135+ kv_cache : torch .Tensor ,
136136 attn_metadata : PallasMetadata ,
137137 output : Optional [torch .Tensor ] = None ,
138138 ) -> torch .Tensor :
@@ -142,14 +142,13 @@ def forward(
142142 query: shape = [num_tokens, num_heads * head_size]
143143 key: shape = [num_tokens, num_kv_heads * head_size]
144144 value: shape = [num_tokens, num_kv_heads * head_size]
145- kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
146- [num_blocks, block_size, num_kv_heads * head_size])
145+ kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
147146 attn_metadata: Metadata for attention.
148147 Returns:
149148 shape = [num_tokens, num_heads * head_size]
150149 """
151150 # For determine_available_memory case.
152- if kv_cache [ 0 ] .numel () == 0 :
151+ if kv_cache .numel () == 0 :
153152 if output is None :
154153 output = torch .ones_like (query )
155154 return output
@@ -158,15 +157,13 @@ def forward(
158157 num_tokens , hidden_size = query .shape
159158 query = query .view (num_tokens , self .num_heads , self .head_size )
160159
161- key_cache , value_cache = kv_cache
162- if kv_cache [0 ].numel () > 0 :
160+ if kv_cache .numel () > 0 :
163161 slot_mapping = attn_metadata .slot_mapping
164- write_to_kv_cache (key , value , key_cache , value_cache , slot_mapping )
162+ write_to_kv_cache (key , value , kv_cache , slot_mapping )
165163
166164 output = torch .ops .xla .ragged_paged_attention (
167165 query ,
168- key_cache ,
169- value_cache ,
166+ kv_cache ,
170167 attn_metadata .context_lens ,
171168 attn_metadata .block_tables ,
172169 attn_metadata .query_start_loc ,
@@ -183,23 +180,27 @@ def forward(
183180def write_to_kv_cache (
184181 key : torch .Tensor ,
185182 value : torch .Tensor ,
186- key_cache : torch .Tensor ,
187- value_cache : torch .Tensor ,
183+ kv_cache : torch .Tensor ,
188184 slot_mapping : torch .Tensor ,
189185) -> None :
190186 """ Write the key and values to the KV cache.
191187
192188 Args:
193189 key: shape = [num_tokens, num_kv_heads * head_size]
194- value: shape = [num_tokens, num_kv_heads * head_size]
195- k_cache = [num_blocks, block_size, num_kv_heads * head_size]
196- v_cache = [num_blocks, block_size, num_kv_heads * head_size]
190+ value: shape = [num_tokens, num_kv_heads * head_size]
191+ kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
197192
198193 """
199- torch . ops . xla . dynamo_set_buffer_donor_ ( key_cache , True )
200- torch . ops . xla . dynamo_set_buffer_donor_ ( value_cache , True )
194+ _ , _ , num_combined_kv_heads , head_size = kv_cache . shape
195+ num_kv_heads = num_combined_kv_heads // 2
201196
202- key_cache = key_cache .flatten (0 , 1 )
203- value_cache = value_cache .flatten (0 , 1 )
204- key_cache .index_copy_ (0 , slot_mapping , key )
205- value_cache .index_copy_ (0 , slot_mapping , value )
197+ key = key .view (- 1 , num_kv_heads , head_size )
198+ value = value .view (- 1 , num_kv_heads , head_size )
199+
200+ kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
201+ head_size )
202+
203+ torch .ops .xla .dynamo_set_buffer_donor_ (kv_cache , True )
204+
205+ kv_cache = kv_cache .flatten (0 , 1 )
206+ kv_cache .index_copy_ (0 , slot_mapping , kv )
0 commit comments