@@ -184,7 +184,9 @@ def _maybe_get_cached_w3_w1_permute_indices(
184184 epilogue_tile_m : int ,
185185 num_elts_per_sf : Union [None , int ] = None ,
186186) -> torch .Tensor :
187- if dst_w3_w1_weight .shape not in _cache_permute_indices :
187+ # Create a unique cache key (weight_type, weight_shape)
188+ cache_key = ("w3_w1" , dst_w3_w1_weight .shape )
189+ if cache_key not in _cache_permute_indices :
188190 # Get permute indices and chain them together
189191 permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (dst_w3_w1_weight )
190192 if num_elts_per_sf is None :
@@ -198,10 +200,10 @@ def _maybe_get_cached_w3_w1_permute_indices(
198200 num_elts_per_sf = num_elts_per_sf ,
199201 )
200202 # Memoize permute indices as recompute is **very** costly
201- _cache_permute_indices [dst_w3_w1_weight . shape ] = permute0 [permute1 ].to (
203+ _cache_permute_indices [cache_key ] = permute0 [permute1 ].to (
202204 dst_w3_w1_weight .device
203205 )
204- permute_indices = _cache_permute_indices [dst_w3_w1_weight . shape ]
206+ permute_indices = _cache_permute_indices [cache_key ]
205207 return permute_indices
206208
207209
@@ -211,7 +213,9 @@ def get_w2_permute_indices_with_cache(
211213 epilogue_tile_m : int ,
212214 num_elts_per_sf : Union [None , int ] = None ,
213215) -> torch .Tensor :
214- if dst_w2_weight .shape not in _cache_permute_indices :
216+ # Create a unique cache key (weight_type, weight_shape)
217+ cache_key = ("w2" , dst_w2_weight .shape )
218+ if cache_key not in _cache_permute_indices :
215219 if num_elts_per_sf is None :
216220 permute_indices = get_shuffle_matrix_a_row_indices (
217221 dst_w2_weight , epilogue_tile_m
@@ -223,8 +227,8 @@ def get_w2_permute_indices_with_cache(
223227 num_elts_per_sf = num_elts_per_sf ,
224228 ).to (dst_w2_weight .device )
225229 # Memoize permute indices as recompute is **very** costly
226- _cache_permute_indices [dst_w2_weight . shape ] = permute_indices
227- permute_indices = _cache_permute_indices [dst_w2_weight . shape ]
230+ _cache_permute_indices [cache_key ] = permute_indices
231+ permute_indices = _cache_permute_indices [cache_key ]
228232 return permute_indices
229233
230234
@@ -1097,12 +1101,12 @@ def trtllm_fp8_per_tensor_scale_moe_op(
10971101 output2_scales_scalar : torch .Tensor ,
10981102 num_experts : int ,
10991103 top_k : int ,
1100- n_group : int ,
1101- topk_group : int ,
1104+ n_group : Optional [ int ] ,
1105+ topk_group : Optional [ int ] ,
11021106 intermediate_size : int ,
11031107 local_expert_offset : int ,
11041108 local_num_experts : int ,
1105- routed_scaling_factor : float ,
1109+ routed_scaling_factor : Optional [ float ] ,
11061110 use_routing_scales_on_input : bool ,
11071111 tile_tokens_dim : int = 8 ,
11081112 routing_method_type : int = 0 ,
@@ -1151,12 +1155,12 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
11511155 output2_scales_scalar : torch .Tensor ,
11521156 num_experts : int ,
11531157 top_k : int ,
1154- n_group : int ,
1155- topk_group : int ,
1158+ n_group : Optional [ int ] ,
1159+ topk_group : Optional [ int ] ,
11561160 intermediate_size : int ,
11571161 local_expert_offset : int ,
11581162 local_num_experts : int ,
1159- routed_scaling_factor : float ,
1163+ routed_scaling_factor : Optional [ float ] ,
11601164 use_routing_scales_on_input : bool ,
11611165 tile_tokens_dim : int = 8 ,
11621166 routing_method_type : int = 0 ,
@@ -1183,12 +1187,12 @@ def trtllm_fp8_block_scale_moe_op(
11831187 output : torch .Tensor ,
11841188 num_experts : int ,
11851189 top_k : int ,
1186- n_group : int ,
1187- topk_group : int ,
1190+ n_group : Optional [ int ] ,
1191+ topk_group : Optional [ int ] ,
11881192 intermediate_size : int ,
11891193 local_expert_offset : int ,
11901194 local_num_experts : int ,
1191- routed_scaling_factor : float ,
1195+ routed_scaling_factor : Optional [ float ] ,
11921196 tile_tokens_dim : int ,
11931197 routing_method_type : int ,
11941198 use_shuffled_weight : bool = False ,
@@ -1197,6 +1201,7 @@ def trtllm_fp8_block_scale_moe_op(
11971201 ) -> torch .Tensor :
11981202 if enable_pdl is None :
11991203 enable_pdl = device_support_pdl (hidden_states .device )
1204+
12001205 # Call the C++ function for block scale MoE
12011206 moe_op .trtllm_fp8_block_scale_moe (
12021207 routing_logits ,
@@ -1238,12 +1243,12 @@ def _fake_trtllm_fp8_block_scale_moe(
12381243 output : torch .Tensor ,
12391244 num_experts : int ,
12401245 top_k : int ,
1241- n_group : int ,
1242- topk_group : int ,
1246+ n_group : Optional [ int ] ,
1247+ topk_group : Optional [ int ] ,
12431248 intermediate_size : int ,
12441249 local_expert_offset : int ,
12451250 local_num_experts : int ,
1246- routed_scaling_factor : float ,
1251+ routed_scaling_factor : Optional [ float ] ,
12471252 tile_tokens_dim : int = 8 ,
12481253 routing_method_type : int = 0 ,
12491254 use_shuffled_weight : bool = False ,
@@ -1503,12 +1508,12 @@ def trtllm_fp8_per_tensor_scale_moe(
15031508 output2_scales_scalar : torch .Tensor ,
15041509 num_experts : int ,
15051510 top_k : int ,
1506- n_group : int ,
1507- topk_group : int ,
1511+ n_group : Optional [ int ] ,
1512+ topk_group : Optional [ int ] ,
15081513 intermediate_size : int ,
15091514 local_expert_offset : int ,
15101515 local_num_experts : int ,
1511- routed_scaling_factor : float ,
1516+ routed_scaling_factor : Optional [ float ] ,
15121517 use_routing_scales_on_input : bool ,
15131518 tile_tokens_dim : int = 8 ,
15141519 routing_method_type : int = 0 ,
@@ -1576,12 +1581,12 @@ def trtllm_fp8_block_scale_moe(
15761581 gemm2_weights_scale : torch .Tensor ,
15771582 num_experts : int ,
15781583 top_k : int ,
1579- n_group : int ,
1580- topk_group : int ,
1584+ n_group : Optional [ int ] ,
1585+ topk_group : Optional [ int ] ,
15811586 intermediate_size : int ,
15821587 local_expert_offset : int ,
15831588 local_num_experts : int ,
1584- routed_scaling_factor : float ,
1589+ routed_scaling_factor : Optional [ float ] ,
15851590 tile_tokens_dim : int = 8 ,
15861591 routing_method_type : int = 0 ,
15871592 use_shuffled_weight : bool = False ,
0 commit comments