misc: support checks unit test tracking#2224
misc: support checks unit test tracking#2224jimmyzho wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
WalkthroughThis pull request adds four new test modules containing xfail-marked support-check tests for Attention, Comm, GEMM, and MoE APIs. Each test verifies the presence of Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~4 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes a comprehensive framework for tracking the implementation of backend and compute capability support checks across numerous FlashInfer APIs. By introducing a suite of Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive set of unit tests to track the implementation of support checks (is_compute_capability_supported and is_backend_supported) across various FlashInfer APIs. The use of pytest.mark.xfail is a good strategy for managing this as a TODO list.
My main feedback is to refactor the new test files to use pytest.mark.parametrize. This will significantly reduce code duplication and improve the maintainability of these tests. I've left specific suggestions on each of the new test files with examples of how to apply this pattern.
| """ | ||
| Test file for Attention API support checks. | ||
|
|
||
| This file serves as a TODO list for support check implementations. | ||
| APIs with @pytest.mark.xfail need support checks to be implemented. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from flashinfer.decode import ( | ||
| BatchDecodeWithPagedKVCacheWrapper, | ||
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper, | ||
| single_decode_with_kv_cache, | ||
| trtllm_batch_decode_with_kv_cache, | ||
| xqa_batch_decode_with_kv_cache, | ||
| ) | ||
| from flashinfer.prefill import ( | ||
| BatchPrefillWithPagedKVCacheWrapper, | ||
| BatchPrefillWithRaggedKVCacheWrapper, | ||
| fmha_v2_prefill_deepseek, | ||
| single_prefill_with_kv_cache, | ||
| trtllm_batch_context_with_kv_cache, | ||
| trtllm_ragged_attention_deepseek, | ||
| ) | ||
| from flashinfer.cascade import ( | ||
| BatchDecodeWithSharedPrefixPagedKVCacheWrapper, | ||
| BatchPrefillWithSharedPrefixPagedKVCacheWrapper, | ||
| MultiLevelCascadeAttentionWrapper, | ||
| merge_state, | ||
| merge_state_in_place, | ||
| merge_states, | ||
| ) | ||
| from flashinfer.mla import ( | ||
| BatchMLAPagedAttentionWrapper, | ||
| trtllm_batch_decode_with_kv_cache_mla, | ||
| xqa_batch_decode_with_kv_cache_mla, | ||
| ) | ||
| from flashinfer.sparse import ( | ||
| BlockSparseAttentionWrapper, | ||
| ) | ||
| from flashinfer.xqa import xqa, xqa_mla | ||
| from flashinfer.page import ( | ||
| append_paged_kv_cache, | ||
| append_paged_mla_kv_cache, | ||
| get_batch_indices_positions, | ||
| ) | ||
| from flashinfer.rope import ( | ||
| apply_llama31_rope, | ||
| apply_llama31_rope_inplace, | ||
| apply_llama31_rope_pos_ids, | ||
| apply_llama31_rope_pos_ids_inplace, | ||
| apply_rope, | ||
| apply_rope_inplace, | ||
| apply_rope_pos_ids, | ||
| apply_rope_pos_ids_inplace, | ||
| apply_rope_with_cos_sin_cache, | ||
| apply_rope_with_cos_sin_cache_inplace, | ||
| ) | ||
| from flashinfer.cudnn.decode import cudnn_batch_decode_with_kv_cache | ||
| from flashinfer.cudnn.prefill import cudnn_batch_prefill_with_kv_cache | ||
| from flashinfer.pod import PODWithPagedKVCacheWrapper, BatchPODWithPagedKVCacheWrapper | ||
|
|
||
|
|
||
| # Decode APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for single_decode_with_kv_cache are not implemented" | ||
| ) | ||
| def test_single_decode_with_kv_cache_support_checks(): | ||
| assert hasattr(single_decode_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(single_decode_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchDecodeWithPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_batch_decode_with_paged_kv_cache_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchDecodeWithPagedKVCacheWrapper.run, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(BatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for CUDAGraphBatchDecodeWithPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_cuda_graph_batch_decode_wrapper_support_checks(): | ||
| assert hasattr( | ||
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run, | ||
| "is_compute_capability_supported", | ||
| ) | ||
| assert hasattr( | ||
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache are not implemented" | ||
| ) | ||
| def test_trtllm_batch_decode_with_kv_cache_support_checks(): | ||
| assert hasattr(trtllm_batch_decode_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_batch_decode_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for xqa_batch_decode_with_kv_cache are not implemented" | ||
| ) | ||
| def test_xqa_batch_decode_with_kv_cache_support_checks(): | ||
| assert hasattr(xqa_batch_decode_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(xqa_batch_decode_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| # Prefill APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for single_prefill_with_kv_cache are not implemented" | ||
| ) | ||
| def test_single_prefill_with_kv_cache_support_checks(): | ||
| assert hasattr(single_prefill_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(single_prefill_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchPrefillWithPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_batch_prefill_with_paged_kv_cache_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchPrefillWithPagedKVCacheWrapper.run, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(BatchPrefillWithPagedKVCacheWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchPrefillWithRaggedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_batch_prefill_with_ragged_kv_cache_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchPrefillWithRaggedKVCacheWrapper.run, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(BatchPrefillWithRaggedKVCacheWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_ragged_attention_deepseek are not implemented" | ||
| ) | ||
| def test_trtllm_ragged_attention_deepseek_support_checks(): | ||
| assert hasattr(trtllm_ragged_attention_deepseek, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_ragged_attention_deepseek, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_batch_context_with_kv_cache are not implemented" | ||
| ) | ||
| def test_trtllm_batch_context_with_kv_cache_support_checks(): | ||
| assert hasattr( | ||
| trtllm_batch_context_with_kv_cache, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_batch_context_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for fmha_v2_prefill_deepseek are not implemented" | ||
| ) | ||
| def test_fmha_v2_prefill_deepseek_support_checks(): | ||
| assert hasattr(fmha_v2_prefill_deepseek, "is_compute_capability_supported") | ||
| assert hasattr(fmha_v2_prefill_deepseek, "is_backend_supported") | ||
|
|
||
|
|
||
| # Cascade APIs | ||
| @pytest.mark.xfail(reason="TODO: Support checks for merge_state are not implemented") | ||
| def test_merge_state_support_checks(): | ||
| assert hasattr(merge_state, "is_compute_capability_supported") | ||
| assert hasattr(merge_state, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for merge_state_in_place are not implemented" | ||
| ) | ||
| def test_merge_state_in_place_support_checks(): | ||
| assert hasattr(merge_state_in_place, "is_compute_capability_supported") | ||
| assert hasattr(merge_state_in_place, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for merge_states are not implemented") | ||
| def test_merge_states_support_checks(): | ||
| assert hasattr(merge_states, "is_compute_capability_supported") | ||
| assert hasattr(merge_states, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for MultiLevelCascadeAttentionWrapper are not implemented" | ||
| ) | ||
| def test_multi_level_cascade_wrapper_support_checks(): | ||
| assert hasattr( | ||
| MultiLevelCascadeAttentionWrapper.run, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(MultiLevelCascadeAttentionWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchDecodeWithSharedPrefixPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_batch_decode_shared_prefix_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, | ||
| "is_compute_capability_supported", | ||
| ) | ||
| assert hasattr( | ||
| BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchPrefillWithSharedPrefixPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_batch_prefill_shared_prefix_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, | ||
| "is_compute_capability_supported", | ||
| ) | ||
| assert hasattr( | ||
| BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported" | ||
| ) | ||
|
|
||
|
|
||
| # MLA APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchMLAPagedAttentionWrapper are not implemented" | ||
| ) | ||
| def test_batch_decode_mla_wrapper_support_checks(): | ||
| assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_compute_capability_supported") | ||
| assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache_mla are not implemented" | ||
| ) | ||
| def test_trtllm_batch_decode_mla_support_checks(): | ||
| assert hasattr( | ||
| trtllm_batch_decode_with_kv_cache_mla, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_batch_decode_with_kv_cache_mla, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for xqa_batch_decode_with_kv_cache_mla are not implemented" | ||
| ) | ||
| def test_xqa_batch_decode_mla_support_checks(): | ||
| assert hasattr( | ||
| xqa_batch_decode_with_kv_cache_mla, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(xqa_batch_decode_with_kv_cache_mla, "is_backend_supported") | ||
|
|
||
|
|
||
| # Sparse APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BlockSparseAttentionWrapper are not implemented" | ||
| ) | ||
| def test_block_sparse_attention_wrapper_support_checks(): | ||
| assert hasattr(BlockSparseAttentionWrapper.run, "is_compute_capability_supported") | ||
| assert hasattr(BlockSparseAttentionWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| # XQA APIs | ||
| @pytest.mark.xfail(reason="TODO: Support checks for xqa are not implemented") | ||
| def test_xqa_support_checks(): | ||
| assert hasattr(xqa, "is_compute_capability_supported") | ||
| assert hasattr(xqa, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for xqa_mla are not implemented") | ||
| def test_xqa_mla_support_checks(): | ||
| assert hasattr(xqa_mla, "is_compute_capability_supported") | ||
| assert hasattr(xqa_mla, "is_backend_supported") | ||
|
|
||
|
|
||
| # Page APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for get_batch_indices_positions are not implemented" | ||
| ) | ||
| def test_get_batch_indices_positions_support_checks(): | ||
| assert hasattr(get_batch_indices_positions, "is_compute_capability_supported") | ||
| assert hasattr(get_batch_indices_positions, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for append_paged_mla_kv_cache are not implemented" | ||
| ) | ||
| def test_append_paged_mla_kv_cache_support_checks(): | ||
| assert hasattr(append_paged_mla_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(append_paged_mla_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for append_paged_kv_cache are not implemented" | ||
| ) | ||
| def test_append_paged_kv_cache_support_checks(): | ||
| assert hasattr(append_paged_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(append_paged_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| # RoPE APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_rope_inplace are not implemented" | ||
| ) | ||
| def test_apply_rope_inplace_support_checks(): | ||
| assert hasattr(apply_rope_inplace, "is_compute_capability_supported") | ||
| assert hasattr(apply_rope_inplace, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_rope_pos_ids_inplace are not implemented" | ||
| ) | ||
| def test_apply_rope_pos_ids_inplace_support_checks(): | ||
| assert hasattr(apply_rope_pos_ids_inplace, "is_compute_capability_supported") | ||
| assert hasattr(apply_rope_pos_ids_inplace, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_llama31_rope_inplace are not implemented" | ||
| ) | ||
| def test_apply_llama31_rope_inplace_support_checks(): | ||
| assert hasattr(apply_llama31_rope_inplace, "is_compute_capability_supported") | ||
| assert hasattr(apply_llama31_rope_inplace, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_llama31_rope_pos_ids_inplace are not implemented" | ||
| ) | ||
| def test_apply_llama31_rope_pos_ids_inplace_support_checks(): | ||
| assert hasattr( | ||
| apply_llama31_rope_pos_ids_inplace, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(apply_llama31_rope_pos_ids_inplace, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for apply_rope are not implemented") | ||
| def test_apply_rope_support_checks(): | ||
| assert hasattr(apply_rope, "is_compute_capability_supported") | ||
| assert hasattr(apply_rope, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_rope_pos_ids are not implemented" | ||
| ) | ||
| def test_apply_rope_pos_ids_support_checks(): | ||
| assert hasattr(apply_rope_pos_ids, "is_compute_capability_supported") | ||
| assert hasattr(apply_rope_pos_ids, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_llama31_rope are not implemented" | ||
| ) | ||
| def test_apply_llama31_rope_support_checks(): | ||
| assert hasattr(apply_llama31_rope, "is_compute_capability_supported") | ||
| assert hasattr(apply_llama31_rope, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_llama31_rope_pos_ids are not implemented" | ||
| ) | ||
| def test_apply_llama31_rope_pos_ids_support_checks(): | ||
| assert hasattr(apply_llama31_rope_pos_ids, "is_compute_capability_supported") | ||
| assert hasattr(apply_llama31_rope_pos_ids, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_rope_with_cos_sin_cache are not implemented" | ||
| ) | ||
| def test_apply_rope_with_cos_sin_cache_support_checks(): | ||
| assert hasattr(apply_rope_with_cos_sin_cache, "is_compute_capability_supported") | ||
| assert hasattr(apply_rope_with_cos_sin_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for apply_rope_with_cos_sin_cache_inplace are not implemented" | ||
| ) | ||
| def test_apply_rope_with_cos_sin_cache_inplace_support_checks(): | ||
| assert hasattr( | ||
| apply_rope_with_cos_sin_cache_inplace, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(apply_rope_with_cos_sin_cache_inplace, "is_backend_supported") | ||
|
|
||
|
|
||
| # cuDNN APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for cudnn_batch_decode_with_kv_cache are not implemented" | ||
| ) | ||
| def test_cudnn_batch_decode_support_checks(): | ||
| assert hasattr(cudnn_batch_decode_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(cudnn_batch_decode_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for cudnn_batch_prefill_with_kv_cache are not implemented" | ||
| ) | ||
| def test_cudnn_batch_prefill_support_checks(): | ||
| assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_compute_capability_supported") | ||
| assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_backend_supported") | ||
|
|
||
|
|
||
| # POD APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for BatchPODWithPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_pod_prefill_wrapper_support_checks(): | ||
| assert hasattr( | ||
| BatchPODWithPagedKVCacheWrapper.run, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(BatchPODWithPagedKVCacheWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for PODWithPagedKVCacheWrapper are not implemented" | ||
| ) | ||
| def test_pod_decode_wrapper_support_checks(): | ||
| assert hasattr(PODWithPagedKVCacheWrapper.run, "is_compute_capability_supported") | ||
| assert hasattr(PODWithPagedKVCacheWrapper.run, "is_backend_supported") |
There was a problem hiding this comment.
This test file is quite long and contains a lot of repetitive code. You can make it more concise and maintainable by using pytest.mark.parametrize to test all APIs with a single test function. This approach also makes it easy to manage which tests are expected to fail.
Here's an example for the Decode APIs:
import pytest
from flashinfer.decode import (
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
trtllm_batch_decode_with_kv_cache,
xqa_batch_decode_with_kv_cache,
)
decode_apis = [
pytest.param(
single_decode_with_kv_cache,
id="single_decode_with_kv_cache",
marks=pytest.mark.xfail(reason="TODO: Support checks for single_decode_with_kv_cache are not implemented")
),
pytest.param(
BatchDecodeWithPagedKVCacheWrapper.run,
id="BatchDecodeWithPagedKVCacheWrapper",
marks=pytest.mark.xfail(reason="TODO: Support checks for BatchDecodeWithPagedKVCacheWrapper are not implemented")
),
# ... other decode APIs
]
@pytest.mark.parametrize("api", decode_apis)
def test_decode_apis_support_checks(api):
assert hasattr(api, "is_compute_capability_supported")
assert hasattr(api, "is_backend_supported")You can create lists of APIs for each category (Prefill, Cascade, etc.) and apply this pattern. When a check is implemented, you can simply remove the marks from the corresponding pytest.param.
| """ | ||
| Test file for Comm API support checks. | ||
|
|
||
| This file serves as a TODO list for support check implementations. | ||
| APIs with @pytest.mark.xfail need support checks to be implemented. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from flashinfer.comm import ( | ||
| trtllm_allreduce_fusion, | ||
| trtllm_custom_all_reduce, | ||
| trtllm_moe_allreduce_fusion, | ||
| trtllm_moe_finalize_allreduce_fusion, | ||
| trtllm_create_ipc_workspace_for_all_reduce, | ||
| trtllm_destroy_ipc_workspace_for_all_reduce, | ||
| trtllm_create_ipc_workspace_for_all_reduce_fusion, | ||
| trtllm_destroy_ipc_workspace_for_all_reduce_fusion, | ||
| trtllm_lamport_initialize, | ||
| trtllm_lamport_initialize_all, | ||
| vllm_all_reduce, | ||
| vllm_init_custom_ar, | ||
| vllm_dispose, | ||
| vllm_register_buffer, | ||
| vllm_register_graph_buffers, | ||
| MoeAlltoAll, | ||
| ) | ||
| from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce | ||
| from flashinfer.comm.mnnvl import MnnvlMemory | ||
| from flashinfer.comm.trtllm_alltoall import MnnvlMoe, MoEAlltoallInfo | ||
| from flashinfer.comm.trtllm_mnnvl_ar import ( | ||
| trtllm_mnnvl_allreduce, | ||
| trtllm_mnnvl_fused_allreduce_add_rmsnorm, | ||
| trtllm_mnnvl_fused_allreduce_rmsnorm, | ||
| trtllm_mnnvl_all_reduce, | ||
| ) | ||
|
|
||
|
|
||
| # TRTLLM AllReduce APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_allreduce_fusion are not implemented" | ||
| ) | ||
| def test_trtllm_allreduce_fusion_support_checks(): | ||
| assert hasattr(trtllm_allreduce_fusion, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_allreduce_fusion, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_custom_all_reduce are not implemented" | ||
| ) | ||
| def test_trtllm_custom_all_reduce_support_checks(): | ||
| assert hasattr(trtllm_custom_all_reduce, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_custom_all_reduce, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_moe_allreduce_fusion are not implemented" | ||
| ) | ||
| def test_trtllm_moe_allreduce_fusion_support_checks(): | ||
| assert hasattr(trtllm_moe_allreduce_fusion, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_moe_allreduce_fusion, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_moe_finalize_allreduce_fusion are not implemented" | ||
| ) | ||
| def test_trtllm_moe_finalize_allreduce_fusion_support_checks(): | ||
| assert hasattr( | ||
| trtllm_moe_finalize_allreduce_fusion, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_moe_finalize_allreduce_fusion, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce are not implemented" | ||
| ) | ||
| def test_trtllm_create_ipc_workspace_for_all_reduce_support_checks(): | ||
| assert hasattr( | ||
| trtllm_create_ipc_workspace_for_all_reduce, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_create_ipc_workspace_for_all_reduce, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce are not implemented" | ||
| ) | ||
| def test_trtllm_destroy_ipc_workspace_for_all_reduce_support_checks(): | ||
| assert hasattr( | ||
| trtllm_destroy_ipc_workspace_for_all_reduce, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_destroy_ipc_workspace_for_all_reduce, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce_fusion are not implemented" | ||
| ) | ||
| def test_trtllm_create_ipc_workspace_for_all_reduce_fusion_support_checks(): | ||
| assert hasattr( | ||
| trtllm_create_ipc_workspace_for_all_reduce_fusion, | ||
| "is_compute_capability_supported", | ||
| ) | ||
| assert hasattr( | ||
| trtllm_create_ipc_workspace_for_all_reduce_fusion, "is_backend_supported" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce_fusion are not implemented" | ||
| ) | ||
| def test_trtllm_destroy_ipc_workspace_for_all_reduce_fusion_support_checks(): | ||
| assert hasattr( | ||
| trtllm_destroy_ipc_workspace_for_all_reduce_fusion, | ||
| "is_compute_capability_supported", | ||
| ) | ||
| assert hasattr( | ||
| trtllm_destroy_ipc_workspace_for_all_reduce_fusion, "is_backend_supported" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_lamport_initialize are not implemented" | ||
| ) | ||
| def test_trtllm_lamport_initialize_support_checks(): | ||
| assert hasattr(trtllm_lamport_initialize, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_lamport_initialize, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_lamport_initialize_all are not implemented" | ||
| ) | ||
| def test_trtllm_lamport_initialize_all_support_checks(): | ||
| assert hasattr(trtllm_lamport_initialize_all, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_lamport_initialize_all, "is_backend_supported") | ||
|
|
||
|
|
||
| # VLLM AllReduce APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for vllm_all_reduce are not implemented" | ||
| ) | ||
| def test_vllm_all_reduce_support_checks(): | ||
| assert hasattr(vllm_all_reduce, "is_compute_capability_supported") | ||
| assert hasattr(vllm_all_reduce, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for vllm_init_custom_ar are not implemented" | ||
| ) | ||
| def test_vllm_init_custom_ar_support_checks(): | ||
| assert hasattr(vllm_init_custom_ar, "is_compute_capability_supported") | ||
| assert hasattr(vllm_init_custom_ar, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for vllm_dispose are not implemented") | ||
| def test_vllm_dispose_support_checks(): | ||
| assert hasattr(vllm_dispose, "is_compute_capability_supported") | ||
| assert hasattr(vllm_dispose, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for vllm_register_buffer are not implemented" | ||
| ) | ||
| def test_vllm_register_buffer_support_checks(): | ||
| assert hasattr(vllm_register_buffer, "is_compute_capability_supported") | ||
| assert hasattr(vllm_register_buffer, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for vllm_register_graph_buffers are not implemented" | ||
| ) | ||
| def test_vllm_register_graph_buffers_support_checks(): | ||
| assert hasattr(vllm_register_graph_buffers, "is_compute_capability_supported") | ||
| assert hasattr(vllm_register_graph_buffers, "is_backend_supported") | ||
|
|
||
|
|
||
| # NVSHMEM APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for NVSHMEMAllReduce are not implemented" | ||
| ) | ||
| def test_nvshmem_allreduce_support_checks(): | ||
| assert hasattr(NVSHMEMAllReduce, "is_compute_capability_supported") | ||
| assert hasattr(NVSHMEMAllReduce, "is_backend_supported") | ||
|
|
||
|
|
||
| # MNNVL APIs | ||
| @pytest.mark.xfail(reason="TODO: Support checks for MnnvlMemory are not implemented") | ||
| def test_mnnvl_memory_support_checks(): | ||
| assert hasattr(MnnvlMemory, "is_compute_capability_supported") | ||
| assert hasattr(MnnvlMemory, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for MnnvlMoe are not implemented") | ||
| def test_mnnvl_moe_support_checks(): | ||
| assert hasattr(MnnvlMoe, "is_compute_capability_supported") | ||
| assert hasattr(MnnvlMoe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for MoEAlltoallInfo are not implemented" | ||
| ) | ||
| def test_moe_alltoall_info_support_checks(): | ||
| assert hasattr(MoEAlltoallInfo, "is_compute_capability_supported") | ||
| assert hasattr(MoEAlltoallInfo, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for MoeAlltoAll are not implemented") | ||
| def test_moe_alltoall_support_checks(): | ||
| assert hasattr(MoeAlltoAll.dispatch, "is_compute_capability_supported") | ||
| assert hasattr(MoeAlltoAll.dispatch, "is_backend_supported") | ||
|
|
||
|
|
||
| # TRTLLM MNNVL AllReduce APIs | ||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_mnnvl_allreduce are not implemented" | ||
| ) | ||
| def test_trtllm_mnnvl_allreduce_support_checks(): | ||
| assert hasattr(trtllm_mnnvl_allreduce, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_mnnvl_allreduce, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_add_rmsnorm are not implemented" | ||
| ) | ||
| def test_trtllm_mnnvl_fused_allreduce_add_rmsnorm_support_checks(): | ||
| assert hasattr( | ||
| trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_rmsnorm are not implemented" | ||
| ) | ||
| def test_trtllm_mnnvl_fused_allreduce_rmsnorm_support_checks(): | ||
| assert hasattr( | ||
| trtllm_mnnvl_fused_allreduce_rmsnorm, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(trtllm_mnnvl_fused_allreduce_rmsnorm, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_mnnvl_all_reduce are not implemented" | ||
| ) | ||
| def test_trtllm_mnnvl_all_reduce_support_checks(): | ||
| assert hasattr(trtllm_mnnvl_all_reduce, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_mnnvl_all_reduce, "is_backend_supported") |
There was a problem hiding this comment.
This file has a lot of repeated test code. To improve maintainability and reduce boilerplate, consider using pytest.mark.parametrize. This allows you to define a list of APIs to test and use a single test function to check them all.
For example:
import pytest
from flashinfer.comm import (
trtllm_allreduce_fusion,
trtllm_custom_all_reduce,
# ... other imports
)
trtllm_allreduce_apis = [
pytest.param(
trtllm_allreduce_fusion,
id="trtllm_allreduce_fusion",
marks=pytest.mark.xfail(reason="TODO: Support checks for trtllm_allreduce_fusion are not implemented")
),
pytest.param(
trtllm_custom_all_reduce,
id="trtllm_custom_all_reduce",
marks=pytest.mark.xfail(reason="TODO: Support checks for trtllm_custom_all_reduce are not implemented")
),
# ... other TRTLLM AllReduce APIs
]
@pytest.mark.parametrize("api", trtllm_allreduce_apis)
def test_trtllm_allreduce_apis_support_checks(api):
assert hasattr(api, "is_compute_capability_supported")
assert hasattr(api, "is_backend_supported")This pattern can be applied to the other API groups in this file as well. It makes it easy to track which tests are expected to fail and to update them as features are implemented.
| """ | ||
| Test file for GEMM API support checks. | ||
|
|
||
| This file serves as a TODO list for support check implementations. | ||
| APIs with @pytest.mark.xfail need support checks to be implemented in gemm_base.py. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from flashinfer import ( | ||
| bmm_fp8, | ||
| mm_fp4, | ||
| mm_fp8, | ||
| prepare_low_latency_gemm_weights, | ||
| tgv_gemm_sm100, | ||
| ) | ||
| from flashinfer.gemm import ( | ||
| SegmentGEMMWrapper, | ||
| batch_deepgemm_fp8_nt_groupwise, | ||
| gemm_fp8_nt_blockscaled, | ||
| gemm_fp8_nt_groupwise, | ||
| group_deepgemm_fp8_nt_groupwise, | ||
| group_gemm_fp8_nt_groupwise, | ||
| group_gemm_mxfp4_nt_groupwise, | ||
| ) | ||
| from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked | ||
| from flashinfer.cute_dsl.gemm_allreduce_two_shot import PersistentDenseGemmKernel | ||
| import flashinfer.triton.sm_constraint_gemm as sm_constraint_gemm | ||
|
|
||
|
|
||
| def test_mm_fp4_support_checks(): | ||
| assert hasattr(mm_fp4, "is_compute_capability_supported") | ||
| assert hasattr(mm_fp4, "is_backend_supported") | ||
|
|
||
|
|
||
| def test_bmm_fp8_support_checks(): | ||
| assert hasattr(bmm_fp8, "is_compute_capability_supported") | ||
| assert hasattr(bmm_fp8, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for tgv_gemm_sm100 are not implemented") | ||
| def test_tgv_gemm_sm100_support_checks(): | ||
| assert hasattr(tgv_gemm_sm100, "is_compute_capability_supported") | ||
| assert hasattr(tgv_gemm_sm100, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="TODO: Support checks for mm_fp8 are not implemented") | ||
| def test_mm_fp8_support_checks(): | ||
| assert hasattr(mm_fp8, "is_compute_capability_supported") | ||
| assert hasattr(mm_fp8, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for gemm_fp8_nt_groupwise are not implemented" | ||
| ) | ||
| def test_gemm_fp8_nt_groupwise_support_checks(): | ||
| assert hasattr(gemm_fp8_nt_groupwise, "is_compute_capability_supported") | ||
| assert hasattr(gemm_fp8_nt_groupwise, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for gemm_fp8_nt_blockscaled are not implemented" | ||
| ) | ||
| def test_gemm_fp8_nt_blockscaled_support_checks(): | ||
| assert hasattr(gemm_fp8_nt_blockscaled, "is_compute_capability_supported") | ||
| assert hasattr(gemm_fp8_nt_blockscaled, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for group_gemm_fp8_nt_groupwise are not implemented" | ||
| ) | ||
| def test_group_gemm_fp8_nt_groupwise_support_checks(): | ||
| assert hasattr(group_gemm_fp8_nt_groupwise, "is_compute_capability_supported") | ||
| assert hasattr(group_gemm_fp8_nt_groupwise, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for group_gemm_mxfp4_nt_groupwise are not implemented" | ||
| ) | ||
| def test_group_gemm_mxfp4_nt_groupwise_support_checks(): | ||
| assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_compute_capability_supported") | ||
| assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for group_deepgemm_fp8_nt_groupwise are not implemented" | ||
| ) | ||
| def test_group_deepgemm_fp8_nt_groupwise_support_checks(): | ||
| assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported") | ||
| assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for batch_deepgemm_fp8_nt_groupwise are not implemented" | ||
| ) | ||
| def test_batch_deepgemm_fp8_nt_groupwise_support_checks(): | ||
| assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported") | ||
| assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for SegmentGEMMWrapper are not implemented" | ||
| ) | ||
| def test_segment_gemm_wrapper_support_checks(): | ||
| assert hasattr(SegmentGEMMWrapper.run, "is_compute_capability_supported") | ||
| assert hasattr(SegmentGEMMWrapper.run, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for prepare_low_latency_gemm_weights are not implemented" | ||
| ) | ||
| def test_prepare_low_latency_gemm_weights_support_checks(): | ||
| assert hasattr(prepare_low_latency_gemm_weights, "is_compute_capability_supported") | ||
| assert hasattr(prepare_low_latency_gemm_weights, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for grouped_gemm_nt_masked are not implemented" | ||
| ) | ||
| def test_grouped_gemm_nt_masked_support_checks(): | ||
| assert hasattr(grouped_gemm_nt_masked, "is_compute_capability_supported") | ||
| assert hasattr(grouped_gemm_nt_masked, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for PersistentDenseGemmKernel are not implemented" | ||
| ) | ||
| def test_persistent_dense_gemm_kernel_support_checks(): | ||
| assert hasattr(PersistentDenseGemmKernel, "is_compute_capability_supported") | ||
| assert hasattr(PersistentDenseGemmKernel, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for sm_constraint_gemm.gemm are not implemented" | ||
| ) | ||
| def test_sm_constraint_gemm_support_checks(): | ||
| assert hasattr(sm_constraint_gemm.gemm, "is_compute_capability_supported") | ||
| assert hasattr(sm_constraint_gemm.gemm, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for sm_constraint_gemm.gemm_persistent are not implemented" | ||
| ) | ||
| def test_sm_constraint_gemm_persistent_support_checks(): | ||
| assert hasattr( | ||
| sm_constraint_gemm.gemm_persistent, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(sm_constraint_gemm.gemm_persistent, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for sm_constraint_gemm.gemm_descriptor_persistent are not implemented" | ||
| ) | ||
| def test_sm_constraint_gemm_descriptor_persistent_support_checks(): | ||
| assert hasattr( | ||
| sm_constraint_gemm.gemm_descriptor_persistent, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr( | ||
| sm_constraint_gemm.gemm_descriptor_persistent, "is_backend_supported" | ||
| ) |
There was a problem hiding this comment.
This test file can be simplified by using pytest.mark.parametrize to avoid repeating the same test logic for each API. This will make the code more concise and easier to maintain. Since you have a mix of passing tests and expected failures, you can configure this within the parametrization.
Here's how you could refactor it:
import pytest
from flashinfer import (
bmm_fp8,
mm_fp4,
tgv_gemm_sm100,
# ... other imports
)
gemm_apis = [
pytest.param(mm_fp4, id="mm_fp4"), # This test is expected to pass
pytest.param(bmm_fp8, id="bmm_fp8"), # This test is expected to pass
pytest.param(
tgv_gemm_sm100,
id="tgv_gemm_sm100",
marks=pytest.mark.xfail(reason="TODO: Support checks for tgv_gemm_sm100 are not implemented")
),
# ... other gemm APIs
]
@pytest.mark.parametrize("api", gemm_apis)
def test_gemm_apis_support_checks(api):
assert hasattr(api, "is_compute_capability_supported")
assert hasattr(api, "is_backend_supported")This approach clearly separates the passing tests from the xfail ones while keeping the test logic DRY (Don't Repeat Yourself).
| """ | ||
| Test file for MoE API support checks. | ||
|
|
||
| This file serves as a TODO list for support check implementations. | ||
| APIs with @pytest.mark.xfail need support checks to be implemented. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from flashinfer.fused_moe import ( | ||
| cutlass_fused_moe, | ||
| fused_topk_deepseek, | ||
| trtllm_bf16_moe, | ||
| trtllm_fp4_block_scale_moe, | ||
| trtllm_fp4_block_scale_routed_moe, | ||
| trtllm_fp8_block_scale_moe, | ||
| trtllm_fp8_per_tensor_scale_moe, | ||
| trtllm_mxint4_block_scale_moe, | ||
| ) | ||
| from flashinfer.comm.trtllm_moe_alltoall import ( | ||
| moe_a2a_combine, | ||
| moe_a2a_dispatch, | ||
| moe_a2a_get_workspace_size_per_rank, | ||
| moe_a2a_initialize, | ||
| moe_a2a_sanitize_expert_ids, | ||
| moe_a2a_wrap_payload_tensor_in_workspace, | ||
| ) | ||
|
|
||
|
|
||
| def test_fused_topk_deepseek_support_checks(): | ||
| assert hasattr(fused_topk_deepseek, "is_compute_capability_supported") | ||
| assert hasattr(fused_topk_deepseek, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for cutlass_fused_moe are not implemented" | ||
| ) | ||
| def test_cutlass_fused_moe_support_checks(): | ||
| assert hasattr(cutlass_fused_moe, "is_compute_capability_supported") | ||
| assert hasattr(cutlass_fused_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_bf16_moe are not implemented" | ||
| ) | ||
| def test_trtllm_bf16_moe_support_checks(): | ||
| assert hasattr(trtllm_bf16_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_bf16_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_fp8_per_tensor_scale_moe are not implemented" | ||
| ) | ||
| def test_trtllm_fp8_per_tensor_scale_moe_support_checks(): | ||
| assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_fp8_block_scale_moe are not implemented" | ||
| ) | ||
| def test_trtllm_fp8_block_scale_moe_support_checks(): | ||
| assert hasattr(trtllm_fp8_block_scale_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_fp8_block_scale_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_fp4_block_scale_moe are not implemented" | ||
| ) | ||
| def test_trtllm_fp4_block_scale_moe_support_checks(): | ||
| assert hasattr(trtllm_fp4_block_scale_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_fp4_block_scale_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_fp4_block_scale_routed_moe are not implemented" | ||
| ) | ||
| def test_trtllm_fp4_block_scale_routed_moe_support_checks(): | ||
| assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for trtllm_mxint4_block_scale_moe are not implemented" | ||
| ) | ||
| def test_trtllm_mxint4_block_scale_moe_support_checks(): | ||
| assert hasattr(trtllm_mxint4_block_scale_moe, "is_compute_capability_supported") | ||
| assert hasattr(trtllm_mxint4_block_scale_moe, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_initialize are not implemented" | ||
| ) | ||
| def test_moe_a2a_initialize_support_checks(): | ||
| assert hasattr(moe_a2a_initialize, "is_compute_capability_supported") | ||
| assert hasattr(moe_a2a_initialize, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_wrap_payload_tensor_in_workspace are not implemented" | ||
| ) | ||
| def test_moe_a2a_wrap_payload_tensor_in_workspace_support_checks(): | ||
| assert hasattr( | ||
| moe_a2a_wrap_payload_tensor_in_workspace, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(moe_a2a_wrap_payload_tensor_in_workspace, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_dispatch are not implemented" | ||
| ) | ||
| def test_moe_a2a_dispatch_support_checks(): | ||
| assert hasattr(moe_a2a_dispatch, "is_compute_capability_supported") | ||
| assert hasattr(moe_a2a_dispatch, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_combine are not implemented" | ||
| ) | ||
| def test_moe_a2a_combine_support_checks(): | ||
| assert hasattr(moe_a2a_combine, "is_compute_capability_supported") | ||
| assert hasattr(moe_a2a_combine, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_sanitize_expert_ids are not implemented" | ||
| ) | ||
| def test_moe_a2a_sanitize_expert_ids_support_checks(): | ||
| assert hasattr(moe_a2a_sanitize_expert_ids, "is_compute_capability_supported") | ||
| assert hasattr(moe_a2a_sanitize_expert_ids, "is_backend_supported") | ||
|
|
||
|
|
||
| @pytest.mark.xfail( | ||
| reason="TODO: Support checks for moe_a2a_get_workspace_size_per_rank are not implemented" | ||
| ) | ||
| def test_moe_a2a_get_workspace_size_per_rank_support_checks(): | ||
| assert hasattr( | ||
| moe_a2a_get_workspace_size_per_rank, "is_compute_capability_supported" | ||
| ) | ||
| assert hasattr(moe_a2a_get_workspace_size_per_rank, "is_backend_supported") |
There was a problem hiding this comment.
To make this test file more compact and maintainable, you could use pytest.mark.parametrize. This would allow you to test multiple APIs with a single test function, reducing code duplication. You can also handle the mix of passing and xfail tests cleanly.
Here is an example of how to apply this pattern:
import pytest
from flashinfer.fused_moe import (
cutlass_fused_moe,
fused_topk_deepseek,
# ... other imports
)
moe_apis = [
pytest.param(fused_topk_deepseek, id="fused_topk_deepseek"), # Expected to pass
pytest.param(
cutlass_fused_moe,
id="cutlass_fused_moe",
marks=pytest.mark.xfail(reason="TODO: Support checks for cutlass_fused_moe are not implemented")
),
# ... other MoE APIs
]
@pytest.mark.parametrize("api", moe_apis)
def test_moe_apis_support_checks(api):
assert hasattr(api, "is_compute_capability_supported")
assert hasattr(api, "is_backend_supported")This refactoring makes it easier to see which APIs are covered and to update their status as the support checks are implemented.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/moe/test_moe_support_checks.py (1)
35-40: Consider addingstrict=Trueto xfail markers.Adding
strict=Trueensures the test suite will flag when support checks are implemented, prompting removal of the xfail marker. This helps keep the TODO list accurate.-@pytest.mark.xfail( - reason="TODO: Support checks for cutlass_fused_moe are not implemented" -) +@pytest.mark.xfail( + reason="TODO: Support checks for cutlass_fused_moe are not implemented", + strict=True, +)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tests/attention/test_attention_support_checks.py(1 hunks)tests/comm/test_comm_support_checks.py(1 hunks)tests/gemm/test_gemm_support_checks.py(1 hunks)tests/moe/test_moe_support_checks.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/gemm/test_gemm_support_checks.py (5)
flashinfer/gemm/gemm_base.py (6)
mm_fp4(2086-2223)mm_fp8(1560-1667)SegmentGEMMWrapper(808-1062)batch_deepgemm_fp8_nt_groupwise(3157-3289)gemm_fp8_nt_blockscaled(2666-2691)gemm_fp8_nt_groupwise(2414-2577)flashinfer/trtllm_low_latency_gemm.py (1)
prepare_low_latency_gemm_weights(196-224)flashinfer/cute_dsl/blockscaled_gemm.py (1)
grouped_gemm_nt_masked(2945-3046)flashinfer/cute_dsl/gemm_allreduce_two_shot.py (1)
PersistentDenseGemmKernel(184-2105)flashinfer/triton/sm_constraint_gemm.py (2)
gemm_persistent(14-95)gemm_descriptor_persistent(173-278)
tests/comm/test_comm_support_checks.py (5)
flashinfer/comm/trtllm_ar.py (4)
trtllm_create_ipc_workspace_for_all_reduce(403-480)trtllm_destroy_ipc_workspace_for_all_reduce(483-495)trtllm_create_ipc_workspace_for_all_reduce_fusion(504-643)trtllm_destroy_ipc_workspace_for_all_reduce_fusion(646-662)flashinfer/comm/nvshmem_allreduce.py (1)
NVSHMEMAllReduce(25-127)flashinfer/comm/mnnvl.py (1)
MnnvlMemory(244-563)flashinfer/comm/trtllm_alltoall.py (1)
MoEAlltoallInfo(423-430)flashinfer/comm/trtllm_mnnvl_ar.py (3)
trtllm_mnnvl_allreduce(308-383)trtllm_mnnvl_fused_allreduce_rmsnorm(620-707)trtllm_mnnvl_all_reduce(546-614)
tests/attention/test_attention_support_checks.py (2)
flashinfer/page.py (1)
get_batch_indices_positions(123-175)flashinfer/cudnn/decode.py (1)
cudnn_batch_decode_with_kv_cache(257-350)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
tests/moe/test_moe_support_checks.py (2)
1-140: LGTM overall structure and coverage.The test file is well-organized with clear imports and consistent test patterns. Each test verifies both
is_compute_capability_supportedandis_backend_supportedattributes, which aligns with the PR objective of tracking support check implementations.
30-32: The support check attributes forfused_topk_deepseekare already implemented via the@backend_requirementdecorator. This decorator automatically adds bothis_compute_capability_supportedandis_backend_supportedmethods to the decorated function. The test is correctly not marked with@pytest.mark.xfailand will pass.Likely an incorrect or invalid review comment.
tests/gemm/test_gemm_support_checks.py (2)
1-160: Good coverage of GEMM API surfaces.The test file comprehensively covers various GEMM implementations including standard, grouped, block-scaled, and Triton-based variants. The organization by category (imports, tests) is clear and maintainable.
31-38: These support checks are already implemented via the@backend_requirementdecorator.Both
mm_fp4andbmm_fp8are decorated with@backend_requirement(inflashinfer/gemm/gemm_base.pyat lines 2078-2084 and 2316-2323 respectively). This decorator automatically adds theis_compute_capability_supportedandis_backend_supportedmethods to these functions, so the test assertions will pass. The tests are correctly not marked@pytest.mark.xfailbecause the implementation is complete.Likely an incorrect or invalid review comment.
tests/comm/test_comm_support_checks.py (2)
1-245: Comprehensive coverage of communication APIs.The test file covers a broad range of communication primitives across TRTLLM, VLLM, NVSHMEM, and MNNVL backends. All tests are consistently marked as xfail, making this a clear roadmap for future support check implementations.
197-202: Remove this test or reconsider ifMoEAlltoallInfoshould implement support checks.
MoEAlltoallInfois a pure dataclass containing only tensor references and a count field. Support check attributes (is_compute_capability_supported,is_backend_supported) belong on classes that perform compute operations or have backend/capability requirements, not data containers. Either this test should be removed, or the design intent for this class should be reconsidered if support checks are actually needed.Likely an incorrect or invalid review comment.
tests/attention/test_attention_support_checks.py (2)
198-221: Consistent method targeting for different wrapper interfaces.Note that
BatchDecodeWithSharedPrefixPagedKVCacheWrapperandBatchPrefillWithSharedPrefixPagedKVCacheWrappercheck.forwardwhile most other wrappers check.run. This appears intentional based on each wrapper's public API, but ensure the@backend_requirementdecorator is applied consistently to the appropriate entry point method for each wrapper.
1-416: Excellent coverage of Attention API surfaces.This file comprehensively covers Decode, Prefill, Cascade, MLA, Sparse, XQA, Page, RoPE, cuDNN, and POD APIs. The organization by category with section comments (e.g.,
# Decode APIs,# RoPE APIs) improves readability and maintainability. All tests are consistently marked as xfail, serving as a clear tracking mechanism for support check implementation progress.
📌 Description
Add unit test-style tracking for all flashinfer APIs. Will return 'xFail' if not implemented, the idea is to have them pass as we progress with applying the @backend_requirement decorator.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.