diff --git a/tests/attention/test_attention_support_checks.py b/tests/attention/test_attention_support_checks.py new file mode 100644 index 0000000000..debfb1aeac --- /dev/null +++ b/tests/attention/test_attention_support_checks.py @@ -0,0 +1,416 @@ +""" +Test file for Attention API support checks. + +This file serves as a TODO list for support check implementations. +APIs with @pytest.mark.xfail need support checks to be implemented. +""" + +import pytest + +from flashinfer.decode import ( + BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + single_decode_with_kv_cache, + trtllm_batch_decode_with_kv_cache, + xqa_batch_decode_with_kv_cache, +) +from flashinfer.prefill import ( + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + fmha_v2_prefill_deepseek, + single_prefill_with_kv_cache, + trtllm_batch_context_with_kv_cache, + trtllm_ragged_attention_deepseek, +) +from flashinfer.cascade import ( + BatchDecodeWithSharedPrefixPagedKVCacheWrapper, + BatchPrefillWithSharedPrefixPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, + merge_state, + merge_state_in_place, + merge_states, +) +from flashinfer.mla import ( + BatchMLAPagedAttentionWrapper, + trtllm_batch_decode_with_kv_cache_mla, + xqa_batch_decode_with_kv_cache_mla, +) +from flashinfer.sparse import ( + BlockSparseAttentionWrapper, +) +from flashinfer.xqa import xqa, xqa_mla +from flashinfer.page import ( + append_paged_kv_cache, + append_paged_mla_kv_cache, + get_batch_indices_positions, +) +from flashinfer.rope import ( + apply_llama31_rope, + apply_llama31_rope_inplace, + apply_llama31_rope_pos_ids, + apply_llama31_rope_pos_ids_inplace, + apply_rope, + apply_rope_inplace, + apply_rope_pos_ids, + apply_rope_pos_ids_inplace, + apply_rope_with_cos_sin_cache, + apply_rope_with_cos_sin_cache_inplace, +) +from flashinfer.cudnn.decode import cudnn_batch_decode_with_kv_cache +from flashinfer.cudnn.prefill import cudnn_batch_prefill_with_kv_cache +from flashinfer.pod import PODWithPagedKVCacheWrapper, BatchPODWithPagedKVCacheWrapper + + +# Decode APIs +@pytest.mark.xfail( + reason="TODO: Support checks for single_decode_with_kv_cache are not implemented" +) +def test_single_decode_with_kv_cache_support_checks(): + assert hasattr(single_decode_with_kv_cache, "is_compute_capability_supported") + assert hasattr(single_decode_with_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for BatchDecodeWithPagedKVCacheWrapper are not implemented" +) +def test_batch_decode_with_paged_kv_cache_wrapper_support_checks(): + assert hasattr( + BatchDecodeWithPagedKVCacheWrapper.run, "is_compute_capability_supported" + ) + assert hasattr(BatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for CUDAGraphBatchDecodeWithPagedKVCacheWrapper are not implemented" +) +def test_cuda_graph_batch_decode_wrapper_support_checks(): + assert hasattr( + CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run, + "is_compute_capability_supported", + ) + assert hasattr( + CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported" + ) + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache are not implemented" +) +def test_trtllm_batch_decode_with_kv_cache_support_checks(): + assert hasattr(trtllm_batch_decode_with_kv_cache, "is_compute_capability_supported") + assert hasattr(trtllm_batch_decode_with_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for xqa_batch_decode_with_kv_cache are not implemented" +) +def test_xqa_batch_decode_with_kv_cache_support_checks(): + assert hasattr(xqa_batch_decode_with_kv_cache, "is_compute_capability_supported") + assert hasattr(xqa_batch_decode_with_kv_cache, "is_backend_supported") + + +# Prefill APIs +@pytest.mark.xfail( + reason="TODO: Support checks for single_prefill_with_kv_cache are not implemented" +) +def test_single_prefill_with_kv_cache_support_checks(): + assert hasattr(single_prefill_with_kv_cache, "is_compute_capability_supported") + assert hasattr(single_prefill_with_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for BatchPrefillWithPagedKVCacheWrapper are not implemented" +) +def test_batch_prefill_with_paged_kv_cache_wrapper_support_checks(): + assert hasattr( + BatchPrefillWithPagedKVCacheWrapper.run, "is_compute_capability_supported" + ) + assert hasattr(BatchPrefillWithPagedKVCacheWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for BatchPrefillWithRaggedKVCacheWrapper are not implemented" +) +def test_batch_prefill_with_ragged_kv_cache_wrapper_support_checks(): + assert hasattr( + BatchPrefillWithRaggedKVCacheWrapper.run, "is_compute_capability_supported" + ) + assert hasattr(BatchPrefillWithRaggedKVCacheWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_ragged_attention_deepseek are not implemented" +) +def test_trtllm_ragged_attention_deepseek_support_checks(): + assert hasattr(trtllm_ragged_attention_deepseek, "is_compute_capability_supported") + assert hasattr(trtllm_ragged_attention_deepseek, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_batch_context_with_kv_cache are not implemented" +) +def test_trtllm_batch_context_with_kv_cache_support_checks(): + assert hasattr( + trtllm_batch_context_with_kv_cache, "is_compute_capability_supported" + ) + assert hasattr(trtllm_batch_context_with_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for fmha_v2_prefill_deepseek are not implemented" +) +def test_fmha_v2_prefill_deepseek_support_checks(): + assert hasattr(fmha_v2_prefill_deepseek, "is_compute_capability_supported") + assert hasattr(fmha_v2_prefill_deepseek, "is_backend_supported") + + +# Cascade APIs +@pytest.mark.xfail(reason="TODO: Support checks for merge_state are not implemented") +def test_merge_state_support_checks(): + assert hasattr(merge_state, "is_compute_capability_supported") + assert hasattr(merge_state, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for merge_state_in_place are not implemented" +) +def test_merge_state_in_place_support_checks(): + assert hasattr(merge_state_in_place, "is_compute_capability_supported") + assert hasattr(merge_state_in_place, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for merge_states are not implemented") +def test_merge_states_support_checks(): + assert hasattr(merge_states, "is_compute_capability_supported") + assert hasattr(merge_states, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for MultiLevelCascadeAttentionWrapper are not implemented" +) +def test_multi_level_cascade_wrapper_support_checks(): + assert hasattr( + MultiLevelCascadeAttentionWrapper.run, "is_compute_capability_supported" + ) + assert hasattr(MultiLevelCascadeAttentionWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for BatchDecodeWithSharedPrefixPagedKVCacheWrapper are not implemented" +) +def test_batch_decode_shared_prefix_wrapper_support_checks(): + assert hasattr( + BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, + "is_compute_capability_supported", + ) + assert hasattr( + BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported" + ) + + +@pytest.mark.xfail( + reason="TODO: Support checks for BatchPrefillWithSharedPrefixPagedKVCacheWrapper are not implemented" +) +def test_batch_prefill_shared_prefix_wrapper_support_checks(): + assert hasattr( + BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, + "is_compute_capability_supported", + ) + assert hasattr( + BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported" + ) + + +# MLA APIs +@pytest.mark.xfail( + reason="TODO: Support checks for BatchMLAPagedAttentionWrapper are not implemented" +) +def test_batch_decode_mla_wrapper_support_checks(): + assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_compute_capability_supported") + assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache_mla are not implemented" +) +def test_trtllm_batch_decode_mla_support_checks(): + assert hasattr( + trtllm_batch_decode_with_kv_cache_mla, "is_compute_capability_supported" + ) + assert hasattr(trtllm_batch_decode_with_kv_cache_mla, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for xqa_batch_decode_with_kv_cache_mla are not implemented" +) +def test_xqa_batch_decode_mla_support_checks(): + assert hasattr( + xqa_batch_decode_with_kv_cache_mla, "is_compute_capability_supported" + ) + assert hasattr(xqa_batch_decode_with_kv_cache_mla, "is_backend_supported") + + +# Sparse APIs +@pytest.mark.xfail( + reason="TODO: Support checks for BlockSparseAttentionWrapper are not implemented" +) +def test_block_sparse_attention_wrapper_support_checks(): + assert hasattr(BlockSparseAttentionWrapper.run, "is_compute_capability_supported") + assert hasattr(BlockSparseAttentionWrapper.run, "is_backend_supported") + + +# XQA APIs +@pytest.mark.xfail(reason="TODO: Support checks for xqa are not implemented") +def test_xqa_support_checks(): + assert hasattr(xqa, "is_compute_capability_supported") + assert hasattr(xqa, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for xqa_mla are not implemented") +def test_xqa_mla_support_checks(): + assert hasattr(xqa_mla, "is_compute_capability_supported") + assert hasattr(xqa_mla, "is_backend_supported") + + +# Page APIs +@pytest.mark.xfail( + reason="TODO: Support checks for get_batch_indices_positions are not implemented" +) +def test_get_batch_indices_positions_support_checks(): + assert hasattr(get_batch_indices_positions, "is_compute_capability_supported") + assert hasattr(get_batch_indices_positions, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for append_paged_mla_kv_cache are not implemented" +) +def test_append_paged_mla_kv_cache_support_checks(): + assert hasattr(append_paged_mla_kv_cache, "is_compute_capability_supported") + assert hasattr(append_paged_mla_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for append_paged_kv_cache are not implemented" +) +def test_append_paged_kv_cache_support_checks(): + assert hasattr(append_paged_kv_cache, "is_compute_capability_supported") + assert hasattr(append_paged_kv_cache, "is_backend_supported") + + +# RoPE APIs +@pytest.mark.xfail( + reason="TODO: Support checks for apply_rope_inplace are not implemented" +) +def test_apply_rope_inplace_support_checks(): + assert hasattr(apply_rope_inplace, "is_compute_capability_supported") + assert hasattr(apply_rope_inplace, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_rope_pos_ids_inplace are not implemented" +) +def test_apply_rope_pos_ids_inplace_support_checks(): + assert hasattr(apply_rope_pos_ids_inplace, "is_compute_capability_supported") + assert hasattr(apply_rope_pos_ids_inplace, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_llama31_rope_inplace are not implemented" +) +def test_apply_llama31_rope_inplace_support_checks(): + assert hasattr(apply_llama31_rope_inplace, "is_compute_capability_supported") + assert hasattr(apply_llama31_rope_inplace, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_llama31_rope_pos_ids_inplace are not implemented" +) +def test_apply_llama31_rope_pos_ids_inplace_support_checks(): + assert hasattr( + apply_llama31_rope_pos_ids_inplace, "is_compute_capability_supported" + ) + assert hasattr(apply_llama31_rope_pos_ids_inplace, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for apply_rope are not implemented") +def test_apply_rope_support_checks(): + assert hasattr(apply_rope, "is_compute_capability_supported") + assert hasattr(apply_rope, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_rope_pos_ids are not implemented" +) +def test_apply_rope_pos_ids_support_checks(): + assert hasattr(apply_rope_pos_ids, "is_compute_capability_supported") + assert hasattr(apply_rope_pos_ids, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_llama31_rope are not implemented" +) +def test_apply_llama31_rope_support_checks(): + assert hasattr(apply_llama31_rope, "is_compute_capability_supported") + assert hasattr(apply_llama31_rope, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_llama31_rope_pos_ids are not implemented" +) +def test_apply_llama31_rope_pos_ids_support_checks(): + assert hasattr(apply_llama31_rope_pos_ids, "is_compute_capability_supported") + assert hasattr(apply_llama31_rope_pos_ids, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_rope_with_cos_sin_cache are not implemented" +) +def test_apply_rope_with_cos_sin_cache_support_checks(): + assert hasattr(apply_rope_with_cos_sin_cache, "is_compute_capability_supported") + assert hasattr(apply_rope_with_cos_sin_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for apply_rope_with_cos_sin_cache_inplace are not implemented" +) +def test_apply_rope_with_cos_sin_cache_inplace_support_checks(): + assert hasattr( + apply_rope_with_cos_sin_cache_inplace, "is_compute_capability_supported" + ) + assert hasattr(apply_rope_with_cos_sin_cache_inplace, "is_backend_supported") + + +# cuDNN APIs +@pytest.mark.xfail( + reason="TODO: Support checks for cudnn_batch_decode_with_kv_cache are not implemented" +) +def test_cudnn_batch_decode_support_checks(): + assert hasattr(cudnn_batch_decode_with_kv_cache, "is_compute_capability_supported") + assert hasattr(cudnn_batch_decode_with_kv_cache, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for cudnn_batch_prefill_with_kv_cache are not implemented" +) +def test_cudnn_batch_prefill_support_checks(): + assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_compute_capability_supported") + assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_backend_supported") + + +# POD APIs +@pytest.mark.xfail( + reason="TODO: Support checks for BatchPODWithPagedKVCacheWrapper are not implemented" +) +def test_pod_prefill_wrapper_support_checks(): + assert hasattr( + BatchPODWithPagedKVCacheWrapper.run, "is_compute_capability_supported" + ) + assert hasattr(BatchPODWithPagedKVCacheWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for PODWithPagedKVCacheWrapper are not implemented" +) +def test_pod_decode_wrapper_support_checks(): + assert hasattr(PODWithPagedKVCacheWrapper.run, "is_compute_capability_supported") + assert hasattr(PODWithPagedKVCacheWrapper.run, "is_backend_supported") diff --git a/tests/comm/test_comm_support_checks.py b/tests/comm/test_comm_support_checks.py new file mode 100644 index 0000000000..1d8ab0fa50 --- /dev/null +++ b/tests/comm/test_comm_support_checks.py @@ -0,0 +1,245 @@ +""" +Test file for Comm API support checks. + +This file serves as a TODO list for support check implementations. +APIs with @pytest.mark.xfail need support checks to be implemented. +""" + +import pytest + +from flashinfer.comm import ( + trtllm_allreduce_fusion, + trtllm_custom_all_reduce, + trtllm_moe_allreduce_fusion, + trtllm_moe_finalize_allreduce_fusion, + trtllm_create_ipc_workspace_for_all_reduce, + trtllm_destroy_ipc_workspace_for_all_reduce, + trtllm_create_ipc_workspace_for_all_reduce_fusion, + trtllm_destroy_ipc_workspace_for_all_reduce_fusion, + trtllm_lamport_initialize, + trtllm_lamport_initialize_all, + vllm_all_reduce, + vllm_init_custom_ar, + vllm_dispose, + vllm_register_buffer, + vllm_register_graph_buffers, + MoeAlltoAll, +) +from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce +from flashinfer.comm.mnnvl import MnnvlMemory +from flashinfer.comm.trtllm_alltoall import MnnvlMoe, MoEAlltoallInfo +from flashinfer.comm.trtllm_mnnvl_ar import ( + trtllm_mnnvl_allreduce, + trtllm_mnnvl_fused_allreduce_add_rmsnorm, + trtllm_mnnvl_fused_allreduce_rmsnorm, + trtllm_mnnvl_all_reduce, +) + + +# TRTLLM AllReduce APIs +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_allreduce_fusion are not implemented" +) +def test_trtllm_allreduce_fusion_support_checks(): + assert hasattr(trtllm_allreduce_fusion, "is_compute_capability_supported") + assert hasattr(trtllm_allreduce_fusion, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_custom_all_reduce are not implemented" +) +def test_trtllm_custom_all_reduce_support_checks(): + assert hasattr(trtllm_custom_all_reduce, "is_compute_capability_supported") + assert hasattr(trtllm_custom_all_reduce, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_moe_allreduce_fusion are not implemented" +) +def test_trtllm_moe_allreduce_fusion_support_checks(): + assert hasattr(trtllm_moe_allreduce_fusion, "is_compute_capability_supported") + assert hasattr(trtllm_moe_allreduce_fusion, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_moe_finalize_allreduce_fusion are not implemented" +) +def test_trtllm_moe_finalize_allreduce_fusion_support_checks(): + assert hasattr( + trtllm_moe_finalize_allreduce_fusion, "is_compute_capability_supported" + ) + assert hasattr(trtllm_moe_finalize_allreduce_fusion, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce are not implemented" +) +def test_trtllm_create_ipc_workspace_for_all_reduce_support_checks(): + assert hasattr( + trtllm_create_ipc_workspace_for_all_reduce, "is_compute_capability_supported" + ) + assert hasattr(trtllm_create_ipc_workspace_for_all_reduce, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce are not implemented" +) +def test_trtllm_destroy_ipc_workspace_for_all_reduce_support_checks(): + assert hasattr( + trtllm_destroy_ipc_workspace_for_all_reduce, "is_compute_capability_supported" + ) + assert hasattr(trtllm_destroy_ipc_workspace_for_all_reduce, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce_fusion are not implemented" +) +def test_trtllm_create_ipc_workspace_for_all_reduce_fusion_support_checks(): + assert hasattr( + trtllm_create_ipc_workspace_for_all_reduce_fusion, + "is_compute_capability_supported", + ) + assert hasattr( + trtllm_create_ipc_workspace_for_all_reduce_fusion, "is_backend_supported" + ) + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce_fusion are not implemented" +) +def test_trtllm_destroy_ipc_workspace_for_all_reduce_fusion_support_checks(): + assert hasattr( + trtllm_destroy_ipc_workspace_for_all_reduce_fusion, + "is_compute_capability_supported", + ) + assert hasattr( + trtllm_destroy_ipc_workspace_for_all_reduce_fusion, "is_backend_supported" + ) + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_lamport_initialize are not implemented" +) +def test_trtllm_lamport_initialize_support_checks(): + assert hasattr(trtllm_lamport_initialize, "is_compute_capability_supported") + assert hasattr(trtllm_lamport_initialize, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_lamport_initialize_all are not implemented" +) +def test_trtllm_lamport_initialize_all_support_checks(): + assert hasattr(trtllm_lamport_initialize_all, "is_compute_capability_supported") + assert hasattr(trtllm_lamport_initialize_all, "is_backend_supported") + + +# VLLM AllReduce APIs +@pytest.mark.xfail( + reason="TODO: Support checks for vllm_all_reduce are not implemented" +) +def test_vllm_all_reduce_support_checks(): + assert hasattr(vllm_all_reduce, "is_compute_capability_supported") + assert hasattr(vllm_all_reduce, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for vllm_init_custom_ar are not implemented" +) +def test_vllm_init_custom_ar_support_checks(): + assert hasattr(vllm_init_custom_ar, "is_compute_capability_supported") + assert hasattr(vllm_init_custom_ar, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for vllm_dispose are not implemented") +def test_vllm_dispose_support_checks(): + assert hasattr(vllm_dispose, "is_compute_capability_supported") + assert hasattr(vllm_dispose, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for vllm_register_buffer are not implemented" +) +def test_vllm_register_buffer_support_checks(): + assert hasattr(vllm_register_buffer, "is_compute_capability_supported") + assert hasattr(vllm_register_buffer, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for vllm_register_graph_buffers are not implemented" +) +def test_vllm_register_graph_buffers_support_checks(): + assert hasattr(vllm_register_graph_buffers, "is_compute_capability_supported") + assert hasattr(vllm_register_graph_buffers, "is_backend_supported") + + +# NVSHMEM APIs +@pytest.mark.xfail( + reason="TODO: Support checks for NVSHMEMAllReduce are not implemented" +) +def test_nvshmem_allreduce_support_checks(): + assert hasattr(NVSHMEMAllReduce, "is_compute_capability_supported") + assert hasattr(NVSHMEMAllReduce, "is_backend_supported") + + +# MNNVL APIs +@pytest.mark.xfail(reason="TODO: Support checks for MnnvlMemory are not implemented") +def test_mnnvl_memory_support_checks(): + assert hasattr(MnnvlMemory, "is_compute_capability_supported") + assert hasattr(MnnvlMemory, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for MnnvlMoe are not implemented") +def test_mnnvl_moe_support_checks(): + assert hasattr(MnnvlMoe, "is_compute_capability_supported") + assert hasattr(MnnvlMoe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for MoEAlltoallInfo are not implemented" +) +def test_moe_alltoall_info_support_checks(): + assert hasattr(MoEAlltoallInfo, "is_compute_capability_supported") + assert hasattr(MoEAlltoallInfo, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for MoeAlltoAll are not implemented") +def test_moe_alltoall_support_checks(): + assert hasattr(MoeAlltoAll.dispatch, "is_compute_capability_supported") + assert hasattr(MoeAlltoAll.dispatch, "is_backend_supported") + + +# TRTLLM MNNVL AllReduce APIs +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_mnnvl_allreduce are not implemented" +) +def test_trtllm_mnnvl_allreduce_support_checks(): + assert hasattr(trtllm_mnnvl_allreduce, "is_compute_capability_supported") + assert hasattr(trtllm_mnnvl_allreduce, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_add_rmsnorm are not implemented" +) +def test_trtllm_mnnvl_fused_allreduce_add_rmsnorm_support_checks(): + assert hasattr( + trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_compute_capability_supported" + ) + assert hasattr(trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_rmsnorm are not implemented" +) +def test_trtllm_mnnvl_fused_allreduce_rmsnorm_support_checks(): + assert hasattr( + trtllm_mnnvl_fused_allreduce_rmsnorm, "is_compute_capability_supported" + ) + assert hasattr(trtllm_mnnvl_fused_allreduce_rmsnorm, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_mnnvl_all_reduce are not implemented" +) +def test_trtllm_mnnvl_all_reduce_support_checks(): + assert hasattr(trtllm_mnnvl_all_reduce, "is_compute_capability_supported") + assert hasattr(trtllm_mnnvl_all_reduce, "is_backend_supported") diff --git a/tests/gemm/test_gemm_support_checks.py b/tests/gemm/test_gemm_support_checks.py new file mode 100644 index 0000000000..eeaa9f419a --- /dev/null +++ b/tests/gemm/test_gemm_support_checks.py @@ -0,0 +1,160 @@ +""" +Test file for GEMM API support checks. + +This file serves as a TODO list for support check implementations. +APIs with @pytest.mark.xfail need support checks to be implemented in gemm_base.py. +""" + +import pytest + +from flashinfer import ( + bmm_fp8, + mm_fp4, + mm_fp8, + prepare_low_latency_gemm_weights, + tgv_gemm_sm100, +) +from flashinfer.gemm import ( + SegmentGEMMWrapper, + batch_deepgemm_fp8_nt_groupwise, + gemm_fp8_nt_blockscaled, + gemm_fp8_nt_groupwise, + group_deepgemm_fp8_nt_groupwise, + group_gemm_fp8_nt_groupwise, + group_gemm_mxfp4_nt_groupwise, +) +from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked +from flashinfer.cute_dsl.gemm_allreduce_two_shot import PersistentDenseGemmKernel +import flashinfer.triton.sm_constraint_gemm as sm_constraint_gemm + + +def test_mm_fp4_support_checks(): + assert hasattr(mm_fp4, "is_compute_capability_supported") + assert hasattr(mm_fp4, "is_backend_supported") + + +def test_bmm_fp8_support_checks(): + assert hasattr(bmm_fp8, "is_compute_capability_supported") + assert hasattr(bmm_fp8, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for tgv_gemm_sm100 are not implemented") +def test_tgv_gemm_sm100_support_checks(): + assert hasattr(tgv_gemm_sm100, "is_compute_capability_supported") + assert hasattr(tgv_gemm_sm100, "is_backend_supported") + + +@pytest.mark.xfail(reason="TODO: Support checks for mm_fp8 are not implemented") +def test_mm_fp8_support_checks(): + assert hasattr(mm_fp8, "is_compute_capability_supported") + assert hasattr(mm_fp8, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for gemm_fp8_nt_groupwise are not implemented" +) +def test_gemm_fp8_nt_groupwise_support_checks(): + assert hasattr(gemm_fp8_nt_groupwise, "is_compute_capability_supported") + assert hasattr(gemm_fp8_nt_groupwise, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for gemm_fp8_nt_blockscaled are not implemented" +) +def test_gemm_fp8_nt_blockscaled_support_checks(): + assert hasattr(gemm_fp8_nt_blockscaled, "is_compute_capability_supported") + assert hasattr(gemm_fp8_nt_blockscaled, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for group_gemm_fp8_nt_groupwise are not implemented" +) +def test_group_gemm_fp8_nt_groupwise_support_checks(): + assert hasattr(group_gemm_fp8_nt_groupwise, "is_compute_capability_supported") + assert hasattr(group_gemm_fp8_nt_groupwise, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for group_gemm_mxfp4_nt_groupwise are not implemented" +) +def test_group_gemm_mxfp4_nt_groupwise_support_checks(): + assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_compute_capability_supported") + assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for group_deepgemm_fp8_nt_groupwise are not implemented" +) +def test_group_deepgemm_fp8_nt_groupwise_support_checks(): + assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported") + assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for batch_deepgemm_fp8_nt_groupwise are not implemented" +) +def test_batch_deepgemm_fp8_nt_groupwise_support_checks(): + assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported") + assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for SegmentGEMMWrapper are not implemented" +) +def test_segment_gemm_wrapper_support_checks(): + assert hasattr(SegmentGEMMWrapper.run, "is_compute_capability_supported") + assert hasattr(SegmentGEMMWrapper.run, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for prepare_low_latency_gemm_weights are not implemented" +) +def test_prepare_low_latency_gemm_weights_support_checks(): + assert hasattr(prepare_low_latency_gemm_weights, "is_compute_capability_supported") + assert hasattr(prepare_low_latency_gemm_weights, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for grouped_gemm_nt_masked are not implemented" +) +def test_grouped_gemm_nt_masked_support_checks(): + assert hasattr(grouped_gemm_nt_masked, "is_compute_capability_supported") + assert hasattr(grouped_gemm_nt_masked, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for PersistentDenseGemmKernel are not implemented" +) +def test_persistent_dense_gemm_kernel_support_checks(): + assert hasattr(PersistentDenseGemmKernel, "is_compute_capability_supported") + assert hasattr(PersistentDenseGemmKernel, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for sm_constraint_gemm.gemm are not implemented" +) +def test_sm_constraint_gemm_support_checks(): + assert hasattr(sm_constraint_gemm.gemm, "is_compute_capability_supported") + assert hasattr(sm_constraint_gemm.gemm, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for sm_constraint_gemm.gemm_persistent are not implemented" +) +def test_sm_constraint_gemm_persistent_support_checks(): + assert hasattr( + sm_constraint_gemm.gemm_persistent, "is_compute_capability_supported" + ) + assert hasattr(sm_constraint_gemm.gemm_persistent, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for sm_constraint_gemm.gemm_descriptor_persistent are not implemented" +) +def test_sm_constraint_gemm_descriptor_persistent_support_checks(): + assert hasattr( + sm_constraint_gemm.gemm_descriptor_persistent, "is_compute_capability_supported" + ) + assert hasattr( + sm_constraint_gemm.gemm_descriptor_persistent, "is_backend_supported" + ) diff --git a/tests/moe/test_moe_support_checks.py b/tests/moe/test_moe_support_checks.py new file mode 100644 index 0000000000..5ebd23312b --- /dev/null +++ b/tests/moe/test_moe_support_checks.py @@ -0,0 +1,140 @@ +""" +Test file for MoE API support checks. + +This file serves as a TODO list for support check implementations. +APIs with @pytest.mark.xfail need support checks to be implemented. +""" + +import pytest + +from flashinfer.fused_moe import ( + cutlass_fused_moe, + fused_topk_deepseek, + trtllm_bf16_moe, + trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, + trtllm_mxint4_block_scale_moe, +) +from flashinfer.comm.trtllm_moe_alltoall import ( + moe_a2a_combine, + moe_a2a_dispatch, + moe_a2a_get_workspace_size_per_rank, + moe_a2a_initialize, + moe_a2a_sanitize_expert_ids, + moe_a2a_wrap_payload_tensor_in_workspace, +) + + +def test_fused_topk_deepseek_support_checks(): + assert hasattr(fused_topk_deepseek, "is_compute_capability_supported") + assert hasattr(fused_topk_deepseek, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for cutlass_fused_moe are not implemented" +) +def test_cutlass_fused_moe_support_checks(): + assert hasattr(cutlass_fused_moe, "is_compute_capability_supported") + assert hasattr(cutlass_fused_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_bf16_moe are not implemented" +) +def test_trtllm_bf16_moe_support_checks(): + assert hasattr(trtllm_bf16_moe, "is_compute_capability_supported") + assert hasattr(trtllm_bf16_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_fp8_per_tensor_scale_moe are not implemented" +) +def test_trtllm_fp8_per_tensor_scale_moe_support_checks(): + assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_compute_capability_supported") + assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_fp8_block_scale_moe are not implemented" +) +def test_trtllm_fp8_block_scale_moe_support_checks(): + assert hasattr(trtllm_fp8_block_scale_moe, "is_compute_capability_supported") + assert hasattr(trtllm_fp8_block_scale_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_fp4_block_scale_moe are not implemented" +) +def test_trtllm_fp4_block_scale_moe_support_checks(): + assert hasattr(trtllm_fp4_block_scale_moe, "is_compute_capability_supported") + assert hasattr(trtllm_fp4_block_scale_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_fp4_block_scale_routed_moe are not implemented" +) +def test_trtllm_fp4_block_scale_routed_moe_support_checks(): + assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_compute_capability_supported") + assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for trtllm_mxint4_block_scale_moe are not implemented" +) +def test_trtllm_mxint4_block_scale_moe_support_checks(): + assert hasattr(trtllm_mxint4_block_scale_moe, "is_compute_capability_supported") + assert hasattr(trtllm_mxint4_block_scale_moe, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_initialize are not implemented" +) +def test_moe_a2a_initialize_support_checks(): + assert hasattr(moe_a2a_initialize, "is_compute_capability_supported") + assert hasattr(moe_a2a_initialize, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_wrap_payload_tensor_in_workspace are not implemented" +) +def test_moe_a2a_wrap_payload_tensor_in_workspace_support_checks(): + assert hasattr( + moe_a2a_wrap_payload_tensor_in_workspace, "is_compute_capability_supported" + ) + assert hasattr(moe_a2a_wrap_payload_tensor_in_workspace, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_dispatch are not implemented" +) +def test_moe_a2a_dispatch_support_checks(): + assert hasattr(moe_a2a_dispatch, "is_compute_capability_supported") + assert hasattr(moe_a2a_dispatch, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_combine are not implemented" +) +def test_moe_a2a_combine_support_checks(): + assert hasattr(moe_a2a_combine, "is_compute_capability_supported") + assert hasattr(moe_a2a_combine, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_sanitize_expert_ids are not implemented" +) +def test_moe_a2a_sanitize_expert_ids_support_checks(): + assert hasattr(moe_a2a_sanitize_expert_ids, "is_compute_capability_supported") + assert hasattr(moe_a2a_sanitize_expert_ids, "is_backend_supported") + + +@pytest.mark.xfail( + reason="TODO: Support checks for moe_a2a_get_workspace_size_per_rank are not implemented" +) +def test_moe_a2a_get_workspace_size_per_rank_support_checks(): + assert hasattr( + moe_a2a_get_workspace_size_per_rank, "is_compute_capability_supported" + ) + assert hasattr(moe_a2a_get_workspace_size_per_rank, "is_backend_supported")