diff --git a/benchmarks/bench_router_gemm.py b/benchmarks/bench_router_gemm.py new file mode 100644 index 0000000000..97a3570cbd --- /dev/null +++ b/benchmarks/bench_router_gemm.py @@ -0,0 +1,79 @@ +import numpy as np +import torch + +from flashinfer.testing.utils import bench_gpu_time_with_cudagraph +from flashinfer.dsv3_ops import mm_M1_16_K7168_N128, mm_M1_16_K7168_N256 + + +@torch.compile +def reference_torch( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, +): + return torch.nn.functional.linear(x, weight, bias) + + +def get_data_torch(num_tokens, num_experts, hidden_dim): + mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + mat_b = torch.randn(num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16) + return mat_a, mat_b + + +def get_data_flashinfer(num_tokens, num_experts, hidden_dim, output_dtype): + mat_a = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + mat_b = torch.randn( + num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16 + ).t() + out = torch.empty(num_tokens, num_experts, device="cuda", dtype=output_dtype) + return mat_a, mat_b, out + + +def bench_router_gemm(gemm_fn, data, M, N, K, reps=1000, warmup_reps=1000): + measurements = bench_gpu_time_with_cudagraph( + lambda: gemm_fn(*data), + dry_run_time_ms=warmup_reps, + repeat_time_ms=reps, + ) + ms = np.median(measurements) + flops = (2 * M * N * K) / ms / 1e9 + add_desc = f" launch_with_pdl={data[3]}" if len(data) > 3 else "" + print( + f"Router GEMM function {gemm_fn} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}" + ) + + +def main(): + hidden_dim = 7168 + for num_tokens in [1, 2, 4, 8, 16]: + for num_experts, output_dtype, flashinfer_fn in [ + (128, torch.bfloat16, mm_M1_16_K7168_N128), + (256, torch.float32, mm_M1_16_K7168_N256), + ]: + data_torch = get_data_torch( + num_tokens=num_tokens, hidden_dim=hidden_dim, num_experts=num_experts + ) + bench_router_gemm( + reference_torch, data_torch, num_tokens, num_experts, hidden_dim + ) + + data_flashinfer = get_data_flashinfer( + num_tokens=num_tokens, + hidden_dim=hidden_dim, + num_experts=num_experts, + output_dtype=output_dtype, + ) + for launch_with_pdl in [False, True]: + bench_router_gemm( + flashinfer_fn, + (*data_flashinfer, launch_with_pdl), + num_tokens, + num_experts, + hidden_dim, + ) + + print() + + +if __name__ == "__main__": + main() diff --git a/csrc/dsv3_router_gemm.cu b/csrc/dsv3_router_gemm.cu index 2d44147d97..311ac3f579 100644 --- a/csrc/dsv3_router_gemm.cu +++ b/csrc/dsv3_router_gemm.cu @@ -2,10 +2,13 @@ #include "tvm_ffi_utils.h" namespace flashinfer::trtllm_dsv3_router_gemm { -template -void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream, + +// Note: Explicit template instantiations are not needed here because +// LoopUnroller already forces instantiation of all required specializations. +template +void invokeRouterGemm(Tout* output, Tin const* mat_a, Tin const* mat_b, cudaStream_t stream, bool use_pdl = false) { - constexpr int VPT = 16 / sizeof(T); + constexpr int VPT = 16 / sizeof(Tin); constexpr int kBlockSize = 128; cudaLaunchConfig_t config; config.gridDim = kNumExperts; @@ -18,83 +21,20 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_ config.numAttrs = 1; config.attrs = attrs; auto status = cudaLaunchKernelEx( - &config, router_gemm_kernel, output, - mat_a, mat_b); + &config, router_gemm_kernel, + output, mat_a, mat_b); TVM_FFI_ICHECK(status == cudaSuccess) << "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status); } -template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - -template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, - __nv_bfloat16 const*, cudaStream_t, - bool); - template struct LoopUnroller { - static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + template + static void unroll(int num_tokens, Tout* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { if (num_tokens == kBegin) { - invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, - stream, launch_with_pdl); + invokeRouterGemm<__nv_bfloat16, Tout, kBegin, kNumExperts, kHiddenDim>( + output, input, weights, stream, launch_with_pdl); } else { LoopUnroller::unroll( num_tokens, output, input, weights, stream, launch_with_pdl); @@ -104,24 +44,26 @@ struct LoopUnroller { template struct LoopUnroller { - static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, + template + static void unroll(int num_tokens, Tout* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { if (num_tokens == kEnd) { - invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream, - launch_with_pdl); + invokeRouterGemm<__nv_bfloat16, Tout, kEnd, kNumExperts, kHiddenDim>(output, input, weights, + stream, launch_with_pdl); } else { throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); } } }; -void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { +template +void generic_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, + bool launch_with_pdl) { int const num_tokens = mat_a.sizes()[0]; int const num_experts = mat_b.sizes()[1]; int const hidden_dim = mat_a.sizes()[1]; auto const out_dtype_ = out.dtype(); auto const data_type = mat_a.dtype(); - constexpr int kNumExperts = 256; constexpr int kHiddenDim = 7168; std::vector output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors"; @@ -132,13 +74,13 @@ void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, boo bool use_custom_kernel = false; if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && - encode_dlpack_dtype(out_dtype_) == float32_code) { + encode_dlpack_dtype(out_dtype_) == tout_code) { use_custom_kernel = true; } if (use_custom_kernel) { - LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( - num_tokens, reinterpret_cast(out.data_ptr()), + LoopUnroller::unroll( + num_tokens, reinterpret_cast(out.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); } else { @@ -146,7 +88,18 @@ void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, boo } } +void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { + generic_router_gemm_op(mat_a, mat_b, out, launch_with_pdl); +} + +void ml3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { + generic_router_gemm_op<__nv_bfloat16, bfloat16_code, 128, 1, 16>(mat_a, mat_b, out, + launch_with_pdl); +} + TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op, flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(ml3_router_gemm_op, + flashinfer::trtllm_dsv3_router_gemm::ml3_router_gemm_op); } // namespace flashinfer::trtllm_dsv3_router_gemm diff --git a/flashinfer/dsv3_ops/__init__.py b/flashinfer/dsv3_ops/__init__.py index 6c9adb106f..db0a7fea7e 100644 --- a/flashinfer/dsv3_ops/__init__.py +++ b/flashinfer/dsv3_ops/__init__.py @@ -1,8 +1,9 @@ -from flashinfer.gemm import mm_M1_16_K7168_N256 +from flashinfer.gemm import mm_M1_16_K7168_N128, mm_M1_16_K7168_N256 from flashinfer.fused_moe import fused_topk_deepseek from flashinfer.concat_ops import concat_mla_k __all__ = [ + "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", "fused_topk_deepseek", "concat_mla_k", diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index cfce6c52c8..bd30c178dc 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -19,6 +19,7 @@ from .gemm_base import fp8_blockscale_gemm_sm90 as fp8_blockscale_gemm_sm90 from .routergemm_dsv3 import ( + mm_M1_16_K7168_N128 as mm_M1_16_K7168_N128, mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, ) @@ -38,5 +39,6 @@ "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", "fp8_blockscale_gemm_sm90", + "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", ] diff --git a/flashinfer/gemm/routergemm_dsv3.py b/flashinfer/gemm/routergemm_dsv3.py index 3a3e1c93d4..65c072d54b 100644 --- a/flashinfer/gemm/routergemm_dsv3.py +++ b/flashinfer/gemm/routergemm_dsv3.py @@ -10,9 +10,9 @@ ) -# TODO: other compute capabilities may be supported but are untested -@supported_compute_capability([100]) -def _mm_M1_16_K7168_N256_shape_checks(mat_a, mat_b, out, launch_with_pdl): +def _mm_M1_16_K7168_shape_checks( + mat_a, mat_b, out, launch_with_pdl, expected_num_experts, expected_out_dtype +): # Dimension checks if mat_a.dim() != 2: raise ValueError("mat_a must be a 2D tensor") @@ -38,7 +38,6 @@ def _mm_M1_16_K7168_N256_shape_checks(mat_a, mat_b, out, launch_with_pdl): # Problem size checks expected_hidden_dim = 7168 - expected_num_experts = 256 min_tokens = 1 max_tokens = 16 if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens: @@ -59,16 +58,54 @@ def _mm_M1_16_K7168_N256_shape_checks(mat_a, mat_b, out, launch_with_pdl): raise ValueError("mat_a must be a bfloat16 tensor") if mat_b.dtype != torch.bfloat16: raise ValueError("mat_b must be a bfloat16 tensor") - if out.dtype != torch.float32: - raise ValueError("out must be a float32 tensor") + if out.dtype != expected_out_dtype: + raise ValueError(f"out must be a {expected_out_dtype} tensor") return True +# TODO: other compute capabilities may be supported but are untested +@supported_compute_capability([100]) +def _mm_M1_16_K7168_N256_shape_checks(mat_a, mat_b, out, launch_with_pdl): + return _mm_M1_16_K7168_shape_checks( + mat_a, + mat_b, + out, + launch_with_pdl, + expected_num_experts=256, + expected_out_dtype=torch.float32, + ) + + +# TODO: other compute capabilities may be supported but are untested +@supported_compute_capability([100]) +def _mm_M1_16_K7168_N128_shape_checks(mat_a, mat_b, out, launch_with_pdl): + return _mm_M1_16_K7168_shape_checks( + mat_a, + mat_b, + out, + launch_with_pdl, + expected_num_experts=128, + expected_out_dtype=torch.bfloat16, + ) + + @functools.cache def get_dsv3_router_gemm_module(): module = gen_dsv3_router_gemm_module().build_and_load() + @register_custom_op( + "flashinfer::ml3_router_gemm_op", + mutates_args=["out"], + ) + def mm_M1_16_K7168_N128( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = False, + ) -> None: + module.ml3_router_gemm_op(mat_a, mat_b, out, launch_with_pdl) + @register_custom_op( "flashinfer::dsv3_router_gemm_op", mutates_args=["out"], @@ -82,10 +119,61 @@ def mm_M1_16_K7168_N256( module.dsv3_router_gemm_op(mat_a, mat_b, out, launch_with_pdl) return SimpleNamespace( + mm_M1_16_K7168_N128=mm_M1_16_K7168_N128, mm_M1_16_K7168_N256=mm_M1_16_K7168_N256, ) +@backend_requirement({}, common_check=_mm_M1_16_K7168_N128_shape_checks) +@flashinfer_api +def mm_M1_16_K7168_N128( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = False, +) -> None: + """Optimized GEMM for the router operation in Mistral Large 3. + + This function performs a highly optimized matrix multiplication specifically tailored + for the expert routing GEMM in Mistral Large 3's Mixture of Experts (MoE) architecture. + It computes out = mat_a @ mat_b where mat_a contains token embeddings and mat_b + contains expert routing weights. + + The implementation is optimized for the specific problem dimensions used in Mistral Large 3: + - Hidden dimension (K): 7168 + - Number of experts (N): 128 + - Number of tokens (M): 1-16 + + Args: + mat_a (torch.Tensor): Input token embeddings of shape (M, K) where M is the number + of tokens (1-16) and K is the hidden dimension (7168). Must be bfloat16, + row-major (contiguous). + mat_b (torch.Tensor): Expert routing weights of shape (K, N) where K is the hidden + dimension (7168) and N is the number of experts (128). Must be bfloat16, + column-major (transposed layout). + out (torch.Tensor): Pre-allocated output tensor of shape (M, N) containing the + routing scores. Must be bfloat16, row-major (contiguous). This tensor is + mutated in-place. + launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent + Device-side Launch. Defaults to False. + + Returns: + None: The result is written directly to the `out` tensor. + + Raises: + ValueError: If tensor dimensions, strides, or data types do not match the + expected Mistral Large 3 router configuration. + + Note: + This kernel is specialized for compute capability 10.0 (Blackwell architecture). + The specific problem size optimization makes this significantly faster than + general-purpose GEMM implementations for the router operation. + """ + get_dsv3_router_gemm_module().mm_M1_16_K7168_N128( + mat_a, mat_b, out, launch_with_pdl + ) + + @backend_requirement({}, common_check=_mm_M1_16_K7168_N256_shape_checks) @flashinfer_api def mm_M1_16_K7168_N256( diff --git a/include/flashinfer/gemm/dsv3_router_gemm.cuh b/include/flashinfer/gemm/dsv3_router_gemm.cuh index aef712d68e..aa9811c985 100644 --- a/include/flashinfer/gemm/dsv3_router_gemm.cuh +++ b/include/flashinfer/gemm/dsv3_router_gemm.cuh @@ -39,9 +39,10 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds } } -template -__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, - T const* mat_b) { +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel(Tout* out, Tin const* mat_a, + Tin const* mat_b) { // Each block handles one expert column int const n_idx = blockIdx.x; int const tid = threadIdx.x; @@ -58,7 +59,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps // B matrix is in column-major order, so we can directly load a column for the n_idx expert - T const* b_col = mat_b + n_idx * kHiddenDim; + Tin const* b_col = mat_b + n_idx * kHiddenDim; // Pre-compute k_base values for each iteration to help compiler optimize // int k_bases[k_iterations]; diff --git a/tests/model_optimizations/test_dsv3_router_gemm.py b/tests/model_optimizations/test_dsv3_router_gemm.py index c4c8f1ce7b..80f921667f 100644 --- a/tests/model_optimizations/test_dsv3_router_gemm.py +++ b/tests/model_optimizations/test_dsv3_router_gemm.py @@ -1,16 +1,24 @@ import torch import pytest -from flashinfer.dsv3_ops import mm_M1_16_K7168_N256 +from flashinfer.dsv3_ops import mm_M1_16_K7168_N128, mm_M1_16_K7168_N256 import torch.nn.functional as F from flashinfer.utils import get_compute_capability # Positive tests @pytest.mark.parametrize("num_tokens", [1, 2, 3, 5, 8, 13, 16]) -@pytest.mark.parametrize("num_experts", [256]) +@pytest.mark.parametrize( + "num_experts,output_dtype,fn_to_test", + ( + [256, torch.float32, mm_M1_16_K7168_N256], + [128, torch.bfloat16, mm_M1_16_K7168_N128], + ), +) @pytest.mark.parametrize("hidden_dim", [7168]) @pytest.mark.parametrize("launch_with_pdl", [True, False]) -def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pdl): +def test_dsv3_router_gemm_op( + num_tokens, num_experts, hidden_dim, launch_with_pdl, output_dtype, fn_to_test +): compute_capability = get_compute_capability(torch.device("cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] if compute_capability_number != 100: @@ -20,8 +28,8 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd mat_b = torch.randn( num_experts, hidden_dim, device="cuda", dtype=torch.bfloat16 ).t() # column major - out = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32) - mm_M1_16_K7168_N256(mat_a, mat_b, out, launch_with_pdl=launch_with_pdl) + out = torch.empty(num_tokens, num_experts, device="cuda", dtype=output_dtype) + fn_to_test(mat_a, mat_b, out, launch_with_pdl=launch_with_pdl) ref = mat_a @ mat_b cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) @@ -30,10 +38,11 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd # Negative tests - test values just outside valid ranges @pytest.mark.parametrize( - "num_tokens,num_experts,hidden_dim,mat_a_dtype,mat_b_dtype,out_dtype,mat_b_transpose,expected_error", + "fn_array,num_tokens,num_experts,hidden_dim,mat_a_dtype,mat_b_dtype,out_dtype,mat_b_transpose,expected_error", [ # Invalid num_tokens (must be 1-16) - ( + pytest.param( + [mm_M1_16_K7168_N128, mm_M1_16_K7168_N256], 0, 256, 7168, @@ -42,8 +51,10 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "num_tokens", + id="all-num_tokens_0", ), - ( + pytest.param( + [mm_M1_16_K7168_N128, mm_M1_16_K7168_N256], 17, 256, 7168, @@ -52,9 +63,35 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "num_tokens", + id="all-num_tokens_17", + ), + # Invalid num_experts (must be 128 or 256, depending on the function) + pytest.param( + [mm_M1_16_K7168_N128], + 8, + 127, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_experts", + id="N128-num_experts_127", + ), + pytest.param( + [mm_M1_16_K7168_N128], + 8, + 129, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "num_experts", + id="N128-num_experts_129", ), - # Invalid num_experts (must be 256) - ( + pytest.param( + [mm_M1_16_K7168_N256], 8, 255, 7168, @@ -63,8 +100,10 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "num_experts", + id="N256-num_experts_255", ), - ( + pytest.param( + [mm_M1_16_K7168_N256], 8, 257, 7168, @@ -73,9 +112,11 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "num_experts", + id="N256-num_experts_257", ), # Invalid hidden_dim (must be 7168) - ( + pytest.param( + [mm_M1_16_K7168_N128, mm_M1_16_K7168_N256], 8, 256, 7167, @@ -84,8 +125,10 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "hidden_dim", + id="all-hidden_dim_7167", ), - ( + pytest.param( + [mm_M1_16_K7168_N128, mm_M1_16_K7168_N256], 8, 256, 7169, @@ -94,13 +137,84 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, True, "hidden_dim", + id="all-hidden_dim_7169", ), # Invalid dtypes - (8, 256, 7168, torch.float32, torch.bfloat16, torch.float32, True, "bfloat16"), - (8, 256, 7168, torch.bfloat16, torch.float32, torch.float32, True, "bfloat16"), - (8, 256, 7168, torch.bfloat16, torch.bfloat16, torch.bfloat16, True, "float32"), + pytest.param( + [mm_M1_16_K7168_N128], + 8, + 128, + 7168, + torch.float32, + torch.bfloat16, + torch.float32, + True, + "bfloat16", + id="N128-invalid_mat_a_dtype", + ), + pytest.param( + [mm_M1_16_K7168_N128], + 8, + 128, + 7168, + torch.bfloat16, + torch.float32, + torch.float32, + True, + "bfloat16", + id="N128-invalid_mat_b_dtype", + ), + pytest.param( + [mm_M1_16_K7168_N128], + 8, + 128, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.float32, + True, + "bfloat16", + id="N128-invalid_out_dtype", + ), + pytest.param( + [mm_M1_16_K7168_N256], + 8, + 256, + 7168, + torch.float32, + torch.bfloat16, + torch.float32, + True, + "bfloat16", + id="N256-invalid_mat_a_dtype", + ), + pytest.param( + [mm_M1_16_K7168_N256], + 8, + 256, + 7168, + torch.bfloat16, + torch.float32, + torch.float32, + True, + "bfloat16", + id="N256-invalid_mat_b_dtype", + ), + pytest.param( + [mm_M1_16_K7168_N256], + 8, + 256, + 7168, + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + True, + "float32", + id="N256-invalid_out_dtype", + ), # Invalid stride (mat_b not transposed = row-major instead of column-major) - ( + pytest.param( + [mm_M1_16_K7168_N128, mm_M1_16_K7168_N256], 8, 256, 7168, @@ -109,10 +223,12 @@ def test_dsv3_router_gemm_op(num_tokens, num_experts, hidden_dim, launch_with_pd torch.float32, False, "column-major", + id="all-invalid_stride", ), ], ) def test_dsv3_router_gemm_op_negative( + fn_array, num_tokens, num_experts, hidden_dim, @@ -133,5 +249,6 @@ def test_dsv3_router_gemm_op_negative( mat_b = mat_b.t() # column major out = torch.randn(num_tokens, num_experts, device="cuda", dtype=out_dtype) - with pytest.raises(ValueError, match=expected_error): - mm_M1_16_K7168_N256(mat_a, mat_b, out, launch_with_pdl=False) + for fn in fn_array: + with pytest.raises(ValueError, match=expected_error): + fn(mat_a, mat_b, out, launch_with_pdl=False)