Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,15 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
"csrc/rocm/moe_q_gemm_rdna3.cu")
endif()

set(VLLM_ROCM_HAS_GFX950 OFF)
if(VLLM_GPU_ARCHES MATCHES "gfx950")
set(VLLM_ROCM_HAS_GFX950 ON)
list(APPEND VLLM_ROCM_EXT_SRC
"csrc/rocm/sparse_mla_decode.cu")
set_source_files_properties("csrc/rocm/sparse_mla_decode.cu"
PROPERTIES COMPILE_OPTIONS "-Wno-c++11-narrowing")
endif()

define_extension_target(
_rocm_C
DESTINATION vllm
Expand All @@ -1325,6 +1334,9 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_ROCM_HAS_GFX1100)
target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX1100)
endif()
if(VLLM_ROCM_HAS_GFX950)
target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX950)
endif()
endif()

# Must run after the last HIP `define_extension_target` so every extension
Expand Down
17 changes: 17 additions & 0 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,20 @@ void paged_attention(
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
const std::string& mfma_type);

void sparse_mla_decode_single(
torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices,
torch::Tensor main_indptr, torch::Tensor extra_cache,
torch::Tensor extra_indices, torch::Tensor extra_indptr,
const std::optional<torch::Tensor>& attn_sink, torch::Tensor output,
int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows,
int64_t extra_num_rows, double scale, bool has_extra);

void sparse_mla_decode_split(
torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices,
torch::Tensor main_indptr, torch::Tensor extra_cache,
torch::Tensor extra_indices, torch::Tensor extra_indptr,
const std::optional<torch::Tensor>& attn_sink, torch::Tensor output,
torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc,
int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows,
int64_t extra_num_rows, double scale, bool has_extra, int64_t split_k);
Loading
Loading