diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..6bca26c4b9f6 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -320,7 +320,7 @@ def sparse_attn_indexer( num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): + if current_platform.is_cuda() and topk_tokens in (512, 2048): workspace_manager = current_workspace_manager() (topk_workspace,) = workspace_manager.get_simultaneous( ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),