1+ # SPDX-License-Identifier: Apache-2.0
2+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ from dataclasses import dataclass , fields
4+
5+ import pytest
6+ import torch
7+ import torch .nn .functional as F
8+ import triton_kernels .swiglu
9+ from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig
10+ from triton_kernels .numerics import InFlexData
11+ from triton_kernels .numerics_details .mxfp import (downcast_to_mxfp ,
12+ upcast_from_mxfp )
13+ from triton_kernels .tensor import FP4 , convert_layout , wrap_torch_tensor
14+ from triton_kernels .tensor_details import layout
15+ from triton_kernels .testing import assert_close
16+
17+ from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
18+ BatchedPrepareAndFinalize )
19+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
20+ from vllm .model_executor .layers .fused_moe .gpt_oss_triton_kernels_moe import (
21+ BatchedOAITritonExperts , triton_kernel_moe_forward )
22+ from vllm .model_executor .layers .fused_moe .modular_kernel import (
23+ FusedMoEModularKernel )
24+ from vllm .model_executor .layers .utils import shuffle_weight
25+ from vllm .utils import round_up
26+
27+
28+ def deshuffle (w : torch .Tensor ):
29+ first = w [..., ::2 ]
30+ second = w [..., 1 ::2 ]
31+
32+ deshuffled = torch .concat ((first , second ), dim = - 1 )
33+ return deshuffled
34+
35+
36+ def init_compute_data (M , K , N , E , a_dtype : str , w_dtype : str , num_warps : int ):
37+ randbits = [torch .randperm (E ) for _ in range (M )]
38+ x_list = [
39+ (- 1 )** i *
40+ ((16384 +
41+ ((i * 512 ) % 4096 ) + bits ).to (torch .int16 ).view (torch .bfloat16 ))
42+ for i , bits in enumerate (randbits )
43+ ]
44+ exp_data = torch .stack (x_list ).to (
45+ device = "cuda" ) # simulating gate_output (M, E)
46+
47+ # create input tensor
48+ x = torch .randn ((M , K ), dtype = torch .bfloat16 , device = "cuda" )
49+ w1 = torch .randn ((E , 2 * N , K ), dtype = torch .bfloat16 , device = "cuda" )
50+ w1_bias = torch .randn ((E , 2 * N ), dtype = torch .bfloat16 , device = "cuda" )
51+
52+ w2 = torch .randn ((E , K , N ), dtype = torch .bfloat16 , device = "cuda" )
53+ w2_bias = torch .randn ((E , K ), dtype = torch .bfloat16 , device = "cuda" )
54+
55+ exp_data_tri = exp_data .clone ()
56+ x_tri = x .clone ()
57+ w1_tri = w1 .clone ()
58+ w2_tri = w2 .clone ()
59+
60+ w1_bias_tri = w1_bias .clone ()
61+ w2_bias_tri = w2_bias .clone ()
62+ w1_bias_tri = w1_bias_tri .to (torch .float32 )
63+ w2_bias_tri = w2_bias_tri .to (torch .float32 )
64+
65+ dtype_dict = {
66+ "bf16" : torch .bfloat16 ,
67+ "fp8_e4m3" : torch .float8_e4m3fn ,
68+ "fp8_e5m2" : torch .float8_e5m2
69+ }
70+
71+ x = x .to (dtype_dict [a_dtype ]).to (torch .bfloat16 )
72+ if w_dtype != "mx4" :
73+ # simulate quantization support on reference impl
74+ w1 = w1 .to (dtype_dict [w_dtype ]).to (torch .bfloat16 )
75+ w2 = w2 .to (dtype_dict [w_dtype ]).to (torch .bfloat16 )
76+
77+ # triton moe kernel use transposed shape for matmul
78+ w1_tri = w1_tri .transpose (- 2 , - 1 )
79+ w2_tri = w2_tri .transpose (- 2 , - 1 )
80+
81+ # shuffle weights
82+ w1_tri = shuffle_weight (w1_tri )
83+ w1_bias_tri = shuffle_weight (w1_bias_tri )
84+
85+ # quant triton_weights
86+ x_tri = x .to (dtype_dict [a_dtype ])
87+ if w_dtype != "mx4" :
88+ pytest .skip ("NYI" )
89+ else : # quantize to mx4
90+ # careful on the padding here, the activation padding need to be
91+ # multiple of 64, the actual engine is not implemented
92+ w1_bottom_pad = round_up (w1_tri .shape [1 ], 64 ) - w1_tri .shape [1 ]
93+ w1_right_pad = round_up (w1_tri .shape [2 ], 128 ) - w1_tri .shape [2 ]
94+
95+ w2_bottom_pad = w1_right_pad // 2
96+ w2_right_pad = w1_bottom_pad
97+
98+ x_pad = w1_bottom_pad
99+
100+ w1_tri = F .pad (w1_tri , (0 , w1_right_pad , 0 , w1_bottom_pad , 0 , 0 ),
101+ mode = "constant" ,
102+ value = 0 )
103+ w2_tri = F .pad (w2_tri , (0 , w2_right_pad , 0 , w2_bottom_pad , 0 , 0 ),
104+ mode = "constant" ,
105+ value = 0 )
106+
107+ w1_bias_tri = F .pad (w1_bias_tri , (0 , w1_right_pad , 0 , 0 ),
108+ mode = "constant" ,
109+ value = 0 )
110+ w2_bias_tri = F .pad (w2_bias_tri , (0 , w2_right_pad , 0 , 0 ),
111+ mode = "constant" ,
112+ value = 0 )
113+
114+ x_tri = F .pad (x_tri , (0 , x_pad , 0 , 0 ), mode = "constant" , value = 0 )
115+
116+ w_layout , w_layout_opts = layout .make_default_matmul_mxfp4_w_layout (
117+ mx_axis = 1 )
118+ w_scale_layout , w_scale_layout_opts = (
119+ layout .make_default_matmul_mxfp4_w_scale_layout (
120+ mx_axis = 1 , num_warps = num_warps ))
121+
122+ w1_tri , w1_scale_tri = downcast_to_mxfp (w1_tri , torch .uint8 , axis = 1 )
123+ w1 = upcast_from_mxfp (w1_tri , w1_scale_tri , torch .bfloat16 , axis = 1 )
124+
125+ w2_tri , w2_scale_tri = downcast_to_mxfp (w2_tri , torch .uint8 , axis = 1 )
126+ w2 = upcast_from_mxfp (w2_tri , w2_scale_tri , torch .bfloat16 , axis = 1 )
127+
128+ w1_tri = convert_layout (wrap_torch_tensor (w1_tri , FP4 ), w_layout ,
129+ ** w_layout_opts )
130+ w1_scale_tri = convert_layout (wrap_torch_tensor (w1_scale_tri ),
131+ w_scale_layout , ** w_scale_layout_opts )
132+
133+ w2_tri = convert_layout (wrap_torch_tensor (w2_tri , FP4 ), w_layout ,
134+ ** w_layout_opts )
135+ w2_scale_tri = convert_layout (wrap_torch_tensor (w2_scale_tri ),
136+ w_scale_layout , ** w_scale_layout_opts )
137+
138+ pc1 = PrecisionConfig (weight_scale = w1_scale_tri ,
139+ flex_ctx = FlexCtx (rhs_data = InFlexData ()))
140+ pc2 = PrecisionConfig (weight_scale = w2_scale_tri ,
141+ flex_ctx = FlexCtx (rhs_data = InFlexData ()))
142+
143+ # tucuate so the rest can run properly
144+ w1 = w1 [..., :K , :2 * N ]
145+ w2 = w2 [..., :N , :K ]
146+
147+ w1 = deshuffle (w1 )
148+
149+ w1 = w1 .transpose (- 1 , - 2 ).contiguous ()
150+ w2 = w2 .transpose (- 1 , - 2 ).contiguous ()
151+
152+ return (x , w1 , w1_bias , w2 , w2_bias , exp_data , x_tri , w1_tri , w2_tri ,
153+ exp_data_tri , w1_bias_tri , w2_bias_tri , pc1 , pc2 )
154+
155+
156+ @dataclass
157+ class ModelConfig :
158+ num_hidden_layers : int = 36
159+ num_experts : int = 128
160+ experts_per_token : int = 4
161+ vocab_size : int = 201088
162+ hidden_size : int = 2880
163+ intermediate_size : int = 2880
164+ head_dim : int = 64
165+ num_attention_heads : int = 64
166+ num_key_value_heads : int = 8
167+ sliding_window : int = 128
168+ initial_context_length : int = 4096
169+ rope_theta : float = 150000.0
170+ rope_scaling_factor : float = 32.0
171+ rope_ntk_alpha : float = 1.0
172+ rope_ntk_beta : float = 32.0
173+
174+
175+ def swiglu (x , alpha : float = 1.702 , limit : float = 1.0 ):
176+ # Note we add an extra bias of 1 to the linear layer
177+ x_glu , x_linear = torch .chunk (x , 2 , dim = - 1 )
178+ if limit is not None :
179+ x_glu = x_glu .clamp (max = limit )
180+ out_glu = x_glu * torch .sigmoid (alpha * x_glu )
181+ if limit is not None :
182+ x_linear = x_linear .clamp (min = - limit , max = limit )
183+ return out_glu * (x_linear + 1 )
184+
185+
186+ def oai_moe_forward (
187+ hidden_states : torch .Tensor , # (M, K)
188+ w1 : torch .Tensor , # (E, 2N)
189+ w1_bias : torch .Tensor , # (E, 2N, K)
190+ w2 : torch .Tensor , # (E, K, N)
191+ w2_bias : torch .Tensor , # (E, N)
192+ gating_output : torch .Tensor , # (M, E)
193+ topk : int ):
194+ # model.py 309:330, assuming gating and norm
195+ t = hidden_states
196+ experts = torch .topk (gating_output , k = topk , dim = - 1 , sorted = True )
197+ expert_weights = torch .nn .functional .softmax (experts .values , dim = 1 )
198+ expert_indices = experts .indices
199+
200+ # MLP #1
201+ mlp1_weight = w1 [expert_indices , ...]
202+ mlp1_bias = w1_bias [expert_indices , ...]
203+ t = torch .einsum ("beck,bk->bec" , mlp1_weight , t ) + mlp1_bias
204+ t = swiglu (t , limit = 7 )
205+
206+ # MLP #2
207+ mlp2_weight = w2 [expert_indices , ...]
208+ mlp2_bias = w2_bias [expert_indices , ...]
209+ t = torch .einsum ("beck,bek->bec" , mlp2_weight , t )
210+ t += mlp2_bias
211+
212+ # Weighted sum of experts
213+ t = torch .einsum ("bec,be->bc" , t , expert_weights )
214+
215+ return t
216+
217+
218+ @dataclass
219+ class Case :
220+ a_dtype : str
221+ w_dtype : str
222+
223+
224+ @pytest .mark .parametrize (
225+ ", " .join (f .name for f in fields (Case )),
226+ [
227+ tuple (getattr (case , f .name ) for f in fields (Case )) for case in [
228+ # Case(a_dtype="bf16", w_dtype="bf16"),
229+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
230+ Case (a_dtype = "bf16" , w_dtype = "mx4" )
231+ ]
232+ ],
233+ )
234+ @pytest .mark .parametrize ("num_token" , [2 ])
235+ @pytest .mark .parametrize ("tp" , [1 , 2 , 4 , 8 ])
236+ def test_equiv (num_token , a_dtype , w_dtype , tp ):
237+ M = num_token
238+ E = ModelConfig .num_experts
239+ K = ModelConfig .hidden_size
240+ N = ModelConfig .intermediate_size // tp
241+ topk = ModelConfig .experts_per_token
242+
243+ x , w1 , w1_bias , w2 , w2_bias , exp_data , \
244+ x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri ,\
245+ w2_bias_tri , pc1 , pc2 = init_compute_data (
246+ M , K , N , E , a_dtype , w_dtype , num_warps = 8 )
247+
248+ out_triton_monolithic = triton_kernel_moe_forward (
249+ hidden_states = x_tri ,
250+ w1 = w1_tri ,
251+ w2 = w2_tri ,
252+ gating_output = exp_data_tri ,
253+ topk = topk ,
254+ renormalize = True ,
255+ w1_bias = w1_bias_tri ,
256+ w2_bias = w2_bias_tri ,
257+ w1_precision = pc1 ,
258+ w2_precision = pc2 )
259+ out_triton_monolithic = out_triton_monolithic [..., :K ]
260+
261+ out_ref = oai_moe_forward (hidden_states = x ,
262+ w1 = w1 ,
263+ w1_bias = w1_bias ,
264+ w2 = w2 ,
265+ w2_bias = w2_bias ,
266+ gating_output = exp_data ,
267+ topk = topk )
268+ assert_close (ref = out_ref ,
269+ tri = out_triton_monolithic ,
270+ maxtol = 0.025 ,
271+ rmstol = 0.005 )
272+
273+
274+ def batched_moe (a : torch .Tensor , w1 , w2 , gating_output : torch .Tensor ,
275+ topk : int , renormalize : bool , w1_bias : torch .Tensor ,
276+ w2_bias : torch .Tensor , w1_precision : PrecisionConfig ,
277+ w2_precision : PrecisionConfig ) -> torch .Tensor :
278+ max_num_tokens = round_up (a .shape [0 ], 64 )
279+
280+ fused_experts = FusedMoEModularKernel (
281+ BatchedPrepareAndFinalize (max_num_tokens ,
282+ num_dispatchers = 1 ,
283+ num_local_experts = w1 .shape [0 ],
284+ rank = 0 ),
285+ BatchedOAITritonExperts (
286+ None ,
287+ max_num_tokens = max_num_tokens ,
288+ num_dispatchers = 1 ,
289+ w1_precision = w1_precision ,
290+ w2_precision = w2_precision ,
291+ ),
292+ )
293+
294+ extra_expert_args = {
295+ "w1_bias" : w1_bias ,
296+ "w2_bias" : w2_bias ,
297+ }
298+
299+ topk_weight , topk_ids , _ = fused_topk (a , gating_output , topk , renormalize )
300+
301+ return fused_experts (
302+ a ,
303+ w1 ,
304+ w2 ,
305+ topk_weight ,
306+ topk_ids ,
307+ extra_expert_args = extra_expert_args ,
308+ )
309+
310+
311+ @pytest .mark .parametrize (
312+ ", " .join (f .name for f in fields (Case )),
313+ [
314+ tuple (getattr (case , f .name ) for f in fields (Case )) for case in [
315+ # Case(a_dtype="bf16", w_dtype="bf16"),
316+ # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
317+ Case (a_dtype = "bf16" , w_dtype = "mx4" )
318+ ]
319+ ],
320+ )
321+ @pytest .mark .parametrize ("num_token" , [64 ])
322+ @pytest .mark .parametrize ("ep" , [1 , 2 , 4 , 8 ])
323+ def test_triton_kernel_batched_moe (num_token , a_dtype , w_dtype , ep ):
324+ M = num_token
325+ E = ModelConfig .num_experts // ep
326+ K = ModelConfig .hidden_size
327+ N = ModelConfig .intermediate_size
328+ topk = ModelConfig .experts_per_token
329+
330+ x , w1 , w1_bias , w2 , w2_bias , exp_data , \
331+ x_tri , w1_tri , w2_tri , exp_data_tri , w1_bias_tri , \
332+ w2_bias_tri , pc1 , pc2 = init_compute_data (
333+ M , K , N , E , a_dtype , w_dtype , num_warps = 4 )
334+
335+ out_tri = batched_moe (a = x_tri ,
336+ w1 = w1_tri ,
337+ w2 = w2_tri ,
338+ gating_output = exp_data_tri ,
339+ topk = topk ,
340+ renormalize = True ,
341+ w1_bias = w1_bias_tri ,
342+ w2_bias = w2_bias_tri ,
343+ w1_precision = pc1 ,
344+ w2_precision = pc2 )
345+ out_tri = out_tri [..., :K ]
346+
347+ out_ref = oai_moe_forward (hidden_states = x ,
348+ w1 = w1 ,
349+ w1_bias = w1_bias ,
350+ w2 = w2 ,
351+ w2_bias = w2_bias ,
352+ gating_output = exp_data ,
353+ topk = topk )
354+ assert_close (ref = out_ref , tri = out_tri , maxtol = 0.025 , rmstol = 0.005 )
355+
356+
357+ def test_unit_shuffle ():
358+ N = ModelConfig .intermediate_size
359+ K = ModelConfig .hidden_size
360+ m = torch .randn ((K , 2 * N ), dtype = torch .bfloat16 , device = "cuda" )
361+
362+ x = torch .randn (K , dtype = torch .bfloat16 , device = "cuda" )
363+
364+ m_shuffled = shuffle_weight (m )
365+
366+ out_ref = x @ m
367+ out_ref = swiglu (out_ref , limit = 1.0 )
368+
369+ out = x @ m_shuffled
370+ out = triton_kernels .swiglu .swiglu_torch (
371+ out ,
372+ alpha = 1.702 ,
373+ precision_config = triton_kernels .swiglu .PrecisionConfig (limit = 1.0 ))
374+
375+ assert_close (ref = out_ref , tri = out )
0 commit comments