diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 98082b447f..7eaeed26ed 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -378,8 +378,8 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cutlass", "cute-dsl"], - "10.3": ["cutlass", "cute-dsl"], + "10.0": ["cutlass", "cute-dsl", "trtllm"], + "10.3": ["cutlass", "cute-dsl", "trtllm"], "11.0": ["cutlass"], "12.0": [], }, diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 401f369530..41ba7b3c18 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -1308,7 +1308,12 @@ def testMmMxfp8(args): res_dtype = args.out_dtype is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - autotune_supported_backends = ["cutlass", "cute-dsl", "auto"] + autotune_supported_backends = [ + "cutlass", + "cute-dsl", + "trtllm", + "auto", + ] res = [] backends = filter_backends_by_compute_capability(backends, args.routine, device) @@ -1336,42 +1341,73 @@ def testMmMxfp8(args): print("[ERROR] No backends to test. Exiting.") return res - ## Prepare input tensors - # Use swizzled layout for optimal performance - is_sf_swizzled_layout = True - + inputs = {} input = torch.randn([m, k], device=device, dtype=torch.bfloat16) - input_mxfp8, input_scale = mxfp8_quantize( - input, is_sf_swizzled_layout=is_sf_swizzled_layout - ) - mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) - mat2_mxfp8, mat2_scale = mxfp8_quantize( - mat2, is_sf_swizzled_layout=is_sf_swizzled_layout - ) + for backend in backends: + ## Prepare input tensors + # Use swizzled layout for optimal performance + is_sf_swizzled_layout = backend in ["cutlass", "trtllm"] + + if not is_sf_swizzled_layout: + sf_layout_input = flashinfer.SfLayout.layout_linear + elif backend == "cutlass" or args.use_128x4_sf_layout: + sf_layout_input = flashinfer.SfLayout.layout_128x4 + elif backend == "trtllm": + if not args.use_128x4_sf_layout: + sf_layout_input = flashinfer.SfLayout.layout_8x4 + else: + sf_layout_input = flashinfer.SfLayout.layout_128x4 + input_mxfp8, input_scale = mxfp8_quantize( + input, sf_swizzle_layout=sf_layout_input + ) + # when using trtllm, the shuffle_matrix_sf_a will swizzle the layout. + mat2_mxfp8, mat2_scale = mxfp8_quantize( + mat2, + is_sf_swizzled_layout=False + if backend == "trtllm" + else is_sf_swizzled_layout, + ) - if args.verbose >= 2: - print(f"[VVERBOSE] {input_mxfp8.shape = }") - print(f"[VVERBOSE] {input_mxfp8.dtype = }") - print(f"[VVERBOSE] {mat2_mxfp8.shape = }") - print(f"[VVERBOSE] {mat2_mxfp8.dtype = }") - print(f"[VVERBOSE] {input_scale.shape = }") - print(f"[VVERBOSE] {input_scale.dtype = }") - print(f"[VVERBOSE] {mat2_scale.shape = }") - print(f"[VVERBOSE] {mat2_scale.dtype = }") + if backend == "trtllm": + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): - if backend in ["cutlass", "cute-dsl", "auto"]: - return flashinfer.gemm.mm_mxfp8( - a=input_mxfp8, - b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t() - a_descale=input_scale, - b_descale=mat2_scale, # mm_mxfp8 handles swizzled 1D internally - out_dtype=res_dtype, - backend=backend, + mat2_mxfp8 = shuffle_matrix_a(mat2_mxfp8, 128).reshape(n, k) + mat2_scale = shuffle_matrix_sf_a( + mat2_scale.reshape(n, k // 32), + 128, + num_elts_per_sf=32, ) - else: - raise ValueError(f"Unsupported backend: {backend}") + mat2_scale = mat2_scale.t() + + if args.verbose >= 2: + print(f"[VERBOSE] {backend}: {input_mxfp8.shape = }") + print(f"[VERBOSE] {backend}: {input_mxfp8.dtype = }") + print(f"[VERBOSE] {backend}: {mat2_mxfp8.shape = }") + print(f"[VERBOSE] {backend}: {mat2_mxfp8.dtype = }") + print(f"[VERBOSE] {backend}: {input_scale.shape = }") + print(f"[VERBOSE] {backend}: {input_scale.dtype = }") + print(f"[VERBOSE] {backend}: {mat2_scale.shape = }") + print(f"[VERBOSE] {backend}: {mat2_scale.dtype = }") + inputs[backend] = (input_mxfp8, mat2_mxfp8, input_scale, mat2_scale) + + def run_backend( + backend: str, + inputs: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + assert backend in ["cutlass", "trtllm", "cute-dsl", "auto"], ( + f"Unsupported backend: {backend}" + ) + input_mxfp8, mat2_mxfp8, input_scale, mat2_scale = inputs + return flashinfer.gemm.mm_mxfp8( + a=input_mxfp8, + b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t() + a_descale=input_scale, + b_descale=mat2_scale, + out_dtype=res_dtype, + backend=backend, + use_8x4_sf_layout=backend == "trtllm" and not args.use_128x4_sf_layout, + ) has_reference_output = False if run_refcheck: @@ -1391,10 +1427,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): for _ in range(warmup_iters): run_backend( cur_backend, - input_mxfp8, - mat2_mxfp8, - input_scale, - mat2_scale, + inputs[cur_backend], ) elif cache_path: with autotune(False, cache=cache_path): @@ -1406,7 +1439,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): for cur_backend in backends: if run_refcheck: outputs[cur_backend] = run_backend( - cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale + cur_backend, inputs[cur_backend] ).detach() backend_times[cur_backend] = bench_gpu_time( fn=run_backend, @@ -1416,7 +1449,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): enable_cupti=args.use_cupti, use_cuda_graph=is_cuda_graph_compatible, cold_l2_cache=True, - input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale), + input_args=(cur_backend, inputs[cur_backend]), ) # Minimum cosine similarity for swizzled layout diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index 55bfa33691..46a0ce2822 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -555,7 +555,7 @@ def testNvfp4Quantize(args): Returns: dict: List of dictionaries containing performance results """ - from flashinfer.fp4_quantization import SfLayout + from flashinfer import SfLayout if args.verbose >= 1: print("[INFO] Running testNvfp4Quantize") diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp index 0af0b1f030..f63b8cf18c 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp @@ -27,7 +27,7 @@ // linear layout. See QuantizationSFLayout enum for more details about the two layouts. // returns void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF, - bool isSfSwizzledLayout, int64_t alignment, bool enable_pdl) { + int64_t sfSwizzleLayout, int64_t alignment, bool enable_pdl) { CHECK_CUDA(input); CHECK_CONTIGUOUS(input); @@ -50,8 +50,7 @@ void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - auto const layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4 - : tensorrt_llm::QuantizationSFLayout::LINEAR; + auto const layout = static_cast(sfSwizzleLayout); #define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ tensorrt_llm::kernels::invokeMxFP8Quantization( \ @@ -94,7 +93,7 @@ inline uint8_t float_to_ue8m0(float value) { // Used in tests to quantize mxe4m3 tensors on host. void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor, - bool is_sf_swizzled_layout) { + int64_t sfSwizzleLayout) { int32_t const sf_vec_size = 32; auto fp32_dtype = DLDataType{kDLFloat, 32, 1}; CHECK_INPUT_TYPE(x_fp32, fp32_dtype); @@ -104,9 +103,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc int hidden_dim = data_shape[1]; int groups_per_hidden_dim = hidden_dim / sf_vec_size; - tensorrt_llm::QuantizationSFLayout layout = - is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4 - : tensorrt_llm::QuantizationSFLayout::LINEAR; + auto const layout = static_cast(sfSwizzleLayout); for (size_t ti = 0; ti < static_cast(data_shape[0]); ++ti) { for (int group = 0; group < groups_per_hidden_dim; ++group) { @@ -141,7 +138,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc // Used in tests to dequantize mxe4m3 tensors on host. void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, - TensorView float_tensor, bool is_sf_swizzled_layout) { + TensorView float_tensor, int64_t sfSwizzleLayout) { int32_t const sf_vec_size = 32; CHECK_INPUT_TYPE(value_e4m3, dl_uint8); CHECK_INPUT_TYPE(scale_ue8m08sf, dl_uint8); @@ -153,9 +150,7 @@ void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, int hidden_dim = data_shape[1]; int groups_per_hidden_dim = hidden_dim / sf_vec_size; - tensorrt_llm::QuantizationSFLayout layout = - is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4 - : tensorrt_llm::QuantizationSFLayout::LINEAR; + auto const layout = static_cast(sfSwizzleLayout); for (size_t ti = 0; ti < static_cast(data_shape[0]); ++ti) { for (int group = 0; group < groups_per_hidden_dim; ++group) { float* float_ptr = diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h index 15d587e8bf..b0404bc06b 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h @@ -67,14 +67,14 @@ inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, // alignment: sfVecSize // returns fp8_quantized and block_scale_factors. void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF, - bool is_sf_swizzled_layout, int64_t alignment, bool enable_pdl); + int64_t sfSwizzleLayout, int64_t alignment, bool enable_pdl); // x_fp32: [M, K], fp32_quantized (on the host) // isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in // linear layout. See QuantizationSFLayout enum for more details about the two layouts. // returns fp8_quantized and block_scale_factors (on the host). void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor, - bool is_sf_swizzled_layout = true); + int64_t sfSwizzleLayout = 2); void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, - TensorView float_tensor, bool is_sf_swizzled_layout = true); + TensorView float_tensor, int64_t sfSwizzleLayout = 2); diff --git a/csrc/trtllm_gemm_runner.cu b/csrc/trtllm_gemm_runner.cu index 7ab52ca06a..5fc0a02eb6 100644 --- a/csrc/trtllm_gemm_runner.cu +++ b/csrc/trtllm_gemm_runner.cu @@ -38,6 +38,7 @@ struct TrtllmGenGemmRunnerOptions { gemm::trtllm::gen::Dtype outputType; bool transposeMmaOutput{false}; gemm::trtllm::gen::SfLayout sfLayoutB; + gemm::gemm::MatrixLayout layoutA{gemm::gemm::MatrixLayout::MajorK}; }; int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K, @@ -45,7 +46,7 @@ int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K, static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO = "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_rM_TN_" "transOut_" - "noShflA_dsFp8_schedP2x2x1x3_sm100f"; + "noShflA_dsFp8_schPd2x2x1x3_sm100f"; static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO = "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_" @@ -53,7 +54,7 @@ int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K, static constexpr const char* KERNEL_NAME_LARGE_N = "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_" - "transOut_noShflA_dsFp8_schedP2x2x1x3_sm100f"; + "transOut_noShflA_dsFp8_schPd2x2x1x3_sm100f"; static constexpr const char* KERNEL_NAME_DEFAULT = "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_rM_TN_" @@ -98,14 +99,18 @@ class TrtllmGenGemmRunner { if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType && options.mTransposeMmaOutput == mOptions.transposeMmaOutput && - options.mSfLayoutB == mOptions.sfLayoutB) { + options.mSfLayoutB == mOptions.sfLayoutB && + options.mLayoutA == mOptions.layoutA) { // FIXME(siyuanf): expose matrix layout to user mPassingConfigIndices.push_back(i); } } - FLASHINFER_CHECK( - mPassingConfigIndices.size() > 0, - "No valid tactic found for the given options (precision, transpose, sf layout)"); + FLASHINFER_CHECK(mPassingConfigIndices.size() > 0, + "No valid tactic found for the given options", + "mDtypeA: ", gemm::trtllm::gen::dtypeToString(mOptions.eltType), + "mDtypeC: ", gemm::trtllm::gen::dtypeToString(mOptions.outputType), + "mTransposeMmaOutput: ", mOptions.transposeMmaOutput, + "mSfLayoutB: ", gemm::trtllm::gen::sfLayoutToString(mOptions.sfLayoutB)); } int64_t getWorkspaceSizeInBytes(int64_t m, int64_t n, int64_t k, int64_t tactic) { @@ -150,6 +155,10 @@ class TrtllmGenGemmRunner { gemmData.mProblemDimensions.mRank = 0; gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + // Inputs gemmData.mInputBuffers.mPtrA = mOptions.transposeMmaOutput ? b : a; gemmData.mInputBuffers.mPtrSfA = mOptions.transposeMmaOutput ? bScale : aScale; @@ -202,6 +211,10 @@ class TrtllmGenGemmRunner { gemmData.mProblemDimensions.mRank = 0; gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + std::vector sortedIndices = mPassingConfigIndices; std::sort(sortedIndices.begin(), sortedIndices.end(), [&configs](int64_t idx0, int64_t idx1) { auto const& optionsA = configs[idx0].mOptions; @@ -250,14 +263,12 @@ class TrtllmGenGemmRunner { int64_t selectHeuristic(int64_t m, int64_t n, int64_t k) const { if (mOptions.eltType == gemm::trtllm::gen::Dtype::E4m3) { return select_kernel_fp8(m, n, k, gemm::gemm::GemmInterface()); - } else if (mOptions.eltType == gemm::trtllm::gen::Dtype::E2m1) { + } else { auto sortedIndices = getValidTactics(m, n, k); TVM_FFI_ICHECK(!sortedIndices.empty()) << "No valid tactic found"; // the getValidTactics is sorted by priority, so the first one is the best one return sortedIndices[0]; - } else { - TVM_FFI_ICHECK(false) << "Unsupported eltType"; } } @@ -269,9 +280,12 @@ class TrtllmGenGemmRunner { using tvm::ffi::Array; using tvm::ffi::Optional; -void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, TensorView a_scale, - TensorView b_scale, Optional globalScale, TensorView out, - bool use_8x4_sf_layout, int64_t tactic) { +void trtllm_gemm(int64_t input_dtype_, int64_t output_dtype_, TensorView workspace_buffer, + TensorView a, TensorView b, TensorView a_scale, TensorView b_scale, + Optional globalScale, TensorView out, bool use_8x4_sf_layout, + int64_t tactic) { + auto input_dtype = static_cast(input_dtype_); + auto output_dtype = static_cast(output_dtype_); CHECK_DEVICE(a, b); CHECK_DEVICE(a, out); CHECK_INPUT(a); @@ -302,11 +316,12 @@ void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, Tensor TVM_FFI_ICHECK(out.size(0) == m && out.size(1) == n) << "Output tensor has wrong dimensions"; auto runner = flashinfer::TrtllmGenGemmRunner(flashinfer::TrtllmGenGemmRunnerOptions{ - .eltType = is_fp8 ? gemm::trtllm::gen::Dtype::E4m3 : gemm::trtllm::gen::Dtype::E2m1, - .outputType = gemm::trtllm::gen::Dtype::Bfloat16, + .eltType = input_dtype, + .outputType = output_dtype, .transposeMmaOutput = true, .sfLayoutB = use_8x4_sf_layout ? gemm::trtllm::gen::SfLayout::R8c4 : gemm::trtllm::gen::SfLayout::R128c4, + .layoutA = gemm::gemm::MatrixLayout::MajorK, // currently only support major k layout }); if (tactic == -1) { @@ -332,27 +347,23 @@ void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, Tensor } } -enum class Dtype : int64_t { - E2m1 = 0, - E4m3 = 1, - Bfloat16 = 2, -}; - -Array trtllm_gemm_tactics(int64_t m, int64_t n, int64_t k, int64_t input_dtype, - int64_t output_dtype, bool use_8x4_sf_layout) { - TVM_FFI_ICHECK(input_dtype == static_cast(Dtype::E4m3) || - input_dtype == static_cast(Dtype::E2m1)) - << "Unsupported input dtype"; - TVM_FFI_ICHECK_EQ(output_dtype, static_cast(Dtype::Bfloat16)) - << "Unsupported output dtype"; +Array trtllm_gemm_tactics(int64_t m, int64_t n, int64_t k, int64_t input_dtype_, + int64_t output_dtype_, bool use_8x4_sf_layout) { + auto input_dtype = static_cast(input_dtype_); + auto output_dtype = static_cast(output_dtype_); + TVM_FFI_CHECK(input_dtype == gemm::trtllm::gen::Dtype::E4m3 || + input_dtype == gemm::trtllm::gen::Dtype::MxE4m3 || + input_dtype == gemm::trtllm::gen::Dtype::E2m1, + "Unsupported input dtype"); + TVM_FFI_CHECK(output_dtype == gemm::trtllm::gen::Dtype::Bfloat16, "Unsupported output dtype"); auto runner = flashinfer::TrtllmGenGemmRunner(flashinfer::TrtllmGenGemmRunnerOptions{ - .eltType = input_dtype == static_cast(Dtype::E4m3) ? gemm::trtllm::gen::Dtype::E4m3 - : gemm::trtllm::gen::Dtype::E2m1, - .outputType = gemm::trtllm::gen::Dtype::Bfloat16, + .eltType = input_dtype, + .outputType = output_dtype, .transposeMmaOutput = true, .sfLayoutB = use_8x4_sf_layout ? gemm::trtllm::gen::SfLayout::R8c4 : gemm::trtllm::gen::SfLayout::R128c4, + .layoutA = gemm::gemm::MatrixLayout::MajorK, // currently only support major k layout }); return runner.getValidTactics(m, n, k); diff --git a/csrc/trtllm_low_latency_gemm_runner.cu b/csrc/trtllm_low_latency_gemm_runner.cu index 6b47d2f7cc..43135445b7 100644 --- a/csrc/trtllm_low_latency_gemm_runner.cu +++ b/csrc/trtllm_low_latency_gemm_runner.cu @@ -50,9 +50,9 @@ gemm::gemm::GemmData createGemmData(int64_t m, int64_t n, int64_t k) { gemmData.mProblemDimensions.mN = m; gemmData.mProblemDimensions.mK = k; // TODO(jimmyzho) disable until fix trtllm-gen - // gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - // gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - // gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; gemmData.mProblemDimensions.mRank = 0; gemmData.mProblemDimensions.mWorldSize = 1; @@ -63,64 +63,40 @@ gemm::gemm::GemmData createGemmData(int64_t m, int64_t n, int64_t k) { * Very rough heuristic for selecting a kernel. Prefer using auto-tuning. */ int64_t select_kernel(int32_t m, int32_t n, int32_t k, const gemm::gemm::GemmInterface& interface) { - static constexpr const char* KERNEL_MMAN_8_TILEK_128_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x2_16dp256b_splitK2_BN_" + static constexpr const char* KERNEL_MMAN_8_TILEK_128 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_8_TILEK_128_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x3_16dp256b_splitK3_BN_" + static constexpr const char* KERNEL_MMAN_8_TILEK_256 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_8_TILEK_256_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x2_16dp256b_splitK2_BN_" + static constexpr const char* KERNEL_MMAN_16_TILEK_128 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x32_m128x64x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_8_TILEK_256_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x3_16dp256b_splitK3_BN_" + static constexpr const char* KERNEL_MMAN_16_TILEK_256 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x32_m128x64x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_16_TILEK_128_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x2_16dp256b_splitK2_BN_" + static constexpr const char* KERNEL_MMAN_32_TILEK_128 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_16_TILEK_128_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x3_16dp256b_splitK3_BN_" + static constexpr const char* KERNEL_MMAN_32_TILEK_256 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_16_TILEK_256_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x2_16dp256b_splitK2_BN_" + static constexpr const char* KERNEL_MMAN_64_TILEK_128 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_16_TILEK_256_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x3_16dp256b_splitK3_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_32_TILEK_128_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x2_16dp256b_splitK2_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_32_TILEK_128_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x3_16dp256b_splitK3_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_32_TILEK_256_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x2_16dp256b_splitK2_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_32_TILEK_256_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x3_16dp256b_splitK3_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_64_TILEK_128_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x64_m128x64x32_cga1x1x2_16dp256b_splitK2_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_64_TILEK_128_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x64_m128x64x32_cga1x1x3_16dp256b_splitK3_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_64_TILEK_256_SPLITK_2 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x2_16dp256b_splitK2_BN_" - "transOut_schedS_sm100f"; - static constexpr const char* KERNEL_MMAN_64_TILEK_256_SPLITK_3 = - "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x3_16dp256b_splitK3_BN_" + static constexpr const char* KERNEL_MMAN_64_TILEK_256 = + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x1_16dp256b_rM_BN_" "transOut_schedS_sm100f"; std::string kernel_name; if (m <= 8) { - kernel_name = KERNEL_MMAN_8_TILEK_128_SPLITK_2; + kernel_name = KERNEL_MMAN_8_TILEK_128; } else if (m <= 16) { - kernel_name = KERNEL_MMAN_16_TILEK_128_SPLITK_2; + kernel_name = KERNEL_MMAN_16_TILEK_128; } else if (m <= 32) { - kernel_name = KERNEL_MMAN_32_TILEK_128_SPLITK_2; + kernel_name = KERNEL_MMAN_32_TILEK_128; } else { - kernel_name = KERNEL_MMAN_64_TILEK_128_SPLITK_2; + kernel_name = KERNEL_MMAN_64_TILEK_128; } auto const& configs = interface.getGemmConfigs(); @@ -170,7 +146,7 @@ class TrtllmLowLatencyGemmRunner { configOptions.mDtypeC == mOptions.outputType && configOptions.mTransposeMmaOutput == true && configOptions.mLayoutA == gemm::gemm::MatrixLayout::BlockMajorK && - configOptions.mUseShuffledMatrixA) { + configOptions.mUseShuffledMatrix) { mPassingConfigIndices.push_back(i); } } diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 6ac32efe69..247218d748 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -60,7 +60,6 @@ from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache from .quantization.fp4_quantization import ( - SfLayout, block_scale_interleave, nvfp4_block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, @@ -79,8 +78,6 @@ ) from .quantization.fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize from .fused_moe import ( - ActivationType, - RoutingMethodType, cutlass_fused_moe, reorder_rows_for_gated_act_gemm, trtllm_bf16_moe, @@ -167,6 +164,7 @@ from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs from .sampling import top_p_renorm_probs as top_p_renorm_probs from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs +from .tllm_enums import SfLayout, ActivationType, RoutingMethodType from . import topk as topk from .topk import top_k as top_k from .topk import top_k_page_table_transform as top_k_page_table_transform diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 99a51039bc..b42cfeeb1a 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -140,7 +140,7 @@ class ArtifactPath: "b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/" ) TRTLLM_GEN_GEMM: str = ( - "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3/" + "b117d5a6b2dd2228aa966a938eac398cf336d8c0/gemm-b3c1646-1fddea2/" ) CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py @@ -162,7 +162,7 @@ class CheckSumHash: ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( - "15cb8c85dfb5eddd4f121d64cb5a718321fb55b85aa19df10ddc1329d4a726b9" + "18262161e624f7da9d2d04c528c645a5ff7f5efd774024a0b2eb92748ab18bb9" ) map_checksums: dict[str, str] = { safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index d6e8f0bf1f..09e8d8eb5f 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -10,7 +10,6 @@ # Re-export everything from the new location from .quantization.fp4_quantization import ( - SfLayout, block_scale_interleave, nvfp4_block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, @@ -40,7 +39,6 @@ ) __all__ = [ - "SfLayout", "block_scale_interleave", "nvfp4_block_scale_interleave", "e2m1_and_ufp8sf_scale_to_float", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7e0760e7b2..2b52c63947 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -15,7 +15,6 @@ """ import functools -from enum import IntEnum from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -54,143 +53,7 @@ get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ) - - -# The type of method in top-K routing, for use in torch custom op -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h -class RoutingMethodType(IntEnum): - # Default: Softmax -> TopK - Default = (0,) - # Renormalize: TopK -> Softmax - Renormalize = (1,) - # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups - DeepSeekV3 = (2,) - # Llama4: Top1 -> Sigmoid - Llama4 = (3,) - # Qwen3: Softmax -> TopK -> Renormalize - RenormalizeNaive = (4,) - # TopK only (no softmax) - TopK = (5,) - # Unspecified - Unspecified = 6 - - -# Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h -class ActivationType(IntEnum): - Gelu = 0 - Relu = 1 - Silu = 2 - Swiglu = 3 - Geglu = 4 - SwigluBias = 5 - Relu2 = 6 - Identity = 7 - InvalidType = 8 - - -class DtypeTrtllmGen(IntEnum): - def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid): - value = ( - (block_format_bit << 24) - | (signed_bit << 20) - | (integer_bit << 16) - | (num_bits << 8) - | uid - ) - obj = int.__new__(cls, value) - obj._value_ = value - return obj - - # keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h - Bfloat16 = (0, 1, 0, 16, 0) - Bool = (0, 0, 1, 1, 1) - E2m1 = (1, 1, 0, 4, 2) - E2m3 = (1, 1, 0, 6, 3) - E3m2 = (1, 1, 0, 6, 4) - E4m3 = (0, 1, 0, 8, 5) - E5m2 = (0, 1, 0, 8, 6) - Fp16 = (0, 1, 0, 16, 7) - Fp32 = (0, 1, 0, 32, 8) - Int8 = (0, 1, 1, 8, 9) - Int32 = (0, 1, 1, 32, 10) - Int64 = (0, 1, 1, 64, 11) - MxE2m1 = (1, 1, 0, 4, 12) - MxE4m3 = (1, 1, 0, 8, 13) - MxInt4 = (1, 1, 1, 4, 14) - UE8m0 = (0, 0, 0, 8, 15) - UInt8 = (0, 0, 1, 8, 16) - UInt16 = (0, 0, 1, 16, 17) - UInt32 = (0, 0, 1, 32, 18) - UInt64 = (0, 0, 1, 64, 19) - UInt128 = (0, 0, 1, 128, 20) - Void = (0, 1, 0, 0, 21) - - -def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: - if dtype in [ - DtypeTrtllmGen.MxE4m3, - DtypeTrtllmGen.E2m1, - DtypeTrtllmGen.MxE2m1, - DtypeTrtllmGen.MxE4m3, - DtypeTrtllmGen.MxInt4, - ]: - return True - else: - return False - - -def deduce_trtllm_gen_tensor_dtype( - x: torch.Tensor, scale: Optional[torch.Tensor] -) -> DtypeTrtllmGen: - hidden_size = x.shape[-1] - if x.dtype == torch.uint8: # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 - hidden_size *= 2 - if x.dtype == torch.bfloat16: - dtype = DtypeTrtllmGen.Bfloat16 - elif x.dtype == torch.float8_e4m3fn: - dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 - elif ( - x.dtype == torch.uint8 - ): # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 - assert scale is not None, "Scale tensor must be provided for float4x2 input" - if scale.shape[-1] == hidden_size // 16: - dtype = DtypeTrtllmGen.E2m1 - else: - dtype = DtypeTrtllmGen.MxE2m1 - else: - raise ValueError("Unsupported trtllm-gen input tensor.") - return dtype - - -# See MatrixLayout from include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h -class WeightLayout(IntEnum): - # K-major layout (default). [Mn, K] - MajorK = 0 - # M-major for A and N-major for B. [K, Mn] - MajorMn = 1 - # Layout is blocked along the K dimension. [K / blockK, Mn, blockK] - # where blockK is fixed at 128B - BlockMajorK = 2 - - -# The type of gated activation function -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h -class GatedActType(IntEnum): - # SwiGlu - SwiGlu = 0 - # GeGlu - GeGlu = 1 - - -# The type of FP8 quantization -# Please keep this in sync with the counterpart defined in trtllm_fused_moe_kernel_launcher.cu -class Fp8QuantizationType(IntEnum): - # No FP8 quantization - NoneFp8 = 0 - # DeepSeek FP8 - DeepSeekFp8 = 1 - # MxFp8 x MxFp8 - MxFp8 = 2 +from ..tllm_enums import * @functools.cache diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 30ef414daf..4fa980a971 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -59,6 +59,7 @@ from ..jit.gemm import gen_deepgemm_sm100_module from ..jit.cpp_ext import get_cuda_version from ..jit.gemm import gen_fp8_blockscale_gemm_sm90_module +from ..tllm_enums import DtypeTrtllmGen, SfLayout CUDNN_AVAILABLE = False @@ -773,14 +774,6 @@ def _pad_to_multiple(x, multiple): ) -@functools.cache -def get_trtllm_gemm_module(): - mod = gen_trtllm_gen_gemm_module() - op = mod.build_and_load() - setup_cubin_loader(mod.get_library_path()) - return op - - @functools.cache def get_gemm_sm100_module_cutlass_fp8(): module = gen_gemm_sm100_module_cutlass_fp8().build_and_load() @@ -2536,7 +2529,8 @@ def _check_mm_mxfp8_problem_size( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused + use_8x4_sf_layout: bool = True, + backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto", # unused ) -> bool: # Generic checks ## pre-check the input tensors and block scale tensors @@ -2574,9 +2568,13 @@ def _check_mm_mxfp8_problem_size( # MXFP8 block size sf_vec_size = 32 + if a_descale.ndim == 2: + sf_layout = SfLayout.layout_linear + else: + sf_layout = SfLayout.layout_8x4 if use_8x4_sf_layout else SfLayout.layout_128x4 if a_descale.ndim == 1: - expected_len = _mxfp8_swizzled_scale_len(a.shape[0], a.shape[1]) + expected_len = _mxfp8_swizzled_scale_len(a.shape[0], a.shape[1], sf_layout) if a_descale.shape[0] != expected_len: raise ValueError( "a_descale shape mismatch for swizzled layout. " @@ -2600,7 +2598,7 @@ def _check_mm_mxfp8_problem_size( ) if b_descale.ndim == 1: - expected_len = _mxfp8_swizzled_scale_len(b.shape[1], b.shape[0]) + expected_len = _mxfp8_swizzled_scale_len(b.shape[1], b.shape[0], sf_layout) if b_descale.shape[0] != expected_len: raise ValueError( "b_descale shape mismatch for swizzled layout. " @@ -2650,8 +2648,30 @@ def _cutlass_gemm_mxfp8_requirement( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + use_8x4_sf_layout: bool = True, + backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto", +): + return True + + +@supported_compute_capability([100, 103]) +def _trtllm_gemm_mxfp8_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + use_8x4_sf_layout: bool = True, + backend: Literal["trtllm", "auto"] = "auto", ): + if out_dtype != torch.bfloat16: + return False + if a.ndim != 2 or b.ndim != 2: # currently don't support BlockMajorK layout + return False + k, n = b.shape + if k % 256 != 0: + return False return True @@ -2663,6 +2683,7 @@ def _cute_dsl_gemm_mxfp8_requirement( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, # unused out_dtype: torch.dtype = torch.bfloat16, # unused + use_8x4_sf_layout: bool = True, # unused backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", # unused ): # CuTe DSL MXFP8 path currently expects swizzled 1D block scales @@ -3068,8 +3089,10 @@ def _heuristic_func_mm_mxfp8( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + use_8x4_sf_layout: bool = True, + backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto", ) -> List[str]: + # don't select trtllm since it requires weight shuffling if "cutlass" in suitable_backends: return ["cutlass"] return [] @@ -3078,6 +3101,7 @@ def _heuristic_func_mm_mxfp8( @backend_requirement( { "cutlass": _cutlass_gemm_mxfp8_requirement, + "trtllm": _trtllm_gemm_mxfp8_requirement, "cute-dsl": _cute_dsl_gemm_mxfp8_requirement, }, common_check=_check_mm_mxfp8_problem_size, @@ -3091,7 +3115,8 @@ def mm_mxfp8( b_descale: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "cute-dsl", "auto"] = "auto", + use_8x4_sf_layout: bool = False, + backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto", ) -> torch.Tensor: r"""MM MXFP8 (block size 32) @@ -3106,7 +3131,8 @@ def mm_mxfp8( a_descale: torch.Tensor Block scale tensor for A. Can be: - 2D non-swizzled: shape (m, k // 32) - - 1D swizzled: shape (M_padded * K_padded,) where M_padded = round_up(m, 128), K_padded = round_up(k // 32, 4) + - 1D swizzled: shape (M_padded * K_padded,) + where M_padded = round_up(m, 8 if 8x4 layout else 128), K_padded = round_up(k // 32, 4) dtype: uint8. b_descale: torch.Tensor @@ -3123,11 +3149,16 @@ def mm_mxfp8( out_dtype: torch.dtype Output dtype, bf16 or fp16. Defaults to ``torch.bfloat16``. - backend: Literal["cutlass", "cute-dsl", "auto"] + use_8x4_sf_layout: bool + Whether the scale tensors for a are in 8x4 layout (vs 128x4). + + backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] The backend to use for the operation. Defaults to ``"auto"``. ``"auto"`` selects the CUTLASS backend. - The ``"cute-dsl"`` backend currently requires swizzled 1D scales - (``mxfp8_quantize(..., is_sf_swizzled_layout=True)``). + - The ``"cute-dsl"`` backend currently requires swizzled 1D scales + (``mxfp8_quantize(..., is_sf_swizzled_layout=True)``). + - The ``"trtllm"`` requires b to be quantized with 128x4 swizzle layout and shuffled. + a can be quantized with either 128x4 or 8x4 layout (controlled by `use_8x4_sf_layout`). Returns ------- @@ -3204,6 +3235,9 @@ def mm_mxfp8( "cutlass": lambda: get_cutlass_mxfp8_gemm_module( major ).cutlass_mxfp8_gemm_runner(), + "trtllm": lambda: get_trtllm_gemm_module().trtllm_mxfp8_gemm_runner( + use_8x4_sf_layout + ), "cute-dsl": lambda: _cute_dsl_gemm_mxfp8_runner(major, minor, True, out_dtype), } @@ -3910,11 +3944,20 @@ def _pad_up(x, y): return ((x + y - 1) // y) * y -def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: - """Return the 1D swizzled scale length for MXFP8 (F8_128x4 layout).""" - m_padded = _pad_up(m, 128) - num_k_tiles = _pad_up(k, 128) // 128 - return m_padded * num_k_tiles * 4 +def _mxfp8_swizzled_scale_len(m: int, k: int, swizzle_layout: SfLayout) -> int: + """Return the 1D swizzled scale length for MXFP8.""" + if swizzle_layout == SfLayout.layout_128x4: + m_padded = _pad_up(m, 128) + num_k_tiles = _pad_up(k, 128) // 128 + return m_padded * num_k_tiles * 4 + elif swizzle_layout == SfLayout.layout_8x4: + m_padded = _pad_up(m, 8) + num_k_tiles = _pad_up(k, 128) // 128 + return m_padded * num_k_tiles * 4 + elif swizzle_layout == SfLayout.layout_linear: + return m * k + else: + raise ValueError(f"Unsupported swizzle layout: {swizzle_layout}") _MM_FP4_TUNING_CONFIG_8x4 = TuningConfig( @@ -3979,7 +4022,9 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: 2, # a_descale_tensor_index 0, lambda shapes: ( - _mxfp8_swizzled_scale_len(shapes[0][0], shapes[0][1]) + _mxfp8_swizzled_scale_len( + shapes[0][0], shapes[0][1], SfLayout.layout_128x4 + ) if len(shapes[2]) == 1 else shapes[0][0] ), @@ -4116,7 +4161,7 @@ def mm_fp4( backend_to_runner_factory = { "cudnn": lambda: _cudnn_gemm_fp4_runner(), - "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( + "trtllm": lambda: get_trtllm_gemm_module().trtllm_fp4_gemm_runner( use_8x4_sf_layout ), "cutlass": lambda: get_cutlass_fp4_gemm_module( @@ -4547,6 +4592,8 @@ def gemm_fp8_nt_groupwise( elif backend == "trtllm": # mma_sm is ignored get_trtllm_gemm_module().trtllm_gemm( + DtypeTrtllmGen.E4m3, + DtypeTrtllmGen.Bfloat16, workspace_buffer, a, b, @@ -4562,87 +4609,166 @@ def gemm_fp8_nt_groupwise( @functools.cache -def get_trtllm_fp4_gemm_module(): +def get_trtllm_gemm_module(): mod = gen_trtllm_gen_gemm_module() op = mod.build_and_load() setup_cubin_loader(mod.get_library_path()) - def trtllm_fp4_gemm_runner(use_8x4_sf_layout: bool = True): - class TrtllmFp4GemmRunner(TunableRunner): - def __init__(self, use_8x4_sf_layout: bool = True): - self._fp4_gemm_runner = op.trtllm_gemm - self._use_8x4_sf_layout = use_8x4_sf_layout + class TrtllmGemmRunner(TunableRunner): + def __init__( + self, + input_dtype: DtypeTrtllmGen, + output_dtype: DtypeTrtllmGen, + use_8x4_sf_layout: bool = True, + ): + self._gemm_runner = op.trtllm_gemm + self._use_8x4_sf_layout = use_8x4_sf_layout + self._input_dtype = input_dtype + self._output_dtype = output_dtype - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 + def unpack_inputs( + self, + inputs: List[torch.Tensor], + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + ( + a, + b, + a_descale, + b_descale, + alpha, + _, + out, + _, + _, + workspace_buffer, + ) = inputs + return a, b, a_descale, b_descale, alpha, out, workspace_buffer - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] - m = a[0] - n = b[0] + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 0 + b_tensor_index = 1 + + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[1] + assert a[1] == b[0], ( + f"The k dimension is inconsistent between A ({a}) and B ({b})" + ) + if self._input_dtype == DtypeTrtllmGen.E2m1: k = a[1] * 2 - ( - a, - b, - a_descale, - b_descale, - alpha, - _, - out, - _, - _, - workspace_buffer, - ) = inputs - type_e2m1 = 0 - type_bf16 = 2 - return list( - op.trtllm_gemm_tactics( - m, n, k, type_e2m1, type_bf16, self._use_8x4_sf_layout - ) + else: + k = a[1] + ( + a, + b, + a_descale, + b_descale, + alpha, + out, + workspace_buffer, + ) = self.unpack_inputs(inputs) + return list( + op.trtllm_gemm_tactics( + m, + n, + k, + self._input_dtype, + self._output_dtype, + self._use_8x4_sf_layout, ) + ) - def forward( + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + a, b, a_descale, b_descale, alpha, out, workspace_buffer = ( + self.unpack_inputs(inputs) + ) + self._gemm_runner( + self._input_dtype, + self._output_dtype, + workspace_buffer, + a, + b.T, + a_descale, + b_descale.T, + alpha, + out, + self._use_8x4_sf_layout, + tactic, + ) + return out + + def trtllm_gemm_runner( + input_dtype: DtypeTrtllmGen, + output_dtype: DtypeTrtllmGen, + use_8x4_sf_layout: bool = True, + ): + return TrtllmGemmRunner(input_dtype, output_dtype, use_8x4_sf_layout) + + def trtllm_fp4_gemm_runner( + use_8x4_sf_layout: bool = True, + ): + return TrtllmGemmRunner( + DtypeTrtllmGen.E2m1, DtypeTrtllmGen.Bfloat16, use_8x4_sf_layout + ) + + def trtllm_mxfp8_gemm_runner( + use_8x4_sf_layout: bool = True, + ): + # monkey patch to align with cutlass runner's input format + class TrtllmMxFp8GemmRunner(TrtllmGemmRunner): + def unpack_inputs( self, inputs: List[torch.Tensor], - tactic: int = -1, - do_preparation: bool = False, - **kwargs, - ): + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: ( a, b, a_descale, b_descale, - alpha, - _, + out_dtype, out, - _, - _, workspace_buffer, ) = inputs - self._fp4_gemm_runner( - workspace_buffer, - a, - b.T, - a_descale, - b_descale.T, - alpha, - out, - self._use_8x4_sf_layout, - tactic, - ) - return out + assert out_dtype == torch.bfloat16 + return a, b, a_descale, b_descale, None, out, workspace_buffer - return TrtllmFp4GemmRunner(use_8x4_sf_layout) + return TrtllmMxFp8GemmRunner( + DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.Bfloat16, use_8x4_sf_layout + ) # Register the module return SimpleNamespace( + trtllm_gemm_runner=trtllm_gemm_runner, trtllm_fp4_gemm_runner=trtllm_fp4_gemm_runner, + trtllm_mxfp8_gemm_runner=trtllm_mxfp8_gemm_runner, + trtllm_gemm=op.trtllm_gemm, ) diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py index 31f0c5db29..8b8e6da1ea 100644 --- a/flashinfer/quantization/fp4_quantization.py +++ b/flashinfer/quantization/fp4_quantization.py @@ -15,7 +15,6 @@ """ import functools -from enum import Enum from types import SimpleNamespace from typing import List, Optional, Tuple @@ -46,6 +45,7 @@ supported_compute_capability, round_up, ) +from ..tllm_enums import SfLayout def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): @@ -825,16 +825,6 @@ def shuffle_matrix_sf_a( return block_scale_interleave(w_shuffled) -class SfLayout(Enum): - """ - Layout of scale factors for NVFP4. - """ - - layout_128x4 = 0 - layout_8x4 = 1 - layout_linear = 2 - - @flashinfer_api def nvfp4_quantize( a, diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py index 997d5d2b5f..f2c9f41249 100644 --- a/flashinfer/quantization/fp8_quantization.py +++ b/flashinfer/quantization/fp8_quantization.py @@ -1,6 +1,6 @@ import functools from types import SimpleNamespace -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import torch @@ -11,6 +11,7 @@ register_custom_op, register_fake_op, ) +from ..tllm_enums import SfLayout def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): @@ -29,7 +30,7 @@ def get_mxfp8_quantization_sm100_module(): ) def mxfp8_quantize_sm100( input: torch.Tensor, - is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: SfLayout = SfLayout.layout_linear, alignment: int = 32, enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -37,7 +38,7 @@ def mxfp8_quantize_sm100( Args: input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (SfLayout, optional): Swizzle layout for scale factors. Defaults to SfLayout.layout_linear. alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability. Defaults to None. @@ -48,18 +49,26 @@ def mxfp8_quantize_sm100( """ if input.device.type == "cpu": out_val = torch.empty(input.shape, dtype=torch.uint8, device=input.device) - if is_sf_swizzled_layout: + if sf_swizzle_layout == SfLayout.layout_128x4: out_sf_size = _compute_swizzled_layout_sf_size( input.shape[0], input.shape[1] // 32, 128 ) - else: + elif sf_swizzle_layout == SfLayout.layout_linear: out_sf_size = input.numel() // 32 + elif sf_swizzle_layout == SfLayout.layout_8x4: + raise ValueError( + f"{sf_swizzle_layout} is not supported for mxfp8 quantization on CPU." + ) + else: + raise ValueError( + f"Invalid sf_swizzle_layout value: {sf_swizzle_layout}" + ) out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) module.mxfp8_quantize_host( input, out_val, out_sf, - is_sf_swizzled_layout, + sf_swizzle_layout.value, ) return out_val, out_sf else: @@ -73,16 +82,22 @@ def mxfp8_quantize_sm100( dtype=torch.float8_e4m3fn, device=input.device, ) - if is_sf_swizzled_layout: + if sf_swizzle_layout == SfLayout.layout_128x4: out_sf_size = _compute_swizzled_layout_sf_size(m, padded_k // 32, 128) - else: + elif sf_swizzle_layout == SfLayout.layout_8x4: + out_sf_size = _compute_swizzled_layout_sf_size(m, padded_k // 32, 8) + elif sf_swizzle_layout == SfLayout.layout_linear: out_sf_size = m * padded_k // 32 + else: + raise ValueError( + f"Invalid sf_swizzle_layout value: {sf_swizzle_layout}" + ) out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) module.mxfp8_quantize( input, out_val, out_sf, - is_sf_swizzled_layout, + sf_swizzle_layout.value, alignment, enable_pdl, ) @@ -91,7 +106,7 @@ def mxfp8_quantize_sm100( @register_fake_op("flashinfer::mxfp8_quantize_sm100") def _fake_mxfp8_quantize_sm100( input: torch.Tensor, - is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: SfLayout = SfLayout.layout_linear, alignment: int = 32, ) -> Tuple[torch.Tensor, torch.Tensor]: m, k = input.shape @@ -107,14 +122,14 @@ def _fake_mxfp8_quantize_sm100( def mxfp8_dequantize_host_sm100( input: torch.Tensor, scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: SfLayout = SfLayout.layout_linear, ) -> torch.Tensor: """Dequantize input tensor from MxFP8 format. Args: input (torch.Tensor): Input tensor of shape [M, K] with dtype FLOAT8_E4M3. scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (SfLayout, optional): Swizzle layout for scale factors. Defaults to SfLayout.layout_linear. Returns: torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. @@ -124,7 +139,7 @@ def mxfp8_dequantize_host_sm100( input, scale_tensor, out, - is_sf_swizzled_layout, + sf_swizzle_layout.value, ) return out @@ -132,7 +147,7 @@ def mxfp8_dequantize_host_sm100( def _fake_mxfp8_dequantize_host_sm100( input: torch.Tensor, scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: SfLayout = SfLayout.layout_linear, ) -> torch.Tensor: return input.new_empty([input.shape[0], input.shape[1]], dtype=torch.float32) @@ -149,7 +164,8 @@ def mxfp8_quantize( is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: Optional[bool] = None, - backend: str = "cuda", + backend: Literal["cuda", "cute-dsl"] = "cuda", + sf_swizzle_layout: Optional[SfLayout] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to MxFP8 format. @@ -162,9 +178,12 @@ def mxfp8_quantize( alignment (int, optional): sfVecSize. Defaults to 32. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. - backend (str, optional): Backend to use for quantization. Options are: + backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are: - "cuda": Use JIT-compiled CUDA kernel (default, stable) - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + The SfLayout.layout_8x4 is only available for 'cuda' backend. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: @@ -183,6 +202,11 @@ def mxfp8_quantize( f"backend must be 'cuda' or 'cute-dsl', got '{backend}'" ) + if sf_swizzle_layout is None: + sf_swizzle_layout = ( + SfLayout.layout_128x4 if is_sf_swizzled_layout else SfLayout.layout_linear + ) + if backend == "cute-dsl": from ..cute_dsl import is_cute_dsl_available @@ -193,6 +217,9 @@ def mxfp8_quantize( ) from .kernels.mxfp8_quantize import mxfp8_quantize_cute_dsl + if sf_swizzle_layout == SfLayout.layout_8x4: + raise ValueError("SfLayout.layout_8x4 is not supported in cute-dsl backend") + return mxfp8_quantize_cute_dsl( input, is_sf_swizzled_layout=is_sf_swizzled_layout, @@ -205,7 +232,7 @@ def mxfp8_quantize( enable_pdl = device_support_pdl(input.device) x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( input, - is_sf_swizzled_layout, + sf_swizzle_layout, alignment, enable_pdl, ) @@ -217,6 +244,7 @@ def mxfp8_dequantize_host( input: torch.Tensor, scale_tensor: torch.Tensor, is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: Optional[SfLayout] = None, ) -> torch.Tensor: """Dequantize input tensor from MxFP8 format. @@ -226,15 +254,22 @@ def mxfp8_dequantize_host( Args: input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear. Returns: torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ + if sf_swizzle_layout is None: + sf_swizzle_layout = ( + SfLayout.layout_128x4 if is_sf_swizzled_layout else SfLayout.layout_linear + ) return get_mxfp8_quantization_sm100_module().mxfp8_dequantize_host_sm100( input, scale_tensor, - is_sf_swizzled_layout, + sf_swizzle_layout, ) diff --git a/flashinfer/tllm_enums.py b/flashinfer/tllm_enums.py new file mode 100644 index 0000000000..034b1acd11 --- /dev/null +++ b/flashinfer/tllm_enums.py @@ -0,0 +1,150 @@ +from enum import IntEnum +import torch +from typing import Optional + + +# The type of method in top-K routing, for use in torch custom op +# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # Qwen3: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK only (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6 + + +# Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h +class ActivationType(IntEnum): + Gelu = 0 + Relu = 1 + Silu = 2 + Swiglu = 3 + Geglu = 4 + SwigluBias = 5 + Relu2 = 6 + Identity = 7 + InvalidType = 8 + + +class DtypeTrtllmGen(IntEnum): + def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid): + value = ( + (block_format_bit << 24) + | (signed_bit << 20) + | (integer_bit << 16) + | (num_bits << 8) + | uid + ) + obj = int.__new__(cls, value) + obj._value_ = value + return obj + + # keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h + Bfloat16 = (0, 1, 0, 16, 0) + Bool = (0, 0, 1, 1, 1) + E2m1 = (1, 1, 0, 4, 2) + E2m3 = (1, 1, 0, 6, 3) + E3m2 = (1, 1, 0, 6, 4) + E4m3 = (0, 1, 0, 8, 5) + E5m2 = (0, 1, 0, 8, 6) + Fp16 = (0, 1, 0, 16, 7) + Fp32 = (0, 1, 0, 32, 8) + Int8 = (0, 1, 1, 8, 9) + Int32 = (0, 1, 1, 32, 10) + Int64 = (0, 1, 1, 64, 11) + MxE2m1 = (1, 1, 0, 4, 12) + MxE4m3 = (1, 1, 0, 8, 13) + MxInt4 = (1, 1, 1, 4, 14) + UE8m0 = (0, 0, 0, 8, 15) + UInt8 = (0, 0, 1, 8, 16) + UInt16 = (0, 0, 1, 16, 17) + UInt32 = (0, 0, 1, 32, 18) + UInt64 = (0, 0, 1, 64, 19) + UInt128 = (0, 0, 1, 128, 20) + Void = (0, 1, 0, 0, 21) + + +def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: + if dtype in [ + DtypeTrtllmGen.E2m1, + DtypeTrtllmGen.MxE2m1, + DtypeTrtllmGen.MxE4m3, + DtypeTrtllmGen.MxInt4, + ]: + return True + else: + return False + + +def deduce_trtllm_gen_tensor_dtype( + x: torch.Tensor, scale: Optional[torch.Tensor] +) -> DtypeTrtllmGen: + hidden_size = x.shape[-1] + if x.dtype == torch.uint8: # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + hidden_size *= 2 + if x.dtype == torch.bfloat16: + dtype = DtypeTrtllmGen.Bfloat16 + elif x.dtype == torch.float8_e4m3fn: + dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 + elif ( + x.dtype == torch.uint8 + ): # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + assert scale is not None, "Scale tensor must be provided for float4x2 input" + if scale.shape[-1] == hidden_size // 16: + dtype = DtypeTrtllmGen.E2m1 + else: + dtype = DtypeTrtllmGen.MxE2m1 + else: + raise ValueError("Unsupported trtllm-gen input tensor.") + return dtype + + +# Please keep the values in sync with include/flashinfer/fp4_layout.cuh +class SfLayout(IntEnum): + """ + Layout of scale factors for quantization. + """ + + layout_128x4 = 0 + layout_8x4 = 1 + layout_linear = 2 + + +# See MatrixLayout from include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +class WeightLayout(IntEnum): + # K-major layout (default). [Mn, K] + MajorK = 0 + # M-major for A and N-major for B. [K, Mn] + MajorMn = 1 + # Layout is blocked along the K dimension. [K / blockK, Mn, blockK] + # where blockK is fixed at 128B + BlockMajorK = 2 + + +# The type of gated activation function +# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h +class GatedActType(IntEnum): + # SwiGlu + SwiGlu = 0 + # GeGlu + GeGlu = 1 + + +# The type of FP8 quantization +# Please keep this in sync with the counterpart defined in trtllm_fused_moe_kernel_launcher.cu +class Fp8QuantizationType(IntEnum): + # No FP8 quantization + NoneFp8 = 0 + # DeepSeek FP8 + DeepSeekFp8 = 1 + # MxFp8 x MxFp8 + MxFp8 = 2 diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py index e7beb51343..3aea77affb 100644 --- a/flashinfer/trtllm_low_latency_gemm.py +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -90,8 +90,6 @@ def forward( global_scale, out, ) = inputs - if tactic < 0: - return out m = a.shape[0] n = b.shape[1] k = a.shape[1] diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h index d77b936476..727e964fc2 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,11 +25,15 @@ namespace gemm { enum class AllReduceAlgo : uint32_t { // Does not apply all-reduce. None = 0, - // Reduction occurs at L2 cache; pulls N-1 partial outputs from peer devices. Result is - // non-deterministic. Potentially lower latency at cost of higher memory traffic. + // Reduction occurs at L2 cache; pulls N-1 partial outputs from peer devices. + // Result is + // non-deterministic. Potentially lower latency at cost of higher memory + // traffic. OneShot, - // Reduction occurs at switch; pulls 1/Nth of the output from switch (reduce-scatter phase) and - // store to multicast mem (all-gather phase). Result is deterministic. Lower memory traffic at + // Reduction occurs at switch; pulls 1/Nth of the output from switch + // (reduce-scatter phase) and + // store to multicast mem (all-gather phase). Result is deterministic. Lower + // memory traffic at // cost of potentially higher latency. TwoShot, }; @@ -41,7 +45,8 @@ enum class MatrixLayout { MajorK = 0, // M-major for A and N-major for B. [K, Mn] MajorMn, - // Layout is blocked along the K dimension as seen in the diagram below. [K / blockK, Mn, blockK] + // Layout is blocked along the K dimension as seen in the diagram below. [K / + // blockK, Mn, blockK] // where blockK is fixed at 128B // // ├────────────── K ──────────────┤ @@ -64,10 +69,12 @@ enum class SplitK : uint32_t { // No split-k is needed. I.e. mNumSlicesForSplitK == 1. None = 0, // CTAs computing one MN tile save partial results to global memory. - // Then wait on the barrier and the last CTA in the group loads partial results from gmem, + // Then wait on the barrier and the last CTA in the group loads partial + // results from gmem, // sums them up and writes back to gmem. Gmem, - // All CTAs in one CGA calculate partial sums. Then send the results to the smem of + // All CTAs in one CGA calculate partial sums. Then send the results to the + // smem of // the last CTA in the CGA, which sums them up and writes to gmem. Dsmem, }; @@ -87,12 +94,40 @@ enum class BiasType : uint32_t { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Type of the element-wise activation to apply after the Gemm +enum class EltwiseActType { + None = 0, + // Gelu is defined as the following operation: + // act = x0 * phi(x0) + // where x0 is the output of the Gemm + // phi is the CDF of standard normal distribution approximated by + // phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))) + Gelu, + // Relu2 (also known as squared Relu) is defined as the following operation: + // act = relu(x0) ^ 2 + // where x0 is the output of the Gemm. + Relu2, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + enum class TileScheduler { // Static scheduler (Non-persistent). Static = 0, - // Dynamic persistent scheduler. This is either based on an atomically incremented global work id - // prior to SM100 archs, or the HW supported work id scheduler based on UGETNEXTWORKID for SM100+. + // Dynamic persistent scheduler for SM100+. Persistent, + // Static persistent scheduler. Launches a fixed grid size based on the number + // of SMs and uses + // the underlying PersistentTileSchedulerSm90 for static work distribution. + // Each CTA iterates + // through tiles and exits the loop by setting is_valid_tile to false when + // work is exhausted. + StaticPersistent, + // Dynamic persistent scheduler for SM90+ using atomicAdd on a global counter. + // Uses DynamicPersistentPipelinedTileSchedulerSm90 which enables + // work-stealing among CTAs + // by atomically fetching work tile indices from a global counter. + PersistentSm90, }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -140,6 +175,28 @@ BIAS_TYPE_FUNCTION(Mn) //////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper function to check if a scheduler is persistent. +inline bool isPersistentScheduler(TileScheduler scheduler) { + return scheduler == TileScheduler::Persistent || scheduler == TileScheduler::StaticPersistent || + scheduler == TileScheduler::PersistentSm90; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helper function to check if CTA rasterization order is compatible with clean +// early exit of the kernel. Clean early exit requires CTA indices to increase +// monotonically along the batch dimension, so when a CTA exits the kernel +// early, it exits with all valid tiles already done. Zigzag or batch-major +// patterns are NOT compatible because they may cause valid tiles to be skipped +// when exiting early. +inline bool supportsCleanEarlyExit(CtaSwizzleType swizzleType, bool batchM, + TileScheduler /* scheduler */) { + return (batchM ? (swizzleType == CtaSwizzleType::RasterizeAlongN) + : (swizzleType == CtaSwizzleType::RasterizeAlongM)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace gemm diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h index ffea1cf4f4..d9c3e8284c 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,24 +18,19 @@ #include #include +#include #include "GemmOptions.h" #include "KernelParams.h" #include "trtllm/gen/CudaKernelLauncher.h" #ifdef TLLM_GEN_EXPORT_INTERFACE +#ifdef TLLM_GEN_EXPORT_FLASHINFER #include "flashinferMetaInfo.h" -#endif // TLLM_GEN_EXPORT_INTERFACE - -#ifdef TLLM_GEN_GEMM_CUBIN_PATH -static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH); #else -static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling"); -#endif - -namespace flashinfer::trtllm_cubin_loader { -std::string getCubin(const std::string& kernelName, const std::string& sha256); -} // namespace flashinfer::trtllm_cubin_loader +#include "KernelMetaInfo.h" +#endif // TLLM_GEN_EXPORT_FLASHINFER +#endif // TLLM_GEN_EXPORT_INTERFACE namespace gemm { @@ -52,13 +47,19 @@ struct GemmData { // The M dimension. // It is the total number of tokens if A is the activation matrix. // It is the total number of output channels if A is the weight matrix. + // ValidM/N/K by default assumes to be full range of M/N/K respectively. If + // we pad M/N/K due to alignment of other constraints, then we can specify + // ValidM/N/K to indicate the valid range. int32_t mM{0}; + int32_t mValidM{0}; // The N dimension. // It is the total number of tokens if B is the activation matrix. // It is the total number of output channels if B is the weight matrix. int32_t mN{0}; + int32_t mValidN{0}; // The K dimension. It is the hidden dimension of the input matrices. int32_t mK{0}; + int32_t mValidK{0}; // The rank id of the current device in the multi-gpu space. int32_t mRank{0}; // The number of devices in tensor-parallel group. @@ -68,11 +69,12 @@ struct GemmData { struct InputBuffers { // The matrix A. The data type is controlled by options.mDtypeA. // - // When layoutA is MatrixLayout::MajorK, the shape is [M, K]. - // When LayoutA is MatrixLayout::MajorMn, the shape is [K, M]. - // When LayoutA is MatrixLayout::BlockMajorK, the shape is [K / blockK, M, blockK] where blockK - // is 128B. - // The rightmost dimension is contiguous in memory. + // If S is the sparsity ratio (1 for dense, 2 for sparse): + // When layoutA is MatrixLayout::MajorK, the shape is [M, K / S]. + // When LayoutA is MatrixLayout::MajorMn, the shape is [K, M] (sparsity not + // supported) When LayoutA is MatrixLayout::BlockMajorK, the shape is [K / S + // / blockK, M, blockK] where blockK is 128B. The rightmost dimension is + // contiguous in memory. void const* mPtrA{nullptr}; // The block scaling factors to dequantize A. @@ -82,14 +84,16 @@ struct GemmData { // Otherwise, shape is [M / 128, K / 128]. // The rightmost dimension is contiguous in memory. // - // If DeepSeek FP8 recipe is not used, but for MxFp{4,8} and NvFp4 formats: + // If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 + // formats: // The layout of scaling factors for A is always R128c4 // M must be a multiple of 128. // K must be a multiple of 64. - // The "logical" shape is: [M, K / 16]. - // The R128c4 layout is: [M / 128, K / 16 / 4, 512]. - // The shape we use for TMA is: [M / 128, K / 16 / 4, 2, 256]. - // Dtype is Dtype::E4m3. + // The "logical" shape is: [M, K / P], where P is the scaling block size. + // The R128c4 layout is: [M / 128, K / P / 4, 512]. + // The shape we use for TMA is: [M / 128, K / P / 4, 2, 256]. + // Dtype is E4m3 for NvFp4, UE8m0 for MxFp{4,8} formats, Bfloat16 for + // MxInt4. // // Otherwise should be set to nullptr. void const* mPtrSfA{nullptr}; @@ -97,21 +101,39 @@ struct GemmData { // The per-token scaling factors from scale A. // // This is used for either: - // * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is - // Dtype::Float32 - // * When the routing scales are applied to the input activations (only when output is not - // transposed). The dtype is Dtype::Bfloat16 + // * Per-token scaling factor quantization schemes, such as MetaFP8. The + // dtype is Dtype::Float32 + // * When the routing scales are applied to the input activations (only + // when output is not transposed). The dtype is Dtype::Bfloat16 // // The shape is [M] void const* mPtrPerTokenSfA{nullptr}; + // The sparsity information of A, if structured sparsity is used. + // + // When sparsityA is Any_2_4: + // 2 elements are non-zero in any chunk of 4 elements. + // A 4-bit index indicates the position of the non-zero elements. + // The shape in Uint8 is: [M, K / 8] (two 4-bit indices packed into one + // UInt8) + // + // When sparsityA is Pairwise_4_8: + // 4 elements are non-zero in any chunk of 8 elements. + // The zero and non-zero elements are grouped in pairs. + // A 4-bit index indicates the position of the non-zero pairs. + // The shape in Uint8 is: [M, K / 16] (two 4-bit indices packed into one + // UInt8) + // + // If sparsityA is Dense, this should be set to nullptr. + void const* mPtrSparsityInfoA{nullptr}; + // The matrix B. The data type is controlled by options.mDtypeB. // // When layoutB is MatrixLayout::MajorK, the shape is [N, K]. // When layoutB is MatrixLayout::MajorMn, the shape is [K, N]. - // When layoutB is MatrixLayout::BlockMajorK, the shape is [K / blockK, N, blockK] where blockK - // is 128B. - // The rightmost dimension is contiguous in memory. + // When layoutB is MatrixLayout::BlockMajorK, the shape is [K / blockK, N, + // blockK] where blockK is 128B. The rightmost dimension is contiguous in + // memory. void const* mPtrB{nullptr}; // The scaling factors to dequantize B. @@ -125,17 +147,18 @@ struct GemmData { // If the layout is R128c4, // N must be a multiple of 128. // K must be a multiple of 64. - // The R128c4 layout is: [N / 128, K / 16 / 4, 512] - // The shape we use for TMA is: [N / 128, K / 16 / 4, 2, 256] + // The R128c4 layout is: [N / 128, K / P / 4, 512], where P is the + // scaling block size. The shape we use for TMA is: [N / 128, K / P / 4, + // 2, 256] // // If the layout is R8c4, // N must be a multiple of 8. // K must be a multiple of 64. - // The R8c4 layout is: [N / 8, K / 16 / 4, 32] - // The shape we use for TMA is: [N / 8, K / 16 / 4 / repeats, repeats * 32] - // where repeats = min(tileK / 16 / 4, 8) + // The R8c4 layout is: [N / 8, K / P / 4, 32], where P is the scaling + // block size. The shape we use for TMA is: [N / 8, K / P / 4 / repeats, + // repeats * 32] where repeats = min(tileK / P / 4, 8) // - // Dtype is Dtype::E4m3. + // Dtype is E4m3 for NvFp4, UE8m0 for MxFp{4,8} formats. // // Otherwise should be set to nullptr. void const* mPtrSfB{nullptr}; @@ -143,10 +166,10 @@ struct GemmData { // The per-token scaling factors from scale B. // // This is used for either: - // * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is - // Dtype::Float32 - // * When the routing scales are applied to the input activations (only when output is - // transposed). The dtype is Dtype::Bfloat16 + // * Per-token scaling factor quantization schemes, such as MetaFP8. The + // dtype is Dtype::Float32 + // * When the routing scales are applied to the input activations (only + // when output is transposed). The dtype is Dtype::Bfloat16 // // The shape is [N] void const* mPtrPerTokenSfB{nullptr}; @@ -155,7 +178,8 @@ struct GemmData { // The bias is applied before applying the global scaling factor. I.e. // C' = (A * B + bias') * scaleC // scaleC = dequantA * dequantB * quantC - // Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias. + // Thus, the bias' = bias / (dequantA * dequantB), where the bias is the + // original bias. // // if BiasType is N, the shape is [N]. // The bias is broadcasted along the M dimension. @@ -166,20 +190,26 @@ struct GemmData { // The dtype is float32. void const* mPtrBias{nullptr}; - // The output tensor scaling factor for Fp8 (not DeepSeek FP8) and NvFp4 quantization. - // TensorRT-LLM API requires a scaling factor on the device. + // The output tensor scaling factor for Fp8 (not DeepSeek FP8) and NvFp4 + // quantization. TensorRT-LLM API requires a scaling factor on the device. // scaleC = dequantA * dequantB * quantC, // where dequantA is global dequantization scaling factor of A - // if dtypeA is FP8, it transforms the range from [-448, 448] to [-amaxA, amaxA] - // if dtypeA is NvFp4, it transforms the range from [-448 * 6, 448 * 6] to [-amaxA, amaxA], - // otherwise it is 1. + // if dtypeA is FP8, it transforms the range from [-448, 448] to [-amaxA, + // amaxA] if dtypeA is NvFp4, it transforms the range from [-448 * 6, 448 + // * 6] to [-amaxA, amaxA], otherwise it is 1. // dequantB is defined similarly to dequantA. // quantC is the quantization scaling factor of C. - // if dtypeC is FP8, it transforms the range from [-amaxC, amaxC] to [-448, 448] - // if dtypeC is NvFp4, it transforms the range from [-amaxC, amaxC] to [-448 * 6, 448 * 6], - // otherwise it is 1. + // if dtypeC is FP8, it transforms the range from [-amaxC, amaxC] to + // [-448, 448] if dtypeC is NvFp4, it transforms the range from [-amaxC, + // amaxC] to [-448 * 6, 448 * 6], otherwise it is 1. // Shape is [1]. void* mPtrScaleC{nullptr}; + + // The pre-activation scaling factor (typically dequantA * dequantB) for + // non-linear activation. Only used when non-linear activation is applied + // (e.g., GELU, Relu2). When used, scaleC should be quantScaleC only, and + // this scale is applied before the activation. Shape is [1]. + void* mPtrScaleAct{nullptr}; }; struct OutputBuffers { @@ -190,13 +220,13 @@ struct GemmData { // Elements in a given row are stored contiguously in memory (row-major). void* mPtrC{nullptr}; - // Pointer for output with multicast mapping. It is used by the "reduce" op (LDGMC.ADD) of the - // two-shot reduce-scatter phase. Otherwise, it should be set to nullptr. - // The shape is [M, N] and the dtype is float. + // Pointer for output with multicast mapping. It is used by the "reduce" op + // (LDGMC.ADD) of the two-shot reduce-scatter phase. Otherwise, it should be + // set to nullptr. The shape is [M, N] and the dtype is float. void* mPtrMultiMemC{nullptr}; - // The scaling factors calculated when quantizing C, for MxFp{4,8} and NvFp4 formats, also - // used for the DeepSeek FP8 recipe. + // The scaling factors calculated when quantizing C, for MxFp{4,8} and NvFp4 + // formats, also used for the DeepSeek FP8 recipe. // // For DeepSeek FP8 recipe: // If transposeMmaOutput is false, shape is [N / 128, M]. @@ -204,10 +234,10 @@ struct GemmData { // The rightmost dimension is contiguous in memory. // // For MxFp{4,8} and NvFp4 formats: - // If transposeMmaOutput is false, shape is [M, N / 16]. - // Otherwise, shape is [N, M / 16]. - // The layout is controlled by options.mSfLayoutC (either R128c4 or R8c4). - // The layout (R128c4 and R8c4) is the same as explained in mPtrSfB. + // If transposeMmaOutput is false, shape is [M, N / P], where P is the + // scaling block size. Otherwise, shape is [N, M / P]. The layout is + // controlled by options.mSfLayoutC (either R128c4 or R8c4). The layout + // (R128c4 and R8c4) is the same as explained in mPtrSfB. // // Otherwise should be set to nullptr. void* mPtrSfC{nullptr}; @@ -216,13 +246,14 @@ struct GemmData { struct AllReduceBuffers { // The barriers in global memory. // - // The kernel arrives at (with release ordering) the multicast mapping of the barrier to - // broadcast amongst peer devices. It then waits (with acquire ordering) for the unicast mapping - // of the barrier. + // The kernel arrives at (with release ordering) the multicast mapping of + // the barrier to broadcast amongst peer devices. It then waits (with + // acquire ordering) for the unicast mapping of the barrier. // - // Flags in global memory that sync on "entrance" of reduce-scatter phase in two-shot - // all-reduce. The shape is [numTilesM * numTilesN] and the dtype is uint32_t. The pointer to - // the unicast memory created with IpcNvlsHandle. Must be set to 0 before the kernel launch. + // Flags in global memory that sync on "entrance" of reduce-scatter phase in + // two-shot all-reduce. The shape is [numTilesM * numTilesN] and the dtype + // is uint32_t. The pointer to the unicast memory created with + // IpcNvlsHandle. Must be set to 0 before the kernel launch. void* mPtrTileBars{nullptr}; // The shape is [numTilesM * numTilesN] and the dtype is uint32_t. @@ -256,355 +287,381 @@ class GemmInterface { public: using ModuleCache = std::unordered_map>; - GemmInterface() {} - - // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. - // Provided config must be validated with isValidConfig before the call. - int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream, - int32_t multiProcessorCount, bool usePdl = true, - std::optional> moduleCache = std::nullopt) const; - - // Initializes the buffers before the world sync. Must be called before run. - int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, - void* cudaStream) const; - - // Returns the size of the workspace buffers in bytes - size_t getWorkspaceSizeInBytes(GemmConfig const& config, GemmData const& data) const; - - // Returns the list of all available cubin configurations - GemmConfig const* getGemmConfigs() const; - - // Returns the number of available cubin configurations - size_t getNumGemmConfigs() const; - - // Returns true if the configuration of the cubin can be executed for the given params. - bool isValidConfig(GemmConfig const& config, GemmData const& data) const; + //////////////////////////////////////////////////////////////////////////////////////////////////// - private: - // Aligns the pointer to the alignment - template - inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; + GemmInterface(int32_t rankId = 0, bool exportsCubin = false, int32_t numRotations = 1) + : mRankId{rankId}, mExportsCubin{exportsCubin}, mNumRotations{numRotations} {} - // Returns the number of tiles and number of CTAs for Z dimension. - std::tuple getGridSize(int32_t M, int32_t N, int32_t tileM, - int32_t tileN, int32_t clusterDimX, - int32_t clusterDimY, - int32_t numSlicesForSplitK) const; + //////////////////////////////////////////////////////////////////////////////////////////////////// - // Creates GemmOptions from kernel and data. - GemmOptions getOptionsFromConfigAndData(GemmConfig const& config, GemmData const& data) const; +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generates and compiles the kernel using either nvcc or nvrtc. + GemmConfig generateAndCompileKernel(GemmConfig const& gemmConfig) const; +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the size of the workspace buffers in bytes - std::vector getWorkspaceSizesInBytes(GemmConfig const& config, - GemmData const& data) const; - - // Returns the size padded to the alignment - size_t getSizePaddedToAlignment(size_t size, size_t alignment) const; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline Dtype* GemmInterface::alignPtr(Dtype* ptr, int64_t alignment) const { - assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); - return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & - ~(alignment - 1)); -} + size_t getWorkspaceSizeInBytes(GemmConfig const& config, GemmData const& data) const { + auto workspaceSizes = getWorkspaceSizesInBytes(config, data); + auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); + // Additional 1023 bytes to align the pointer to 1024 + return size > 0 ? size + 1023 : 0; + } -//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// -GemmConfig const* GemmInterface::getGemmConfigs() const { + // Returns the list of all available cubin configurations + GemmConfig const* getGemmConfigs() const { #ifdef TLLM_GEN_EXPORT_INTERFACE - return tensorrt_llm::kernels::tllmGenGemmList; + return tensorrt_llm::kernels::tllmGenGemmList; #else - return nullptr; + return nullptr; #endif -} + } -//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// -size_t GemmInterface::getNumGemmConfigs() const { + // Returns the number of available cubin configurations + size_t getNumGemmConfigs() const { #ifdef TLLM_GEN_EXPORT_INTERFACE - return sizeof(tensorrt_llm::kernels::tllmGenGemmList) / - sizeof(tensorrt_llm::kernels::tllmGenGemmList[0]); + return sizeof(tensorrt_llm::kernels::tllmGenGemmList) / + sizeof(tensorrt_llm::kernels::tllmGenGemmList[0]); #else - return 0; + return 0; #endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// + } -std::tuple GemmInterface::getGridSize(int32_t M, int32_t N, - int32_t tileM, int32_t tileN, - int32_t clusterDimX, - int32_t clusterDimY, - int32_t numSlicesForSplitK) const { - // The number of tiles in the M dimension. - auto numTilesM = gemm::divUpMul(gemm::divUp(M, tileM), clusterDimX); - // The number of tiles in the N dimension. - auto numTilesN = gemm::divUpMul(gemm::divUp(N, tileN), clusterDimY); - return std::make_tuple(numTilesM, numTilesN, numSlicesForSplitK); -} + //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Creates GemmOptions from kernel and data. + GemmOptions getOptionsFromConfigAndData(GemmConfig const& config, GemmData const& data) const { + // Create options from config and data. + GemmOptions options; + options = config.mOptions; + options.mM = data.mProblemDimensions.mM; + options.mN = data.mProblemDimensions.mN; + options.mK = data.mProblemDimensions.mK; + options.mValidM = data.mProblemDimensions.mValidM; + options.mValidN = data.mProblemDimensions.mValidN; + options.mValidK = data.mProblemDimensions.mValidK; + return options; + } -GemmOptions GemmInterface::getOptionsFromConfigAndData(GemmConfig const& config, - GemmData const& data) const { - // Create options from config and data. - GemmOptions options; - options = config.mOptions; - options.mM = data.mProblemDimensions.mM; - options.mN = data.mProblemDimensions.mN; - options.mK = data.mProblemDimensions.mK; - return options; -} + //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns true if the configuration of the cubin can be executed for the + // given params. + bool isValidConfig(GemmConfig const& config, GemmData const& data) const { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); -size_t GemmInterface::getSizePaddedToAlignment(size_t size, size_t alignment) const { - assert((alignment & (alignment - 1)) == 0); - return (size + alignment - 1) & ~(alignment - 1); -} + // Check options without modifications. + return checkAndUpdateGemmOptions(options, config.mSm, data.mProblemDimensions.mWorldSize, + /* updateOptions */ false); + } -//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// -size_t GemmInterface::getWorkspaceSizeInBytes(GemmConfig const& config, - GemmData const& data) const { - auto workspaceSizes = getWorkspaceSizesInBytes(config, data); - auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); - // Additional 1023 bytes to align the pointer to 1024 - return size > 0 ? size + 1023 : 0; -} + // If config.mData is specified, it launches the cubin from the provided + // config. Otherwise, it generates and compiles the kernel using either nvcc + // or nvrtc. Launch the cubin from the provided config. It calls all necessary + // memsets for internal buffers. Provided config must be validated with + // isValidConfig before the call. + int32_t run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream, + int32_t multiProcessorCount, bool usePdl = true, + std::optional> moduleCache = std::nullopt) const { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); -//////////////////////////////////////////////////////////////////////////////////////////////////// + auto workspaceSizes = getWorkspaceSizesInBytes(config, data); + void* dSplitKSlices{nullptr}; + void* dPtrSplitKCompletionBars{nullptr}; -std::vector GemmInterface::getWorkspaceSizesInBytes(GemmConfig const& config, - GemmData const& data) const { - // Get options from config. - auto& options = config.mOptions; - - // Get the number of tiles and cluster dimension Z. - auto [numTilesM, numTilesN, gridDimZ] = getGridSize( - data.mProblemDimensions.mM, data.mProblemDimensions.mN, options.mTileM, options.mTileN, - options.mClusterDimX, options.mClusterDimY, options.mNumSlicesForSplitK); - - std::vector workspaceSizes; - - int64_t numBytesSplitK{0}, numBytesSplitKBars{0}; - if (doesSplitKUseGmem(options.mSplitK)) { - // The number of elements for intermediate split-k buffer that contains K slices padded to - // TileM/TileN sizes to avoid OOB accesses during the reduction. - // FIXME: Split-K has excessive memory traffic when combined with slice-K. - // Currently, data for all slice-K slices is sent, even though the slice-K reduction - // has already been performed. - // This should be optimized to send data for only one reduced slice. - auto numEltsSplitK = options.mNumSlicesForSplitK * numTilesM * numTilesN * options.mTileM * - options.mTileN * options.mNumSlicesForSliceK; - - // The number of bytes for intermediate split-k buffer that contains K slices. - numBytesSplitK = numEltsSplitK * tg::dtypeGetNumBits(tg::Dtype::Fp32) / /* bits */ 8; - // The number of bytes for the split-k completion barriers. - numBytesSplitKBars = numTilesM * numTilesN * sizeof(uint32_t); - // Two epilogue warps do GMEM split-k in DS GEMM. - if (options.mUseDeepSeekFp8) { - numBytesSplitKBars *= 2; + // Set the completion barriers to 0 if needed. + if (doesSplitKUseGmem(options.mSplitK)) { + dSplitKSlices = alignPtr(reinterpret_cast(workspace), 1024); + dPtrSplitKCompletionBars = reinterpret_cast( + alignPtr(reinterpret_cast(dSplitKSlices) + workspaceSizes[0], 1024)); + auto err = cudaMemsetAsync((void*)dPtrSplitKCompletionBars, 0x00, workspaceSizes[1], + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 1; + } } - // TODO: do we need to pad to 1024? - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesSplitK, 1024)); - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesSplitKBars, 1024)); - } + // Determine if the scheduler requires a fixed grid dimension. + bool const isFixedGridDim = (options.mTileScheduler == gemm::TileScheduler::StaticPersistent || + options.mTileScheduler == gemm::TileScheduler::PersistentSm90); + // Get the number of tiles and number of CTAs for Z dimension. + auto [gridDimX, gridDimY, gridDimZ] = + isFixedGridDim + ? getFixedGridSize(options.mClusterDimX, options.mClusterDimY, + options.mNumSlicesForSplitK, multiProcessorCount) + : getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, + options.mClusterDimX, options.mClusterDimY, options.mNumSlicesForSplitK); + + // Create kernel params. + auto kernelParams = gemm::KernelParamsSetup::setKernelParams( + options, data.mInputBuffers.mPtrA, data.mInputBuffers.mPtrSfA, + data.mInputBuffers.mPtrPerTokenSfA, data.mInputBuffers.mPtrB, data.mInputBuffers.mPtrSfB, + data.mInputBuffers.mPtrPerTokenSfB, data.mInputBuffers.mPtrSparsityInfoA, + data.mInputBuffers.mPtrBias, data.mOutputBuffers.mPtrC, data.mOutputBuffers.mPtrSfC, + data.mOutputBuffers.mPtrMultiMemC, (float*)data.mInputBuffers.mPtrScaleC, + (float*)data.mInputBuffers.mPtrScaleAct, dSplitKSlices, data.mAllReduceBuffers.mPtrTileBars, + data.mAllReduceBuffers.mPtrMultiMemTileBars, data.mAllReduceBuffers.mPtrCompletionBars, + data.mAllReduceBuffers.mPtrMultiMemCompletionBars, dPtrSplitKCompletionBars, + /* dPtrNumNonExitingCtas */ nullptr, data.mProblemDimensions.mRank, + data.mProblemDimensions.mWorldSize); + // The size of the grid. + std::vector grid{gridDimX, gridDimY, gridDimZ}; + + // When split-k is enabled and to guarantee the forward progress, we must + // ensure that the number of tiles is less than number of SMs. This way, at + // least one CTA in the grid can make forward. + if (doesSplitKUseGmem(options.mSplitK)) { + if (grid[0] * grid[1] >= multiProcessorCount) { + // The number of MN tiles in Split-K (grid[0] * grid[1]) must be less + // than the number of SMs. + return 2; + } + } - return workspaceSizes; -} + GemmConfig gemmConfig = config; + +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generate and compile the kernel if data is not provided. + if (config.mData == nullptr) { + gemmConfig = generateAndCompileKernel(gemmConfig); + TLLM_CHECK_ERROR(gemmConfig.mCudaRunner != nullptr, "CudaRunner is not set"); + gemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid, + /*cluster*/ {}, + /*instanceId*/ gemmConfig.mInstanceIdx); + return 0; + } +#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Load from cubin if data is provided. + CUmodule cuModule; + CUfunction cuFunction; + + if (moduleCache.has_value()) { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context, so the context is + // included in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or + // converting it to a string in decimal representation. + std::string const ctxName = + std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(gemmConfig.mFunctionName); + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Use cache if module is found, otherwise load and insert into cache + if (module != moduleCacheRef.end()) { + cuFunction = std::get<1>(module->second); + } else { + loadCubinData(&cuModule, gemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, gemmConfig.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } else { + loadCubinData(&cuModule, gemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, gemmConfig.mFunctionName); + } -bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data) const { - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); + // Prepare the grid/block. + dim3 block3{static_cast(gemmConfig.mNumThreadsPerCTA), static_cast(1), + static_cast(1)}; + dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), + (grid.size() > 1 ? static_cast(grid[1]) : 1u), + (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; + // Prepare the cluster size. + dim3 cluster3{static_cast(options.mClusterDimX), + static_cast(options.mClusterDimY), + static_cast(options.mClusterDimZ)}; + + // Whether PDL can safely be enabled + const bool pdlSafe = gemmConfig.mOptions.mGridWaitForPrimaryEarlyExit || + gemmConfig.mOptions.mGridWaitForPrimaryA || + gemmConfig.mOptions.mGridWaitForPrimaryB; + + // Run the kernel. + auto result = + trtllm::gen::launchKernel((void*)&kernelParams, cudaStream, gemmConfig.mSharedMemSize, + cuFunction, block3, grid3, cluster3, usePdl && pdlSafe); + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) { + cuModuleUnload(cuModule); + } + if (result != CUDA_SUCCESS) { + return result; + } - // Is Blackwell? - bool isBlackwell = isSmVersionBlackwell(config.mSm); + return 0; + } - // Check options without modifications. - return checkAndUpdateGemmOptions(options, isBlackwell, data.mProblemDimensions.mWorldSize, - /* updateOptions */ false); -} + //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Initializes the buffers before the world sync. Must be called before run. + int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, + void* cudaStream) const { + if (data.mProblemDimensions.mWorldSize > 1) { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); + if (options.mAllReduceAlgo == gemm::AllReduceAlgo::OneShot) { + // The size of each element of C in bits. + int64_t const numBitsPerEltC = options.mAllReduceAlgo == gemm::AllReduceAlgo::TwoShot + ? tg::dtypeGetNumBits(options.mDtypeAcc) + : tg::dtypeGetNumBits(options.mDtypeC); + // The number of bytes for C. + int64_t const numBytesC = data.mProblemDimensions.mM * data.mProblemDimensions.mN * + numBitsPerEltC / + /*bits*/ 8; + // Reset the output buffer as one-shot uses UTMAREDG at multicast memory + // for reduction. + auto err = cudaMemsetAsync(data.mOutputBuffers.mPtrC, 0x00, numBytesC, + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 1; + } + } -int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, - void* cudaStream, int32_t multiProcessorCount, bool usePdl, - std::optional> moduleCache) const { - // Might be used. - (void)usePdl; - (void)moduleCache; - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); - - auto workspaceSizes = getWorkspaceSizesInBytes(config, data); - void* dSplitKSlices{nullptr}; - void* dPtrSplitKCompletionBars{nullptr}; - - // Set the completion barriers to 0 if needed. - if (doesSplitKUseGmem(options.mSplitK)) { - dSplitKSlices = alignPtr(reinterpret_cast(workspace), 1024); - dPtrSplitKCompletionBars = reinterpret_cast( - alignPtr(reinterpret_cast(dSplitKSlices) + workspaceSizes[0], 1024)); - auto err = cudaMemsetAsync((void*)dPtrSplitKCompletionBars, 0x00, workspaceSizes[1], - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 1; + // Get the number of tiles and number of CTAs for Z dimension. + auto [numTilesM, numTilesN, gridDimZ] = + getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX, + options.mClusterDimY, options.mNumSlicesForSplitK); + // The number of bytes for the tile barriers. + int32_t numBytesTileBars = numTilesM * numTilesN * sizeof(uint32_t); + // Sanitize system barriers. + auto err = cudaMemsetAsync((void*)data.mAllReduceBuffers.mPtrTileBars, 0x00, numBytesTileBars, + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 2; + } + err = cudaMemsetAsync((void*)data.mAllReduceBuffers.mPtrCompletionBars, 0x00, + numBytesTileBars, reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 3; + } } + return 0; } - // Get the number of tiles and number of CTAs for Z dimension. - auto [numTilesM, numTilesN, gridDimZ] = - getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX, - options.mClusterDimY, options.mNumSlicesForSplitK); - - // Create kernel params. - auto kernelParams = gemm::KernelParamsSetup::setKernelParams( - options, data.mInputBuffers.mPtrA, data.mInputBuffers.mPtrSfA, - data.mInputBuffers.mPtrPerTokenSfA, data.mInputBuffers.mPtrB, data.mInputBuffers.mPtrSfB, - data.mInputBuffers.mPtrPerTokenSfB, data.mInputBuffers.mPtrBias, data.mOutputBuffers.mPtrC, - data.mOutputBuffers.mPtrSfC, data.mOutputBuffers.mPtrMultiMemC, - (float*)data.mInputBuffers.mPtrScaleC, dSplitKSlices, data.mAllReduceBuffers.mPtrTileBars, - data.mAllReduceBuffers.mPtrMultiMemTileBars, data.mAllReduceBuffers.mPtrCompletionBars, - data.mAllReduceBuffers.mPtrMultiMemCompletionBars, dPtrSplitKCompletionBars, - /* dPtrNumNonExitingCtas */ nullptr, data.mProblemDimensions.mRank, - data.mProblemDimensions.mWorldSize); - // The size of the grid. - std::vector grid{numTilesM, numTilesN, gridDimZ}; - - // When split-k is enabled and to guarantee the forward progress, we must ensure that the number - // of tiles is less than number of SMs. This way, at least one CTA in the grid can make forward. - if (doesSplitKUseGmem(options.mSplitK)) { - if (grid[0] * grid[1] >= multiProcessorCount) { - // The number of MN tiles in Split-K (grid[0] * grid[1]) must be less than the number of SMs. - return 2; - } - } + //////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef TLLM_GEN_EXPORT_INTERFACE - CUmodule cuModule; - CUfunction cuFunction; - - auto fiModuleLoadData = [&](CUmodule* module) { - const std::string sha256 = config.mHash ? config.mHash : ""; - std::string fname_cubin = config.mFunctionName; - if (!fname_cubin.empty()) { - fname_cubin[0] = static_cast(std::toupper(static_cast(fname_cubin[0]))); - } - fname_cubin = tllm_gen_gemm_cubin_path + "/" + fname_cubin + ".cubin"; - std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); - cuModuleLoadData(&cuModule, cubin.c_str()); - }; + private: + //////////////////////////////////////////////////////////////////////////////////////////////////// - if (moduleCache.has_value()) { - ModuleCache& moduleCacheRef = moduleCache.value().get(); - - // Modules are associated with a specific context, so the context is included in the key - CUcontext ctx; - unsigned long long ctxId; - cuCtxGetCurrent(&ctx); - cuCtxGetId(ctx, &ctxId); - - // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a - // string in decimal representation. - std::string const ctxName = - std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); - std::string const funcName = std::string(config.mFunctionName); - auto const moduleKey = ctxName + funcName; - auto module = moduleCacheRef.find(moduleKey); - - // Use cache if module is found, otherwise load and insert into cache - if (module != moduleCacheRef.end()) { - cuFunction = std::get<1>(module->second); - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); - moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); - } - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + // Aligns the pointer to the alignment + template + inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const { + assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); + return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & + ~(alignment - 1)); } - // Prepare the grid/block. - dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), - static_cast(1)}; - dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), - (grid.size() > 1 ? static_cast(grid[1]) : 1u), - (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; - // Prepare the cluster size. - dim3 cluster3{static_cast(options.mClusterDimX), - static_cast(options.mClusterDimY), - static_cast(options.mClusterDimZ)}; - - // Run the kernel. - auto result = trtllm::gen::launchKernel( - (void*)&kernelParams, cudaStream, config.mSharedMemSize, cuFunction, block3, grid3, cluster3, - usePdl && (config.mOptions.mGridWaitForPrimaryEarlyExit | - config.mOptions.mGridWaitForPrimaryA | config.mOptions.mGridWaitForPrimaryB)); - // If a module cache has not been given, unload the module to avoid leaking - if (!moduleCache.has_value()) { - cuModuleUnload(cuModule); - } - if (result != CUDA_SUCCESS) { - return -1; + //////////////////////////////////////////////////////////////////////////////////////////////////// + + // Returns the number of tiles and number of CTAs for Z dimension. + std::tuple getFixedGridSize(int32_t clusterDimX, int32_t clusterDimY, + int32_t numSlicesForSplitK, + int32_t multiProcessorCount) const { + assert(multiProcessorCount > 0 && + "multiProcessorCount must be provided " + "when using StaticPersistent scheduler"); + // The cluster size spanned in the XY dimension. + auto clusterSizeXy = clusterDimX * clusterDimY; + // The maximum number of CTAs a GPU can run across the XY dimension. + auto numCtasXy = multiProcessorCount / numSlicesForSplitK; + // Round down to the nearest multiple of the cluster size. + numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy; + + return std::make_tuple(numCtasXy, 1, numSlicesForSplitK); } -#else - config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid); -#endif - return 0; -} + //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns the number of tiles and number of CTAs for Z dimension. + std::tuple getGridSize(int32_t M, int32_t N, int32_t tileM, + int32_t tileN, int32_t clusterDimX, + int32_t clusterDimY, + int32_t numSlicesForSplitK) const { + // The number of tiles in the M dimension. + auto numTilesM = gemm::divUpMul(gemm::divUp(M, tileM), clusterDimX); + // The number of tiles in the N dimension. + auto numTilesN = gemm::divUpMul(gemm::divUp(N, tileN), clusterDimY); + return std::make_tuple(numTilesM, numTilesN, numSlicesForSplitK); + } -int32_t GemmInterface::runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, - void* cudaStream) const { - if (data.mProblemDimensions.mWorldSize > 1) { - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); - if (options.mAllReduceAlgo == gemm::AllReduceAlgo::OneShot) { - // The size of each element of C in bits. - int64_t const numBitsPerEltC = options.mAllReduceAlgo == gemm::AllReduceAlgo::TwoShot - ? tg::dtypeGetNumBits(options.mDtypeAcc) - : tg::dtypeGetNumBits(options.mDtypeC); - // The number of bytes for C. - int64_t const numBytesC = - data.mProblemDimensions.mM * data.mProblemDimensions.mN * numBitsPerEltC / /*bits*/ 8; - // Reset the output buffer as one-shot uses UTMAREDG at multicast memory for reduction. - auto err = cudaMemsetAsync(data.mOutputBuffers.mPtrC, 0x00, numBytesC, - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 1; + //////////////////////////////////////////////////////////////////////////////////////////////////// + + // Returns the size of the workspace buffers in bytes + std::vector getWorkspaceSizesInBytes(GemmConfig const& config, + GemmData const& data) const { + // Get options from config. + auto& options = config.mOptions; + + // Get the number of tiles and cluster dimension Z. + auto [numTilesM, numTilesN, gridDimZ] = getGridSize( + data.mProblemDimensions.mM, data.mProblemDimensions.mN, options.mTileM, options.mTileN, + options.mClusterDimX, options.mClusterDimY, options.mNumSlicesForSplitK); + + std::vector workspaceSizes; + + int64_t numBytesSplitK{0}, numBytesSplitKBars{0}; + if (doesSplitKUseGmem(options.mSplitK)) { + // The number of elements for intermediate split-k buffer that contains K + // slices padded to TileM/TileN sizes to avoid OOB accesses during the + // reduction. + // FIXME: Split-K has excessive memory traffic when combined with slice-K. + // Currently, data for all slice-K slices is sent, even though the slice-K + // reduction has already been performed. This should be optimized to send + // data for only one reduced slice. + auto numEltsSplitK = options.mNumSlicesForSplitK * numTilesM * numTilesN * options.mTileM * + options.mTileN * options.mNumSlicesForSliceK; + + // The number of bytes for intermediate split-k buffer that contains K + // slices. + numBytesSplitK = numEltsSplitK * tg::dtypeGetNumBits(tg::Dtype::Fp32) / /* bits */ 8; + // The number of bytes for the split-k completion barriers. + numBytesSplitKBars = numTilesM * numTilesN * sizeof(uint32_t); + // Two epilogue warps do GMEM split-k in DS GEMM. + if (options.mUseDeepSeekFp8) { + numBytesSplitKBars *= 2; } - } - // Get the number of tiles and number of CTAs for Z dimension. - auto [numTilesM, numTilesN, gridDimZ] = - getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX, - options.mClusterDimY, options.mNumSlicesForSplitK); - // The number of bytes for the tile barriers. - int32_t numBytesTileBars = numTilesM * numTilesN * sizeof(uint32_t); - // Sanitize system barriers. - auto err = cudaMemsetAsync((void*)data.mAllReduceBuffers.mPtrTileBars, 0x00, numBytesTileBars, - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 2; - } - err = cudaMemsetAsync((void*)data.mAllReduceBuffers.mPtrCompletionBars, 0x00, numBytesTileBars, - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 3; + // TODO: do we need to pad to 1024? + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesSplitK, 1024)); + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesSplitKBars, 1024)); } + + return workspaceSizes; } - return 0; -} + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + // Returns the size padded to the alignment + size_t getSizePaddedToAlignment(size_t size, size_t alignment) const { + assert((alignment & (alignment - 1)) == 0); + return (size + alignment - 1) & ~(alignment - 1); + } + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + private: + // The rank id of the current device in the multi-gpu space. + int32_t mRankId; + // Whether to export the cubin file. + bool mExportsCubin; + // The number of rotations. + int32_t mNumRotations; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h index cfb9e348cc..a0a621dffd 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,14 +23,23 @@ #include "Enums.h" #include "KernelParams.h" #include "KernelTraits.h" +#include "trtllm/gen/CudaArchDecl.h" #include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" #include "trtllm/gen/SfLayoutDecl.h" +#include "trtllm/gen/SparsityDecl.h" #ifndef TLLM_GEN_EXPORT_INTERFACE #include "trtllm/gen/CudaRunner.h" #include "trtllm/gen/GenCtx.h" #else +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#include +namespace flashinfer::trtllm_cubin_loader { +std::string getCubin(const std::string& kernelName, const std::string& sha256); +} +#endif // TLLM_GEN_EXPORT_FLASHINFER #include +namespace gemm { template void printArgs(T arg) { @@ -72,7 +81,18 @@ void printArgs(T first, Args... args) { #endif // TLLM_GEN_EXPORT_INTERFACE -namespace gemm { +#define GEMM_UPDATE_OR_ERROR(OPTION, VALUE) \ + if (updateOptions) { \ + OPTION = VALUE; \ + } else \ + return false + +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm namespace gemm { @@ -91,31 +111,34 @@ struct GemmOptions { #endif GemmOptions() = default; - GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, - tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, - bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, - int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, - bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, - bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, - MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, - bool mockAllReduce, int n, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, bool clcFastDrain, + int clusterDimX, int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, + tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, + tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, EltwiseActType eltwiseActType, + bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, + int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, + bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, + MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numEpilogueWarps, + int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsCopySparsityInfo, int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int sfReshapeFactor, bool sliceK, SplitK splitK, int tileK, int tileM, int tileN, - TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, - bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, - int worldSize) + bool outputDebugTensors, bool patchF2fp, int32_t sfBlockSizeA, int32_t sfBlockSizeB, + int32_t sfBlockSizeC, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, int sfReshapeFactor, bool sliceK, tg::Sparsity sparsityA, + SplitK splitK, int tileK, int tileM, int tileN, TileScheduler tileScheduler, + bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, + bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, + bool usePerTokenSfB, bool useShuffledMatrix, bool useTmaStore, + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int validM, + int validN, int validK, int worldSize) : mAllReduceAlgo{allReduceAlgo}, mBiasType{biasType}, mBlockK(blockK), + mClcFastDrain{clcFastDrain}, mClusterDimX{clusterDimX}, mClusterDimY{clusterDimY}, mClusterDimZ{clusterDimZ}, @@ -126,6 +149,7 @@ struct GemmOptions { mDtypeC{dtypeC}, mDtypeMmaA{dtypeMmaA}, mDtypeMmaB{dtypeMmaB}, + mEltwiseActType{eltwiseActType}, mEnablesEarlyExit{enablesEarlyExit}, mEnablesDelayedEarlyExit{enablesDelayedEarlyExit}, mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs}, @@ -133,6 +157,7 @@ struct GemmOptions { mEpilogueLdtmBits{epilogueLdtmBits}, mEpilogueTileM{epilogueTileM}, mEpilogueTileN{epilogueTileN}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, mGridTriggerSecondaryA{gridTriggerSecondaryA}, mGridTriggerSecondaryB{gridTriggerSecondaryB}, mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit}, @@ -151,8 +176,10 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumEpilogueWarps{numEpilogueWarps}, mNumRegsCastAWarps(numRegsCastAWarps), mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsCopySparsityInfo(numRegsCopySparsityInfo), mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), mNumSlicesForSplitK{numSlicesForSplitK}, @@ -165,11 +192,14 @@ struct GemmOptions { mOutputDebugTensors{outputDebugTensors}, mPatchF2fp{patchF2fp}, mSfBlockSizeA{sfBlockSizeA}, + mSfBlockSizeB{sfBlockSizeB}, + mSfBlockSizeC{sfBlockSizeC}, mSfLayoutA{sfLayoutA}, mSfLayoutB{sfLayoutB}, mSfLayoutC{sfLayoutC}, mSfReshapeFactor{sfReshapeFactor}, mSliceK{sliceK}, + mSparsityA{sparsityA}, mSplitK{splitK}, mTileK{tileK}, mTileM{tileM}, @@ -179,21 +209,27 @@ struct GemmOptions { mUseCustomMmaSchedule{useCustomMmaSchedule}, mUseDeepSeekFp8{useDeepSeekFp8}, mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule}, + mUseMaxTmemOverlap{useMaxTmemOverlap}, mUsePerTokenSfA{usePerTokenSfA}, mUsePerTokenSfB{usePerTokenSfB}, - mUseShuffledMatrixA{useShuffledMatrixA}, + mUseShuffledMatrix{useShuffledMatrix}, mUseTmaStore{useTmaStore}, mUseTwoTmaLoadWarps{useTwoTmaLoadWarps}, mUseTwoMmaWarps{useTwoMmaWarps}, mUseUnrollLoop2xForMma{useUnrollLoop2xForMma}, + mValidM{validM}, + mValidN{validN}, + mValidK{validK}, mWorldSize{worldSize} {} - // The all-reduce algorithm. AllReduceAlgo mAllReduceAlgo{AllReduceAlgo::None}; // The type of bias. BiasType mBiasType{BiasType::None}; // Block size in the K dimension int mBlockK{-1}; + // Whether to enable CLC fast drain for early exit in SM100 CLC-based + // scheduler. + bool mClcFastDrain{true}; // Cluster size in X dim. int mClusterDimX{1}; // Cluster size in Y dim. @@ -214,15 +250,18 @@ struct GemmOptions { tg::Dtype mDtypeMmaA{tg::Dtype::Void}; // Data type of the B matrix for the MMA, if different from the input type. tg::Dtype mDtypeMmaB{tg::Dtype::Void}; + // The type of activation. + EltwiseActType mEltwiseActType{EltwiseActType::None}; // Whether to enable early exit. bool mEnablesEarlyExit{false}; // Whether to enable delayed early exit to overlap // numNonExitingCtas loading with the other instructions. bool mEnablesDelayedEarlyExit{false}; - // Whether to enable the global PTX knobs for guiding the compiler optimizations. + // Whether to enable the global PTX knobs for guiding the compiler + // optimizations. bool mEnablesGlobalPtxKnobs{true}; - // The epilogue supports multiple LDTM shapes, although not every shape is applicable in every - // case. In particular: + // The epilogue supports multiple LDTM shapes, although not every shape is + // applicable in every case. In particular: // - On Hopper: must be 16dp256bit. // - Transposed output: must be 16dp256bit. // - Non-transposed output: @@ -236,11 +275,14 @@ struct GemmOptions { int mEpilogueTileM{128}; // Tile size for the epilogue in N dimension. int mEpilogueTileN{32}; + // Whether fuse UTCCP with UTC*MMA. + bool mFuseUtccpWithUtcmma{false}; // Whether load task A triggers the next grid. bool mGridTriggerSecondaryA{false}; // Whether load task B triggers the next grid. bool mGridTriggerSecondaryB{false}; - // Whether the loads that check for an early exit should wait on a grid dependency. + // Whether the loads that check for an early exit should wait on a grid + // dependency. bool mGridWaitForPrimaryEarlyExit{true}; // Whether the load of A should wait on a grid dependency. bool mGridWaitForPrimaryA{true}; @@ -248,7 +290,8 @@ struct GemmOptions { bool mGridWaitForPrimaryB{true}; // Whether to hoist the initialization of the loading tasks. bool mHoistLoadTaskInit{true}; - // Whether to hoist the mbarrier try_waits (e.g., mma.prodAcq, smemAb.consWait) in the MMA task. + // Whether to hoist the mbarrier try_waits (e.g., mma.prodAcq, + // smemAb.consWait) in the MMA task. bool mHoistMmaTaskTryWaits{false}; // The K dimension of GEMM. int mK{16 * 16}; @@ -272,27 +315,32 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of Epilogue Warps + int mNumEpilogueWarps{4}; // Number of registers for the cast A warps. int mNumRegsCastAWarps{0}; // Number of registers for the LDS+STTM warps. int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread to copy sparsity info with LDS+STTM. + int mNumRegsCopySparsityInfo{0}; // Number of registers per thread for epilogue warps int mNumRegsPerThreadEpilogueWarp{0}; // Number of registers per thread for non-epilogue warps int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, - // the problem is distributed across several SMs, where each CTA works on its local K slice. - // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) - // to exchange the data between CTAs. + // the problem is distributed across several SMs, where each CTA works on its + // local K slice. Partial results are accumulated afterwards using either GMEM + // or DSMEM (in CGA) to exchange the data between CTAs. int mNumSlicesForSplitK{1}; // Number of slices for slice-K along K dimension. int mNumSlicesForSliceK{1}; // The depth of the mainloop pipeline. int mNumStages{2}; - // The depth of the mma pipeline. Equals numStagesMmaWithinWorkTile * numStagesMmaAcrossWorkTile. + // The depth of the mma pipeline. Equals numStagesMmaWithinWorkTile * + // numStagesMmaAcrossWorkTile. int mNumStagesMma{1}; - // The depth of the mma pipeline within work tile. Only GmemC classes with "WithAccInReg" suffix - // are allowed to be greater than 1. + // The depth of the mma pipeline within work tile. Only GmemC classes with + // "WithAccInReg" suffix are allowed to be greater than 1. int mNumStagesMmaWithinWorkTile{-1}; // The depth of the mma pipeline across work tiles in the persistent loop. int mNumStagesMmaAcrossWorkTile{-1}; @@ -302,23 +350,31 @@ struct GemmOptions { bool mOutputDebugTensors{false}; // Patch float conversions. bool mPatchF2fp{false}; - // Block size of A. For dtypeA == E2m1 and dtypeB == E4m3. - std::optional mSfBlockSizeA{std::nullopt}; + // Block size of A, for block-scaled types. + int mSfBlockSizeA{-1}; + // Block size of B, for block-scaled types. + int mSfBlockSizeB{-1}; + // Block size of C, for block-scaled types. + int mSfBlockSizeC{-1}; // Scale factors layout for A. tg::SfLayout mSfLayoutA{tg::SfLayout::R128c4}; // Scale factors layout for B. tg::SfLayout mSfLayoutB{tg::SfLayout::R128c4}; // Scale factors layout for C. tg::SfLayout mSfLayoutC{tg::SfLayout::R128c4}; - // Number of "repeats", i.e. reshaping factor, to fold hidden dimension into SfBlock dimension. - // As result, the hidden dimension of the SF tensor must be a multiple of NumRepeats * - // numEltsPerSf * 4. This reduces the problem shape space that the kernel is able to run. - // But it reduces the number of L2 requests under the hood and potentially improves perf. - // Applies to layout 8x4 only. + // Number of "repeats", i.e. reshaping factor, to fold hidden dimension into + // SfBlock dimension. As result, the hidden dimension of the SF tensor must be + // a multiple of NumRepeats * numEltsPerSf * 4. This reduces the problem shape + // space that the kernel is able to run. But it reduces the number of L2 + // requests under the hood and potentially improves perf. Applies to layout + // 8x4 only. int mSfReshapeFactor{1}; // Slice-K implementation to use TileM dimension for TileK. bool mSliceK{false}; - // The location of the exchange for split-K (it's None when split-K is disabled). + // Sparsity of A. + tg::Sparsity mSparsityA{tg::Sparsity::Dense}; + // The location of the exchange for split-K (it's None when split-K is + // disabled). SplitK mSplitK{SplitK::None}; // K tile dimension of GEMM. int mTileK{16}; @@ -334,16 +390,20 @@ struct GemmOptions { bool mUseCustomMmaSchedule{false}; // Use DeepSeek Fp8. bool mUseDeepSeekFp8{false}; - // The purpose of hoisting trywaits is to opportunistically peek at the availability of the next - // k-block. It benefits when the next k-block is already available and thus sustaining the - // momentum, but it adds latency to the first k-block for smaller k-loop. + // The purpose of hoisting trywaits is to opportunistically peek at the + // availability of the next k-block. It benefits when the next k-block is + // already available and thus sustaining the momentum, but it adds latency to + // the first k-block for smaller k-loop. bool mUseHoistTryWaitForCustomMmaSchedule{false}; + // Whether use the max Tmem overlap trick. + bool mUseMaxTmemOverlap{false}; // Apply per-token scales from A bool mUsePerTokenSfA{false}; // Apply per-token scales from B bool mUsePerTokenSfB{false}; - // Reorder rows/cols in the A matrix for the better memory accesses in the M-major epilogue. - bool mUseShuffledMatrixA{false}; + // Reorder rows/cols in the A matrix (when TransposeMmaOutput is true, + // otherwise B matrix) for the better memory accesses in the M-major epilogue. + bool mUseShuffledMatrix{false}; // Use TMA to store the result. bool mUseTmaStore{true}; // Use two different warps for A and B matrix load. @@ -352,20 +412,23 @@ struct GemmOptions { bool mUseTwoMmaWarps{false}; // Whether to unroll the loop by 2x. bool mUseUnrollLoop2xForMma{true}; + // The valid range of M/N/K dimension of GEMM without padding values. + // Used to opportunistically remove memory traffic from the padding due to + // rigid SF shape constraint or TMA constraint. Such as: + // 1. outputDim % (4 * sfBlockSize) == 0; as 4x SFs are packed into 4 bytes + // 2. MxFp4 x Fp8 mmaType requires bespoke TMA load which requires hiddenDim % + // 128 == 0 + // 3. TMA requires 16B alignment for each row + int mValidM{-1}; + int mValidN{-1}; + int mValidK{-1}; // World size for all-reduce. int mWorldSize{1}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -enum class SmVersion { Sm90a, Sm100a, Sm100f, Sm103a }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline bool isSmVersionBlackwell(SmVersion smVersion) { - return smVersion == SmVersion::Sm100a || smVersion == SmVersion::Sm100f || - smVersion == SmVersion::Sm103a; -} +using SmVersion = tg::CudaArch; //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -374,23 +437,20 @@ inline bool isSmVersionBlackwell(SmVersion smVersion) { //////////////////////////////////////////////////////////////////////////////////////////////////// struct GemmConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + trtllm::gen::GenCfg* mGenCfg{nullptr}; int32_t mInstanceIdx{0}; -#endif GemmOptions mOptions{}; - SmVersion mSm{SmVersion::Sm100a}; + tg::CudaArch mSm{tg::CudaArch::Sm100a}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -417,13 +477,36 @@ inline std::string toString(trtllm::gen::MmaKind e) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(GemmOptions const& options) { +template <> +inline std::string toString(CtaSwizzleType e) { + switch (e) { + case CtaSwizzleType::RasterizeAlongM: + return "RasterizeAlongM"; + case CtaSwizzleType::RasterizeAlongN: + return "RasterizeAlongN"; + case CtaSwizzleType::ZigZagAlongM2: + return "ZigZagAlongM2"; + case CtaSwizzleType::ZigZagAlongN2: + return "ZigZagAlongN2"; + case CtaSwizzleType::ZigZagAlongM4: + return "ZigZagAlongM4"; + case CtaSwizzleType::ZigZagAlongN4: + return "ZigZagAlongN4"; + default: + return std::to_string(static_cast(e)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; ss << "mAllReduceAlgo=" << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) << ")" << "," << std::endl; ss << "mBiasType=" << "gemm::BiasType(" << static_cast(options.mBiasType) << ")" << "," << std::endl; ss << "mBlockK=" << options.mBlockK << "," << std::endl; + ss << "mClcFastDrain=" << options.mClcFastDrain << "," << std::endl; ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; @@ -441,6 +524,8 @@ inline std::string dumpOptions(GemmOptions const& options) { << "," << std::endl; ss << "mDtypeMmaB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" << "," << std::endl; + ss << "mEltwiseActType=" << "gemm::EltwiseActType(" + << static_cast(options.mEltwiseActType) << ")" << "," << std::endl; ss << "mEnablesEarlyExit=" << options.mEnablesEarlyExit << "," << std::endl; ss << "mEnablesDelayedEarlyExit=" << options.mEnablesDelayedEarlyExit << "," << std::endl; ss << "mEnablesGlobalPtxKnobs=" << options.mEnablesGlobalPtxKnobs << "," << std::endl; @@ -448,6 +533,7 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mEpilogueLdtmBits=" << options.mEpilogueLdtmBits << "," << std::endl; ss << "mEpilogueTileM=" << options.mEpilogueTileM << "," << std::endl; ss << "mEpilogueTileN=" << options.mEpilogueTileN << "," << std::endl; + ss << "mFuseUtccpWithUtcmma=" << options.mFuseUtccpWithUtcmma << "," << std::endl; ss << "mGridTriggerSecondaryA=" << options.mGridTriggerSecondaryA << "," << std::endl; ss << "mGridTriggerSecondaryB=" << options.mGridTriggerSecondaryB << "," << std::endl; ss << "mGridWaitForPrimaryEarlyExit=" << options.mGridWaitForPrimaryEarlyExit << "," << std::endl; @@ -455,22 +541,30 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mGridWaitForPrimaryB=" << options.mGridWaitForPrimaryB << "," << std::endl; ss << "mHoistLoadTaskInit=" << options.mHoistLoadTaskInit << "," << std::endl; ss << "mHoistMmaTaskTryWaits=" << options.mHoistMmaTaskTryWaits << "," << std::endl; - ss << "mK=" << options.mK << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mK=" << options.mK << "," << std::endl; + } ss << "mKernelTraits={}" << "," << std::endl; ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" << "," << std::endl; ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" << "," << std::endl; - ss << "mM=" << options.mM << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mM=" << options.mM << "," << std::endl; + } ss << "mMmaK=" << options.mMmaK << "," << std::endl; ss << "mMmaKind=" << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" << "," << std::endl; ss << "mMmaM=" << options.mMmaM << "," << std::endl; ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; - ss << "mN=" << options.mN << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mN=" << options.mN << "," << std::endl; + } + ss << "mNumEpilogueWarps=" << options.mNumEpilogueWarps << "," << std::endl; ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsCopySparsityInfo=" << options.mNumRegsCopySparsityInfo << "," << std::endl; ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," << std::endl; ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," @@ -484,11 +578,9 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mNumStagesWorkId=" << options.mNumStagesWorkId << "," << std::endl; ss << "mOutputDebugTensors=" << options.mOutputDebugTensors << "," << std::endl; ss << "mPatchF2fp=" << options.mPatchF2fp << "," << std::endl; - if (options.mSfBlockSizeA.has_value()) { - ss << "mSfBlockSizeA=" << options.mSfBlockSizeA.value() << "," << std::endl; - } else { - ss << "mSfBlockSizeA=" << "std::nullopt" << ", " << std::endl; - } + ss << "mSfBlockSizeA=" << options.mSfBlockSizeA << "," << std::endl; + ss << "mSfBlockSizeB=" << options.mSfBlockSizeB << "," << std::endl; + ss << "mSfBlockSizeC=" << options.mSfBlockSizeC << "," << std::endl; ss << "mSfLayoutA=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutA) << ")" << "," << std::endl; ss << "mSfLayoutB=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutB) << ")" @@ -497,6 +589,8 @@ inline std::string dumpOptions(GemmOptions const& options) { << "," << std::endl; ss << "mSfReshapeFactor=" << options.mSfReshapeFactor << "," << std::endl; ss << "mSliceK=" << options.mSliceK << "," << std::endl; + ss << "mSparsityA=" << "trtllm::gen::Sparsity(" << static_cast(options.mSparsityA) << ")" + << "," << std::endl; ss << "mSplitK=" << "gemm::SplitK(" << static_cast(options.mSplitK) << ")" << "," << std::endl; ss << "mTileK=" << options.mTileK << "," << std::endl; @@ -509,14 +603,20 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mUseDeepSeekFp8=" << options.mUseDeepSeekFp8 << "," << std::endl; ss << "mUseHoistTryWaitForCustomMmaSchedule=" << options.mUseHoistTryWaitForCustomMmaSchedule << "," << std::endl; + ss << "mUseMaxTmemOverlap=" << options.mUseMaxTmemOverlap << "," << std::endl; ss << "mUsePerTokenSfA=" << options.mUsePerTokenSfA << "," << std::endl; ss << "mUsePerTokenSfB=" << options.mUsePerTokenSfB << "," << std::endl; - ss << "mUseShuffledMatrixA=" << options.mUseShuffledMatrixA << "," << std::endl; + ss << "mUseShuffledMatrix=" << options.mUseShuffledMatrix << "," << std::endl; ss << "mUseTmaStore=" << options.mUseTmaStore << "," << std::endl; ss << "mUseTwoTmaLoadWarps=" << options.mUseTwoTmaLoadWarps << "," << std::endl; ss << "mUseTwoMmaWarps=" << options.mUseTwoMmaWarps << "," << std::endl; ss << "mUseUnrollLoop2xForMma=" << options.mUseUnrollLoop2xForMma << "," << std::endl; - ss << "mWorldSize=" << options.mWorldSize << std::endl; + if (dumpRuntimeParams) { + ss << "mValidM=" << options.mValidM << "," << std::endl; + ss << "mValidN=" << options.mValidN << "," << std::endl; + ss << "mValidK=" << options.mValidK << "," << std::endl; + ss << "mWorldSize=" << options.mWorldSize << std::endl; + } return ss.str(); } @@ -536,6 +636,33 @@ inline T divUpMul(T a, T b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// clang-format off +inline std::vector srcToDstBlk16RowMap = + { + 0, 8, + 1, 9, + 2, 10, + 3, 11, + 4, 12, + 5, 13, + 6, 14, + 7, 15 + }; +inline std::vector srcToDstBlk32RowMap = + { + 0, 8, 16, 24, + 1, 9, 17, 25, + 2, 10, 18, 26, + 3, 11, 19, 27, + 4, 12, 20, 28, + 5, 13, 21, 29, + 6, 14, 22, 30, + 7, 15, 23, 31 + }; +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline int32_t getShuffleBlockSize(int epilogueTileM) { int shuffleBlockSize = 16; if (epilogueTileM % 128 == 0) { @@ -546,11 +673,21 @@ inline int32_t getShuffleBlockSize(int epilogueTileM) { //////////////////////////////////////////////////////////////////////////////////////////////////// +inline std::vector const& getShuffleIndices(int epilogueTileM) { + auto const shuffleBlockSize = getShuffleBlockSize(epilogueTileM); + return shuffleBlockSize == 16 ? srcToDstBlk16RowMap : srcToDstBlk32RowMap; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Check if the options are valid or not. -inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, int tpGrpSize, +inline bool checkAndUpdateGemmOptions(GemmOptions& options, tg::CudaArch cudaArch, int tpGrpSize, bool updateOptions = true) { options.mWorldSize = tpGrpSize; + bool isBlackwell = tg::isArchBlackwell(cudaArch); + + // If dtypeB is unspecified (Dtype::Void), assign to dtypeA. if (options.mDtypeB == tg::Dtype::Void) { if (updateOptions) { options.mDtypeB = options.mDtypeA; @@ -558,6 +695,15 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in return false; } } + // If dtypeC is unspecified (Dtype::Void), assign to dtypeA. + if (options.mDtypeC == tg::Dtype::Void) { + TLLM_LOG_INFO("Setting dtypeC to ", tg::dtypeToString(options.mDtypeA)); + if (updateOptions) { + options.mDtypeC = options.mDtypeA; + } else { + return false; + } + } // If not specified, used the input dtypes as MMA dtypes (no cast required). if (options.mDtypeMmaA == tg::Dtype::Void) { @@ -575,11 +721,58 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } + // If validM/N/K is not specified, then assume the full range of the dimension + // is valid. + if (options.mValidM < 0 || options.mValidN < 0 || options.mValidK < 0) { + if (updateOptions) { + options.mValidM = options.mValidM < 0 ? options.mM : options.mValidM; + options.mValidN = options.mValidN < 0 ? options.mN : options.mValidN; + options.mValidK = options.mValidK < 0 ? options.mK : options.mValidK; + } else { + return false; + } + } + + // It must not exceed the padded dimensions. + if (options.mValidM > options.mM || options.mValidN > options.mN || + options.mValidK > options.mK) { + TLLM_LOG_WARNING( + "ValidM, ValidN, and ValidK must be less than or equal to " + "M, N, and K respectively."); + if (updateOptions) { + options.mValidM = std::min(options.mValidM, options.mM); + options.mValidN = std::min(options.mValidN, options.mN); + options.mValidK = std::min(options.mValidK, options.mK); + } else { + return false; + } + } + + // BlockMajorK layout does not support validM, validN, validK parameters + if (options.mLayoutA == gemm::MatrixLayout::BlockMajorK || + options.mLayoutB == gemm::MatrixLayout::BlockMajorK) { + bool hasValidParams = (options.mValidM != -1 && options.mValidM != options.mM) || + (options.mValidN != -1 && options.mValidN != options.mN) || + (options.mValidK != -1 && options.mValidK != options.mK); + TLLM_CHECK_ERROR(!hasValidParams, + "BlockMajorK layout does not support validM/validN/validK " + "parameters due to swizzled layout. " + "Found validM=", + options.mValidM, " validN=", options.mValidN, " validK=", options.mValidK); + } + +#ifdef TLLM_PUBLIC_RELEASE + if (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3) { + TLLM_CHECK_ERROR(false, "E2m1 x E4m3 is not supported for JIT compile. Use cubins instead."); + } +#endif // TLLM_PUBLIC_RELEASE + // Check that the A cast is supported. - // Currently, we only support {MxFp4, NvFp4} -> Bf16. + // Currently, we only support {MxFp4, NvFp4, MxInt4} -> Bf16. TLLM_CHECK_ERROR( (options.mDtypeA == options.mDtypeMmaA) || - ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && + ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1 || + options.mDtypeA == tg::Dtype::MxInt4) && options.mDtypeMmaA == tg::Dtype::Bfloat16) || (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), "Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", @@ -604,15 +797,19 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mDtypeA == tg::Dtype::MxE2m1 && options.mDtypeMmaA == tg::Dtype::Bfloat16, "PatchF2fp is only supported for MxFp4 to Bf16 casts."); } +#ifdef TLLM_PUBLIC_RELEASE + options.mPatchF2fp = false; +#endif // TLLM_PUBLIC_RELEASE - // FIXME: We do not support different dtypes for A and B when not on Blackwell. + // FIXME: We do not support different dtypes for A and B when not on + // Blackwell. if (!isBlackwell) { TLLM_CHECK_ERROR(options.mDtypeMmaA == options.mDtypeMmaB, "For non-Blackwell, A and B must have the same dtype."); } - // Check that the different dtypes for A and B are supported by the tensor core - // kind::f8f6f4 + // Check that the different dtypes for A and B are supported by the tensor + // core kind::f8f6f4 if (options.mDtypeMmaA == tg::Dtype::E4m3 || options.mDtypeMmaA == tg::Dtype::E2m1) { TLLM_CHECK_ERROR(options.mDtypeMmaB == tg::Dtype::E4m3 || options.mDtypeMmaB == tg::Dtype::E2m1, "For dtypeMmaA = E4m3/E2m1 A, dtypeMmaB must also be E4m3/E2m1."); @@ -622,31 +819,38 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1) { TLLM_CHECK_ERROR( options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1, - "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be MxE4m3 or MxE2m1."); + "For dtypeMmaA = MxE4m3 or MxE2m1, dtypeMmaB must also be " + "MxE4m3 or MxE2m1."); } if (options.mDtypeMmaB == tg::Dtype::MxE4m3 || options.mDtypeMmaB == tg::Dtype::MxE2m1) { TLLM_CHECK_ERROR( options.mDtypeMmaA == tg::Dtype::MxE4m3 || options.mDtypeMmaA == tg::Dtype::MxE2m1, - "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be MxE4m3 or MxE2m1."); + "For dtypeMmaB = MxE4m3 or MxE2m1, dtypeMmaA must also be " + "MxE4m3 or MxE2m1."); } // kind::f16 if (options.mDtypeMmaA == tg::Dtype::Fp16 || options.mDtypeMmaA == tg::Dtype::Bfloat16) { TLLM_CHECK_ERROR(options.mDtypeMmaB == options.mDtypeMmaA, - "For dtypeMmaA = Fp16/Bfloat16, dtypeMmaB must be the same as dtypeMmaA."); + "For dtypeMmaA = Fp16/Bfloat16, dtypeMmaB must be the " + "same as dtypeMmaA."); } // When one of the inputs needs to be cast, we must use two load warps. if ((options.mDtypeMmaA != options.mDtypeA || options.mDtypeMmaB != options.mDtypeB) && !options.mUseTwoTmaLoadWarps) { - TLLM_LOG_WARNING("Two TMA load warps must be enabled if any of the inputs needs to be cast."); + TLLM_LOG_WARNING( + "Two TMA load warps must be enabled if any of the inputs " + "needs to be cast."); } - // When different dtypes are used for A and B, we must use different tiles to do the loading. - // It is not strictly required, but current implementation of SmemAb requires that. + // When different dtypes are used for A and B, we must use different tiles to + // do the loading. It is not strictly required, but current implementation of + // SmemAb requires that. if (options.mDtypeA != options.mDtypeB) { TLLM_CHECK_ERROR(options.mUseTwoTmaLoadWarps, - "Two TMA load warps must be enabled for different input types of A and B."); + "Two TMA load warps must be enabled for different input " + "types of A and B."); } // Get the mma kind for the input types. @@ -658,16 +862,63 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } - if ((options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4 || - options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) && - options.mMmaK != 32) { - TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, - ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to 32"); - if (updateOptions) { - options.mMmaK = 32; - options.mTileK = std::max(options.mMmaK, options.mTileK); - } else { - return false; + // Check that the sparsity mode of A is supported, and compatible with the MMA + // kind. Note: trtllm-gen currently does not support sparsity with tf32, fp16, + // bf16. + switch (options.mSparsityA) { + case tg::Sparsity::Dense: + // Always supported. + break; + case tg::Sparsity::Any_1_2: + TLLM_LOG_ERROR("1:2 sparsity is not supported."); + break; + case tg::Sparsity::Any_2_4: { + bool isSupported_2_4 = (options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4 || + options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4); + TLLM_CHECK_ERROR(isSupported_2_4, "2:4 sparsity is not supported for MMA kind ", + tg::mmaKindToString(options.mMmaKind), " on target ", + tg::cudaArchToString(cudaArch)); + break; + } + case tg::Sparsity::Pairwise_4_8: + TLLM_CHECK_ERROR(options.mMmaKind == tg::MmaKind::MxFp4NvFp4, + "Pairwise 4:8 sparsity is only supported for MMA kind MxFp4NvFp4."); + break; + default: + TLLM_CHECK_ERROR(false, "Unsupported sparsityA: ", tg::sparsityToString(options.mSparsityA)); + break; + } + + // Is A sparse? + bool const isSparseA = tg::isSparse(options.mSparsityA); + + // Requirements specific to sparsity, and compatibility with other features. + if (isSparseA) { + TLLM_CHECK_ERROR(isBlackwell, "Sparsity is only supported on Blackwell"); + // The following requirement is for TMA load: the box width must be a + // multiple of 16B. + TLLM_CHECK_ERROR(tg::getNumBytesSparsityInfo(options.mSparsityA, options.mTileK) % 16 == 0, + "The sparsity information for one tile row must be a multiple of 16B. " + "Use larger tileK."); + TLLM_CHECK_ERROR(options.mDtypeA == options.mDtypeMmaA, + "Sparsity is not supported with on-the-fly upcasting."); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, "Sparsity is not supported with DeepSeek Fp8."); + TLLM_CHECK_ERROR(!options.mSliceK, "Sparsity is not supported with slice-k."); + } + + if (options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4) { + int mmaK = isSparseA ? 64 : 32; + + if (options.mMmaK != mmaK) { + TLLM_LOG_WARNING( + "Unsupported MmaK (", options.mMmaK, ") for MmaKind=", gemm::toString(options.mMmaKind), + " and sparsity=", tg::sparsityToString(options.mSparsityA), ". Setting MmaK to ", mmaK); + if (updateOptions) { + options.mMmaK = mmaK; + options.mTileK = std::max(options.mMmaK, options.mTileK); + } else { + return false; + } } } @@ -683,24 +934,25 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "dp", options.mEpilogueLdtmBits, "bit."); } if (options.mTransposeMmaOutput) { - // We can't use 32dp32bit LDTM for transposed outputs because we need each thread to own - // multiple consecutive output elements. + // We can't use 32dp32bit LDTM for transposed outputs because we need each + // thread to own multiple consecutive output elements. TLLM_CHECK_ERROR((options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256), "Only 16dp256bit LDTM is supported for transposed outputs."); } } else { - TLLM_CHECK_ERROR( - options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, - "Hopper does not use TMEM. The register layout corresponds to 16dp256bit. Got ", - options.mEpilogueLdtmDps, "dp", options.mEpilogueLdtmBits, "bit."); + TLLM_CHECK_ERROR(options.mEpilogueLdtmDps == 16 && options.mEpilogueLdtmBits == 256, + "Hopper does not use TMEM. The register layout corresponds to " + "16dp256bit. Got ", + options.mEpilogueLdtmDps, "dp", options.mEpilogueLdtmBits, "bit."); } - // Constraints for NvFp4 and MxFp8. + // Constraints for NvFp4, MxFp8, and MxFp4. if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { if (options.mClusterDimX == 1) { - // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. + // MMA M must be 128 when the input uses block scaling, or when the output + // is an Mx format. int newTileM = 128 * divUp(options.mTileM, 128); TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, ") for MmaKind=", gemm::toString(options.mMmaKind), @@ -723,14 +975,13 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); - int mmaK = 32; + int mmaK = isSparseA ? 64 : 32; if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4) { - if (options.mMmaK == 96) { + mmaK = isSparseA ? 128 : 64; + if (options.mMmaK == 96 && !isSparseA) { mmaK = 96; TLLM_CHECK_ERROR(options.mTileK == 768, "When mmaK == 96, only tileK == 768 is supported"); TLLM_CHECK_ERROR(options.mTileN <= 128, "When mmaK == 96, only tileN <= 128 is supported"); - } else { - mmaK = 64; } } if (options.mMmaK != mmaK) { @@ -751,25 +1002,96 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mMmaN, ") must be >= 64 or equal to TileN (", options.mTileN, ")"); } - if (options.mSfBlockSizeA.has_value()) { - // Only E2m1 x E4m3 is tested. MxE2m1 x bf16 may also work. - TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeB == tg::Dtype::E4m3, - "sfBlockSizeA is only supported for E2m1 and E4m3 types. Found dtypeA=", - tg::dtypeToString(options.mDtypeA), - " dtypeB=", tg::dtypeToString(options.mDtypeB)); - - // sfBlockSizeA must be 16 or 32. - // SfBlockSizeA can also support 64 and 128, although they are not officially supported Nvida - // format. Note that the type conversion needs to happen before TCs. + // Note: the logic for selecting/checking the correct block size based on + // dtypes and sparsity is centralized here, to avoid error-prone code + // duplication and make it a more explicit "contract" with the user who is + // providing inputs in this format. Additionally, in some cases, multiple + // values are possible: + // - When we use type casting before the MMA (e.g. e2m1 x e4m3). + // - For output C, based on whether the consumer will use sparsity. + + // SF block size for A. + if (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeB == tg::Dtype::E4m3) { + // Note that the type conversion needs to happen before TCs. // For example, convert e2m1 to e4m3 inside TmemCastA. - // If we want to support sfBlockSizeA=8, we can write another version of convertE2m1ToSfE4m3, - // which only packs 8 e2m1 elements. - TLLM_CHECK_ERROR(options.mSfBlockSizeA.value() == 16 || options.mSfBlockSizeA.value() == 32, - "SfBlockSizeA (", options.mSfBlockSizeA.value(), ") must be 16 or 32."); + if (!(options.mSfBlockSizeA == 16 || options.mSfBlockSizeA == 32)) { + TLLM_LOG_WARNING("sfBlockSizeA must be 16 or 32 for e2m1 x e4m3, got ", + options.mSfBlockSizeA); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeA, 16); + } + } else if (options.mDtypeA == tg::Dtype::E2m1) { + if (!((options.mSfBlockSizeA == 16 && !isSparseA) || + (options.mSfBlockSizeA == 32 && isSparseA))) { + TLLM_LOG_WARNING( + "sfBlockSizeA must be 16 (dense) or 32 (sparse) for " + "dtypeA=e2m1, got ", + options.mSfBlockSizeA); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeA, isSparseA ? 32 : 16); + } + } else if (options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || + options.mDtypeA == tg::Dtype::MxInt4) { + if (!((options.mSfBlockSizeA == 32 && !isSparseA) || + (options.mSfBlockSizeA == 64 && isSparseA))) { + TLLM_LOG_WARNING( + "sfBlockSizeA must be 32 (dense) or 64 (sparse) for " + "dtypeA=mx{e2m1,e4m3,int4}, got ", + options.mSfBlockSizeA); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeA, isSparseA ? 64 : 32); + } + } else if (options.mSfBlockSizeA > 0) { + TLLM_LOG_WARNING("Got sfBlockSizeA=", options.mSfBlockSizeA, + " but dtypeA=", tg::dtypeToString(options.mDtypeA), + " does not use block scales"); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeA, -1); + } + // SF block size for B. + if (options.mDtypeB == tg::Dtype::E2m1) { + if (!((options.mSfBlockSizeB == 16 && !isSparseA) || + (options.mSfBlockSizeB == 32 && isSparseA))) { + TLLM_LOG_WARNING( + "sfBlockSizeB must be 16 (dense) or 32 (sparse) for " + "dtypeB=e2m1, got ", + options.mSfBlockSizeB); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeB, isSparseA ? 32 : 16); + } + } else if (options.mDtypeB == tg::Dtype::MxE2m1 || options.mDtypeB == tg::Dtype::MxE4m3 || + (options.mDtypeB == tg::Dtype::E4m3 && options.mDtypeMmaB == tg::Dtype::MxE4m3)) { + if (!((options.mSfBlockSizeB == 32 && !isSparseA) || + (options.mSfBlockSizeB == 64 && isSparseA))) { + TLLM_LOG_WARNING( + "sfBlockSizeB must be 32 (dense) or 64 (sparse) for " + "dtypeB=mx{e2m1,e4m3}, got ", + options.mSfBlockSizeB); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeB, isSparseA ? 64 : 32); + } + } else if (options.mSfBlockSizeB > 0) { + TLLM_LOG_WARNING("Got sfBlockSizeB=", options.mSfBlockSizeB, + " but dtypeB=", tg::dtypeToString(options.mDtypeB), + " does not use block scales"); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeB, -1); + } + // SF block size for C. + if (options.mDtypeC == tg::Dtype::E2m1) { + if (!(options.mSfBlockSizeC == 16 || options.mSfBlockSizeC == 32)) { + TLLM_LOG_WARNING("sfBlockSizeC must be 16 or 32 for dtypeC=e2m1, got ", + options.mSfBlockSizeC); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeC, 16); + } + } else if (options.mDtypeC == tg::Dtype::MxE2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { + if (!(options.mSfBlockSizeC == 32 || options.mSfBlockSizeC == 64)) { + TLLM_LOG_WARNING("sfBlockSizeC must be 32 or 64 for dtypeC=mx{e2m1,e4m3}, got ", + options.mSfBlockSizeC); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeC, 32); + } + } else if (options.mSfBlockSizeC > 0) { + TLLM_LOG_WARNING("Got sfBlockSizeC=", options.mSfBlockSizeC, + " but dtypeC=", tg::dtypeToString(options.mDtypeC), + " does not use block scales"); + GEMM_UPDATE_OR_ERROR(options.mSfBlockSizeC, -1); } if (tg::dtypeIsBlockFmt(options.mDtypeA)) { - int numEltsPerSfA = options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA)); + int numEltsPerSfA = options.mSfBlockSizeA; TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfA) == 0, "TileK (", options.mTileK, ") must be a multiple of ", (4 * numEltsPerSfA), " for typeA ", gemm::toString(options.mDtypeA)); @@ -790,7 +1112,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in ") must be a multiple of ", numSfTileRowsB, " for B SF layout ", tg::sfLayoutToString(options.mSfLayoutB)); - int numEltsPerSfB = tg::dtypeNumEltsPerSf(options.mDtypeB); + int numEltsPerSfB = options.mSfBlockSizeB; TLLM_CHECK_ERROR(options.mTileK % (4 * numEltsPerSfB) == 0, "TileK (", options.mTileK, ") must be a multiple of ", (4 * numEltsPerSfB), " for typeB ", gemm::toString(options.mDtypeB)); @@ -816,12 +1138,17 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in (padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == 0, "K dimension of B must be aligned to 16 bytes."); - if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { + if (tg::dtypeIsBlockFmt(options.mDtypeC)) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); TLLM_CHECK_ERROR( options.mSfLayoutC == tg::SfLayout::R128c4 || options.mSfLayoutC == tg::SfLayout::R8c4, "Only the 128x4 and 8x4 SF layouts are supported for C."); + if (!options.mTransposeMmaOutput) { + TLLM_CHECK_ERROR(options.mEpilogueTileN % options.mSfBlockSizeC == 0, + "EpilogueTileN must be a multiple of the number of " + "elements per SF for C"); + } int const numSfTileRowsC = options.mSfLayoutC == tg::SfLayout::R128c4 ? 128 : 8; int const tileTokenDim = options.mTransposeMmaOutput ? options.mTileN : options.mTileM; TLLM_CHECK_ERROR_FMT(tileTokenDim % numSfTileRowsC == 0, @@ -829,25 +1156,20 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mTransposeMmaOutput ? "N" : "M", tileTokenDim, numSfTileRowsC, tg::sfLayoutToString(options.mSfLayoutC).c_str()); + int numEltsPerSfC = options.mSfBlockSizeC; int const hiddenDim = options.mTransposeMmaOutput ? options.mM : options.mN; - int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); + int const hiddenGranularity = 4 * numEltsPerSfC; TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, "Hidden dim (", hiddenDim, ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); - TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA, - "Transposing block-scaled outputs requires shuffled A."); - } - - // If dtypeC is unspecified (Dtype::Void), assign to the input dtype. - if (options.mDtypeC == tg::Dtype::Void) { - TLLM_LOG_INFO("Setting dtypeC to ", tg::dtypeToString(options.mDtypeA)); - if (updateOptions) { - options.mDtypeC = options.mDtypeA; - } else { - return false; - } + int const validHiddenDim = options.mTransposeMmaOutput ? options.mValidM : options.mValidN; + TLLM_CHECK_ERROR(validHiddenDim % numEltsPerSfC == 0, "Valid hidden dim (", validHiddenDim, + ") must be a multiple of ", numEltsPerSfC, " for block-scaled outputs."); + TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrix, + "Transposing block-scaled outputs requires shuffled matrix."); } - // Set epilogue tile sizes to the output tile sizes, when epilogue tile sizes are incorrect. + // Set epilogue tile sizes to the output tile sizes, when epilogue tile sizes + // are incorrect. if (options.mTileM % options.mEpilogueTileM != 0) { TLLM_LOG_WARNING("TileM (", options.mTileM, ") must be divisible by EpilogueTileM (", options.mEpilogueTileM, "). Setting EpilogueTileM to TileM"); @@ -872,7 +1194,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (!isBlackwell && (options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN)) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match " + "tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -884,7 +1207,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Unsupported epilogue tile size. if (options.mMmaM == 128 && options.mEpilogueTileM != options.mTileM) { TLLM_LOG_WARNING( - "When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM"); + "When MmaM = 128, EpilogueTileM must be equal to TileM. " + "Setting EpilogueTileM to TileM"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; } else { @@ -896,19 +1220,19 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "M, N and K must be larger than 0"); TLLM_CHECK_ERROR(options.mNumSlicesForSplitK > 0, "Split K must be larger than 0."); - if (options.mUseShuffledMatrixA) { + if (options.mUseShuffledMatrix) { auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM); - TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, - "M must be a multiple of shuffle block size (", shuffleBlockSize, - ") when useShuffledMatrixA"); + TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0 && options.mValidM % shuffleBlockSize == 0, + "M/validM must be a multiple of shuffle block size (", shuffleBlockSize, + ") when useShuffledMatrix"); } if (!options.mSliceK) { TLLM_CHECK_ERROR(options.mMmaM / options.mClusterDimX <= options.mEpilogueTileM, "EpilogueTileM must be larger or equal than mmaM."); } else { - // FIXME: this is not necessary limitation. Simply fixing num repeats in TmemSliceKA should be - // enough. + // FIXME: this is not necessary limitation. Simply fixing num repeats in + // TmemSliceKA should be enough. TLLM_CHECK_ERROR((options.mTileN & (options.mTileN - 1)) == 0, "For Slice-K TileN is required to be a power of 2"); } @@ -920,7 +1244,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in TLLM_CHECK_ERROR( options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, - "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); + "TileM and TileN must be divisible by EpilogueTileM and " + "EpilogueTileN respectively."); TLLM_CHECK_ERROR( (options.mClusterDimX == 1 || options.mClusterDimX == 2) && options.mClusterDimY == 1, "GEMM does not support cluster in X and Y dimensions."); @@ -934,18 +1259,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "K must be a multiple of TileK * numSlicesForSplitK for DeepSeekFp8"); } - // When the A-matrix is shuffled, the output must be transposed. - if (options.mUseShuffledMatrixA) { - // TODO add matrix shuffle for N-major epilogue. - TLLM_CHECK_ERROR( - options.mTransposeMmaOutput, - "Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput"); - } - // Check all-reduce options. if (options.mAllReduceAlgo == AllReduceAlgo::OneShot) { - // One shot is implemented with PTX cp.reduce.async.bulk.tensor which supports only the - // following types for reduce add: u32, s32, u64, f32, f16, bf16. + // One shot is implemented with PTX cp.reduce.async.bulk.tensor which + // supports only the following types for reduce add: u32, s32, u64, f32, + // f16, bf16. // // See: https://docs.nvidia.com/cuda/parallel-thread-execution/ // #data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor @@ -957,10 +1275,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // TODO(anchengc): // Input dtype == output dtype -> can perform all-reduce in-place. // Input dtype != output dtype -> must perform all-reduce out of place. - TLLM_CHECK_ERROR_FMT( - options.mDtypeC == options.mDtypeAcc, - "Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update", - tg::dtypeToString(options.mDtypeC).c_str(), tg::dtypeToString(options.mDtypeAcc).c_str()); + TLLM_CHECK_ERROR_FMT(options.mDtypeC == options.mDtypeAcc, + "Not implemented - mixed dtype (dtypeC (%s) != " + "dtypeAcc (%s)) requires out of place update", + tg::dtypeToString(options.mDtypeC).c_str(), + tg::dtypeToString(options.mDtypeAcc).c_str()); } if (options.mAllReduceAlgo != AllReduceAlgo::None) { TLLM_CHECK_ERROR(options.mUseTmaStore, "Non-TMA store with all-reduce is not implemented"); @@ -988,7 +1307,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN) && !options.mUseDeepSeekFp8) { TLLM_LOG_WARNING( - "Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively"); + "Overwriting epilogueTileM and epilogueTileN to match " + "tileM and tileN respectively"); if (updateOptions) { options.mEpilogueTileM = options.mTileM; options.mEpilogueTileN = options.mTileN; @@ -1002,10 +1322,35 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "CGA size must be equal to the number of slices in split-k"); } - // Maps numStagesMma to (stagesWithinWorkTile, stagesAcrossWorkTile) if not already set. - // If (-1, -1) -> (numStagesMma / min(2, numStagesMma), min(2, numStagesMma)) - // If ( m, -1) -> (m, numStagesMma / m) - // If (-1, n) -> (numStagesMma / n, n) + if (options.mUseShuffledMatrix && !options.mTransposeMmaOutput) { + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "DeepSeek Fp8 is not supported when using shuffled matrix " + "and non-transposed mma output"); + TLLM_CHECK_ERROR(options.mEpilogueLdtmBits == 32, + "EpilogueLdtmBits must be 32 when using shuffled matrix " + "and non-transposed mma output"); + TLLM_CHECK_ERROR(options.mEpilogueLdtmDps == 32, + "EpilogueLdtmDps must be 32 when using shuffled matrix " + "and non-transposed mma output"); + TLLM_CHECK_ERROR(options.mUseTmaStore, + "TMA store is required when using shuffled matrix and " + "non-transposed mma output"); + TLLM_CHECK_ERROR(!options.mSliceK, + "Slice-K is not supported when using shuffled matrix and " + "non-transposed mma output"); + // When doing unshuffle in the epilogue, one fragment of epilogue tile must + // have at least one shuffle block. + auto minEpilogueTileN = getShuffleBlockSize(options.mEpilogueTileM); + TLLM_CHECK_ERROR_FMT(options.mEpilogueTileN >= minEpilogueTileN, + "EpilogueTileN (%d) must be a larger than the shuffle block size (%d) " + "when using shuffled matrix and non-transposed mma output", + options.mEpilogueTileN, minEpilogueTileN); + } + + // Maps numStagesMma to (stagesWithinWorkTile, stagesAcrossWorkTile) if not + // already set. If (-1, -1) -> (numStagesMma / min(2, numStagesMma), min(2, + // numStagesMma)) If ( m, -1) -> (m, numStagesMma / m) If (-1, n) -> + // (numStagesMma / n, n) if (options.mNumStagesMmaWithinWorkTile == -1 && options.mNumStagesMmaAcrossWorkTile == -1) { if (updateOptions) { options.mNumStagesMmaAcrossWorkTile = std::min(2, options.mNumStagesMma); @@ -1034,7 +1379,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mNumStagesMma && options.mNumStagesMmaAcrossWorkTile <= 2, "Condition numStagesMmaWithinWorkTile (%d) * numStagesMmaAcrossWorkTile " - "(%d) == numStagesMma (%d) && numStagesMmaAcrossWorkTile (%d) <= 2 must be " + "(%d) == numStagesMma (%d) && numStagesMmaAcrossWorkTile (%d) <= 2 must " + "be " "satisfied. Check arguments.", options.mNumStagesMmaWithinWorkTile, options.mNumStagesMmaAcrossWorkTile, options.mNumStagesMma, options.mNumStagesMmaAcrossWorkTile); @@ -1046,8 +1392,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in TLLM_CHECK_ERROR(options.mNumStagesMmaWithinWorkTile == 1, "Non-DeepSeekFp8 requires numStagesMmaWithinWorkTile == 1"); if (options.mNumStagesMma > 1) { - TLLM_CHECK_ERROR(options.mTileScheduler == TileScheduler::Persistent, - "Non-DeepSeekFp8 requires persistent scheduler when using numStagesMma >1"); + TLLM_CHECK_ERROR(isPersistentScheduler(options.mTileScheduler), + "Non-DeepSeekFp8 requires persistent scheduler when " + "using numStagesMma >1"); } } if (options.mUseDeepSeekFp8) { @@ -1074,36 +1421,42 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in auto hiddenDimName = options.mTransposeMmaOutput ? "M" : "N"; TLLM_CHECK_WARNING(options.mNumStagesMmaWithinWorkTile > 1, "DeepSeekFp8 recommends setting \"-numStagesMmaWithinWorkTile 2\"."); - // Update the number of stages of the MMA accumulator pipeline. TODO: enable by default for - // deepseek. - // options.mNumStagesMma = 2; - // Use two MMA warps to reduce mbar trywait latency. TODO: enable by default for deepseek. + // Update the number of stages of the MMA accumulator pipeline. TODO: enable + // by default for deepseek. options.mNumStagesMma = 2; Use two MMA warps to + // reduce mbar trywait latency. TODO: enable by default for deepseek. // options.mUseTwoMmaWarps = true; - // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mK % 128 == 0, - "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mK); + // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek + // FP8. + TLLM_CHECK_ERROR(options.mK % 128 == 0 && options.mValidK % 128 == 0, + "GEMM-K and validK must be a multiple of 128 when using " + "DeepSeek Fp8. Found ", + options.mK, " and validK=", options.mValidK); - // Check that the output tile N can be processed with the epilogue tile granularity. + // Check that the output tile N can be processed with the epilogue tile + // granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", hiddenDimPerOutputTile / 2, ") being a multiple of EpilogueTile", hiddenDimName, " (", hiddenDimPerEpilogueTile, ")"); - // Check that the output tile N can be processed with the epilogue tile granularity. + // Check that the output tile N can be processed with the epilogue tile + // granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerMma == 0, "DeepSeek Fp8 requires Tile", hiddenDimName, " / 2 (", hiddenDimPerOutputTile / 2, ") being a multiple of mma", hiddenDimName, " (", hiddenDimPerMma, ")"); } + TLLM_CHECK_ERROR(options.mNumEpilogueWarps == 4 || options.mNumEpilogueWarps == 8, + "mNumEpilogueWarps has to be either 4 or 8."); + if (options.mSliceK) { TLLM_CHECK_ERROR(isBlackwell, "Slice-K is not supported on Hopper"); TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, "DeepSeek Fp8 GEMM is not supported for slice-K"); TLLM_CHECK_ERROR(options.mUseTwoTmaLoadWarps, "Slice-K requires two warp load for A and B"); TLLM_CHECK_ERROR(options.mTransposeMmaOutput, "Slice-K requires transpose mma output"); - TLLM_CHECK_ERROR(options.mUseShuffledMatrixA, "Slice-K requires shuffled matrix A"); + TLLM_CHECK_ERROR(options.mUseShuffledMatrix, "Slice-K requires shuffled matrix"); TLLM_CHECK_ERROR(options.mTileK % 128 == 0, "Slice-K requires TileK be a multiple of 128"); TLLM_CHECK_ERROR(options.mMmaM == 128, "Slice-K requires MmaM == 128"); TLLM_CHECK_ERROR(options.mTileN == options.mEpilogueTileN, @@ -1136,15 +1489,16 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } // Number of iterations in K dimension after padding. - // Note the perCtaK in each CTA in the splitK group are padded to the same number of iterations. - // E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. Then the padded K is + // Note the perCtaK in each CTA in the splitK group are padded to the same + // number of iterations. E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. + // Then the padded K is // // ceil(512 / (128*3)) * (128*3) = 768 // int const paddedK = divUpMul(options.mK, options.mTileK * options.mNumSlicesForSplitK); int const perCtaK = paddedK / options.mNumSlicesForSplitK; - // However, number of iterations is clamped to multiples of tileK within individual CTAs - // E.g., K = 448, TileK = 64, numSlicesForSplitK = 4. + // However, number of iterations is clamped to multiples of tileK within + // individual CTAs E.g., K = 448, TileK = 64, numSlicesForSplitK = 4. // // paddedK = 512 // perCtaK = 128 @@ -1153,9 +1507,10 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in int const paddingForK = paddedK - options.mK; int const clampedAndPaddedPerCtaK = divUpMul(perCtaK - paddingForK, options.mTileK); if (options.mUseUnrollLoop2xForMma) { - // Check that the padded K and clamped padded K (K rounded to next multiple of tileK) is a - // multiple of 2*TileK when UnrollLoop2x is enabled. This is to avoid deadlock when mma runs - // even-numbered loop while the other warps run odd-numbered loop. + // Check that the padded K and clamped padded K (K rounded to next multiple + // of tileK) is a multiple of 2*TileK when UnrollLoop2x is enabled. This is + // to avoid deadlock when mma runs even-numbered loop while the other warps + // run odd-numbered loop. // bool notSupported = (perCtaK % (options.mTileK * 2) != 0) || (clampedAndPaddedPerCtaK % (options.mTileK * 2) != 0); @@ -1173,14 +1528,16 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } if (options.mNumSlicesForSplitK > 1) { - TLLM_CHECK_ERROR( - perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, - "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); + TLLM_CHECK_ERROR(perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, + "K must be greater than perCtaK * (numSlicesForSplitK - " + "1) to ensure each CTA has work"); } if (!isBlackwell && options.mTileScheduler == TileScheduler::Persistent) { - // TODO(anchengc): will be supported in upcoming MRs. - TLLM_LOG_WARNING("Persistent scheduling is not supported on Hopper. Using Static scheduling."); + TLLM_LOG_WARNING( + "Persistent scheduling is not supported on Hopper. Use " + "StaticPersistent or " + "PersistentSm90 instead. Fallback to Static scheduling."); if (updateOptions) { options.mTileScheduler = TileScheduler::Static; } else { @@ -1188,9 +1545,19 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } + if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8 && + isPersistentScheduler(options.mTileScheduler)) { + if (updateOptions) { + options.mUseCustomMmaSchedule = true; + } else { + TLLM_CHECK_ERROR(false, "Persistent scheduler and !UseCustomMmaSchedule is not supported."); + } + } + if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit) { TLLM_LOG_WARNING( - "Only one of early exit and delayed early exit should be enabled. Disabling " + "Only one of early exit and delayed early exit should be " + "enabled. Disabling " "delayed early exit"); if (updateOptions) { options.mEnablesDelayedEarlyExit = false; @@ -1199,9 +1566,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } - // This check prevents the triggering of the secondary (PREEXIT) from executing before the wait - // for primary (ACQBULK). This could lead to the following confusing situation, which we want to - // avoid: + // This check prevents the triggering of the secondary (PREEXIT) from + // executing before the wait for primary (ACQBULK). This could lead to the + // following confusing situation, which we want to avoid: // // Kernel 3 is written with the assumption that it can read the output of // kernel 1 *without* ACQBULK and the output of kernel 2 *with* ACQBULK. @@ -1214,15 +1581,16 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // Kernel 2: -------PREEXIT----ACQBULK---FLUSH // Kernel 3: Warp 0: ---- (!) Output of 1,2 is not yet visible // ----------------------- - // Warp 1: ---- (!) We normally assume that 1 is visible is not yet - // visible- Warp 2: -------------------ACQBULK-- Kernel 1,2 output visible + // Warp 1: ---- (!) We normally assume that 1 is visible + // is not yet visible- Warp 2: -------------------ACQBULK-- Kernel + // 1,2 output visible // ---------- - TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), - "A: If a task triggers a secondary kernel, it must also wait for primary kernel."); - TLLM_CHECK_ERROR( - (options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), - "B: If a task triggers a secondary kernel, it must also wait for primary kernel."); + TLLM_CHECK_ERROR((options.mGridWaitForPrimaryA || !options.mGridTriggerSecondaryA), + "A: If a task triggers a secondary kernel, it must also wait for primary " + "kernel."); + TLLM_CHECK_ERROR((options.mGridWaitForPrimaryB || !options.mGridTriggerSecondaryB), + "B: If a task triggers a secondary kernel, it must also wait for primary " + "kernel."); if (options.mUsePerTokenSfA || options.mUsePerTokenSfB) { // Checks applicable to both MetaFP8 and RoutingScalesOnInput @@ -1239,28 +1607,30 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // RoutingScalesOnInput case TLLM_CHECK_ERROR((options.mUsePerTokenSfA && !options.mTransposeMmaOutput) || (options.mUsePerTokenSfB && options.mTransposeMmaOutput), - "In RoutingScalesOnInput mode, perToken scales must be used on activations"); + "In RoutingScalesOnInput mode, perToken scales must be used on " + "activations"); } } - // The generation should support non K-major layouts for both A and B; however, it is unclear if - // there is a use-case + // The generation should support non K-major layouts for both A and B; + // however, it is unclear if there is a use-case TLLM_CHECK_ERROR( (options.mLayoutA == MatrixLayout::MajorK) || (options.mLayoutB == MatrixLayout::MajorK), "At least one matrix must be in k-major layout"); - // Some features are currently only support when both matrices are in K-major format - if (options.mLayoutB != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { + // Some features are currently only support when both matrices are in K-major + // format + if (options.mLayoutA != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { TLLM_CHECK_ERROR(isBlackwell, "Non K-major layouts are only supported on Blackwell"); TLLM_CHECK_ERROR(options.mSplitK == SplitK::None, "Non K-major layouts do not support split K"); } if (options.mLayoutA == MatrixLayout::MajorMn) { TLLM_CHECK_ERROR(tg::dtypeGetNumBits(options.mDtypeA) >= 8, - "Subbyte types only support K major layout"); + "Subbyte types do not support m-major layout"); } if (options.mLayoutB == MatrixLayout::MajorMn) { TLLM_CHECK_ERROR(tg::dtypeGetNumBits(options.mDtypeB) >= 8, - "Subbyte types only support K major layout"); + "Subbyte types do not support n-major layout"); } if ((options.mLayoutA == MatrixLayout::BlockMajorK) || @@ -1268,8 +1638,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in bool const isBlockA = options.mLayoutA == MatrixLayout::BlockMajorK; // Block K size must be 128B. - // TODO Leaving this as an option for now in case we want to expertiment with other block sizes - // As the user is not expected to set this, do not fail if updateOptions is false + // TODO Leaving this as an option for now in case we want to expertiment + // with other block sizes As the user is not expected to set this, do not + // fail if updateOptions is false int32_t const elemSizeInBits = (isBlockA) ? tg::dtypeGetNumBits(options.mDtypeA) : tg::dtypeGetNumBits(options.mDtypeB); int32_t const elemsIn128B = 128 * 8 /* Bits in byte */ / elemSizeInBits; @@ -1283,13 +1654,13 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (options.mBlockK > options.mTileK) { - TLLM_CHECK_ERROR( - options.mBlockK % options.mTileK == 0, - "If block size is greater than tile size, block size must be a multiple of tile size"); + TLLM_CHECK_ERROR(options.mBlockK % options.mTileK == 0, + "If block size is greater than tile size, block size " + "must be a multiple of tile size"); } else if (options.mBlockK < options.mTileK) { - TLLM_CHECK_ERROR( - options.mTileK % options.mBlockK == 0, - "If tile size is greater than block size, tile size must be a multiple of block size"); + TLLM_CHECK_ERROR(options.mTileK % options.mBlockK == 0, + "If tile size is greater than block size, tile size " + "must be a multiple of block size"); } } @@ -1300,15 +1671,47 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Bias is not supported for Meta Fp8"); } + if (options.mUseMaxTmemOverlap) { + TLLM_CHECK_ERROR(options.mUseTmaStore, "mUseMaxTmemOverlap only works with TMA store"); + TLLM_CHECK_ERROR(options.mFuseUtccpWithUtcmma, + "mUseMaxTmemOverlap only works with mFuseUtccpWithUtcmma"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "mUseMaxTmemOverlap does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "mUseMaxTmemOverlap does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "mUseMaxTmemOverlap does not work with mUseDeepSeekFp8"); + TLLM_CHECK_ERROR(!options.mUseUnrollLoop2xForMma, + "mUseMaxTmemOverlap does not work with mUseUnrollLoop2xForMma"); + } + + if (options.mNumEpilogueWarps > 4) { + TLLM_CHECK_ERROR(options.mUseTmaStore, + "Using more than 4 warps for epilogue only works with TMA store"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "Using more than 4 warps for epilogue does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "Using more than 4 warps for epilogue does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "Using more than 4 warps for epilogue does not work with " + "mUseDeepSeekFp8"); + + auto const numEpilogueWrpGrps = options.mNumEpilogueWarps / 4; + TLLM_CHECK_ERROR(options.mTileN % (options.mEpilogueTileN * numEpilogueWrpGrps) == 0, + "TileN must be a multiple of EpilogueTileN * numEpilogueWrpGrps"); + } + if (updateOptions) { // Init kernel traits. options.mKernelTraits = KernelTraits( options.mDtypeA, options.mDtypeB, options.mDtypeC, options.mDtypeAcc, options.mDtypeMmaA, - options.mDtypeMmaB, options.mMmaKind, options.mMmaK, options.mTileM, options.mTileN, - options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages, - options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, - options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, - options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, + options.mDtypeMmaB, options.mMmaKind, options.mSparsityA, options.mMmaK, options.mTileM, + options.mTileN, options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, + options.mSfBlockSizeA, options.mSfBlockSizeB, options.mNumStages, options.mNumStagesMma, + options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, + options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, + options.mFuseUtccpWithUtcmma, options.mUseMaxTmemOverlap, options.mNumEpilogueWarps, + isPersistentScheduler(options.mTileScheduler), options.mUseDeepSeekFp8, options.mUsePerTokenSfA, options.mUsePerTokenSfB, /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); } @@ -1318,6 +1721,70 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool getDoesScaleC(tg::Dtype dtypeC) { + // Need to scale/quantize the output C matrix when the output type is Fp8 or + // NvFp4. + return dtypeC == tg::Dtype::E4m3 || dtypeC == tg::Dtype::E2m1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getDoesScaleAb(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekFp8) { + // Need to scale/dequantize the input A/B matrices when the input type is Fp8 + // or NvFp4 and DeepSeekFp8 is not used. + bool const doesScaleAb{ + dtypeA == tg::Dtype::E2m1 || dtypeB == tg::Dtype::E2m1 || + ((dtypeA == tg::Dtype::E4m3 || dtypeB == tg::Dtype::E4m3) && !useDeepSeekFp8)}; + return doesScaleAb; +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getDoesScaleAct(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekFp8, + EltwiseActType eltwiseActType) { + // Only non-linear activations require separate scaleAct. + bool const isLinearAct = eltwiseActType == EltwiseActType::None; + return !isLinearAct && getDoesScaleAb(dtypeA, dtypeB, useDeepSeekFp8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getKernelDoesScaleC(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, + bool useDeepSeekFp8) { + // In the Gemm/BatchedGemm kernels, dequantScaleAb and quantScaleC are + // combined into one single scaling factor (called scaleC). As a result, we + // combine the logic for getDoesScaleAb and getDoesScaleC. + return getDoesScaleC(dtypeC) || getDoesScaleAb(dtypeA, dtypeB, useDeepSeekFp8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline CUresult loadCubinData(CUmodule* module, Config const& config) { + // Trtllm links the cubin into the executable while Flashinfer loads the cubin + // from storage. +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#ifdef TLLM_GEN_GEMM_CUBIN_PATH + static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH); + const std::string sha256 = config.mHash ? config.mHash : ""; + std::string fileName = config.mFunctionName; + if (!fileName.empty()) { + fileName[0] = static_cast(std::toupper(static_cast(fileName[0]))); + } + const std::string& data = flashinfer::trtllm_cubin_loader::getCubin( + tllm_gen_gemm_cubin_path + "/" + fileName + ".cubin", sha256); + CUresult result = cuModuleLoadData(module, data.c_str()); +#else + static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling"); +#endif // TLLM_GEN_GEMM_CUBIN_PATH +#else + CUresult result = cuModuleLoadData(module, config.mData); +#endif // TLLM_GEN_EXPORT_FLASHINFER + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm #ifdef TLLM_GEN_EXPORT_INTERFACE diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h index 64d065cd21..f30e673ea5 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,9 +20,10 @@ #include "TmaDescriptor.h" #include "trtllm/gen/CommonUtils.h" #include "trtllm/gen/SfLayoutDecl.h" +#include "trtllm/gen/SparsityDecl.h" -// NOTE: keep this code dependency free. It has to be included by the device code and has to be -// compilable with NVRTC. +// NOTE: keep this code dependency free. It has to be included by the device +// code and has to be compilable with NVRTC. #include "KernelParamsDecl.h" namespace gemm { @@ -40,25 +41,44 @@ using MatrixType = KernelParams::MatrixType; // Create the TMA shape/stride for A/B. template static auto makeTmaShapeStrideAb(GemmOptions const& options, MatrixType matrixType) { - // The outer dimension. - auto numTokens = (matrixType == MatrixType::MatrixA) ? options.mM : options.mN; + // For sparse A, the k dimension in the TMA shapes is halved. + int const isSparse = matrixType == MatrixType::MatrixA && tg::isSparse(options.mSparsityA); + + // The padded/valid dimensions. + int const sizeM = options.mM; + int const sizeN = options.mN; + int const sizeK = options.mK >> isSparse; + int const tileM = options.mTileM; + int const tileN = options.mTileN; + int const tileK = options.mTileK >> isSparse; + int const validM = options.mValidM; + int const validN = options.mValidN; + int const validK = options.mValidK >> isSparse; + + // The outer dimension. Uses padded dimensions for strides and valid + // dimensions for shapes. + auto numTokens = (matrixType == MatrixType::MatrixA) ? sizeM : sizeN; + auto numTokensValid = (matrixType == MatrixType::MatrixA) ? validM : validN; // The outer dimension tile size. - auto tileMn = (matrixType == MatrixType::MatrixA) ? options.mTileM : options.mTileN; + auto tileMn = (matrixType == MatrixType::MatrixA) ? tileM : tileN; // The inner dimension. - auto hiddenSize = options.mK; + auto hiddenSize = sizeK; + auto hiddenSizeValid = validK; // The cute tensor shape for A/B: (numTokens, hiddenSize). // Note that TMA descriptor expects the first dimension's stride to be - // 1, so swap the first two dimension so that the hiddenSize dimension comes first. - auto shape = - std::vector{static_cast(hiddenSize), static_cast(numTokens)}; + // 1, so swap the first two dimension so that the hiddenSize dimension comes + // first. Use valid dimensions for shape, padded dimension for stride. + auto shape = std::vector{static_cast(hiddenSizeValid), + static_cast(numTokensValid)}; // Assemble the stride (strideTokens, 1). // Swap the first two dimension as mentioned before. auto stride = std::vector{1, static_cast(hiddenSize)}; // Assemble the box shape - std::vector tileShape = {options.mTileK, tileMn}; - // When using 2CTA MMA, we only need to load half of the tile in each CTA for B. + std::vector tileShape = {tileK, tileMn}; + // When using 2CTA MMA, we only need to load half of the tile in each CTA for + // B. if (matrixType == MatrixType::MatrixB && options.mClusterDimX == 2) { tileShape[1] /= 2; } @@ -71,15 +91,15 @@ static auto makeTmaShapeStrideAb(GemmOptions const& options, MatrixType matrixTy std::swap(tileShape[0], tileShape[1]); } else if (layout == MatrixLayout::BlockMajorK) { // FIXME: fix for the 2CTA MMA case - // Set shapes based on blocking layout + // Set shapes based on blocking layout. shape = {static_cast(options.mBlockK), static_cast(numTokens), - static_cast(options.mK / options.mBlockK)}; + static_cast(sizeK / options.mBlockK)}; stride = {1, static_cast(options.mBlockK), static_cast(numTokens * options.mBlockK)}; // If blockK > tileK, then the inner most box size will be based on the tile - int32_t const tileBlockK = std::min(options.mBlockK, options.mTileK); - tileShape = {tileBlockK, tileMn, options.mTileK / tileBlockK}; + int32_t const tileBlockK = std::min(options.mBlockK, tileK); + tileShape = {tileBlockK, tileMn, tileK / tileBlockK}; } return std::make_tuple(shape, stride, tileShape); @@ -93,7 +113,8 @@ static auto makeTmaShapeStrideC(GemmOptions const& options) { // The hidden dimension. auto hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN; // Note that TMA descriptor expects the first dimension's stride to be - // 1, so swap the first two dimension so that the hiddenSize dimension comes first. + // 1, so swap the first two dimension so that the hiddenSize dimension comes + // first. auto shape = std::vector{static_cast(hiddenSize), static_cast(numTokens)}; @@ -107,7 +128,7 @@ static auto makeTmaShapeStrideC(GemmOptions const& options) { // Create the TMA shape/stride for A/B block scaling factors. template static auto makeTmaShapeStrideSfAb(GemmOptions const& options, MatrixType matrixType, - tg::SfLayout layout) { + tg::SfLayout layout, int32_t numEltsPerSf) { // The outer dimension. auto numTokens = matrixType == MatrixType::MatrixA ? options.mM : options.mN; // The inner dimension. @@ -116,27 +137,9 @@ static auto makeTmaShapeStrideSfAb(GemmOptions const& options, MatrixType matrix auto numTokensPerTile = matrixType == MatrixType::MatrixA ? options.mTileM : options.mTileN; // The inner tile dimension. auto hiddenSizePerTile = options.mTileK; - // The dtype of the matrix. - tg::Dtype matrixDtype = matrixType == MatrixType::MatrixA ? options.mDtypeA : options.mDtypeB; - // Number of elements per scaling factor. - int32_t const numEltsPerSf = - (matrixType == MatrixType::MatrixA && options.mSfBlockSizeA.has_value()) - ? options.mSfBlockSizeA.value() - : (tg::dtypeIsBlockFmt(matrixDtype) ? tg::dtypeNumEltsPerSf(matrixDtype) : 32); switch (layout) { case tg::SfLayout::R128c4: { - // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. - // The 512B block maps to a 32x16B (32x128b) block in TMEM. - // See https://nvbugspro.nvidia.com/bug/4165523 - // - // Additionally, we have to meet constraints of TMA that the box dimensions are less - // than 256 and boxDim[0] is a multiple of 16B. - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The aforementioned format is: [⌈outer / 128⌉, inner / (4 * numEltsPerSf), 512] - // The shape we use for TMA is: [⌈outer / 128⌉, inner / (4 * numEltsPerSf), 2, 256] - auto shape = std::vector{ 256, 2, static_cast(tg::ceilDiv(hiddenSize, numEltsPerSf * 4)), static_cast(tg::ceilDiv(numTokens, 128))}; @@ -157,20 +160,22 @@ static auto makeTmaShapeStrideSfAb(GemmOptions const& options, MatrixType matrix case tg::SfLayout::R8c4: { // The scaling factor tensor packs 8x4 tiles into contiguous 32B blocks. // - // As the inner dimension (k) is often a multiple of the tile size, we can reshape to use - // fewer read requests, if the tile dimensions allow. It does not reduce the number of - // instructions. + // As the inner dimension (k) is often a multiple of the tile size, we can + // reshape to use fewer read requests, if the tile dimensions allow. It does + // not reduce the number of instructions. // // I.e., let's define r = min(⌈hiddenSizePerTile / (numEltsPerSf * 4)⌉, 8) // // The "logical" tensor is: [outer, inner / numEltsPerSf] // The 8x4 SF layout is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf), 32] - // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r * 32] + // The TMA tensor shape is: [⌈outer / 8⌉, inner / (4 * numEltsPerSf * r), r + // * 32] // - // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to multiples of - // NumRepeats * numEltsPerSf * 4. + // The caveat of NumRepeats>1 is we must pad the hidden dimension of SF to + // multiples of NumRepeats * numEltsPerSf * 4. - // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - 1) == 0b0000. + // Detect if the supplied factor is power of 2. E.g., 0b0100 and (0b0100 - + // 1) == 0b0000. int const r = options.mSfReshapeFactor; if (r > 0 && (r & (r - 1)) != 0) { throw std::runtime_error("mSfReshapeFactor must be positive and a power of 2. Found " + @@ -212,17 +217,40 @@ static auto makeTmaShapeStrideSfAb(GemmOptions const& options, MatrixType matrix return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } +// Create the TMA shape/stride for the sparsity information of A. +template +static auto makeTmaShapeStrideSparsityInfoA(GemmOptions const& options) { + // Tensor dimensions. + auto outerDim = options.mM; + auto innerDim = tg::getNumBytesSparsityInfo(options.mSparsityA, options.mK); + // Tile dimensions. + auto tileOuterDim = options.mTileM; + auto tileInnerDim = tg::getNumBytesSparsityInfo(options.mSparsityA, options.mTileK); + + auto shape = + std::vector{static_cast(innerDim), static_cast(outerDim)}; + + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < shape.size(); i++) { + stride[i] = shape[i - 1] * stride[i - 1]; + } + + auto tileShapes = + std::vector{static_cast(tileInnerDim), static_cast(tileOuterDim)}; + + return std::make_tuple(shape, stride, tileShapes); +} + // Setup the kernel parameters. template -static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptrA, - void const* ptrSfA, void const* ptrPerTokenSfA, - void const* ptrB, void const* ptrSfB, - void const* ptrPerTokenSfB, void const* ptrBias, void* ptrC, - void* ptrSfC, void* multimemC, float* ptrScaleC, - void* ptrPartialSumsForSplitK, void* ptrTileBars, - void* multimemTileBars, void* ptrCompletionBars, - void* multimemCompletionBars, void* ptrSplitKCompletionBars, - int32_t* ptrNumNonExitingCtas, int rank, int tpGrpSize) { +static KernelParams setKernelParams( + GemmOptions_ const& options, void const* ptrA, void const* ptrSfA, void const* ptrPerTokenSfA, + void const* ptrB, void const* ptrSfB, void const* ptrPerTokenSfB, void const* ptrSparsityInfoA, + void const* ptrBias, void* ptrC, void* ptrSfC, void* multimemC, float* ptrScaleC, + float* ptrScaleAct, void* ptrPartialSumsForSplitK, void* ptrTileBars, void* multimemTileBars, + void* ptrCompletionBars, void* multimemCompletionBars, void* ptrSplitKCompletionBars, + int32_t* ptrNumNonExitingCtas, int rank, int tpGrpSize) { // Is one-shot all-reduce? bool const oneShotAr{options.mAllReduceAlgo == AllReduceAlgo::OneShot}; // Is two-shot all-reduce? @@ -233,27 +261,42 @@ static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptr // Create the return struct. KernelParams params; + // Is A using sparsity? + int32_t const isSparseA = tg::isSparse(options.mSparsityA); + // Do we pad A or B? + bool doPadA = tg::dtypeNeedsPadding(options.mDtypeA, options.mMmaKind, options.mMmaK, isSparseA); + bool doPadB = tg::dtypeNeedsPadding(options.mDtypeB, options.mMmaKind, options.mMmaK, isSparseA); + // Shape/stride for gmem tensor A. auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAb(options, MatrixType::MatrixA); // Build tma descriptor for A. - params.tmaA = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, - tileShapeA, const_cast(ptrA)); + params.tmaA = gemm::buildNdTmaDescriptor(options.mDtypeA, shapeA, strideA, tileShapeA, + const_cast(ptrA), doPadA, + /*doSwizzle=*/true); // Shape/stride for gmem tensor B. auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAb(options, MatrixType::MatrixB); // Build tma descriptor for B. - params.tmaB = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, - tileShapeB, const_cast(ptrB), - /* swizzle */ !options.mSliceK); + params.tmaB = gemm::buildNdTmaDescriptor(options.mDtypeB, shapeB, strideB, tileShapeB, + const_cast(ptrB), doPadB, + /*doSwizzle=*/!options.mSliceK); if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE2m1 || - options.mDtypeA == tg::Dtype::MxE4m3) { - tg::Dtype const dTypeSfA = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + options.mDtypeA == tg::Dtype::MxE4m3 || options.mDtypeA == tg::Dtype::MxInt4) { + tg::Dtype dTypeSfA{}; + if (options.mDtypeA == tg::Dtype::E2m1) { + dTypeSfA = tg::Dtype::E4m3; + } else if (options.mDtypeA == tg::Dtype::MxInt4) { + dTypeSfA = tg::Dtype::Bfloat16; + } else { + dTypeSfA = tg::Dtype::UE8m0; + } + + int32_t const numEltsPerSfA = options.mSfBlockSizeA; // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = - makeTmaShapeStrideSfAb(options, MatrixType::MatrixA, tg::SfLayout::R128c4); + makeTmaShapeStrideSfAb(options, MatrixType::MatrixA, tg::SfLayout::R128c4, numEltsPerSfA); params.tmaSfA = gemm::buildSfTmaDescriptor(dTypeSfA, shapeSfA, strideSfA, tileShapesSfA, const_cast(ptrSfA)); } @@ -263,13 +306,26 @@ static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptr tg::Dtype const dTypeSfB = (options.mDtypeB == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + int32_t const numEltsPerSfB = options.mSfBlockSizeB; + // Build TMA descriptor for gmem B block scaling factors. auto [shapeSfB, strideSfB, tileShapesSfB] = - makeTmaShapeStrideSfAb(options, MatrixType::MatrixB, options.mSfLayoutB); + makeTmaShapeStrideSfAb(options, MatrixType::MatrixB, options.mSfLayoutB, numEltsPerSfB); params.tmaSfB = gemm::buildSfTmaDescriptor(dTypeSfB, shapeSfB, strideSfB, tileShapesSfB, const_cast(ptrSfB)); } + if (isSparseA) { + // Build TMA descriptor for gmem A sparsity. + auto [shapeSparsityInfoA, strideSparsityInfoA, tileShapesSparsityInfoA] = + makeTmaShapeStrideSparsityInfoA(options); + params.tmaSparsityInfoA = + gemm::buildNdTmaDescriptor(tg::Dtype::UInt8, shapeSparsityInfoA, strideSparsityInfoA, + tileShapesSparsityInfoA, const_cast(ptrSparsityInfoA), + /*doPad=*/false, + /*doSwizzle=*/true); + } + if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. auto [shapeC, strideC] = makeTmaShapeStrideC(options); @@ -280,9 +336,9 @@ static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptr auto outputTileN = options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN; - // One-shot performs TMA reduction on multicast mapping of the output buffer directly. - // Two-shot performs TMA store on unicast mapping of the output buffer. The reduction happens - // in the next phase. + // One-shot performs TMA reduction on multicast mapping of the output buffer + // directly. Two-shot performs TMA store on unicast mapping of the output + // buffer. The reduction happens in the next phase. void* ptrTmaC{oneShotAr && multiDevice ? multimemC : ptrC}; auto dtypeC{options.mDtypeC}; // Regardless of output dtype, two-shot all-reduce store partial @@ -292,12 +348,14 @@ static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptr } // Build tma descriptor for C. - params.tmaC = gemm::buildNdTmaDescriptor(dtypeC, tg::MmaKind::Auto, shapeC, strideC, + params.tmaC = gemm::buildNdTmaDescriptor(dtypeC, shapeC, strideC, std::vector{outputTileN, outputTileM}, - const_cast(ptrTmaC)); + const_cast(ptrTmaC), + /*doPad=*/false); } - // Set the dequantization factors for A and B when DeepSeek FP8 recipe is used. + // Set the dequantization factors for A and B when DeepSeek FP8 recipe is + // used. params.ptrSfA = ptrSfA; params.ptrSfB = ptrSfB; @@ -308,12 +366,17 @@ static KernelParams setKernelParams(GemmOptions_ const& options, void const* ptr // Set the bias. params.ptrBias = ptrBias; - // Also set ptrC (it may be used by the NCCL reduction code in "layers/Llama"). + // Also set ptrC (it may be used by the NCCL reduction code in + // "layers/Llama"). params.ptrC = ptrC; + + // The scaling factors for the output tensor and the pre-activation scale. params.ptrScaleC = ptrScaleC; + params.ptrScaleAct = ptrScaleAct; // The block scaling factors of C for MxFp{4,8} and NvFp4 formats. - // (not to be confused with the tensor-level scaling factor stored in ptrScaleC) + // (not to be confused with the tensor-level scaling factor stored in + // ptrScaleC) params.ptrSfC = ptrSfC; params.m = options.mM; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h index 413a4ccda3..0865b119d2 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,8 +16,10 @@ */ #pragma once -// NOTE: keep this code dependency free. It has to be included by the device code and has to be -// compilable with NVRTC. +#include + +// NOTE: keep this code dependency free. It has to be included by the device +// code and has to be compilable with NVRTC. namespace gemm { @@ -52,9 +54,9 @@ struct KernelParams { // If layoutA is MatrixLayout::BlockMajorK // Logical shape is [K / blockK, M, blockK]. // Logical strides are [M * blockK, blockK, 1]. - // Tile box shape is [tileK / min(blockK, tileK), tileM, min(blockK, tileK)]. - // Tile box strides are [tileM * min(blockK, tileK), min(blockK, tileK), 1]. - // Dtype is set from options.mDtypeA, and blockK is 128B. + // Tile box shape is [tileK / min(blockK, tileK), tileM, min(blockK, + // tileK)]. Tile box strides are [tileM * min(blockK, tileK), min(blockK, + // tileK), 1]. Dtype is set from options.mDtypeA, and blockK is 128B. CUtensorMap tmaA; // TMA descriptor for B. @@ -78,9 +80,9 @@ struct KernelParams { // If layoutB is MatrixLayout::BlockMajorK // Logical shape is [K / blockK, N, blockK]. // Logical strides are [N * blockK, blockK, 1]. - // Tile box shape is [tileK / min(blockK, tileK), tileN, min(blockK, tileK)]. - // Tile box strides are [tileN * min(blockK, tileK), min(blockK, tileK), 1]. - // Dtype is set from options.mDtypeB, and blockK is 128B. + // Tile box shape is [tileK / min(blockK, tileK), tileN, min(blockK, + // tileK)]. Tile box strides are [tileN * min(blockK, tileK), min(blockK, + // tileK), 1]. Dtype is set from options.mDtypeB, and blockK is 128B. CUtensorMap tmaB; // TMA descriptor for C, (when useTmaStore is true) @@ -102,27 +104,26 @@ struct KernelParams { // Dtype is set from options.mDtypeC. CUtensorMap tmaC; - // TMA descriptor for the block scaling factors for A, for MxFp{4,8} and NvFp4 formats. - // Must be setup using gemm::buildSfTmaDescriptor with shapes and strides from - // makeTmaShapeStrideSfAb. - // The layout of scaling factors for A is always R128c4 + // TMA descriptor for the block scaling factors for A, for MxFp{4,8} and NvFp4 + // formats. Must be setup using gemm::buildSfTmaDescriptor with shapes and + // strides from makeTmaShapeStrideSfAb. The layout of scaling factors for A is + // always R128c4 // - // Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx formats. - // K must be a multiple of 4P. - // The "logical" shape is: [M, K / P]. + // Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx + // formats. K must be a multiple of 4P. The "logical" shape is: [M, K / P]. // The R128c4 layout is: [⌈M / 128⌉, K / P / 4, 512]. // The shape we use for TMA is: [⌈M / 128⌉, K / P / 4, 2, 256]. // // Dtype is Dtype::E4m3 for NvFp4, Dtype::UE8m0 for Mx formats. CUtensorMap tmaSfA; - // TMA descriptor for the block scaling factors for B, for MxFp{4,8} and NvFp4 formats. - // Must be setup using gemm::buildSfTmaDescriptor with shapes and strides from - // makeTmaShapeStrideSfAb. - // The layout of scaling factors for B is controlled by options.mSfLayoutB. + // TMA descriptor for the block scaling factors for B, for MxFp{4,8}, MxInt4 + // and NvFp4 formats. Must be setup using gemm::buildSfTmaDescriptor with + // shapes and strides from makeTmaShapeStrideSfAb. The layout of scaling + // factors for B is controlled by options.mSfLayoutB. // - // Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx formats. - // The "logical" shape is: [N, K / P] + // Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx + // formats. The "logical" shape is: [N, K / P] // // If the layout is R128c4, // K must be a multiple of 4P. @@ -135,9 +136,28 @@ struct KernelParams { // The shape we use for TMA is: [⌈N / 8⌉, K / P / 4 / r, r * 32] // where r = min(tileK / P / 4, 8) // - // Dtype is Dtype::E4m3 for NvFp4, Dtype::UE8m0 for Mx formats. + // Dtype is Dtype::E4m3 for NvFp4, Dtype::UE8m0 for MxFp{4,8} formats, + // Dtype::Bfloat16 for MxInt4. CUtensorMap tmaSfB; + // TMA descriptor for the sparsity information of A, if structured sparsity is + // used. Must be setup using gemm::buildNdTmaDescriptor with shapes and + // strides from makeTmaShapeStrideSparsityInfoA. + // + // When sparsityA is Any_2_4: + // 2 elements are non-zero in any chunk of 4 elements. + // A 4-bit index indicates the position of the non-zero elements. + // The shape in UInt8 is: [M, K / 8] + // + // When sparsityA is Pairwise_4_8: + // 4 elements are non-zero in any chunk of 8 elements. + // The zero and non-zero elements are grouped in pairs. + // A 4-bit index indicates the position of the non-zero pairs. + // The shape in UInt8 is: [M, K / 16] + // + // Dtype is Dtype::UInt8. + CUtensorMap tmaSparsityInfoA; + // The output matrix C. The data type is controlled by options.mDtypeC. // // When transposeMmaOutput is true, the shape is [N, M]. @@ -175,7 +195,8 @@ struct KernelParams { // The bias is applied before applying the global scaling factor. I.e. // C' = (A * B + bias') * scaleC // scaleC = dequantA * dequantB * quantC - // Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias. + // Thus, the bias' = bias / (dequantA * dequantB), where the bias is the + // original bias. // // if BiasType is N, the shape is [N]. // The bias is broadcasted along the M dimension. @@ -189,9 +210,10 @@ struct KernelParams { // The per-token scaling factors from scale A. // // This is used for either: - // * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is Dtype::Float32 - // * When the routing scales are applied to the input activations (only when output is not - // transposed). The dtype is Dtype::Bfloat16 + // * Per-token scaling factor quantization schemes, such as MetaFP8. The + // dtype is Dtype::Float32 + // * When the routing scales are applied to the input activations (only when + // output is not transposed). The dtype is Dtype::Bfloat16 // // The shape is [M] void const* ptrPerTokenSfA; @@ -199,15 +221,16 @@ struct KernelParams { // The per-token scaling factors from scale B. // // This is used for either: - // * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is Dtype::Float32 - // * When the routing scales are applied to the input activations (only when output is - // transposed). The dtype is Dtype::Bfloat16 + // * Per-token scaling factor quantization schemes, such as MetaFP8. The + // dtype is Dtype::Float32 + // * When the routing scales are applied to the input activations (only when + // output is transposed). The dtype is Dtype::Bfloat16 // // The shape is [N] void const* ptrPerTokenSfB; - // The scaling factors calculated when quantizing C, for MxFp{4,8} and NvFp4 formats, also - // used for the DeepSeek FP8 recipe. + // The scaling factors calculated when quantizing C, for MxFp{4,8} and NvFp4 + // formats, also used for the DeepSeek FP8 recipe. // // For DeepSeek FP8 recipe: // If transposeMmaOutput is false, shape is [N / 128, M]. @@ -220,11 +243,20 @@ struct KernelParams { // The layout is controlled by options.mSfLayoutC (either R128c4 or R8c4). void* ptrSfC; - // The output tensor scaling factor for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 quantization. - // TensorRT-LLM API requires a scaling factor on the device. - // Shape is [1]. + // The output tensor scaling factor for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 + // quantization. TensorRT-LLM API requires a scaling factor on the device. + // Without non-linear activation, it is typically defined as quantScaleC * + // dequantScaleAb. With non-linear activation, it is typically defined as + // quantScaleC only. Shape is [1]. float const* ptrScaleC; + // The pre-activation scaling factor for MxFp{4,8}, Fp8, NvFp4 and DeepSeek + // FP8 quantization. Only used with non-linear activation functions (e.g., + // GELU, Relu2). Without non-linear activation, this is ignored by the kernel. + // With non-linear activation, it is typically defined as dequantScaleAb. + // Shape is [1]. + float const* ptrScaleAct; + // The M dimension. // It is the total number of tokens if A is the activation matrix. // It is the total number of output channels if A is the weight matrix. @@ -246,20 +278,20 @@ struct KernelParams { int rank; // The number of peer devices in tensor-parallel group. int tpGrpSize; - // Pointer for output with multicast mapping. It is used by the "reduce" op (LDGMC.ADD) of the - // two-shot reduce-scatter phase. - // The shape is [M, N] and the dtype is float. + // Pointer for output with multicast mapping. It is used by the "reduce" op + // (LDGMC.ADD) of the two-shot reduce-scatter phase. The shape is [M, N] and + // the dtype is float. void* multimemC; // The barriers in global memory. // - // The kernel arrives at (with release ordering) the multicast mapping of the barrier to broadcast - // amongst peer devices. It then waits (with acquire ordering) for the unicast mapping of the - // barrier. + // The kernel arrives at (with release ordering) the multicast mapping of the + // barrier to broadcast amongst peer devices. It then waits (with acquire + // ordering) for the unicast mapping of the barrier. // - // Flags in global memory that sync on "entrance" of reduce-scatter phase in two-shot all-reduce. - // The shape is [numTilesM * numTilesN] and the dtype is uint32_t. - // The pointer to the unicast memory created with IpcNvlsHandle. + // Flags in global memory that sync on "entrance" of reduce-scatter phase in + // two-shot all-reduce. The shape is [numTilesM * numTilesN] and the dtype is + // uint32_t. The pointer to the unicast memory created with IpcNvlsHandle. // Must be set to 0 before the kernel launch. void* ptrTileBars; // The shape is [numTilesM * numTilesN] and the dtype is uint32_t. @@ -282,22 +314,24 @@ struct KernelParams { ////////////////////////////////////////////////////////////////////////////////////////////////// // The barriers in global memory for Split-k reduction with exchange in GMEM. - // Each CTAs arrives at the barrier and blockIdx.z == gridDim.Z - 1 waits for the barrier to flip - // to perform a reduction. - // The shape is [numTilesM * numTilesN] and the dtype is uint32_t. - // For DeepSeek FP8 recipe, the shape is [numTilesM * numTilesN * 2]. - // The memory must be set to 0 before the kernel launch. + // Each CTAs arrives at the barrier and blockIdx.z == gridDim.Z - 1 waits for + // the barrier to flip to perform a reduction. The shape is [numTilesM * + // numTilesN] and the dtype is uint32_t. For DeepSeek FP8 recipe, the shape is + // [numTilesM * numTilesN * 2]. The memory must be set to 0 before the kernel + // launch. void* ptrSplitKCompletionBars; // Pointer to the memory holding the partial sums for split-K in GMEM. - // The shape is [numSlicesForSplitK, numSlicesForSliceK, numTilesM * tileM, numTilesN * tileN]. - // The dtype is dtypeAcc, i.e. float. + // The shape is [numSlicesForSplitK, numSlicesForSliceK, numTilesM * tileM, + // numTilesN * tileN]. The dtype is dtypeAcc, i.e. float. void* ptrPartialSumsForSplitK; - // In some cases, some CTAs need to exit early. E.g. when the grid is statically set, but the - // actual workload is decided at runtime. This device pointer maps to the number of non exiting - // CTAs in the X dim of the grid when transposeMmaOutput is false. And the Y dim, otherwise. - // The pointer points to a scalar and the dtype is int32_t. The pointed value must be >= 0. + // In some cases, some CTAs need to exit early. E.g. when the grid is + // statically set, but the actual workload is decided at runtime. This device + // pointer maps to the number of non exiting CTAs in the X dim of the grid + // when transposeMmaOutput is false. And the Y dim, otherwise. The pointer + // points to a scalar and the dtype is int32_t. The pointed value must be >= + // 0. int32_t* ptrNumNonExitingCtas; ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h index 4ca6af8a4c..549b03fdc3 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +17,14 @@ #pragma once #include +#include #include #include "Enums.h" #include "trtllm/gen/CommonUtils.h" #include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" +#include "trtllm/gen/SparsityDecl.h" namespace gemm { @@ -58,7 +60,8 @@ class MemAllocatorHelper { // E.g. possible in case of // mNumBytesAndAlignmentPerSmemChunk = {{1, 1}, {1, 1}, {1024, 1}} // mFirstChunkReuse = {false, false, true} - // The last chunk is larger than the first plus second, so total size is 1024. + // The last chunk is larger than the first plus second, so total size is + // 1024. totalSize = paddedSize; } else if (!mFirstChunkReuse[ii]) { totalSize += paddedSize; @@ -129,7 +132,8 @@ class MemAllocatorHelper { // NOTE: be careful and make sure that the memory dependency is clear and // chunks in the beginning of the SMEM can be overwritten. std::vector> mNumBytesAndAlignmentPerSmemChunk; - // Chunk reuse configuration. True at ith position means that ith chunk starts at smemOffset = 0. + // Chunk reuse configuration. True at ith position means that ith chunk starts + // at smemOffset = 0. std::vector mFirstChunkReuse; // Buffer names for inspection purposes. std::vector mSmemChunkNames; @@ -137,7 +141,7 @@ class MemAllocatorHelper { //////////////////////////////////////////////////////////////////////////////////////////////////// -int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind) { +inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK, bool isSparseA) { if (mmaKind == tg::MmaKind::Auto) { throw std::runtime_error("mmaKind != tg::MmaKind::Auto"); } @@ -157,14 +161,19 @@ class KernelTraits { // The constructor. KernelTraits(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeAcc, - tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, tg::MmaKind mmaKind, int32_t mmaK, - int32_t tileM, int32_t tileN, int32_t tileK, int32_t epilogueTileM, - int32_t epilogueTileN, int32_t numStages, int32_t numStagesMma, + tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, tg::MmaKind mmaKind, + tg::Sparsity sparsityA, int32_t mmaK, int32_t tileM, int32_t tileN, int32_t tileK, + int32_t epilogueTileM, int32_t epilogueTileN, int32_t numEltsPerSfA, + int32_t numEltsPerSfB, int32_t numStages, int32_t numStagesMma, int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, + bool fuseUtccpWithUtcmma, bool useMaxTmemOverlap, int32_t numEpilogueWarps, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) - : mMmaKind{mmaKind} { + : mMmaKind{mmaKind}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, + mUseMaxTmemOverlap{useMaxTmemOverlap}, + mNumEpilogueWarps{numEpilogueWarps} { // // SMEM // @@ -179,10 +188,12 @@ class KernelTraits { // [per-token SF ] (16B aligned) (if needed) // [bias ] (16B aligned) (if needed) // - // SMEM for smemA and smemB might be repurposed and used for gmemC0 and gmemC1: + // SMEM for smemA and smemB might be repurposed and used for gmemC0 and + // gmemC1: // // [..smemA..][..smemB..][..smemBShuffle..] - // [..gmemC0..][..gmemC1..][..rowMax..][..sliceK..][..per-token SF..][..bias..] + // [..gmemC0..][..gmemC1..][..rowMax..][..sliceK..][..per-token + // SF..][..bias..] // if (mMmaKind == tg::MmaKind::Auto) { @@ -194,11 +205,15 @@ class KernelTraits { // Buffer names for inspection purposes. std::vector smemChunkNames; + int const isSparseA = static_cast(tg::isSparse(sparsityA)); + // LoadA { // Number of bytes in load A shared memory. - auto const numSmemBytesLoadA = - numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; + // If A is sparse, we load only the non-zero elements. + auto const numSmemBytesLoadA = numStages * tileM * (tileK >> isSparseA) * + getNumSmemBitsPerElt(dtypeA, mMmaKind, mmaK, isSparseA) / + 8 /* bits */; // Number of bytes for load A alignment for TMA load. auto const numBytesAlignmentLoadA = 1024; // loadA is already at first chunk. No need to reuse it. @@ -214,7 +229,8 @@ class KernelTraits { { // Number of bytes in load B shared memory. auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * - getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + getNumSmemBitsPerElt(dtypeB, mMmaKind, mmaK, isSparseA) / + 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -228,14 +244,15 @@ class KernelTraits { // SmemBShuffle // FIXME: we should be able either: - // - Do modification in-place. For that we need to resolve pipeline dependency between - // smemB -> shuffleSmemB -> mma + // - Do modification in-place. For that we need to resolve pipeline + // dependency between smemB -> shuffleSmemB -> mma // - Do 4 TMA SW32 loads or several LDGSTS loads. { // Number of bytes in save shuffled B in shared memory. auto const numSmemBytesLoadB = numSlicesForSliceK > 1 - ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ + ? numStages * tileN * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind, mmaK, isSparseA) / 8 /* bits */ : 0; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; @@ -250,7 +267,8 @@ class KernelTraits { } // GmemC - // FIXME we might need to fix this for GemmGatedAct, it needs less SMEM to store gated output. + // FIXME we might need to fix this for GemmGatedAct, it needs less SMEM to + // store gated output. for (int resIdx = 0; resIdx < 2; ++resIdx) { // Type of the data in the SMEM for GmemC auto dtypeSmemC = dtypeC; @@ -271,6 +289,10 @@ class KernelTraits { extraGmemCMultiplier = 0; } + if (numEpilogueWarps) { + extraGmemCMultiplier *= numEpilogueWarps / 4; + } + // Number of bytes to store the output in smem. auto const numBytesSmemStoreC = usesSmemForGmemC ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * @@ -279,8 +301,10 @@ class KernelTraits { // Number of bytes for store C alignment for TMA store. auto const numBytesAlignmentStoreC = 1024; // gmemC reuses loadAb memory for split-K in DSMEM. - // Epilogue1 does not reuse and continues after the memory allocated Epilogue0 - // NOTE: we can always reuse loadAb SMEM as long as we don't have persistent scheduler. + // Epilogue1 does not reuse and continues after the memory allocated + // Epilogue0 NOTE: we can always reuse loadAb SMEM as long as we don't + // have persistent scheduler. + auto const reuseFirstChunksSmemStoreC = doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; @@ -291,6 +315,23 @@ class KernelTraits { firstChunkReuseSmem.emplace_back(reuseFirstChunksSmemStoreC); } + // SmemSparsityInfoA + { + // Number of bytes for sparsity info in SMEM. + auto const numBytesSmemSparsityInfoA = + numStages * tileM * tg::getNumBytesSparsityInfo(sparsityA, tileK); + // Number of bytes alignment for sparsity info in SMEM. + auto const numBytesAlignmentSparsityInfoA = 1024; + // No need to reuse the first chunk. + auto const reuseChunksSmemSparsityInfoA = false; + + // Add info. + smemChunkNames.emplace_back("smemSparsityInfoA"); + numBytesAndAlignmentPerSmemChunk.emplace_back( + std::make_pair(numBytesSmemSparsityInfoA, numBytesAlignmentSparsityInfoA)); + firstChunkReuseSmem.emplace_back(reuseChunksSmemSparsityInfoA); + } + // RowMax { // Number of dqSfsC per CTA. @@ -361,7 +402,6 @@ class KernelTraits { // Per-block absolute maximum for multi-warp reduction. { // Number of bytes: number of epilogue warps * number of tile columns. - // TODO: avoid allocating this memory when it's not needed (it's only for MxFp8 + fusedAct) auto const numBytesSmemBlockAmax = transposeMmaOutput ? 4 * tileN * sizeof(float) : 0; // Number of bytes alignment. auto const numBytesAlignmentBlockAmax = 16; @@ -418,8 +458,11 @@ class KernelTraits { std::vector tmemChunkNames; // Matrix D { + // Two set of TMEM resources for D share epilogueTileN columns, + // | set0:epiTileN0 | set0:epiTileN1/set1:epiTileN0 | set1:epiTileN1 | + auto const numCols = mUseMaxTmemOverlap ? 2 * tileN - epilogueTileN : tileN; // Number of columns for accumulators. - auto const numTmemColsD = numSlicesForSliceK * tileN * numStagesMma * + auto const numTmemColsD = numSlicesForSliceK * numCols * numStagesMma * tg::dtypeGetNumBits(dtypeAcc) / tg::dtypeGetNumBits(tg::Dtype::UInt32); // Number of columns for D alignment. @@ -462,13 +505,17 @@ class KernelTraits { bool const useBlockScalingA = tg::dtypeIsBlockFmt(dtypeMmaA); // Are the block scales constant? bool const useConstSfA = useBlockScalingA && !tg::dtypeIsBlockFmt(dtypeA); + // TMEM cols group size in the K dimension. + int32_t kGroupSize = 4; + // Number of columns per stage. + int32_t const numColsPerStage = + useBlockScalingA ? ((tileK / (kGroupSize * numEltsPerSfA)) * + tg::getTmemColStridePerGroup(tileM, mmaK, kGroupSize)) + : 0; // Number of columns for scaling factors of A. auto const numTmemColsSfA = - useConstSfA - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages - : 0); + useConstSfA ? tg::roundUp(numColsPerStage, 4) + : (numColsPerStage * (mFuseUtccpWithUtcmma ? 1 : numStages)); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -487,13 +534,17 @@ class KernelTraits { bool const useBlockScalingB = tg::dtypeIsBlockFmt(dtypeMmaB); // Are the block scales constant? bool const useConstSfB = useBlockScalingB && !tg::dtypeIsBlockFmt(dtypeB); + // TMEM cols group size in the K dimension. + int32_t kGroupSize = 4; + // Number of columns per stage. + int32_t const numColsPerStage = + useBlockScalingB ? ((tileK / (kGroupSize * numEltsPerSfB)) * + tg::getTmemColStridePerGroup(tileN, mmaK, kGroupSize)) + : 0; // Number of columns for scaling factors of B. auto const numTmemColsSfB = - useConstSfB - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages - : 0); + useConstSfB ? tg::roundUp(numColsPerStage, 4) + : (numColsPerStage * (mFuseUtccpWithUtcmma ? 1 : numStages)); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -506,6 +557,24 @@ class KernelTraits { firstChunkReuseTmem.emplace_back(reuseChunksTmemSfB); } + // Sparsity info for A + { + // Number of columns for the sparsity info for A (note: for Dense, this + // is 0). + auto const numTmemColsSparsityInfoA = + numStages * tg::getNumBytesSparsityInfo(sparsityA, tileK) / 4 /* bytes */; + // Number of columns for Sf alignment. + auto const numColsAlignmentSparsityInfoA = 2; + // No need to reuse TMEM. + auto const reuseChunksTmemSparsityInfoA = false; + + // Add info. + tmemChunkNames.emplace_back("tmemSparsityInfoA"); + numBytesAndAlignmentPerTmemChunk.emplace_back( + std::make_pair(numTmemColsSparsityInfoA, numColsAlignmentSparsityInfoA)); + firstChunkReuseTmem.emplace_back(reuseChunksTmemSparsityInfoA); + } + // Create TMEM helper object. mTmemAllocatorHelper = MemAllocatorHelper(numBytesAndAlignmentPerTmemChunk, firstChunkReuseTmem, tmemChunkNames); @@ -514,7 +583,13 @@ class KernelTraits { public: // The MMA kind. - tg::MmaKind mMmaKind; + tg::MmaKind mMmaKind{}; + // Whether fuse Utccp into the MMA task. + bool mFuseUtccpWithUtcmma{}; + // Whether use the max TMEM overlap trick. + bool mUseMaxTmemOverlap{}; + // The number of epilogue warps. + int32_t mNumEpilogueWarps{}; // Helper for SMEM allocation. MemAllocatorHelper mSmemAllocatorHelper; // Helper for TMEM allocation. @@ -603,6 +678,12 @@ inline int32_t getSmemOffsetConstSfBuf(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// +inline int32_t getSmemOffsetSparsityInfoA(KernelTraits traits) { + return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemSparsityInfoA"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline int32_t isSmemAbRepurposedToGmemC(KernelTraits traits, int resIdx = 0) { return traits.mSmemAllocatorHelper.getFirstChunkReuseFlagByName("smemGmemC" + std::to_string(resIdx)); @@ -638,6 +719,12 @@ inline int32_t getTmemOffsetSfB(KernelTraits traits) { //////////////////////////////////////////////////////////////////////////////////////////////////// +inline int32_t getTmemOffsetSparsityInfoA(KernelTraits traits) { + return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemSparsityInfoA"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace gemm diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h index 0722a42a4e..744755fb45 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,8 +23,6 @@ #ifdef TLLM_ENABLE_CUDA #include -#include -#include #endif namespace gemm { @@ -39,17 +37,17 @@ namespace tg = trtllm::gen; #ifdef TLLM_ENABLE_CUDA -inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, - std::vector const& shapes, +inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, std::vector const& shapes, std::vector const& strides, std::vector const& tileShapes, void* gmemAddr, - bool doSwizzle = true) { + bool doPad, bool doSwizzle = true) { // The multiplication factor of the data padding in SMEM. int32_t padMultiplier = 1; CUtensorMap desc{}; // The data type. CUtensorMapDataType tmaDataFormat{CU_TENSOR_MAP_DATA_TYPE_FLOAT32}; - if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::MxE4m3 || dtype == tg::Dtype::UE8m0) { + if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::MxE4m3 || dtype == tg::Dtype::UE8m0 || + dtype == tg::Dtype::UInt8) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if (dtype == tg::Dtype::Fp16) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; @@ -57,13 +55,11 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (dtype == tg::Dtype::E2m1) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; - } else if (dtype == tg::Dtype::MxE2m1) { - if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4) { + } else if (dtype == tg::Dtype::MxE2m1 || dtype == tg::Dtype::MxInt4) { + if (doPad) { padMultiplier = 2; tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else { - // Note: this is used with the MMA kind MxFp4NvFp4 and also when casting to a higher-precision - // type such as Bfloat16 before the MMA. tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } } else if (dtype == tg::Dtype::Fp32) { @@ -75,8 +71,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, // The swizzle type. CUtensorMapSwizzle swizzleType{CU_TENSOR_MAP_SWIZZLE_NONE}; - int32_t fastestDimTileSizeBytes = - (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / /* bits */ 8; + int32_t fastestDimTileSizeBytes = (tileShapes[0] * tg::dtypeGetNumBits(dtype) * padMultiplier) / + /* bits */ 8; if (doSwizzle) { if ((fastestDimTileSizeBytes % 128) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; @@ -84,9 +80,9 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, swizzleType = CU_TENSOR_MAP_SWIZZLE_64B; } else if ((fastestDimTileSizeBytes % 32) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_32B; - // This path is only for the scaling factors. } else if ((fastestDimTileSizeBytes % 16) == 0 && - (dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3)) { + (dtype == tg::Dtype::UE8m0 || dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::E2m1 || + dtype == tg::Dtype::UInt8)) { swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE; } else { std::cerr << "buildNdTmaDescriptor: unexpected fastestDimTileSizeBytes " @@ -100,8 +96,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); - // Expect 2 dimensions for regular gemm, 3 dimensions for batched gemm or blocked layout, and 4 - // dimensions for batched gemm with blocked layout. + // Expect 2 dimensions for regular gemm, 3 dimensions for batched gemm or + // blocked layout, and 4 dimensions for batched gemm with blocked layout. assert(dim == 2 || dim == 3 || dim == 4); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { @@ -114,7 +110,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, assert(strides[0] == 1); // Build strides in bytes. - // cuTensorMapEncodeTiled ignores the stride of the first dimension (implicitly 1). + // cuTensorMapEncodeTiled ignores the stride of the first dimension + // (implicitly 1). std::vector stridesInBytes(dim - 1); for (int32_t ii = 0; ii < dim - 1; ++ii) { stridesInBytes[ii] = (strides[ii + 1] * tg::dtypeGetNumBits(dtype)) / /* bits */ 8; @@ -124,7 +121,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, auto const numEltsPerUInt32 = 4 * /* bits */ 8 / (tg::dtypeGetNumBits(dtype) * padMultiplier); // The number of elements in 128B. auto const numEltsIn128B = numEltsPerUInt32 /*4B*/ * 32; - // The number of tile K hidden size (per token) in each block of shared memory. + // The number of tile K hidden size (per token) in each block of shared + // memory. auto const numEltsInClampedFastestTileSize = std::min(numEltsIn128B, tileShapes[0]); // Build box dim array. If tileShapes is smaller than dim, just fill with 1s. @@ -197,9 +195,11 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c std::vector const& strides, const std::vector& tileShapes, void* gmemAddr) { CUtensorMap desc{}; - CUtensorMapDataType tmaDataFormat; + CUtensorMapDataType tmaDataFormat{}; if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::UE8m0) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (dtype == tg::Dtype::Bfloat16) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else { std::cerr << "buildSfTmaDescriptor: unexpected dtype " << tg::dtypeToString(dtype) << std::endl; assert(false); @@ -224,7 +224,8 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c assert(strides[0] == 1); // Build strides in bytes. - // cuTensorMapEncodeTiled ignores the stride of the first dimension (implicitly 1). + // cuTensorMapEncodeTiled ignores the stride of the first dimension + // (implicitly 1). std::vector stridesInBytes(dim - 1); for (int32_t ii = 0; ii < dim - 1; ++ii) { stridesInBytes[ii] = (strides[ii + 1] * tg::dtypeGetNumBits(dtype)) / /* bits */ 8; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h index 331f6cd285..b40cd58ad2 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,12 +26,13 @@ namespace gen { // // TMA OOB optimization constants. // -// CUDA Programming Guide states that "globalDim must be non-zero and less than or equal to 2^32". -// In practice, the kernel acts funny with TMA shape of 2^32 so we use 2^31. +// CUDA Programming Guide states that "globalDim must be non-zero and less than +// or equal to 2^32". In practice, the kernel acts funny with TMA shape of 2^32 +// so we use 2^31. constexpr unsigned long TmaDimMax = 1UL << 31; -// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow and effectively -// becomes 0. As sizeof(dtype) can be as small as 0.5B, we choose LargeN = 2^30 and XLargeN = 2^35 -// so overflow can happen. +// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow +// and effectively becomes 0. As sizeof(dtype) can be as small as 0.5B, we +// choose LargeN = 2^30 and XLargeN = 2^35 so overflow can happen. constexpr unsigned long LargeN = 1UL << 30; // Used in TMA stride. Should be less than 2^40. constexpr unsigned long XLargeN = 1UL << 35; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaArchDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaArchDecl.h new file mode 100644 index 0000000000..19dc8c6dea --- /dev/null +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaArchDecl.h @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Be careful when modifying this file as it is included by the generated +// kernels. For example, do not add TLLM_CHECK_* constructs in this file. +// Thanks! +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace gemm { + +namespace trtllm { +namespace gen { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class CudaArch { + // Hopper + Sm90a = 0, + // Blackwell + Sm100a, + // Blackwell-family + Sm100f, + // Blackwell Ultra + Sm103a, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool isArchHopper(CudaArch cudaArch) { return cudaArch == CudaArch::Sm90a; } + +inline bool isArchBlackwell(CudaArch cudaArch) { + return cudaArch == CudaArch::Sm100a || cudaArch == CudaArch::Sm100f || + cudaArch == CudaArch::Sm103a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::string cudaArchToString(CudaArch cudaArch, bool isFull = true) { + switch (cudaArch) { + case CudaArch::Sm90a: + return isFull ? "90a" : "90"; + case CudaArch::Sm100a: + return isFull ? "100a" : "100"; + case CudaArch::Sm100f: + return isFull ? "100f" : "100"; + case CudaArch::Sm103a: + return isFull ? "103a" : "103"; + default: + assert(false); + return ""; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline CudaArch stringToCudaArch(std::string const& str) { + if (str == "90a") { + return CudaArch::Sm90a; + } else if (str == "100a") { + return CudaArch::Sm100a; + } else if (str == "100f") { + return CudaArch::Sm100f; + } else if (str == "103a") { + return CudaArch::Sm103a; + } else { + assert(false); + return CudaArch::Sm100a; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gen +} // namespace trtllm +} // namespace gemm diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h index 92c88dda16..7e1fa9d1ed 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -36,8 +36,8 @@ inline CUresult launchKernel(void* kernelParams, void* cudaStream, int32_t smemS bool enablesPdl) { // Make sure we can launch with that much shared memory. if (smemSize > 48 * 1024) { - CUresult result = - cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); + CUresult result; + result = cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smemSize); if (result != CUDA_SUCCESS) { return result; } diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h index 0e087769f0..88e6c96626 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,8 +28,9 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// // -// Be careful when modifying this file as it is included by the generated kernels. For example, do -// not add TLLM_CHECK_* constructs in this file. Thanks! +// Be careful when modifying this file as it is included by the generated +// kernels. For example, do not add TLLM_CHECK_* constructs in this file. +// Thanks! // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,9 +45,9 @@ enum class Dtype : uint32_t { // We use the following encoding for the types: // -// Byte 0: Identifier for the type (going from 0 to the number of data types - 1, -// Byte 1: Number of bits in the type, -// Byte 2: Bit 0: Is it an integer? 0x1 if true, 0x0 otherwise; +// Byte 0: Identifier for the type (going from 0 to the number of data types - +// 1, Byte 1: Number of bits in the type, Byte 2: Bit 0: Is it an integer? 0x1 +// if true, 0x0 otherwise; // Bit 4: is it signed? 0x1 if true, 0x0 otherwise. // Byte 3: Is it a block format? 0x1 if true, 0x0 otherwise. @@ -70,13 +71,14 @@ enum class Dtype : uint32_t { Int64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 11u), MxE2m1 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 4u, /*uid*/ 12u), MxE4m3 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 13u), - UE8m0 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 14u), - UInt8 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 8u, /*uid*/ 15u), - UInt16 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 16u, /*uid*/ 16u), - UInt32 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 32u, /*uid*/ 17u), - UInt64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 18u), - UInt128 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 128u, /*uid*/ 19u), - Void = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 0u, /*uid*/ 20u), + MxInt4 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 1u, /*bits*/ 4u, /*uid*/ 14u), + UE8m0 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 15u), + UInt8 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 8u, /*uid*/ 16u), + UInt16 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 16u, /*uid*/ 17u), + UInt32 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 32u, /*uid*/ 18u), + UInt64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 19u), + UInt128 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 128u, /*uid*/ 20u), + Void = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 0u, /*uid*/ 21u), // clang-format on #undef TLLM_ENCODE_DTYPE @@ -160,6 +162,8 @@ inline std::string dtypeToString(Dtype dtype) { return "MxE4m3"; case Dtype::MxE2m1: return "MxE2m1"; + case Dtype::MxInt4: + return "MxInt4"; case Dtype::UE8m0: return "UE8m0"; case Dtype::UInt8: @@ -195,13 +199,16 @@ inline Dtype dtypeEltType(Dtype dtype) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline int dtypeNumEltsPerSf(Dtype dtype) { +// Note: the block size from the options should be used instead. +// TODO: remove this function? +inline int dtypeNumEltsPerSf(Dtype dtype, bool useSparsity = false) { switch (dtype) { case Dtype::E2m1: - return 16; + return useSparsity ? 32 : 16; case Dtype::MxE2m1: case Dtype::MxE4m3: - return 32; + case Dtype::MxInt4: + return useSparsity ? 64 : 32; default: assert(false); return -1; @@ -218,6 +225,8 @@ inline Dtype dtypeGetBlockSfType(Dtype dtype) { case Dtype::MxE2m1: case Dtype::MxE4m3: return Dtype::UE8m0; + case Dtype::MxInt4: + return Dtype::Bfloat16; default: assert(false); return Dtype::Void; @@ -253,7 +262,8 @@ inline MmaKind dtypeGetMmaKind(Dtype dtypeA, Dtype dtypeB) { return MmaKind::Fp8Fp6Fp4; } - // At this point we know that both dtypes are Mx types and not both MxE2m1 at the same time. + // At this point we know that both dtypes are Mx types and not both MxE2m1 at + // the same time. if ((dtypeEltA == Dtype::E4m3 || dtypeEltA == Dtype::E5m2 || dtypeEltA == Dtype::E2m3 || dtypeEltA == Dtype::E3m2 || dtypeEltA == Dtype::E2m1) && (dtypeEltB == Dtype::E4m3 || dtypeEltB == Dtype::E5m2 || dtypeEltB == Dtype::E2m3 || @@ -265,6 +275,14 @@ inline MmaKind dtypeGetMmaKind(Dtype dtypeA, Dtype dtypeB) { //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool dtypeNeedsPadding(Dtype dtype, MmaKind mmaKind, [[maybe_unused]] int mmaK, + [[maybe_unused]] bool isSparseA) { + bool needsPadding = mmaKind == MmaKind::MxFp8Fp6Fp4 && dtype == Dtype::MxE2m1; + return needsPadding; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gen } // namespace trtllm diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h index ab030bec99..0b24a533e1 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,7 +43,8 @@ enum class MmaKind : uint32_t { // or dtypeA = dtypeB = Bfloat16 and dtypeD = [Fp32] // Corresponds to the kind::f16 of tcgen05.mma. Fp16 = 1, - // Supports dtypeA/B = [E4m3, E5m2, E2m3, E3m2, E2m1] and dtypeD = [Fp16, Fp32] + // Supports dtypeA/B = [E4m3, E5m2, E2m3, E3m2, E2m1] and dtypeD = [Fp16, + // Fp32] // Corresponds to the kind::f8f6f4 of tcgen05.mma. Fp8Fp6Fp4 = 2, // Supports dtypeA = dtypeB = [Int8, Uint8] and dtypeD = [Int32] @@ -95,11 +96,13 @@ inline std::string mmaKindToString(MmaKind mmaKind) { //////////////////////////////////////////////////////////////////////////////////////////////////// -// function to get the TMEM column stride per group (i.e., 64 K elements) -inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK) { - // Calculate the stride of TMEM column for every 64 elements in the K dimension - int32_t div = 2 * ceilDiv(tileMn, 64); - return mmaK == 96 ? std::max(4, div) : div; +// Get the TMEM column stride per group (i.e. kGroupSize * blockSize K elements) +inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK, int32_t kGroupSize) { + int32_t colStride = 2 * ceilDiv(tileMn, 64); + if (mmaK == 96) { + colStride = std::max(4, colStride); + } + return colStride; } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h index 25adab8902..423f9d4223 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,8 +21,9 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// // -// Be careful when modifying this file as it is included by the generated kernels. For example, do -// not add TLLM_CHECK_* constructs in this file. Thanks! +// Be careful when modifying this file as it is included by the generated +// kernels. For example, do not add TLLM_CHECK_* constructs in this file. +// Thanks! // //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -33,21 +34,25 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// -// This enumeration defines layouts for storing scale factors for FP4, FP6, and FP8 formats. +// This enumeration defines layouts for storing scale factors for FP4, FP6, and +// FP8 formats. enum class SfLayout { // Scale factors are stored in the same order as the associated matrix. - // I.e., the SF buffer is a tensor [m, ⌈n/b⌉], where m, n, and b are respectively the number of + // I.e., the SF buffer is a tensor [m, ⌈n/b⌉], where m, n, and b are + // respectively the number of // rows, columns and the block size. // The SF for the element (i, j) is stored at (i, j/b). Linear = 0, - // A tile of 8x4 is stored contiguously. The order of elements inside the tile, and the order + // A tile of 8x4 is stored contiguously. The order of elements inside the + // tile, and the order // of tiles, are both row-major. // I.e., the SF buffer is a tensor [⌈m/8⌉, ⌈n/b/4⌉, 8, 4]. // The SF for the element (i, j) is stored at (i/8, j/b/4, i%8, (j/b)%4). R8c4, - // A tile of 8x16 is stored contiguously. The order of elements inside the tile, and the order + // A tile of 8x16 is stored contiguously. The order of elements inside the + // tile, and the order // of tiles, are both row-major. // I.e., the SF buffer is a tensor [⌈m/8⌉, ⌈n/b/16⌉, 8, 16]. // The SF for the element (i, j) is stored at (i/8, j/b/16, i%8, (j/b)%16). @@ -57,16 +62,16 @@ enum class SfLayout { // addition to the above requirements it requires n to be a multiple of 256. R8c16, - // A tile of 128x4 is stored contiguously. Rows 0-31, 32-63, 64-95 and 96-127 are interleaved + // A tile of 128x4 is stored contiguously. Rows 0-31, 32-63, 64-95 and 96-127 + // are interleaved // as illustrated below: // | 0,0 | 0,1 | 0,2 | 0,3 | 32,0 | 32,1 | 32,2 | 32,3 | ... | 96,3 | // | 1,0 | 1,1 | 1,2 | 1,3 | 33,0 | 33,1 | 33,2 | 33,3 | ... | 97,3 | // | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | // | 31,0 | 31,1 | 31,2 | 31,3 | 63,0 | 63,1 | 63,2 | 63,3 | ... | 127,3 | - // See https://nvbugspro.nvidia.com/bug/4165523 - // // I.e., the SF buffer is a tensor [⌈m/128⌉, ⌈n/b/4⌉, 32, 4, 4] - // The SF for the element (i, j) is stored at (i/128, j/b/4, i%32, (i%128)/32, (j/b)%4). + // The SF for the element (i, j) is stored at (i/128, j/b/4, i%32, (i%128)/32, + // (j/b)%4). R128c4, }; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SparsityDecl.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SparsityDecl.h new file mode 100644 index 0000000000..645d348be3 --- /dev/null +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SparsityDecl.h @@ -0,0 +1,134 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Be careful when modifying this file as it is included by the generated kernels. For example, do +// not add TLLM_CHECK_* constructs in this file. Thanks! +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace gemm { + +namespace trtllm { +namespace gen { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// This enumeration defines structured sparsity modes. Please refer to the PTX ISA for more details. +enum class Sparsity { + // No sparsity. + Dense, + + // For each chunk of 2 elements, 1 is non-zero. Only non-zero elements are stored. + // A 4-bit index is used to indicate the position of the non-zero element. + // The index may only take the value 0b1110 or 0b0100, other values are undefined behavior. + // + // 0b1110: 0b0100: + // |------ a ------|------ 0 ------| |------ 0 ------|------ a ------| + // | 11 | 10 | 01 | 00 | | 11 | 10 | 01 | 00 | + Any_1_2, + + // For each chunk of 4 elements, 2 are non-zero. Only non-zero elements are stored. + // A 4-bit index is used to indicate the position of the non-zero elements. + // Meaningful values are: 0b0100, 0b1000, 0b1100, 0b1001, 0b1101, 0b1110. + // Most other values are undefined behavior. + // + // E.g. 0b1100 corresponds to: + // |-- b --|-- 0 --|-- 0 --|-- a --| + // | 11 | 10 | 01 | 00 | + Any_2_4, + + // For each chunk of 8 elements, 4 are non-zero. Only non-zero elements are stored. + // Further, the zero and non-zero elements are grouped in pairs. + // A 4-bit index is used to indicate the position of the non-zero elements. + // Meaningful values are: 0b0100, 0b1000, 0b1100, 0b1001, 0b1101, 0b1110. + // Most other values are undefined behavior. + // + // E.g. 0b1100 corresponds to: + // | d | c | 0 | 0 | 0 | 0 | b | a | + // | 11 | 10 | 01 | 00 | + Pairwise_4_8, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool isSparse(Sparsity sparsity) { return sparsity != Sparsity::Dense; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::string sparsityToString(Sparsity sparsity) { + switch (sparsity) { + case Sparsity::Dense: + return "dense"; + case Sparsity::Any_1_2: + return "1:2"; + case Sparsity::Any_2_4: + return "2:4"; + case Sparsity::Pairwise_4_8: + return "4:8"; + default: + assert(false); + return "Unsupported sparsity"; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Size of a sparsity chunk, for sparse modes. +inline int32_t getSparsityChunkSize(Sparsity sparsity) { + switch (sparsity) { + case Sparsity::Any_1_2: + return 2; + case Sparsity::Any_2_4: + return 4; + case Sparsity::Pairwise_4_8: + return 8; + case Sparsity::Dense: + default: + assert(false); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Number of bytes needed to store the sparsity information. +inline size_t getNumBytesSparsityInfo(Sparsity sparsity, size_t numElts) { + switch (sparsity) { + case Sparsity::Dense: + return 0; + case Sparsity::Any_1_2: + case Sparsity::Any_2_4: + case Sparsity::Pairwise_4_8: + return numElts / getSparsityChunkSize(sparsity) * 4 /*bits*/ / 8; + default: + assert(false); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gen +} // namespace trtllm + +} // namespace gemm diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index 19a218826d..27fff7846e 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -3,9 +3,9 @@ from tests.test_helpers.sink_attention_reference import sink_attention_unified import flashinfer +from flashinfer import SfLayout from flashinfer.utils import get_compute_capability from flashinfer.fp4_quantization import ( - SfLayout, nvfp4_quantize, e2m1_and_ufp8sf_scale_to_float, ) diff --git a/tests/gemm/test_mm_mxfp8.py b/tests/gemm/test_mm_mxfp8.py index ea87645e1d..12e4384962 100644 --- a/tests/gemm/test_mm_mxfp8.py +++ b/tests/gemm/test_mm_mxfp8.py @@ -2,7 +2,13 @@ import torch import torch.nn.functional as F -from flashinfer import autotune, mm_mxfp8 +from flashinfer import ( + autotune, + mm_mxfp8, + SfLayout, + shuffle_matrix_a, + shuffle_matrix_sf_a, +) from flashinfer.fp8_quantization import mxfp8_quantize from flashinfer.utils import get_compute_capability @@ -75,8 +81,19 @@ def _run_mm_mxfp8( backend, auto_tuning, provide_out, + use_8x4_sf_layout_for_a=False, ): _skip_if_unsupported(backend) + if backend == "trtllm": + if not is_sf_swizzled_layout: + pytest.skip("trtllm must have swizzled scales") + if k % 256 != 0: + pytest.skip("trtllm does not support non-multiple of 256") + if out_dtype != torch.bfloat16: + pytest.skip("trtllm does not support non-bfloat16 output") + if backend == "cutlass": + if is_sf_swizzled_layout and use_8x4_sf_layout_for_a: + pytest.skip("cutlass doesn't support 8x4 swizzle layout") if backend == "cute-dsl" and not is_sf_swizzled_layout: pytest.skip( "cute-dsl mm_mxfp8 currently supports only swizzled 1D scale layout." @@ -85,9 +102,23 @@ def _run_mm_mxfp8( input = torch.randn([m, k], device="cuda", dtype=input_dtype) mat2 = torch.randn([n, k], device="cuda", dtype=input_dtype) + if is_sf_swizzled_layout: + sflayout_a = SfLayout.layout_128x4 + sflayout_b = SfLayout.layout_128x4 + else: + sflayout_a = SfLayout.layout_linear + sflayout_b = SfLayout.layout_linear + if is_sf_swizzled_layout and use_8x4_sf_layout_for_a: + sflayout_a = SfLayout.layout_8x4 + input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( - input, mat2, is_sf_swizzled_layout + input, + mat2, + sflayout_a, + sflayout_b, + backend, ) + reference = torch.mm(input, mat2.T) res = torch.empty([m, n], device="cuda", dtype=out_dtype) if provide_out else None @@ -101,6 +132,7 @@ def _run_mm_mxfp8( out=res, out_dtype=out_dtype, backend=backend, + use_8x4_sf_layout=use_8x4_sf_layout_for_a, ) assert res.shape == (m, n) @@ -111,27 +143,38 @@ def _run_mm_mxfp8( _assert_cosine_similarity(reference, res, is_sf_swizzled_layout) -def _prepare_descales(input_scale, weight_scale, m, n, k, is_sf_swizzled_layout): - if is_sf_swizzled_layout: - return input_scale, weight_scale - input_descale = input_scale.view(m, k // 32) - weight_descale = weight_scale.view(n, k // 32).t() - return input_descale, weight_descale - - -def _prepare_mxfp8_tensors(input_bf16, weight_bf16, is_sf_swizzled_layout): +def _prepare_mxfp8_tensors( + input_bf16, + weight_bf16, + sf_layout_input: SfLayout, + sf_layout_weight: SfLayout, + backend="cutlass", +): m, k = input_bf16.shape n = weight_bf16.shape[0] input_mxfp8, input_scale = mxfp8_quantize( - input_bf16, is_sf_swizzled_layout=is_sf_swizzled_layout + input_bf16, sf_swizzle_layout=sf_layout_input ) weight_mxfp8, weight_scale = mxfp8_quantize( - weight_bf16, is_sf_swizzled_layout=is_sf_swizzled_layout + weight_bf16, + sf_swizzle_layout=SfLayout.layout_linear + if backend == "trtllm" + else sf_layout_weight, ) - input_descale, weight_descale = _prepare_descales( - input_scale, weight_scale, m, n, k, is_sf_swizzled_layout - ) - return input_mxfp8, weight_mxfp8, input_descale, weight_descale + if backend == "trtllm": + assert sf_layout_weight == SfLayout.layout_128x4, ( + "shuffle_matrix_sf_a only supports 128x4 swizzling now" + ) + # the shuffle_matrix_sf_a expects linear layout and will swizzle the scales to 128x4 afterwards + weight_mxfp8 = shuffle_matrix_a(weight_mxfp8, 128).reshape(n, k) + weight_scale = shuffle_matrix_sf_a( + weight_scale.reshape(n, k // 32), 128, num_elts_per_sf=32 + ).reshape(-1) + if sf_layout_input == SfLayout.layout_linear: + input_scale = input_scale.view(m, k // 32) + if sf_layout_weight == SfLayout.layout_linear: + weight_scale = weight_scale.view(n, k // 32).t() + return input_mxfp8, weight_mxfp8, input_scale, weight_scale @pytest.mark.parametrize("m", [128, 256, 512, 1024]) @@ -140,7 +183,7 @@ def _prepare_mxfp8_tensors(input_bf16, weight_bf16, is_sf_swizzled_layout): @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["cutlass", "cute-dsl"]) +@pytest.mark.parametrize("backend", ["cutlass", "cute-dsl", "trtllm"]) @pytest.mark.parametrize("auto_tuning", [True, False]) def test_mm_mxfp8( m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend, auto_tuning @@ -155,6 +198,7 @@ def test_mm_mxfp8( backend, auto_tuning, provide_out=True, + use_8x4_sf_layout_for_a=backend == "trtllm", ) @@ -164,7 +208,7 @@ def test_mm_mxfp8( @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -@pytest.mark.parametrize("backend", ["cutlass", "cute-dsl", "auto"]) +@pytest.mark.parametrize("backend", ["cutlass", "cute-dsl", "trtllm", "auto"]) def test_mm_mxfp8_large_dimensions( m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend ): @@ -178,6 +222,7 @@ def test_mm_mxfp8_large_dimensions( backend, auto_tuning=False, provide_out=True, + use_8x4_sf_layout_for_a=backend == "trtllm", ) @@ -265,7 +310,11 @@ def test_mm_mxfp8_find_minimum_cosine_similarity(is_sf_swizzled_layout): mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) * value_scale input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( - input_data, mat2, is_sf_swizzled_layout + input_data, + mat2, + SfLayout.layout_128x4 if is_sf_swizzled_layout else SfLayout.layout_linear, + SfLayout.layout_128x4 if is_sf_swizzled_layout else SfLayout.layout_linear, + backend="cutlass", ) reference = torch.mm(input_data, mat2.T) @@ -327,7 +376,11 @@ def test_mm_mxfp8_realistic_model_statistics(m, n, k, input_std, weight_std): reference = torch.mm(input_data, mat2.T) input_mxfp8, mat2_mxfp8, input_descale, mat2_descale = _prepare_mxfp8_tensors( - input_data, mat2, True + input_data, + mat2, + SfLayout.layout_128x4, + SfLayout.layout_128x4, + backend="cutlass", ) result = mm_mxfp8( @@ -405,7 +458,13 @@ def test_mm_mxfp8_llm_full_layer_simulation(): reference = torch.mm(layer_input, weight.T) input_mxfp8, weight_mxfp8, input_descale, weight_descale = ( - _prepare_mxfp8_tensors(layer_input, weight, True) + _prepare_mxfp8_tensors( + layer_input, + weight, + SfLayout.layout_128x4, + SfLayout.layout_128x4, + backend="cutlass", + ) ) result = mm_mxfp8( @@ -490,7 +549,11 @@ def test_mm_mxfp8_scale_1d_tensor_interpretation(m): weight_bf16 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) * 0.02 input_fp8, weight_fp8, input_descale, weight_descale = _prepare_mxfp8_tensors( - input_bf16, weight_bf16, True + input_bf16, + weight_bf16, + SfLayout.layout_128x4, + SfLayout.layout_128x4, + backend="cutlass", ) input_scale = input_descale