Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions csrc/gemm_groupwise_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ using namespace flashinfer;
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \
...) \
[&]() -> bool { \
/* SM120 Cooperative schedule uses 128x128x128 tile shape */ \
/* SM120/SM121 Cooperative schedule uses 128x128x128 tile shape */ \
/* TODO (yongwww): PingPong schedule (64x128x128) will need additional dispatch logic */ \
constexpr int SCALE_GRANULARITY_K = 128; \
if (scale_granularity_k != 128) { \
TVM_FFI_ICHECK(false) \
<< "SM120 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK must " \
<< "SM120/SM121 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK " \
"must " \
"equal tile shape K dimension (128 for both Cooperative and PingPong schedules)."; \
return false; \
} \
Expand Down
5 changes: 3 additions & 2 deletions csrc/group_gemm_fp8_groupwise_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ using namespace flashinfer;
constexpr int SCALE_GRANULARITY_K = 128; \
if (scale_granularity_k != 128) { \
TVM_FFI_ICHECK(false) \
<< "SM120 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK must " \
"equal tile shape K dimension (128 for both Cooperative and PingPong schedules)."; \
<< "SM120/SM121 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK " \
"must equal tile shape K dimension (128 for both Cooperative and PingPong " \
"schedules)."; \
return false; \
} \
/* Match SM100's approach: support only (1,128,128) and (128,128,128) */ \
Expand Down
2 changes: 1 addition & 1 deletion csrc/trtllm_fmha_v2_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ static inline void determine_launch_params(
* - Layout: Separate Q, K, V tensors
* - Q/K dimension: 192, V dimension: 128 (MLA-specific)
* - Output type: BF16
* - Target architecture: SM120 (Blackwell)
* - Target architecture: SM120/SM121 (Blackwell)
*
* @param q Query tensor [batch, q_seqlen, num_heads, 192] in E4M3
* @param k Key tensor [batch, kv_seqlen, num_kv_heads, 192] in E4M3
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,7 +2215,7 @@ def trtllm_batch_decode_with_kv_cache(
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_90 (hopper architecture) and sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
For sm_90 (hopper architecture) and sm_120/sm_121 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
o_scale : Optional[float] = 1.0
output scale factor for xqa fp8 output.
Expand Down
12 changes: 6 additions & 6 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024

# Error messages
CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120/SM121 with cuDNN backend version < 9.14.0."


def _match_sm_version(device: torch.device, sm_version: list[str]):
Expand Down Expand Up @@ -622,7 +622,7 @@ def forward(
a, b, scale_a, scale_b, out, workspace_buffer = inputs

# Handle both 2D (MM) and 3D (BMM) cases
# SM120 kernel now supports batch operations natively
# SM120/SM121 kernel now supports batch operations natively
if a.dim() == 2:
# 2D case: simple matrix multiplication
# Make B column-major for the kernel
Expand Down Expand Up @@ -652,7 +652,7 @@ def forward(
def _pad_to_multiple(x, multiple):
return ((x + multiple - 1) // multiple) * multiple

# SM120 CUTLASS blockwise scaling requires:
# SM120/SM121 CUTLASS blockwise scaling requires:
# - N % 128 == 0 (ScaleGranularityN)
# - K % 128 == 0 (TileK)
# If not aligned, we pad and then slice the result
Expand Down Expand Up @@ -709,7 +709,7 @@ def _pad_to_multiple(x, multiple):
else:
out_padded = out

# For scalar scales, create compatible shapes for SM120
# For scalar scales, create compatible shapes for SM120/SM121
if scale_a.numel() == 1:
scale_m_count = (
batch_size * m_dim + scale_gran_m - 1
Expand All @@ -736,7 +736,7 @@ def _pad_to_multiple(x, multiple):
else:
scale_b_expanded = scale_b

# Call SM120 gemm_fp8_nt_groupwise (now handles both 2D and 3D)
# Call SM120/SM121 gemm_fp8_nt_groupwise (now handles both 2D and 3D)
module.gemm_fp8_nt_groupwise(
workspace_buffer,
a_padded,
Expand Down Expand Up @@ -3038,7 +3038,7 @@ def _cudnn_gemm_fp4_requirement(
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and _match_sm_version(a.device, ["120", "121"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this change correctly adds support for SM121, the error message raised on line 3044, which is controlled by the CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR constant, is now potentially misleading as it only mentions SM120. To avoid confusion for users on SM121 devices, please consider updating the error message to include SM121.

and cudnn.backend_version() < 91400
):
raise LibraryError(CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR)
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def gen_gemm_sm120_module() -> JitSpec:
dtype_in_list = [torch.float8_e4m3fn, torch.float8_e5m2]
dtype_out_list = [torch.float16, torch.bfloat16]
scale_major_k_list = ["true", "false"]
# SM120 uses fixed 128x128x128 tiles with Cooperative schedule
# SM120/SM121 uses fixed 128x128x128 tiles with Cooperative schedule

with open(jit_env.FLASHINFER_CSRC_DIR / f"{prefix}_sm120_kernel_inst.jinja") as f:
kernel_inst_templ = jinja2.Template(f.read())
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/gemm/cutlass/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled):
else:
act_type, weight_type = dtype, dtype

# Minimal filter: for mixed FP8xFP4 on SM120, only emit 128x128x128
# Minimal filter: for mixed FP8xFP4 on SM120/SM121, only emit 128x128x128
if act_type == DataType.e4m3 and weight_type == e2m1:
if cta_shape_mnk != [128, 128, 128]:
continue
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
or kv_cache.dtype != torch.float8_e4m3fn
):
raise ValueError(
f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}"
f"XQA MLA only supports fp8 operation on SM120/SM121 GPUs, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def xqa(
run_sm90_fp8_mha = False

if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")
raise RuntimeError("XQA is only supported on SM90, SM100, SM120/SM121 GPUs")

xqa_module = get_xqa_module(
q.dtype,
Expand Down Expand Up @@ -501,7 +501,7 @@ def xqa_mla(
assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype"

if get_compute_capability(torch.device(device="cuda"))[0] not in [12]:
raise RuntimeError("XQA MLA is only supported on SM120 GPUs")
raise RuntimeError("XQA MLA is only supported on SM120/SM121 GPUs")

xqa_module = get_xqa_module_mla(
q.dtype,
Expand Down
11 changes: 6 additions & 5 deletions include/flashinfer/gemm/cutlass_gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ struct CutlassGemmConfig {
bool enableCudaKernel = false;
int sm_version = 80; // Use 80 as a catch all for <90
bool is_tma_warp_specialized = false;
bool use_stream_k = false; // SM120: false = DP scheduler (default), true = StreamK scheduler
bool use_stream_k =
false; // SM120/SM121: false = DP scheduler (default), true = StreamK scheduler

CutlassGemmConfig() = default;

Expand Down Expand Up @@ -379,7 +380,7 @@ struct CutlassGemmConfig {
sm_version(100),
is_tma_warp_specialized(true) {}

// SM120 constructor with optional StreamK scheduler
// SM120/SM121 constructor with optional StreamK scheduler
// use_stream_k: false = DP scheduler (default), true = StreamK scheduler (auto heuristic)
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120,
MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule,
Expand All @@ -393,7 +394,7 @@ struct CutlassGemmConfig {
use_stream_k(use_stream_k) {}

int getTileConfigAsInt() const {
if (sm_version == 120) return (int)tile_config_sm120;
if (sm_version == 120 || sm_version == 121) return (int)tile_config_sm120;
if (sm_version == 110) return (int)tile_config_sm100;
if (sm_version >= 100) return (int)tile_config_sm100;
if (sm_version == 90) return (int)tile_config_sm90;
Expand All @@ -413,8 +414,8 @@ struct CutlassGemmConfig {
<< "\n\tmainloop sched: " << (int)mainloop_schedule
<< "\n\tepi sched: " << (int)epilogue_schedule
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
// SM120 specific: StreamK scheduler option
if (sm_version == 120) {
// SM120/SM121 specific: StreamK scheduler option
if (sm_version == 120 || sm_version == 121) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With this change, the check now includes SM121, but the comment on the preceding line (416) still says 'SM120 specific'. Please update the comment to reflect this change for better code clarity. For example: // SM120/SM121 specific: StreamK scheduler option.

tactic << "\n\tscheduler: " << (use_stream_k ? "StreamK (auto heuristic)" : "DP (default)");
}
} else if (tile_config_sm80 != flashinfer::gemm::CutlassTileConfig::ChooseWithHeuristic) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ std::vector<CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getConfigs(
CutlassTileConfigSM120::CtaShape256x128x128B,
};

// SM120 only supports 1x1x1 cluster shape
// SM120/SM121 only supports 1x1x1 cluster shape
ClusterShape clusterShape = ClusterShape::ClusterShape_1x1x1;

// Generate configs for both DP and StreamK schedulers
Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/gemm/fp4_gemm_template_sm120.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const*
auto initStatus = gemm.initialize(args, workspace, stream);
if (initStatus != cutlass::Status::kSuccess) {
throw std::runtime_error(std::string("[FP4 gemm Runner") + scheduler_name + "] " +
"Failed to initialize cutlass FP4 gemm on sm120. Error: " +
"Failed to initialize cutlass FP4 gemm on sm120/sm121. Error: " +
std::string(cutlass::cutlassGetStatusString(initStatus)));
}

auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true);
if (runStatus != cutlass::Status::kSuccess) {
throw std::runtime_error(std::string("[FP4 gemm Runner") + scheduler_name + "] " +
"Failed to run cutlass FP4 gemm on sm120. Error: " +
"Failed to run cutlass FP4 gemm on sm120/sm121. Error: " +
std::string(cutlass::cutlassGetStatusString(runStatus)));
}

Expand Down Expand Up @@ -251,8 +251,8 @@ inline size_t runFp4GemmImpl(void* D, void const* A, void const* B, void const*
cutlass::epilogue::fusion::LinearCombination<OutElementType, float, void, \
float>>::CollectiveOp; \
\
/* SM120 BlockScaled - Use nv_float4_t without tuples like example 79 */ \
/* Use fixed 2 stages for SM120 to meet minimum requirement */ \
/* SM120/SM121 BlockScaled - Use nv_float4_t without tuples like example 79 */ \
/* Use fixed 2 stages for SM120/SM121 to meet minimum requirement */ \
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< \
Arch, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, \
ElementAccumulator, ThreadBlockShape, ClusterShape, \
Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/gemm/gemm_groupwise_sm120.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ cudaError_t CutlassGroupwiseScaledGEMMSM120(void* float_buffer, size_t float_buf
DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr,
float* SFB_ptr, DTypeOut* D_ptr, int m, int n, int k,
int l, cudaStream_t stream) {
// SM120 only supports these specific scale granularities
// SM120/SM121 only supports these specific scale granularities
static_assert(ScaleGranularityM == 1 || ScaleGranularityM == 128,
"SM120 only supports ScaleGranularityM = 1 or 128");
static_assert(ScaleGranularityN == 128, "SM120 only supports ScaleGranularityN = 128");
static_assert(ScaleGranularityK == 128, "SM120 only supports ScaleGranularityK = 128");
"SM120/SM121 only supports ScaleGranularityM = 1 or 128");
static_assert(ScaleGranularityN == 128, "SM120/SM121 only supports ScaleGranularityN = 128");
static_assert(ScaleGranularityK == 128, "SM120/SM121 only supports ScaleGranularityK = 128");
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
using namespace cute;

Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/gemm/group_gemm_fp8_groupwise_sm120.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ cudaError_t CutlassFP8GroupwiseScaledGroupGEMMSM120(
void* int_buffer, size_t int_buffer_size_in_bytes, void* float_buffer,
size_t float_buffer_size_in_bytes, DTypeIn* A, DTypeIn* B, float* SFA, float* SFB, DTypeOut* D,
int* m_indptr, int max_m, int n, int k, int num_groups, cudaStream_t stream) {
// SM120 only supports these specific scale granularities
// SM120/SM121 only supports these specific scale granularities
static_assert(ScaleGranularityM == 1 || ScaleGranularityM == 128,
"SM120 only supports ScaleGranularityM = 1 or 128");
static_assert(ScaleGranularityN == 128, "SM120 only supports ScaleGranularityN = 128");
static_assert(ScaleGranularityK == 128, "SM120 only supports ScaleGranularityK = 128");
"SM120/SM121 only supports ScaleGranularityM = 1 or 128");
static_assert(ScaleGranularityN == 128, "SM120/SM121 only supports ScaleGranularityN = 128");
static_assert(ScaleGranularityK == 128, "SM120/SM121 only supports ScaleGranularityK = 128");
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group

Expand Down
1 change: 1 addition & 0 deletions include/flashinfer/trtllm/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,4 @@ constexpr int32_t kSM_100 = 100;
constexpr int32_t kSM_100f = 10100;
constexpr int32_t kSM_103 = 103;
constexpr int32_t kSM_120 = 120;
constexpr int32_t kSM_121 = 121;
2 changes: 1 addition & 1 deletion tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def trtllm_batch_decode_mla(
compute_capability = get_compute_capability(torch.device(device="cuda"))
if backend == "xqa":
if compute_capability[0] != 12:
pytest.skip("XQA MLA only supports SM120 GPUs")
pytest.skip("XQA MLA only supports SM120/SM121 GPUs")
if q_len_per_request != 1 or dtype != torch.float8_e4m3fn:
pytest.skip(
"XQA MLA only supports q_len_per_request == 1 and dtype == torch.float8_e4m3fn"
Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def ref_attention(

@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
reason="XQA is only supported on SM90, SM100, SM120 GPUs",
reason="XQA is only supported on SM90, SM100, SM120/SM121 GPUs",
)
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("use_sliding_window", [True, False])
Expand Down Expand Up @@ -467,7 +467,7 @@ def test_xqa(

@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
reason="XQA mla is only supported on SM120 GPUs",
reason="XQA mla is only supported on SM120/SM121 GPUs",
)
@pytest.mark.parametrize("kv_scale", [1.0, 0.5])
@pytest.mark.parametrize("q_scale", [1.0, 0.5])
Expand Down
2 changes: 1 addition & 1 deletion tests/attention/test_xqa_batch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def generate_causal_mask(

@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
reason="XQA is only supported on SM90, SM100, SM120 GPUs",
reason="XQA is only supported on SM90, SM100, SM120/SM121 GPUs",
)
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
Expand Down
2 changes: 1 addition & 1 deletion tests/attention/test_xqa_mla_batch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_xqa_mla_batch_decode(
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 12:
pytest.skip("These tests are only guaranteed to work on SM120 GPUs.")
pytest.skip("These tests are only guaranteed to work on SM120/SM121 GPUs.")

torch.manual_seed(42)
dtype = torch.float8_e4m3fn
Expand Down
4 changes: 2 additions & 2 deletions tests/moe/test_trtllm_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_moe_fp8(
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="NVFP4 is only supported on SM100, SM110 and SM120",
reason="NVFP4 is only supported on SM100, SM110 and SM120/SM121",
)
def test_moe_nvfp4(
batch_size,
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def dequant_mxfp4_batches(
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120",
reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120/SM121",
)
def test_moe_mxfp8_mxfp4(
batch_size,
Expand Down
Loading