Skip to content

misc: support checks unit test tracking#2224

Open
jimmyzho wants to merge 1 commit intoflashinfer-ai:mainfrom
jimmyzho:support-checks
Open

misc: support checks unit test tracking#2224
jimmyzho wants to merge 1 commit intoflashinfer-ai:mainfrom
jimmyzho:support-checks

Conversation

@jimmyzho
Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho commented Dec 16, 2025

📌 Description

Add unit test-style tracking for all flashinfer APIs. Will return 'xFail' if not implemented, the idea is to have them pass as we progress with applying the @backend_requirement decorator.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Added comprehensive test infrastructure for validation and support verification across multiple core platform APIs: Attention operations, Communication primitives, GEMM matrix computations, and Mixture-of-Experts routing mechanisms. These test suites now verify compute capability and backend compatibility requirements, ensuring reliable and appropriate functionality across all supported hardware platforms and diverse deployment scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 16, 2025

Walkthrough

This pull request adds four new test modules containing xfail-marked support-check tests for Attention, Comm, GEMM, and MoE APIs. Each test verifies the presence of is_compute_capability_supported and is_backend_supported attributes on corresponding API functions and classes, serving as test infrastructure placeholders for future support check implementations.

Changes

Cohort / File(s) Summary
Attention API Support Checks
tests/attention/test_attention_support_checks.py
Adds 40 xfail-marked test functions covering Decode, Prefill, Cascade, MLA, Sparse, XQA, Page, RoPE, cuDNN, and POD attention API surfaces; each test asserts presence of support check attributes on target API components
Comm API Support Checks
tests/comm/test_comm_support_checks.py
Adds 24 xfail-marked test functions for TRTLLM, VLLM, NVSHMEM, MNNVL, MoE allreduce, and related communication APIs; validates support check attribute presence across all comm subsystems
GEMM API Support Checks
tests/gemm/test_gemm_support_checks.py
Adds 17 xfail-marked test functions covering GEMM variants (fp4, fp8, fp8_nt, group/deep GEMM, etc.), utilities, and constraint-based GEMM kernels; verifies support check attributes on GEMM-related components
MoE API Support Checks
tests/moe/test_moe_support_checks.py
Adds xfail-marked test functions for fused MoE and A2A MoE components including topk, CUTLASS, TRT-LLM variants, and workspace utilities; validates support check attribute presence across MoE API surface

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~4 minutes

  • All changes consist of highly repetitive, homogeneous test placeholders following identical xfail-decorated patterns with attribute assertions
  • Large sets of similar test functions with no logic variance or control-flow complexity
  • No changes to core functionality, runtime behavior, or existing test logic—purely additive test infrastructure

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • cyx-6
  • yongwww
  • wenscarl
  • nvmbreughe

Poem

🐰 Hop, hop, hooray! Test checks now align,
Four new modules, each one fine—
Attention, Comm, GEMM, and MoE combine,
Xfail placeholders await their time to shine! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'misc: support checks unit test tracking' is vague and uses non-descriptive terms that don't clearly convey the specific changes being made. Consider a more descriptive title that specifies the main change, such as 'Add support check unit tests for Attention, Comm, GEMM, and MoE APIs' or 'Add placeholder unit tests for API support checks tracking'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description includes the required section about adding unit test tracking for FlashInfer APIs and explains the xFail mechanism, though it uses the template format and contains incomplete checklist items.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes a comprehensive framework for tracking the implementation of backend and compute capability support checks across numerous FlashInfer APIs. By introducing a suite of xfail unit tests, it provides a clear roadmap for integrating the @backend_requirement decorator, ensuring that each API properly reports its hardware and software compatibility status as development progresses.

Highlights

  • New Test Files for Support Checks: Introduced four new test files: test_attention_support_checks.py, test_comm_support_checks.py, test_gemm_support_checks.py, and test_moe_support_checks.py.
  • API Support Tracking: These new tests track the implementation of is_compute_capability_supported and is_backend_supported attributes across a wide range of FlashInfer APIs, including Attention, Communication, GEMM, and Mixture-of-Experts (MoE) functions.
  • XFail Marking: All newly added tests are marked with pytest.mark.xfail, indicating that they are expected to fail until the corresponding support checks are fully implemented using the @backend_requirement decorator.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive set of unit tests to track the implementation of support checks (is_compute_capability_supported and is_backend_supported) across various FlashInfer APIs. The use of pytest.mark.xfail is a good strategy for managing this as a TODO list.

My main feedback is to refactor the new test files to use pytest.mark.parametrize. This will significantly reduce code duplication and improve the maintainability of these tests. I've left specific suggestions on each of the new test files with examples of how to apply this pattern.

Comment on lines +1 to +416
"""
Test file for Attention API support checks.

This file serves as a TODO list for support check implementations.
APIs with @pytest.mark.xfail need support checks to be implemented.
"""

import pytest

from flashinfer.decode import (
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
trtllm_batch_decode_with_kv_cache,
xqa_batch_decode_with_kv_cache,
)
from flashinfer.prefill import (
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
fmha_v2_prefill_deepseek,
single_prefill_with_kv_cache,
trtllm_batch_context_with_kv_cache,
trtllm_ragged_attention_deepseek,
)
from flashinfer.cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
merge_state,
merge_state_in_place,
merge_states,
)
from flashinfer.mla import (
BatchMLAPagedAttentionWrapper,
trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla,
)
from flashinfer.sparse import (
BlockSparseAttentionWrapper,
)
from flashinfer.xqa import xqa, xqa_mla
from flashinfer.page import (
append_paged_kv_cache,
append_paged_mla_kv_cache,
get_batch_indices_positions,
)
from flashinfer.rope import (
apply_llama31_rope,
apply_llama31_rope_inplace,
apply_llama31_rope_pos_ids,
apply_llama31_rope_pos_ids_inplace,
apply_rope,
apply_rope_inplace,
apply_rope_pos_ids,
apply_rope_pos_ids_inplace,
apply_rope_with_cos_sin_cache,
apply_rope_with_cos_sin_cache_inplace,
)
from flashinfer.cudnn.decode import cudnn_batch_decode_with_kv_cache
from flashinfer.cudnn.prefill import cudnn_batch_prefill_with_kv_cache
from flashinfer.pod import PODWithPagedKVCacheWrapper, BatchPODWithPagedKVCacheWrapper


# Decode APIs
@pytest.mark.xfail(
reason="TODO: Support checks for single_decode_with_kv_cache are not implemented"
)
def test_single_decode_with_kv_cache_support_checks():
assert hasattr(single_decode_with_kv_cache, "is_compute_capability_supported")
assert hasattr(single_decode_with_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for BatchDecodeWithPagedKVCacheWrapper are not implemented"
)
def test_batch_decode_with_paged_kv_cache_wrapper_support_checks():
assert hasattr(
BatchDecodeWithPagedKVCacheWrapper.run, "is_compute_capability_supported"
)
assert hasattr(BatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for CUDAGraphBatchDecodeWithPagedKVCacheWrapper are not implemented"
)
def test_cuda_graph_batch_decode_wrapper_support_checks():
assert hasattr(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run,
"is_compute_capability_supported",
)
assert hasattr(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper.run, "is_backend_supported"
)


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache are not implemented"
)
def test_trtllm_batch_decode_with_kv_cache_support_checks():
assert hasattr(trtllm_batch_decode_with_kv_cache, "is_compute_capability_supported")
assert hasattr(trtllm_batch_decode_with_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for xqa_batch_decode_with_kv_cache are not implemented"
)
def test_xqa_batch_decode_with_kv_cache_support_checks():
assert hasattr(xqa_batch_decode_with_kv_cache, "is_compute_capability_supported")
assert hasattr(xqa_batch_decode_with_kv_cache, "is_backend_supported")


# Prefill APIs
@pytest.mark.xfail(
reason="TODO: Support checks for single_prefill_with_kv_cache are not implemented"
)
def test_single_prefill_with_kv_cache_support_checks():
assert hasattr(single_prefill_with_kv_cache, "is_compute_capability_supported")
assert hasattr(single_prefill_with_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for BatchPrefillWithPagedKVCacheWrapper are not implemented"
)
def test_batch_prefill_with_paged_kv_cache_wrapper_support_checks():
assert hasattr(
BatchPrefillWithPagedKVCacheWrapper.run, "is_compute_capability_supported"
)
assert hasattr(BatchPrefillWithPagedKVCacheWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for BatchPrefillWithRaggedKVCacheWrapper are not implemented"
)
def test_batch_prefill_with_ragged_kv_cache_wrapper_support_checks():
assert hasattr(
BatchPrefillWithRaggedKVCacheWrapper.run, "is_compute_capability_supported"
)
assert hasattr(BatchPrefillWithRaggedKVCacheWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_ragged_attention_deepseek are not implemented"
)
def test_trtllm_ragged_attention_deepseek_support_checks():
assert hasattr(trtllm_ragged_attention_deepseek, "is_compute_capability_supported")
assert hasattr(trtllm_ragged_attention_deepseek, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_batch_context_with_kv_cache are not implemented"
)
def test_trtllm_batch_context_with_kv_cache_support_checks():
assert hasattr(
trtllm_batch_context_with_kv_cache, "is_compute_capability_supported"
)
assert hasattr(trtllm_batch_context_with_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for fmha_v2_prefill_deepseek are not implemented"
)
def test_fmha_v2_prefill_deepseek_support_checks():
assert hasattr(fmha_v2_prefill_deepseek, "is_compute_capability_supported")
assert hasattr(fmha_v2_prefill_deepseek, "is_backend_supported")


# Cascade APIs
@pytest.mark.xfail(reason="TODO: Support checks for merge_state are not implemented")
def test_merge_state_support_checks():
assert hasattr(merge_state, "is_compute_capability_supported")
assert hasattr(merge_state, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for merge_state_in_place are not implemented"
)
def test_merge_state_in_place_support_checks():
assert hasattr(merge_state_in_place, "is_compute_capability_supported")
assert hasattr(merge_state_in_place, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for merge_states are not implemented")
def test_merge_states_support_checks():
assert hasattr(merge_states, "is_compute_capability_supported")
assert hasattr(merge_states, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for MultiLevelCascadeAttentionWrapper are not implemented"
)
def test_multi_level_cascade_wrapper_support_checks():
assert hasattr(
MultiLevelCascadeAttentionWrapper.run, "is_compute_capability_supported"
)
assert hasattr(MultiLevelCascadeAttentionWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for BatchDecodeWithSharedPrefixPagedKVCacheWrapper are not implemented"
)
def test_batch_decode_shared_prefix_wrapper_support_checks():
assert hasattr(
BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward,
"is_compute_capability_supported",
)
assert hasattr(
BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported"
)


@pytest.mark.xfail(
reason="TODO: Support checks for BatchPrefillWithSharedPrefixPagedKVCacheWrapper are not implemented"
)
def test_batch_prefill_shared_prefix_wrapper_support_checks():
assert hasattr(
BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward,
"is_compute_capability_supported",
)
assert hasattr(
BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, "is_backend_supported"
)


# MLA APIs
@pytest.mark.xfail(
reason="TODO: Support checks for BatchMLAPagedAttentionWrapper are not implemented"
)
def test_batch_decode_mla_wrapper_support_checks():
assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_compute_capability_supported")
assert hasattr(BatchMLAPagedAttentionWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_batch_decode_with_kv_cache_mla are not implemented"
)
def test_trtllm_batch_decode_mla_support_checks():
assert hasattr(
trtllm_batch_decode_with_kv_cache_mla, "is_compute_capability_supported"
)
assert hasattr(trtllm_batch_decode_with_kv_cache_mla, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for xqa_batch_decode_with_kv_cache_mla are not implemented"
)
def test_xqa_batch_decode_mla_support_checks():
assert hasattr(
xqa_batch_decode_with_kv_cache_mla, "is_compute_capability_supported"
)
assert hasattr(xqa_batch_decode_with_kv_cache_mla, "is_backend_supported")


# Sparse APIs
@pytest.mark.xfail(
reason="TODO: Support checks for BlockSparseAttentionWrapper are not implemented"
)
def test_block_sparse_attention_wrapper_support_checks():
assert hasattr(BlockSparseAttentionWrapper.run, "is_compute_capability_supported")
assert hasattr(BlockSparseAttentionWrapper.run, "is_backend_supported")


# XQA APIs
@pytest.mark.xfail(reason="TODO: Support checks for xqa are not implemented")
def test_xqa_support_checks():
assert hasattr(xqa, "is_compute_capability_supported")
assert hasattr(xqa, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for xqa_mla are not implemented")
def test_xqa_mla_support_checks():
assert hasattr(xqa_mla, "is_compute_capability_supported")
assert hasattr(xqa_mla, "is_backend_supported")


# Page APIs
@pytest.mark.xfail(
reason="TODO: Support checks for get_batch_indices_positions are not implemented"
)
def test_get_batch_indices_positions_support_checks():
assert hasattr(get_batch_indices_positions, "is_compute_capability_supported")
assert hasattr(get_batch_indices_positions, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for append_paged_mla_kv_cache are not implemented"
)
def test_append_paged_mla_kv_cache_support_checks():
assert hasattr(append_paged_mla_kv_cache, "is_compute_capability_supported")
assert hasattr(append_paged_mla_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for append_paged_kv_cache are not implemented"
)
def test_append_paged_kv_cache_support_checks():
assert hasattr(append_paged_kv_cache, "is_compute_capability_supported")
assert hasattr(append_paged_kv_cache, "is_backend_supported")


# RoPE APIs
@pytest.mark.xfail(
reason="TODO: Support checks for apply_rope_inplace are not implemented"
)
def test_apply_rope_inplace_support_checks():
assert hasattr(apply_rope_inplace, "is_compute_capability_supported")
assert hasattr(apply_rope_inplace, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_rope_pos_ids_inplace are not implemented"
)
def test_apply_rope_pos_ids_inplace_support_checks():
assert hasattr(apply_rope_pos_ids_inplace, "is_compute_capability_supported")
assert hasattr(apply_rope_pos_ids_inplace, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_llama31_rope_inplace are not implemented"
)
def test_apply_llama31_rope_inplace_support_checks():
assert hasattr(apply_llama31_rope_inplace, "is_compute_capability_supported")
assert hasattr(apply_llama31_rope_inplace, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_llama31_rope_pos_ids_inplace are not implemented"
)
def test_apply_llama31_rope_pos_ids_inplace_support_checks():
assert hasattr(
apply_llama31_rope_pos_ids_inplace, "is_compute_capability_supported"
)
assert hasattr(apply_llama31_rope_pos_ids_inplace, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for apply_rope are not implemented")
def test_apply_rope_support_checks():
assert hasattr(apply_rope, "is_compute_capability_supported")
assert hasattr(apply_rope, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_rope_pos_ids are not implemented"
)
def test_apply_rope_pos_ids_support_checks():
assert hasattr(apply_rope_pos_ids, "is_compute_capability_supported")
assert hasattr(apply_rope_pos_ids, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_llama31_rope are not implemented"
)
def test_apply_llama31_rope_support_checks():
assert hasattr(apply_llama31_rope, "is_compute_capability_supported")
assert hasattr(apply_llama31_rope, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_llama31_rope_pos_ids are not implemented"
)
def test_apply_llama31_rope_pos_ids_support_checks():
assert hasattr(apply_llama31_rope_pos_ids, "is_compute_capability_supported")
assert hasattr(apply_llama31_rope_pos_ids, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_rope_with_cos_sin_cache are not implemented"
)
def test_apply_rope_with_cos_sin_cache_support_checks():
assert hasattr(apply_rope_with_cos_sin_cache, "is_compute_capability_supported")
assert hasattr(apply_rope_with_cos_sin_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for apply_rope_with_cos_sin_cache_inplace are not implemented"
)
def test_apply_rope_with_cos_sin_cache_inplace_support_checks():
assert hasattr(
apply_rope_with_cos_sin_cache_inplace, "is_compute_capability_supported"
)
assert hasattr(apply_rope_with_cos_sin_cache_inplace, "is_backend_supported")


# cuDNN APIs
@pytest.mark.xfail(
reason="TODO: Support checks for cudnn_batch_decode_with_kv_cache are not implemented"
)
def test_cudnn_batch_decode_support_checks():
assert hasattr(cudnn_batch_decode_with_kv_cache, "is_compute_capability_supported")
assert hasattr(cudnn_batch_decode_with_kv_cache, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for cudnn_batch_prefill_with_kv_cache are not implemented"
)
def test_cudnn_batch_prefill_support_checks():
assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_compute_capability_supported")
assert hasattr(cudnn_batch_prefill_with_kv_cache, "is_backend_supported")


# POD APIs
@pytest.mark.xfail(
reason="TODO: Support checks for BatchPODWithPagedKVCacheWrapper are not implemented"
)
def test_pod_prefill_wrapper_support_checks():
assert hasattr(
BatchPODWithPagedKVCacheWrapper.run, "is_compute_capability_supported"
)
assert hasattr(BatchPODWithPagedKVCacheWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for PODWithPagedKVCacheWrapper are not implemented"
)
def test_pod_decode_wrapper_support_checks():
assert hasattr(PODWithPagedKVCacheWrapper.run, "is_compute_capability_supported")
assert hasattr(PODWithPagedKVCacheWrapper.run, "is_backend_supported")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test file is quite long and contains a lot of repetitive code. You can make it more concise and maintainable by using pytest.mark.parametrize to test all APIs with a single test function. This approach also makes it easy to manage which tests are expected to fail.

Here's an example for the Decode APIs:

import pytest
from flashinfer.decode import (
    BatchDecodeWithPagedKVCacheWrapper,
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
    single_decode_with_kv_cache,
    trtllm_batch_decode_with_kv_cache,
    xqa_batch_decode_with_kv_cache,
)

decode_apis = [
    pytest.param(
        single_decode_with_kv_cache,
        id="single_decode_with_kv_cache",
        marks=pytest.mark.xfail(reason="TODO: Support checks for single_decode_with_kv_cache are not implemented")
    ),
    pytest.param(
        BatchDecodeWithPagedKVCacheWrapper.run,
        id="BatchDecodeWithPagedKVCacheWrapper",
        marks=pytest.mark.xfail(reason="TODO: Support checks for BatchDecodeWithPagedKVCacheWrapper are not implemented")
    ),
    # ... other decode APIs
]

@pytest.mark.parametrize("api", decode_apis)
def test_decode_apis_support_checks(api):
    assert hasattr(api, "is_compute_capability_supported")
    assert hasattr(api, "is_backend_supported")

You can create lists of APIs for each category (Prefill, Cascade, etc.) and apply this pattern. When a check is implemented, you can simply remove the marks from the corresponding pytest.param.

Comment on lines +1 to +245
"""
Test file for Comm API support checks.

This file serves as a TODO list for support check implementations.
APIs with @pytest.mark.xfail need support checks to be implemented.
"""

import pytest

from flashinfer.comm import (
trtllm_allreduce_fusion,
trtllm_custom_all_reduce,
trtllm_moe_allreduce_fusion,
trtllm_moe_finalize_allreduce_fusion,
trtllm_create_ipc_workspace_for_all_reduce,
trtllm_destroy_ipc_workspace_for_all_reduce,
trtllm_create_ipc_workspace_for_all_reduce_fusion,
trtllm_destroy_ipc_workspace_for_all_reduce_fusion,
trtllm_lamport_initialize,
trtllm_lamport_initialize_all,
vllm_all_reduce,
vllm_init_custom_ar,
vllm_dispose,
vllm_register_buffer,
vllm_register_graph_buffers,
MoeAlltoAll,
)
from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce
from flashinfer.comm.mnnvl import MnnvlMemory
from flashinfer.comm.trtllm_alltoall import MnnvlMoe, MoEAlltoallInfo
from flashinfer.comm.trtllm_mnnvl_ar import (
trtllm_mnnvl_allreduce,
trtllm_mnnvl_fused_allreduce_add_rmsnorm,
trtllm_mnnvl_fused_allreduce_rmsnorm,
trtllm_mnnvl_all_reduce,
)


# TRTLLM AllReduce APIs
@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_allreduce_fusion are not implemented"
)
def test_trtllm_allreduce_fusion_support_checks():
assert hasattr(trtllm_allreduce_fusion, "is_compute_capability_supported")
assert hasattr(trtllm_allreduce_fusion, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_custom_all_reduce are not implemented"
)
def test_trtllm_custom_all_reduce_support_checks():
assert hasattr(trtllm_custom_all_reduce, "is_compute_capability_supported")
assert hasattr(trtllm_custom_all_reduce, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_moe_allreduce_fusion are not implemented"
)
def test_trtllm_moe_allreduce_fusion_support_checks():
assert hasattr(trtllm_moe_allreduce_fusion, "is_compute_capability_supported")
assert hasattr(trtllm_moe_allreduce_fusion, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_moe_finalize_allreduce_fusion are not implemented"
)
def test_trtllm_moe_finalize_allreduce_fusion_support_checks():
assert hasattr(
trtllm_moe_finalize_allreduce_fusion, "is_compute_capability_supported"
)
assert hasattr(trtllm_moe_finalize_allreduce_fusion, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce are not implemented"
)
def test_trtllm_create_ipc_workspace_for_all_reduce_support_checks():
assert hasattr(
trtllm_create_ipc_workspace_for_all_reduce, "is_compute_capability_supported"
)
assert hasattr(trtllm_create_ipc_workspace_for_all_reduce, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce are not implemented"
)
def test_trtllm_destroy_ipc_workspace_for_all_reduce_support_checks():
assert hasattr(
trtllm_destroy_ipc_workspace_for_all_reduce, "is_compute_capability_supported"
)
assert hasattr(trtllm_destroy_ipc_workspace_for_all_reduce, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_create_ipc_workspace_for_all_reduce_fusion are not implemented"
)
def test_trtllm_create_ipc_workspace_for_all_reduce_fusion_support_checks():
assert hasattr(
trtllm_create_ipc_workspace_for_all_reduce_fusion,
"is_compute_capability_supported",
)
assert hasattr(
trtllm_create_ipc_workspace_for_all_reduce_fusion, "is_backend_supported"
)


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_destroy_ipc_workspace_for_all_reduce_fusion are not implemented"
)
def test_trtllm_destroy_ipc_workspace_for_all_reduce_fusion_support_checks():
assert hasattr(
trtllm_destroy_ipc_workspace_for_all_reduce_fusion,
"is_compute_capability_supported",
)
assert hasattr(
trtllm_destroy_ipc_workspace_for_all_reduce_fusion, "is_backend_supported"
)


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_lamport_initialize are not implemented"
)
def test_trtllm_lamport_initialize_support_checks():
assert hasattr(trtllm_lamport_initialize, "is_compute_capability_supported")
assert hasattr(trtllm_lamport_initialize, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_lamport_initialize_all are not implemented"
)
def test_trtllm_lamport_initialize_all_support_checks():
assert hasattr(trtllm_lamport_initialize_all, "is_compute_capability_supported")
assert hasattr(trtllm_lamport_initialize_all, "is_backend_supported")


# VLLM AllReduce APIs
@pytest.mark.xfail(
reason="TODO: Support checks for vllm_all_reduce are not implemented"
)
def test_vllm_all_reduce_support_checks():
assert hasattr(vllm_all_reduce, "is_compute_capability_supported")
assert hasattr(vllm_all_reduce, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for vllm_init_custom_ar are not implemented"
)
def test_vllm_init_custom_ar_support_checks():
assert hasattr(vllm_init_custom_ar, "is_compute_capability_supported")
assert hasattr(vllm_init_custom_ar, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for vllm_dispose are not implemented")
def test_vllm_dispose_support_checks():
assert hasattr(vllm_dispose, "is_compute_capability_supported")
assert hasattr(vllm_dispose, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for vllm_register_buffer are not implemented"
)
def test_vllm_register_buffer_support_checks():
assert hasattr(vllm_register_buffer, "is_compute_capability_supported")
assert hasattr(vllm_register_buffer, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for vllm_register_graph_buffers are not implemented"
)
def test_vllm_register_graph_buffers_support_checks():
assert hasattr(vllm_register_graph_buffers, "is_compute_capability_supported")
assert hasattr(vllm_register_graph_buffers, "is_backend_supported")


# NVSHMEM APIs
@pytest.mark.xfail(
reason="TODO: Support checks for NVSHMEMAllReduce are not implemented"
)
def test_nvshmem_allreduce_support_checks():
assert hasattr(NVSHMEMAllReduce, "is_compute_capability_supported")
assert hasattr(NVSHMEMAllReduce, "is_backend_supported")


# MNNVL APIs
@pytest.mark.xfail(reason="TODO: Support checks for MnnvlMemory are not implemented")
def test_mnnvl_memory_support_checks():
assert hasattr(MnnvlMemory, "is_compute_capability_supported")
assert hasattr(MnnvlMemory, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for MnnvlMoe are not implemented")
def test_mnnvl_moe_support_checks():
assert hasattr(MnnvlMoe, "is_compute_capability_supported")
assert hasattr(MnnvlMoe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for MoEAlltoallInfo are not implemented"
)
def test_moe_alltoall_info_support_checks():
assert hasattr(MoEAlltoallInfo, "is_compute_capability_supported")
assert hasattr(MoEAlltoallInfo, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for MoeAlltoAll are not implemented")
def test_moe_alltoall_support_checks():
assert hasattr(MoeAlltoAll.dispatch, "is_compute_capability_supported")
assert hasattr(MoeAlltoAll.dispatch, "is_backend_supported")


# TRTLLM MNNVL AllReduce APIs
@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_mnnvl_allreduce are not implemented"
)
def test_trtllm_mnnvl_allreduce_support_checks():
assert hasattr(trtllm_mnnvl_allreduce, "is_compute_capability_supported")
assert hasattr(trtllm_mnnvl_allreduce, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_add_rmsnorm are not implemented"
)
def test_trtllm_mnnvl_fused_allreduce_add_rmsnorm_support_checks():
assert hasattr(
trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_compute_capability_supported"
)
assert hasattr(trtllm_mnnvl_fused_allreduce_add_rmsnorm, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_mnnvl_fused_allreduce_rmsnorm are not implemented"
)
def test_trtllm_mnnvl_fused_allreduce_rmsnorm_support_checks():
assert hasattr(
trtllm_mnnvl_fused_allreduce_rmsnorm, "is_compute_capability_supported"
)
assert hasattr(trtllm_mnnvl_fused_allreduce_rmsnorm, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_mnnvl_all_reduce are not implemented"
)
def test_trtllm_mnnvl_all_reduce_support_checks():
assert hasattr(trtllm_mnnvl_all_reduce, "is_compute_capability_supported")
assert hasattr(trtllm_mnnvl_all_reduce, "is_backend_supported")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file has a lot of repeated test code. To improve maintainability and reduce boilerplate, consider using pytest.mark.parametrize. This allows you to define a list of APIs to test and use a single test function to check them all.

For example:

import pytest
from flashinfer.comm import (
    trtllm_allreduce_fusion,
    trtllm_custom_all_reduce,
    # ... other imports
)

trtllm_allreduce_apis = [
    pytest.param(
        trtllm_allreduce_fusion,
        id="trtllm_allreduce_fusion",
        marks=pytest.mark.xfail(reason="TODO: Support checks for trtllm_allreduce_fusion are not implemented")
    ),
    pytest.param(
        trtllm_custom_all_reduce,
        id="trtllm_custom_all_reduce",
        marks=pytest.mark.xfail(reason="TODO: Support checks for trtllm_custom_all_reduce are not implemented")
    ),
    # ... other TRTLLM AllReduce APIs
]

@pytest.mark.parametrize("api", trtllm_allreduce_apis)
def test_trtllm_allreduce_apis_support_checks(api):
    assert hasattr(api, "is_compute_capability_supported")
    assert hasattr(api, "is_backend_supported")

This pattern can be applied to the other API groups in this file as well. It makes it easy to track which tests are expected to fail and to update them as features are implemented.

Comment on lines +1 to +160
"""
Test file for GEMM API support checks.

This file serves as a TODO list for support check implementations.
APIs with @pytest.mark.xfail need support checks to be implemented in gemm_base.py.
"""

import pytest

from flashinfer import (
bmm_fp8,
mm_fp4,
mm_fp8,
prepare_low_latency_gemm_weights,
tgv_gemm_sm100,
)
from flashinfer.gemm import (
SegmentGEMMWrapper,
batch_deepgemm_fp8_nt_groupwise,
gemm_fp8_nt_blockscaled,
gemm_fp8_nt_groupwise,
group_deepgemm_fp8_nt_groupwise,
group_gemm_fp8_nt_groupwise,
group_gemm_mxfp4_nt_groupwise,
)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from flashinfer.cute_dsl.gemm_allreduce_two_shot import PersistentDenseGemmKernel
import flashinfer.triton.sm_constraint_gemm as sm_constraint_gemm


def test_mm_fp4_support_checks():
assert hasattr(mm_fp4, "is_compute_capability_supported")
assert hasattr(mm_fp4, "is_backend_supported")


def test_bmm_fp8_support_checks():
assert hasattr(bmm_fp8, "is_compute_capability_supported")
assert hasattr(bmm_fp8, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for tgv_gemm_sm100 are not implemented")
def test_tgv_gemm_sm100_support_checks():
assert hasattr(tgv_gemm_sm100, "is_compute_capability_supported")
assert hasattr(tgv_gemm_sm100, "is_backend_supported")


@pytest.mark.xfail(reason="TODO: Support checks for mm_fp8 are not implemented")
def test_mm_fp8_support_checks():
assert hasattr(mm_fp8, "is_compute_capability_supported")
assert hasattr(mm_fp8, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for gemm_fp8_nt_groupwise are not implemented"
)
def test_gemm_fp8_nt_groupwise_support_checks():
assert hasattr(gemm_fp8_nt_groupwise, "is_compute_capability_supported")
assert hasattr(gemm_fp8_nt_groupwise, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for gemm_fp8_nt_blockscaled are not implemented"
)
def test_gemm_fp8_nt_blockscaled_support_checks():
assert hasattr(gemm_fp8_nt_blockscaled, "is_compute_capability_supported")
assert hasattr(gemm_fp8_nt_blockscaled, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for group_gemm_fp8_nt_groupwise are not implemented"
)
def test_group_gemm_fp8_nt_groupwise_support_checks():
assert hasattr(group_gemm_fp8_nt_groupwise, "is_compute_capability_supported")
assert hasattr(group_gemm_fp8_nt_groupwise, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for group_gemm_mxfp4_nt_groupwise are not implemented"
)
def test_group_gemm_mxfp4_nt_groupwise_support_checks():
assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_compute_capability_supported")
assert hasattr(group_gemm_mxfp4_nt_groupwise, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for group_deepgemm_fp8_nt_groupwise are not implemented"
)
def test_group_deepgemm_fp8_nt_groupwise_support_checks():
assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported")
assert hasattr(group_deepgemm_fp8_nt_groupwise, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for batch_deepgemm_fp8_nt_groupwise are not implemented"
)
def test_batch_deepgemm_fp8_nt_groupwise_support_checks():
assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_compute_capability_supported")
assert hasattr(batch_deepgemm_fp8_nt_groupwise, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for SegmentGEMMWrapper are not implemented"
)
def test_segment_gemm_wrapper_support_checks():
assert hasattr(SegmentGEMMWrapper.run, "is_compute_capability_supported")
assert hasattr(SegmentGEMMWrapper.run, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for prepare_low_latency_gemm_weights are not implemented"
)
def test_prepare_low_latency_gemm_weights_support_checks():
assert hasattr(prepare_low_latency_gemm_weights, "is_compute_capability_supported")
assert hasattr(prepare_low_latency_gemm_weights, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for grouped_gemm_nt_masked are not implemented"
)
def test_grouped_gemm_nt_masked_support_checks():
assert hasattr(grouped_gemm_nt_masked, "is_compute_capability_supported")
assert hasattr(grouped_gemm_nt_masked, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for PersistentDenseGemmKernel are not implemented"
)
def test_persistent_dense_gemm_kernel_support_checks():
assert hasattr(PersistentDenseGemmKernel, "is_compute_capability_supported")
assert hasattr(PersistentDenseGemmKernel, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for sm_constraint_gemm.gemm are not implemented"
)
def test_sm_constraint_gemm_support_checks():
assert hasattr(sm_constraint_gemm.gemm, "is_compute_capability_supported")
assert hasattr(sm_constraint_gemm.gemm, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for sm_constraint_gemm.gemm_persistent are not implemented"
)
def test_sm_constraint_gemm_persistent_support_checks():
assert hasattr(
sm_constraint_gemm.gemm_persistent, "is_compute_capability_supported"
)
assert hasattr(sm_constraint_gemm.gemm_persistent, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for sm_constraint_gemm.gemm_descriptor_persistent are not implemented"
)
def test_sm_constraint_gemm_descriptor_persistent_support_checks():
assert hasattr(
sm_constraint_gemm.gemm_descriptor_persistent, "is_compute_capability_supported"
)
assert hasattr(
sm_constraint_gemm.gemm_descriptor_persistent, "is_backend_supported"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test file can be simplified by using pytest.mark.parametrize to avoid repeating the same test logic for each API. This will make the code more concise and easier to maintain. Since you have a mix of passing tests and expected failures, you can configure this within the parametrization.

Here's how you could refactor it:

import pytest
from flashinfer import (
    bmm_fp8,
    mm_fp4,
    tgv_gemm_sm100,
    # ... other imports
)

gemm_apis = [
    pytest.param(mm_fp4, id="mm_fp4"), # This test is expected to pass
    pytest.param(bmm_fp8, id="bmm_fp8"), # This test is expected to pass
    pytest.param(
        tgv_gemm_sm100,
        id="tgv_gemm_sm100",
        marks=pytest.mark.xfail(reason="TODO: Support checks for tgv_gemm_sm100 are not implemented")
    ),
    # ... other gemm APIs
]

@pytest.mark.parametrize("api", gemm_apis)
def test_gemm_apis_support_checks(api):
    assert hasattr(api, "is_compute_capability_supported")
    assert hasattr(api, "is_backend_supported")

This approach clearly separates the passing tests from the xfail ones while keeping the test logic DRY (Don't Repeat Yourself).

Comment on lines +1 to +140
"""
Test file for MoE API support checks.

This file serves as a TODO list for support check implementations.
APIs with @pytest.mark.xfail need support checks to be implemented.
"""

import pytest

from flashinfer.fused_moe import (
cutlass_fused_moe,
fused_topk_deepseek,
trtllm_bf16_moe,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
trtllm_mxint4_block_scale_moe,
)
from flashinfer.comm.trtllm_moe_alltoall import (
moe_a2a_combine,
moe_a2a_dispatch,
moe_a2a_get_workspace_size_per_rank,
moe_a2a_initialize,
moe_a2a_sanitize_expert_ids,
moe_a2a_wrap_payload_tensor_in_workspace,
)


def test_fused_topk_deepseek_support_checks():
assert hasattr(fused_topk_deepseek, "is_compute_capability_supported")
assert hasattr(fused_topk_deepseek, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for cutlass_fused_moe are not implemented"
)
def test_cutlass_fused_moe_support_checks():
assert hasattr(cutlass_fused_moe, "is_compute_capability_supported")
assert hasattr(cutlass_fused_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_bf16_moe are not implemented"
)
def test_trtllm_bf16_moe_support_checks():
assert hasattr(trtllm_bf16_moe, "is_compute_capability_supported")
assert hasattr(trtllm_bf16_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_fp8_per_tensor_scale_moe are not implemented"
)
def test_trtllm_fp8_per_tensor_scale_moe_support_checks():
assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_compute_capability_supported")
assert hasattr(trtllm_fp8_per_tensor_scale_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_fp8_block_scale_moe are not implemented"
)
def test_trtllm_fp8_block_scale_moe_support_checks():
assert hasattr(trtllm_fp8_block_scale_moe, "is_compute_capability_supported")
assert hasattr(trtllm_fp8_block_scale_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_fp4_block_scale_moe are not implemented"
)
def test_trtllm_fp4_block_scale_moe_support_checks():
assert hasattr(trtllm_fp4_block_scale_moe, "is_compute_capability_supported")
assert hasattr(trtllm_fp4_block_scale_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_fp4_block_scale_routed_moe are not implemented"
)
def test_trtllm_fp4_block_scale_routed_moe_support_checks():
assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_compute_capability_supported")
assert hasattr(trtllm_fp4_block_scale_routed_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for trtllm_mxint4_block_scale_moe are not implemented"
)
def test_trtllm_mxint4_block_scale_moe_support_checks():
assert hasattr(trtllm_mxint4_block_scale_moe, "is_compute_capability_supported")
assert hasattr(trtllm_mxint4_block_scale_moe, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_initialize are not implemented"
)
def test_moe_a2a_initialize_support_checks():
assert hasattr(moe_a2a_initialize, "is_compute_capability_supported")
assert hasattr(moe_a2a_initialize, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_wrap_payload_tensor_in_workspace are not implemented"
)
def test_moe_a2a_wrap_payload_tensor_in_workspace_support_checks():
assert hasattr(
moe_a2a_wrap_payload_tensor_in_workspace, "is_compute_capability_supported"
)
assert hasattr(moe_a2a_wrap_payload_tensor_in_workspace, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_dispatch are not implemented"
)
def test_moe_a2a_dispatch_support_checks():
assert hasattr(moe_a2a_dispatch, "is_compute_capability_supported")
assert hasattr(moe_a2a_dispatch, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_combine are not implemented"
)
def test_moe_a2a_combine_support_checks():
assert hasattr(moe_a2a_combine, "is_compute_capability_supported")
assert hasattr(moe_a2a_combine, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_sanitize_expert_ids are not implemented"
)
def test_moe_a2a_sanitize_expert_ids_support_checks():
assert hasattr(moe_a2a_sanitize_expert_ids, "is_compute_capability_supported")
assert hasattr(moe_a2a_sanitize_expert_ids, "is_backend_supported")


@pytest.mark.xfail(
reason="TODO: Support checks for moe_a2a_get_workspace_size_per_rank are not implemented"
)
def test_moe_a2a_get_workspace_size_per_rank_support_checks():
assert hasattr(
moe_a2a_get_workspace_size_per_rank, "is_compute_capability_supported"
)
assert hasattr(moe_a2a_get_workspace_size_per_rank, "is_backend_supported")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make this test file more compact and maintainable, you could use pytest.mark.parametrize. This would allow you to test multiple APIs with a single test function, reducing code duplication. You can also handle the mix of passing and xfail tests cleanly.

Here is an example of how to apply this pattern:

import pytest
from flashinfer.fused_moe import (
    cutlass_fused_moe,
    fused_topk_deepseek,
    # ... other imports
)

moe_apis = [
    pytest.param(fused_topk_deepseek, id="fused_topk_deepseek"), # Expected to pass
    pytest.param(
        cutlass_fused_moe,
        id="cutlass_fused_moe",
        marks=pytest.mark.xfail(reason="TODO: Support checks for cutlass_fused_moe are not implemented")
    ),
    # ... other MoE APIs
]

@pytest.mark.parametrize("api", moe_apis)
def test_moe_apis_support_checks(api):
    assert hasattr(api, "is_compute_capability_supported")
    assert hasattr(api, "is_backend_supported")

This refactoring makes it easier to see which APIs are covered and to update their status as the support checks are implemented.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/moe/test_moe_support_checks.py (1)

35-40: Consider adding strict=True to xfail markers.

Adding strict=True ensures the test suite will flag when support checks are implemented, prompting removal of the xfail marker. This helps keep the TODO list accurate.

-@pytest.mark.xfail(
-    reason="TODO: Support checks for cutlass_fused_moe are not implemented"
-)
+@pytest.mark.xfail(
+    reason="TODO: Support checks for cutlass_fused_moe are not implemented",
+    strict=True,
+)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0355f7 and 41f0b79.

📒 Files selected for processing (4)
  • tests/attention/test_attention_support_checks.py (1 hunks)
  • tests/comm/test_comm_support_checks.py (1 hunks)
  • tests/gemm/test_gemm_support_checks.py (1 hunks)
  • tests/moe/test_moe_support_checks.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/gemm/test_gemm_support_checks.py (5)
flashinfer/gemm/gemm_base.py (6)
  • mm_fp4 (2086-2223)
  • mm_fp8 (1560-1667)
  • SegmentGEMMWrapper (808-1062)
  • batch_deepgemm_fp8_nt_groupwise (3157-3289)
  • gemm_fp8_nt_blockscaled (2666-2691)
  • gemm_fp8_nt_groupwise (2414-2577)
flashinfer/trtllm_low_latency_gemm.py (1)
  • prepare_low_latency_gemm_weights (196-224)
flashinfer/cute_dsl/blockscaled_gemm.py (1)
  • grouped_gemm_nt_masked (2945-3046)
flashinfer/cute_dsl/gemm_allreduce_two_shot.py (1)
  • PersistentDenseGemmKernel (184-2105)
flashinfer/triton/sm_constraint_gemm.py (2)
  • gemm_persistent (14-95)
  • gemm_descriptor_persistent (173-278)
tests/comm/test_comm_support_checks.py (5)
flashinfer/comm/trtllm_ar.py (4)
  • trtllm_create_ipc_workspace_for_all_reduce (403-480)
  • trtllm_destroy_ipc_workspace_for_all_reduce (483-495)
  • trtllm_create_ipc_workspace_for_all_reduce_fusion (504-643)
  • trtllm_destroy_ipc_workspace_for_all_reduce_fusion (646-662)
flashinfer/comm/nvshmem_allreduce.py (1)
  • NVSHMEMAllReduce (25-127)
flashinfer/comm/mnnvl.py (1)
  • MnnvlMemory (244-563)
flashinfer/comm/trtllm_alltoall.py (1)
  • MoEAlltoallInfo (423-430)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • trtllm_mnnvl_allreduce (308-383)
  • trtllm_mnnvl_fused_allreduce_rmsnorm (620-707)
  • trtllm_mnnvl_all_reduce (546-614)
tests/attention/test_attention_support_checks.py (2)
flashinfer/page.py (1)
  • get_batch_indices_positions (123-175)
flashinfer/cudnn/decode.py (1)
  • cudnn_batch_decode_with_kv_cache (257-350)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
tests/moe/test_moe_support_checks.py (2)

1-140: LGTM overall structure and coverage.

The test file is well-organized with clear imports and consistent test patterns. Each test verifies both is_compute_capability_supported and is_backend_supported attributes, which aligns with the PR objective of tracking support check implementations.


30-32: The support check attributes for fused_topk_deepseek are already implemented via the @backend_requirement decorator. This decorator automatically adds both is_compute_capability_supported and is_backend_supported methods to the decorated function. The test is correctly not marked with @pytest.mark.xfail and will pass.

Likely an incorrect or invalid review comment.

tests/gemm/test_gemm_support_checks.py (2)

1-160: Good coverage of GEMM API surfaces.

The test file comprehensively covers various GEMM implementations including standard, grouped, block-scaled, and Triton-based variants. The organization by category (imports, tests) is clear and maintainable.


31-38: These support checks are already implemented via the @backend_requirement decorator.

Both mm_fp4 and bmm_fp8 are decorated with @backend_requirement (in flashinfer/gemm/gemm_base.py at lines 2078-2084 and 2316-2323 respectively). This decorator automatically adds the is_compute_capability_supported and is_backend_supported methods to these functions, so the test assertions will pass. The tests are correctly not marked @pytest.mark.xfail because the implementation is complete.

Likely an incorrect or invalid review comment.

tests/comm/test_comm_support_checks.py (2)

1-245: Comprehensive coverage of communication APIs.

The test file covers a broad range of communication primitives across TRTLLM, VLLM, NVSHMEM, and MNNVL backends. All tests are consistently marked as xfail, making this a clear roadmap for future support check implementations.


197-202: Remove this test or reconsider if MoEAlltoallInfo should implement support checks.

MoEAlltoallInfo is a pure dataclass containing only tensor references and a count field. Support check attributes (is_compute_capability_supported, is_backend_supported) belong on classes that perform compute operations or have backend/capability requirements, not data containers. Either this test should be removed, or the design intent for this class should be reconsidered if support checks are actually needed.

Likely an incorrect or invalid review comment.

tests/attention/test_attention_support_checks.py (2)

198-221: Consistent method targeting for different wrapper interfaces.

Note that BatchDecodeWithSharedPrefixPagedKVCacheWrapper and BatchPrefillWithSharedPrefixPagedKVCacheWrapper check .forward while most other wrappers check .run. This appears intentional based on each wrapper's public API, but ensure the @backend_requirement decorator is applied consistently to the appropriate entry point method for each wrapper.


1-416: Excellent coverage of Attention API surfaces.

This file comprehensively covers Decode, Prefill, Cascade, MLA, Sparse, XQA, Page, RoPE, cuDNN, and POD APIs. The organization by category with section comments (e.g., # Decode APIs, # RoPE APIs) improves readability and maintainability. All tests are consistently marked as xfail, serving as a clear tracking mechanism for support check implementation progress.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jimmyzho it looks good to me! Just wondering does this mean we have to update this unittest whenever we introduce a new API?

Also would you mind mentioning is_compute_capability_supported and is_backend_supported to claude code? You should be able to commit change to #2240

@jimmyzho
Copy link
Copy Markdown
Contributor Author

@yzh119 Added to commit !2240. And yes, it would be good to keep our interface updated with support checks, if feasible.

@jimmyzho jimmyzho requested a review from yzh119 January 5, 2026 19:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants