Skip to content

Commit 1f3fb07

Browse files
authored
Merge branch 'main' into artifacts
2 parents 43a2ae7 + 0260ab3 commit 1f3fb07

12 files changed

+1046
-715
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 87 additions & 51 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 253 additions & 165 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_llama4.cu

Lines changed: 143 additions & 84 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 239 additions & 83 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_runner.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
7070
routingData.mUsePdl = true;
7171

7272
// output:
73-
routingData.mPtrExpertIdx = routingExpertIndexes;
73+
routingData.mPtrTopKPacked = routingExpertIndexes;
7474
routingData.mPtrExpertCounts = expertCountHistogram;
7575
routingData.mPtrPermutedIdxSize = permutedIdxSize;
7676
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
7777
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
78-
routingData.mPtrExpertWeights = expertWeights;
78+
routingData.mPtrTopKWeights = expertWeights;
7979

8080
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
8181
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
@@ -107,12 +107,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
107107
routingData.mUsePdl = true;
108108

109109
// output:
110-
routingData.mPtrExpertIdx = routingExpertIndexes;
110+
routingData.mPtrTopKPacked = routingExpertIndexes;
111111
routingData.mPtrExpertCounts = expertCountHistogram;
112112
routingData.mPtrPermutedIdxSize = permutedIdxSize;
113113
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
114114
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
115-
routingData.mPtrExpertWeights = expertWeights;
115+
routingData.mPtrTopKWeights = expertWeights;
116116

117117
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
118118
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
@@ -149,12 +149,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
149149
//
150150
// Outputs
151151
//
152-
routingData.mPtrExpertIdx = routingExpertIndexes;
152+
routingData.mPtrTopKPacked = routingExpertIndexes;
153153
routingData.mPtrExpertCounts = expertCountHistogram;
154154
routingData.mPtrPermutedIdxSize = permutedIdxSize;
155155
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
156156
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
157-
routingData.mPtrExpertWeights = expertWeights;
157+
routingData.mPtrTopKWeights = expertWeights;
158158

159159
//
160160
// Grouped Gemm Launch Config Buffers

flashinfer/artifacts.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,16 @@ class ArtifactPath:
9999

100100
@dataclass(frozen=True)
101101
class MetaInfoHash:
102-
"""
103-
Encode sha256 hash of kernel_map.json for DEEPGEMM
104-
"""
105-
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
102+
DEEPGEMM: str = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb"
103+
TRTLLM_GEN_FMHA: str = (
104+
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
105+
)
106+
TRTLLM_GEN_BMM: str = (
107+
"4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
108+
)
109+
TRTLLM_GEN_GEMM: str = (
110+
"bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340"
111+
)
106112

107113

108114
class CheckSumHash:

flashinfer/fused_moe/core.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def _maybe_get_cached_w3_w1_permute_indices(
184184
epilogue_tile_m: int,
185185
num_elts_per_sf: Union[None, int] = None,
186186
) -> torch.Tensor:
187-
if dst_w3_w1_weight.shape not in _cache_permute_indices:
187+
# Create a unique cache key (weight_type, weight_shape)
188+
cache_key = ("w3_w1", dst_w3_w1_weight.shape)
189+
if cache_key not in _cache_permute_indices:
188190
# Get permute indices and chain them together
189191
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
190192
if num_elts_per_sf is None:
@@ -198,10 +200,10 @@ def _maybe_get_cached_w3_w1_permute_indices(
198200
num_elts_per_sf=num_elts_per_sf,
199201
)
200202
# Memoize permute indices as recompute is **very** costly
201-
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
203+
_cache_permute_indices[cache_key] = permute0[permute1].to(
202204
dst_w3_w1_weight.device
203205
)
204-
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
206+
permute_indices = _cache_permute_indices[cache_key]
205207
return permute_indices
206208

207209

@@ -211,7 +213,9 @@ def get_w2_permute_indices_with_cache(
211213
epilogue_tile_m: int,
212214
num_elts_per_sf: Union[None, int] = None,
213215
) -> torch.Tensor:
214-
if dst_w2_weight.shape not in _cache_permute_indices:
216+
# Create a unique cache key (weight_type, weight_shape)
217+
cache_key = ("w2", dst_w2_weight.shape)
218+
if cache_key not in _cache_permute_indices:
215219
if num_elts_per_sf is None:
216220
permute_indices = get_shuffle_matrix_a_row_indices(
217221
dst_w2_weight, epilogue_tile_m
@@ -223,8 +227,8 @@ def get_w2_permute_indices_with_cache(
223227
num_elts_per_sf=num_elts_per_sf,
224228
).to(dst_w2_weight.device)
225229
# Memoize permute indices as recompute is **very** costly
226-
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
227-
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
230+
_cache_permute_indices[cache_key] = permute_indices
231+
permute_indices = _cache_permute_indices[cache_key]
228232
return permute_indices
229233

230234

@@ -1097,12 +1101,12 @@ def trtllm_fp8_per_tensor_scale_moe_op(
10971101
output2_scales_scalar: torch.Tensor,
10981102
num_experts: int,
10991103
top_k: int,
1100-
n_group: int,
1101-
topk_group: int,
1104+
n_group: Optional[int],
1105+
topk_group: Optional[int],
11021106
intermediate_size: int,
11031107
local_expert_offset: int,
11041108
local_num_experts: int,
1105-
routed_scaling_factor: float,
1109+
routed_scaling_factor: Optional[float],
11061110
use_routing_scales_on_input: bool,
11071111
tile_tokens_dim: int = 8,
11081112
routing_method_type: int = 0,
@@ -1151,12 +1155,12 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
11511155
output2_scales_scalar: torch.Tensor,
11521156
num_experts: int,
11531157
top_k: int,
1154-
n_group: int,
1155-
topk_group: int,
1158+
n_group: Optional[int],
1159+
topk_group: Optional[int],
11561160
intermediate_size: int,
11571161
local_expert_offset: int,
11581162
local_num_experts: int,
1159-
routed_scaling_factor: float,
1163+
routed_scaling_factor: Optional[float],
11601164
use_routing_scales_on_input: bool,
11611165
tile_tokens_dim: int = 8,
11621166
routing_method_type: int = 0,
@@ -1183,12 +1187,12 @@ def trtllm_fp8_block_scale_moe_op(
11831187
output: torch.Tensor,
11841188
num_experts: int,
11851189
top_k: int,
1186-
n_group: int,
1187-
topk_group: int,
1190+
n_group: Optional[int],
1191+
topk_group: Optional[int],
11881192
intermediate_size: int,
11891193
local_expert_offset: int,
11901194
local_num_experts: int,
1191-
routed_scaling_factor: float,
1195+
routed_scaling_factor: Optional[float],
11921196
tile_tokens_dim: int,
11931197
routing_method_type: int,
11941198
use_shuffled_weight: bool = False,
@@ -1197,6 +1201,7 @@ def trtllm_fp8_block_scale_moe_op(
11971201
) -> torch.Tensor:
11981202
if enable_pdl is None:
11991203
enable_pdl = device_support_pdl(hidden_states.device)
1204+
12001205
# Call the C++ function for block scale MoE
12011206
moe_op.trtllm_fp8_block_scale_moe(
12021207
routing_logits,
@@ -1238,12 +1243,12 @@ def _fake_trtllm_fp8_block_scale_moe(
12381243
output: torch.Tensor,
12391244
num_experts: int,
12401245
top_k: int,
1241-
n_group: int,
1242-
topk_group: int,
1246+
n_group: Optional[int],
1247+
topk_group: Optional[int],
12431248
intermediate_size: int,
12441249
local_expert_offset: int,
12451250
local_num_experts: int,
1246-
routed_scaling_factor: float,
1251+
routed_scaling_factor: Optional[float],
12471252
tile_tokens_dim: int = 8,
12481253
routing_method_type: int = 0,
12491254
use_shuffled_weight: bool = False,
@@ -1503,12 +1508,12 @@ def trtllm_fp8_per_tensor_scale_moe(
15031508
output2_scales_scalar: torch.Tensor,
15041509
num_experts: int,
15051510
top_k: int,
1506-
n_group: int,
1507-
topk_group: int,
1511+
n_group: Optional[int],
1512+
topk_group: Optional[int],
15081513
intermediate_size: int,
15091514
local_expert_offset: int,
15101515
local_num_experts: int,
1511-
routed_scaling_factor: float,
1516+
routed_scaling_factor: Optional[float],
15121517
use_routing_scales_on_input: bool,
15131518
tile_tokens_dim: int = 8,
15141519
routing_method_type: int = 0,
@@ -1576,12 +1581,12 @@ def trtllm_fp8_block_scale_moe(
15761581
gemm2_weights_scale: torch.Tensor,
15771582
num_experts: int,
15781583
top_k: int,
1579-
n_group: int,
1580-
topk_group: int,
1584+
n_group: Optional[int],
1585+
topk_group: Optional[int],
15811586
intermediate_size: int,
15821587
local_expert_offset: int,
15831588
local_num_experts: int,
1584-
routed_scaling_factor: float,
1589+
routed_scaling_factor: Optional[float],
15851590
tile_tokens_dim: int = 8,
15861591
routing_method_type: int = 0,
15871592
use_shuffled_weight: bool = False,

0 commit comments

Comments
 (0)