|
6 | 6 | import torch |
7 | 7 |
|
8 | 8 | import vllm.model_executor.layers.fused_moe.modular_kernel as mk |
| 9 | +from vllm import envs |
| 10 | +from vllm.logger import init_logger |
9 | 11 | from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig |
10 | 12 | from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( |
11 | 13 | TopKWeightAndReduceDelegate, |
|
24 | 26 | DEEPEP_QUANT_BLOCK_SIZE = 128 |
25 | 27 | DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] |
26 | 28 |
|
| 29 | +logger = init_logger(__name__) |
| 30 | + |
27 | 31 |
|
28 | 32 | def dequant_fp8( |
29 | 33 | expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor |
@@ -110,21 +114,31 @@ def _do_quant( |
110 | 114 | assert isinstance(x, torch.Tensor) |
111 | 115 |
|
112 | 116 | num_experts, max_tokens, hidden_dim = x.size() |
113 | | - |
114 | | - # TODO (varun): Optimization - Use a batched version of quant |
115 | | - x = x.view((-1, hidden_dim)) |
116 | | - x, x_scales = moe_kernel_quantize_input( |
117 | | - x, |
118 | | - quant_config.a1_scale, |
119 | | - quant_config.quant_dtype, |
120 | | - quant_config.per_act_token_quant, |
121 | | - quant_config.block_shape, |
122 | | - ) |
123 | | - x = x.view((num_experts, -1, hidden_dim)) |
124 | | - |
125 | | - if quant_config.quant_dtype is not None: |
126 | | - assert x_scales is not None |
127 | | - x_scales = normalize_batched_scales_shape(x_scales, num_experts) |
| 117 | + if not envs.VLLM_DEEPEPLL_BF16_DISPATCH: |
| 118 | + # TODO (varun): Optimization - Use a batched version of quant |
| 119 | + x = x.view((-1, hidden_dim)) |
| 120 | + x, x_scales = moe_kernel_quantize_input( |
| 121 | + x, |
| 122 | + quant_config.a1_scale, |
| 123 | + quant_config.quant_dtype, |
| 124 | + quant_config.per_act_token_quant, |
| 125 | + quant_config.block_shape, |
| 126 | + ) |
| 127 | + x = x.view((num_experts, -1, hidden_dim)) |
| 128 | + |
| 129 | + if quant_config.quant_dtype is not None: |
| 130 | + assert x_scales is not None |
| 131 | + x_scales = normalize_batched_scales_shape(x_scales, num_experts) |
| 132 | + else: |
| 133 | + # BF16 dispatch path - no quantization |
| 134 | + # TODO([email protected]): enable nvfp4 dispatch once DEEPEP is ready. |
| 135 | + logger.info_once("Using BF16 dispatch path for DeepEPLLPrepareAndFinalize") |
| 136 | + assert x.dtype == torch.bfloat16, ( |
| 137 | + "BF16 dispatch requires input to be in BF16" |
| 138 | + ) |
| 139 | + x_scales = None |
| 140 | + x = x.view((num_experts, -1, hidden_dim)) |
| 141 | + # print(f"after deepepll: x.shape = {x.shape}") |
128 | 142 |
|
129 | 143 | return x, x_scales |
130 | 144 |
|
@@ -262,6 +276,8 @@ def _finalize( |
262 | 276 |
|
263 | 277 | # TODO (varun) : Enable zero copy mode |
264 | 278 | dbo_maybe_run_recv_hook() |
| 279 | + # print("xxx"*100, fused_expert_output.shape) |
| 280 | + # print("ttt"*100, fused_expert_output.dtype) |
265 | 281 | _, _, recv_hook = self.buffer.low_latency_combine( |
266 | 282 | fused_expert_output, |
267 | 283 | topk_ids, |
|
0 commit comments