diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 953d7b3c45dd..0f81a2bc53b2 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -348,6 +348,8 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool: class TritonAttentionImpl(AttentionImpl): + forward_includes_kv_cache_update: bool = False + def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -572,11 +574,16 @@ def do_kv_cache_update( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, - ): + ) -> None: if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return + + if self.kv_sharing_target_layer_name is not None: + # Skip this if sharing KV cache with an earlier attention layer. + return + # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(1) @@ -587,16 +594,19 @@ def do_kv_cache_update( # triton kernel does not support uint8 kv_cache # (because some explicit casts (e.g. float8_e4m3fnuz) # are not supported) - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + # NOTE: key/value may be padded while slot_mapping is not. + if key is not None and value is not None: + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) def fused_rope_kvcache_supported(self): return rocm_aiter_ops.is_enabled() diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c70970fdc06e..174642353680 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -550,16 +550,25 @@ def swap_states(self, i1: int, i2: int) -> None: self.num_computed_tokens_cpu[i1], ) - # NOTE: the following is unsafe - # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ - # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] # instead, we need to temporarily copy the data for one of the indices - # TODO(lucas): optimize this by only copying valid indices - tmp = self.token_ids_cpu[i1, ...].copy() - self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] - self.token_ids_cpu[i2, ...] = tmp - - self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + # optimize this by only copying valid indices (active token prefix) + len_of_orig_i2 = self.num_tokens_no_spec[i1] + len(self.spec_token_ids[i1]) + len_of_orig_i1 = self.num_tokens_no_spec[i2] + len(self.spec_token_ids[i2]) + + tmp_tokens = self.token_ids_cpu[i1, :len_of_orig_i1].copy() + tmp_is_token_ids = self.is_token_ids[i1, :len_of_orig_i1].copy() + + self.token_ids_cpu[i1, :len_of_orig_i2] = self.token_ids_cpu[ + i2, :len_of_orig_i2 + ] + self.is_token_ids[i1, :len_of_orig_i2] = self.is_token_ids[i2, :len_of_orig_i2] + if len_of_orig_i1 > len_of_orig_i2: + self.is_token_ids[i1, len_of_orig_i2:len_of_orig_i1] = False + + self.token_ids_cpu[i2, :len_of_orig_i1] = tmp_tokens + self.is_token_ids[i2, :len_of_orig_i1] = tmp_is_token_ids + if len_of_orig_i2 > len_of_orig_i1: + self.is_token_ids[i2, len_of_orig_i1:len_of_orig_i2] = False # Swap prompt embeddings if they exist embeds_i1 = self.req_prompt_embeds.get(i1)