diff --git a/flashinfer/dsv3_ops/__init__.py b/flashinfer/dsv3_ops/__init__.py index 05a7c4e657..527f4f3d9b 100644 --- a/flashinfer/dsv3_ops/__init__.py +++ b/flashinfer/dsv3_ops/__init__.py @@ -1,7 +1,7 @@ from flashinfer.gemm import mm_M1_16_K7168_N256 -from flashinfer.fused_moe import NoAuxTc +from flashinfer.fused_moe import fused_topk_deepseek __all__ = [ "mm_M1_16_K7168_N256", - "NoAuxTc", + "fused_topk_deepseek", ] diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index a7d7a368db..a34d37f149 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -35,7 +35,7 @@ ) from .fused_routing_dsv3 import ( # noqa: F401 - NoAuxTc as NoAuxTc, + fused_topk_deepseek as fused_topk_deepseek, ) __all__ = [ @@ -56,5 +56,5 @@ "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", "trtllm_mxint4_block_scale_moe", - "NoAuxTc", + "fused_topk_deepseek", ] diff --git a/flashinfer/fused_moe/fused_routing_dsv3.py b/flashinfer/fused_moe/fused_routing_dsv3.py index bb12472272..9c6fb79c91 100644 --- a/flashinfer/fused_moe/fused_routing_dsv3.py +++ b/flashinfer/fused_moe/fused_routing_dsv3.py @@ -116,7 +116,7 @@ def NoAuxTc( @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) -def NoAuxTc( +def fused_topk_deepseek( scores: torch.Tensor, bias: torch.Tensor, n_group: int, diff --git a/tests/model_optimizations/test_dsv3_fused_routing.py b/tests/model_optimizations/test_dsv3_fused_routing.py index 1749e94f46..e84c9ca884 100644 --- a/tests/model_optimizations/test_dsv3_fused_routing.py +++ b/tests/model_optimizations/test_dsv3_fused_routing.py @@ -1,7 +1,7 @@ """ -Test for NoAuxTc (DSv3 Fused Routing) Kernel +Test for fused_topk_deepseek (DSv3 Fused Routing) Kernel -This test validates the NoAuxTc kernel against a reference implementation, +This test validates the fused_topk_deepseek kernel against a reference implementation, accounting for numerical precision and tie-breaking differences. ================================================================================ @@ -118,7 +118,7 @@ import torch import pytest -from flashinfer.dsv3_ops import NoAuxTc +from flashinfer.dsv3_ops import fused_topk_deepseek # from flashinfer.utils import get_compute_capability @@ -429,7 +429,7 @@ def test_dsv3_fused_routing_op( num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type ): """ - Test NoAuxTc kernel against reference implementation. + Test fused_topk_deepseek kernel against reference implementation. Validates: 1. Expert selection equivalence (allowing for ties) @@ -473,7 +473,7 @@ def test_dsv3_fused_routing_op( topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=data_type) topk_indices = torch.zeros(num_tokens, topk, device="cuda", dtype=torch.int32) - NoAuxTc( + fused_topk_deepseek( scores, bias, n_group,