Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/dsv3_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from flashinfer.gemm import mm_M1_16_K7168_N256
from flashinfer.fused_moe import NoAuxTc
from flashinfer.fused_moe import fused_topk_deepseek

__all__ = [
"mm_M1_16_K7168_N256",
"NoAuxTc",
"fused_topk_deepseek",
]
4 changes: 2 additions & 2 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)

from .fused_routing_dsv3 import ( # noqa: F401
NoAuxTc as NoAuxTc,
fused_topk_deepseek as fused_topk_deepseek,
)

__all__ = [
Expand All @@ -56,5 +56,5 @@
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
"trtllm_mxint4_block_scale_moe",
"NoAuxTc",
"fused_topk_deepseek",
]
2 changes: 1 addition & 1 deletion flashinfer/fused_moe/fused_routing_dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def NoAuxTc(


@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
def NoAuxTc(
def fused_topk_deepseek(
scores: torch.Tensor,
bias: torch.Tensor,
n_group: int,
Expand Down
10 changes: 5 additions & 5 deletions tests/model_optimizations/test_dsv3_fused_routing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Test for NoAuxTc (DSv3 Fused Routing) Kernel
Test for fused_topk_deepseek (DSv3 Fused Routing) Kernel

This test validates the NoAuxTc kernel against a reference implementation,
This test validates the fused_topk_deepseek kernel against a reference implementation,
accounting for numerical precision and tie-breaking differences.

================================================================================
Expand Down Expand Up @@ -118,7 +118,7 @@

import torch
import pytest
from flashinfer.dsv3_ops import NoAuxTc
from flashinfer.dsv3_ops import fused_topk_deepseek
# from flashinfer.utils import get_compute_capability


Expand Down Expand Up @@ -429,7 +429,7 @@ def test_dsv3_fused_routing_op(
num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type
):
"""
Test NoAuxTc kernel against reference implementation.
Test fused_topk_deepseek kernel against reference implementation.

Validates:
1. Expert selection equivalence (allowing for ties)
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_dsv3_fused_routing_op(
topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=data_type)
topk_indices = torch.zeros(num_tokens, topk, device="cuda", dtype=torch.int32)

NoAuxTc(
fused_topk_deepseek(
scores,
bias,
n_group,
Expand Down