Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool:


class TritonAttentionImpl(AttentionImpl):
forward_includes_kv_cache_update: bool = False

def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym

Expand Down Expand Up @@ -572,11 +574,16 @@ def do_kv_cache_update(
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
) -> None:
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return

if self.kv_sharing_target_layer_name is not None:
# Skip this if sharing KV cache with an earlier attention layer.
return

# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)

Expand All @@ -587,16 +594,19 @@ def do_kv_cache_update(
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# NOTE: key/value may be padded while slot_mapping is not.
if key is not None and value is not None:
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
Expand Down
27 changes: 18 additions & 9 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,16 +550,25 @@ def swap_states(self, i1: int, i2: int) -> None:
self.num_computed_tokens_cpu[i1],
)

# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp

self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# optimize this by only copying valid indices (active token prefix)
len_of_orig_i2 = self.num_tokens_no_spec[i1] + len(self.spec_token_ids[i1])
len_of_orig_i1 = self.num_tokens_no_spec[i2] + len(self.spec_token_ids[i2])

tmp_tokens = self.token_ids_cpu[i1, :len_of_orig_i1].copy()
tmp_is_token_ids = self.is_token_ids[i1, :len_of_orig_i1].copy()

self.token_ids_cpu[i1, :len_of_orig_i2] = self.token_ids_cpu[
i2, :len_of_orig_i2
]
self.is_token_ids[i1, :len_of_orig_i2] = self.is_token_ids[i2, :len_of_orig_i2]
if len_of_orig_i1 > len_of_orig_i2:
self.is_token_ids[i1, len_of_orig_i2:len_of_orig_i1] = False

self.token_ids_cpu[i2, :len_of_orig_i1] = tmp_tokens
self.is_token_ids[i2, :len_of_orig_i1] = tmp_is_token_ids
if len_of_orig_i2 > len_of_orig_i1:
self.is_token_ids[i2, len_of_orig_i1:len_of_orig_i2] = False

# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
Expand Down