66
77import vllm .model_executor .layers .fused_moe .modular_kernel as mk
88from vllm .logger import init_logger
9- from vllm .model_executor .layers .fused_moe .utils import (
10- _resize_cache , per_token_group_quant_fp8 )
9+ from vllm .model_executor .layers .fused_moe .utils import _resize_cache
10+ from vllm . triton_utils import tl , triton
1111
1212logger = init_logger (__name__ )
1313
1414has_deep_gemm = importlib .util .find_spec ("deep_gemm" ) is not None
1515
1616
17+ @triton .jit
18+ def _silu_mul_fp8_quant_deep_gemm (
19+ # Pointers ------------------------------------------------------------
20+ input_ptr , # 16-bit activations (E, T, 2*H)
21+ y_q_ptr , # fp8 quantized activations (E, T, H)
22+ y_s_ptr , # 16-bit scales (E, T, G)
23+ counts_ptr , # int32 num tokens per expert (E)
24+
25+ # Sizes ---------------------------------------------------------------
26+ H : tl .constexpr , # hidden dimension (per output)
27+ GROUP_SIZE : tl .constexpr , # elements per group (usually 128)
28+
29+ # Strides for input (elements) ---------------------------------------
30+ stride_i_e ,
31+ stride_i_t ,
32+ stride_i_h ,
33+
34+ # Strides for y_q (elements) -----------------------------------------
35+ stride_yq_e ,
36+ stride_yq_t ,
37+ stride_yq_h ,
38+
39+ # Strides for y_s (elements) -----------------------------------------
40+ stride_ys_e ,
41+ stride_ys_t ,
42+ stride_ys_g ,
43+
44+ # Stride for counts (elements)
45+ stride_counts_e ,
46+
47+ # Numeric params ------------------------------------------------------
48+ eps : tl .constexpr ,
49+ fp8_min : tl .constexpr ,
50+ fp8_max : tl .constexpr ,
51+
52+ # Meta ---------------------------------------------------------------
53+ BLOCK : tl .constexpr ,
54+ ):
55+ G = H // GROUP_SIZE
56+
57+ # map program id -> (e, g)
58+ pid = tl .program_id (0 )
59+ e = pid // G
60+ g = pid % G
61+
62+ e = e .to (tl .int64 )
63+ g = g .to (tl .int64 )
64+
65+ # number of valid tokens for this expert
66+ n_tokens = tl .load (counts_ptr + e * stride_counts_e ).to (tl .int64 )
67+
68+ cols = tl .arange (0 , BLOCK )
69+ cols = cols .to (tl .int64 )
70+ mask_h = cols < BLOCK
71+
72+ t = tl .zeros ([], tl .int64 )
73+ while t < n_tokens :
74+ base_i_offset = (e * stride_i_e + t * stride_i_t +
75+ g * GROUP_SIZE * stride_i_h )
76+ base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
77+ g * GROUP_SIZE * stride_yq_h )
78+ base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
79+
80+ mask = mask_h
81+ x = tl .load (input_ptr + base_i_offset + cols * stride_i_h ,
82+ mask = mask ,
83+ other = 0.0 ).to (tl .float32 )
84+ y2 = tl .load (input_ptr + base_i_offset + H * stride_i_h +
85+ cols * stride_i_h ,
86+ mask = mask ,
87+ other = 0.0 ).to (tl .float32 )
88+
89+ x = x * (1.0 / (1.0 + tl .exp (- x )))
90+ y = x * y2
91+
92+ _absmax = tl .maximum (tl .max (tl .abs (y )), eps )
93+ y_s = _absmax / fp8_max
94+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
95+
96+ tl .store (y_q_ptr + base_yq_offset + cols * stride_yq_h , y_q , mask = mask )
97+ tl .store (y_s_ptr + base_ys_offset , y_s )
98+
99+ t += 1
100+
101+
102+ def silu_mul_fp8_quant_deep_gemm (
103+ y : torch .Tensor , # (E, T, 2*H) float32
104+ tokens_per_expert : torch .Tensor , # (E,) number of valid tokens per expert
105+ group_size : int = 128 ,
106+ eps : float = 1e-10 ,
107+ ):
108+ """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
109+
110+ y has shape (E, T, 2*H). The first half of the last dimension is
111+ silu-activated, multiplied by the second half, then quantized into FP8.
112+
113+ Returns `(y_q, y_s)` where
114+ * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
115+ * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
116+ """
117+ assert y .ndim == 3 , "y must be (E, T, 2*H)"
118+ E , T , H2 = y .shape
119+ assert H2 % 2 == 0 , "last dim of y must be even (2*H)"
120+ H = H2 // 2
121+ G = H // group_size
122+ assert H % group_size == 0 , "H must be divisible by group_size"
123+ assert tokens_per_expert .ndim == 1 and tokens_per_expert .shape [0 ] == E , \
124+ "tokens_per_expert must be shape (E,)"
125+ tokens_per_expert = tokens_per_expert .to (device = y .device ,
126+ dtype = torch .int32 )
127+
128+ # allocate outputs
129+ fp8_dtype = torch .float8_e4m3fn
130+ y_q = torch .empty ((E , T , H ), dtype = fp8_dtype , device = y .device )
131+
132+ # strides (elements)
133+ stride_i_e , stride_i_t , stride_i_h = y .stride ()
134+ stride_yq_e , stride_yq_t , stride_yq_h = y_q .stride ()
135+
136+ # desired scale strides (elements): (T*G, 1, T)
137+ stride_ys_e = T * G
138+ stride_ys_t = 1
139+ stride_ys_g = T
140+ y_s = torch .empty_strided ((E , T , G ),
141+ (stride_ys_e , stride_ys_t , stride_ys_g ),
142+ dtype = torch .float32 ,
143+ device = y .device )
144+
145+ stride_cnt_e = tokens_per_expert .stride ()[0 ]
146+
147+ # static grid over experts and H-groups.
148+ # A loop inside the kernel handles the token dim
149+ grid = (E * G , )
150+
151+ f_info = torch .finfo (fp8_dtype )
152+ fp8_max = f_info .max
153+ fp8_min = f_info .min
154+
155+ _silu_mul_fp8_quant_deep_gemm [grid ](
156+ y ,
157+ y_q ,
158+ y_s ,
159+ tokens_per_expert ,
160+ H ,
161+ group_size ,
162+ stride_i_e ,
163+ stride_i_t ,
164+ stride_i_h ,
165+ stride_yq_e ,
166+ stride_yq_t ,
167+ stride_yq_h ,
168+ stride_ys_e ,
169+ stride_ys_t ,
170+ stride_ys_g ,
171+ stride_cnt_e ,
172+ eps ,
173+ fp8_min ,
174+ fp8_max ,
175+ BLOCK = group_size ,
176+ num_warps = 4 ,
177+ )
178+
179+ return y_q , y_s
180+
181+
17182class BatchedDeepGemmExperts (mk .FusedMoEPermuteExpertsUnpermute ):
18183
19184 # The Deep Gemm kernels only support block size of 128
@@ -96,7 +261,6 @@ def apply(
96261 hidden_states , w1 , w2 , topk_ids )
97262
98263 workspace1 = _resize_cache (workspace13 , (E , max_num_tokens , N ))
99- workspace2 = _resize_cache (workspace2 , (E , max_num_tokens , N // 2 ))
100264
101265 # (from deepgemm docs) : A value hint (which is a value on CPU)
102266 # for the M expectation of each batch, correctly setting this value
@@ -109,19 +273,9 @@ def apply(
109273 masked_m = expert_num_tokens ,
110274 expected_m = expected_m )
111275
112- # TODO (varun) [Optimization]: Use a batched version of activation.
113- # Similarly for the quant below.
114- self .activation (activation , workspace2 , workspace1 .view (- 1 , N ))
115-
116- w2_hidden_size = workspace2 .size (- 1 )
117- workspace2 = workspace2 .view (- 1 , w2_hidden_size )
118-
119- a2q_scale : Optional [torch .Tensor ] = None
120- a2q , a2q_scale = per_token_group_quant_fp8 (workspace2 ,
121- self .block_shape [1 ],
122- column_major_scales = False )
123- a2q = a2q .view (E , max_num_tokens , - 1 )
124- a2q_scale = a2q_scale .view (E , max_num_tokens , - 1 )
276+ assert expert_num_tokens is not None
277+ a2q , a2q_scale = silu_mul_fp8_quant_deep_gemm (workspace1 ,
278+ expert_num_tokens )
125279
126280 dg .m_grouped_gemm_fp8_fp8_bf16_nt_masked ((a2q , a2q_scale ),
127281 (w2 , w2_scale ),
0 commit comments