diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..79f7e714c2f3 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -320,7 +320,12 @@ def sparse_attn_indexer( num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): + allowed_topk = ( + (512, 1024, 2048) + if current_platform.is_device_capability_family(100) + else (512, 2048) + ) + if current_platform.is_cuda() and topk_tokens in allowed_topk: workspace_manager = current_workspace_manager() (topk_workspace,) = workspace_manager.get_simultaneous( ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),