-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling #31593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d6a1f64
9a28683
e84eaa2
058a998
d53b6ff
9a7cf4d
1182e1d
2126f98
33741a8
31c4e22
e0129dd
eb6699b
5be7ab1
844a65a
24a0302
7edf70f
9d994a6
6ff4b75
c9a7e5b
2408ad2
96ff599
f9a4724
e8831f9
f8f9a33
59f97a6
113e472
a98a380
df82e9c
df5035c
3678402
783b64d
c30d404
a285f5e
2d96161
a910872
d2decd6
870fc6a
23e79fd
86a0e5c
7eaa18b
83a7d9b
7300bc5
344167d
b887c4f
d4d4231
218e697
b2e3a50
dd30416
3d22ba3
56edeca
e917f5d
39987f6
140f447
de6faa1
173e67d
12a0638
6982255
59db9a9
d6c4a87
ce913de
a8c5cc9
84dc7ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,8 @@ | |
| flashinfer_cutlass_moe_fp8, | ||
| get_flashinfer_moe_backend, | ||
| is_flashinfer_supporting_global_sf, | ||
| register_moe_scaling_factors, | ||
| make_fp8_moe_alpha_scales_for_fi, | ||
| register_scales_for_trtllm_fp8_per_tensor_moe, | ||
| rotate_flashinfer_fp8_moe_weights, | ||
| select_cutlass_fp8_gemm_impl, | ||
| swap_w13_to_w31, | ||
|
|
@@ -947,9 +948,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| if self.flashinfer_moe_backend is not None: | ||
| if self.moe.is_act_and_mul: | ||
| layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) | ||
|
|
||
| # NOTE: this adds some attributes used by the trtllm kernel, | ||
| # which does not conform to the modular kernels abstraction (yet). | ||
| if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: | ||
| rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) | ||
| register_moe_scaling_factors(layer) | ||
| register_scales_for_trtllm_fp8_per_tensor_moe( | ||
| layer=layer, | ||
| w13_weight_scale=layer.w13_weight_scale, | ||
| w13_input_scale=layer.w13_input_scale, | ||
| w2_weight_scale=layer.w2_weight_scale, | ||
| w2_input_scale=layer.w2_input_scale, | ||
| ) | ||
|
|
||
| def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: | ||
| """Pad intermediate size so FlashInfer kernels' alignment constraints hold. | ||
|
|
@@ -999,19 +1009,34 @@ def get_fused_moe_quant_config( | |
| self, layer: torch.nn.Module | ||
| ) -> FusedMoEQuantConfig | None: | ||
| if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: | ||
| # TRTLLM does not use modular kernels | ||
| return None | ||
|
|
||
| return fp8_w8a8_moe_quant_config( | ||
| w1_scale=layer.w13_weight_scale, | ||
| g1_alphas=layer.output1_scales_gate_scalar.squeeze(), | ||
| w2_scale=layer.w2_weight_scale, | ||
| g2_alphas=layer.output2_scales_scalar.squeeze(), | ||
| a1_scale=layer.w13_input_scale, | ||
| a1_gscale=layer.w13_input_scale, | ||
| a2_scale=layer.w2_input_scale, | ||
| a2_gscale=layer.w2_input_scale_inv, | ||
| per_act_token_quant=False, | ||
| ) | ||
| elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: | ||
| g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( | ||
| layer.w13_weight_scale, | ||
| layer.w13_input_scale, | ||
| layer.w2_weight_scale, | ||
| layer.w2_input_scale, | ||
| ) | ||
| return fp8_w8a8_moe_quant_config( | ||
| w1_scale=layer.w13_weight_scale, | ||
| w2_scale=layer.w2_weight_scale, | ||
| a1_scale=layer.w13_input_scale, | ||
| a2_scale=layer.w2_input_scale, | ||
| a1_gscale=(1.0 / layer.w13_input_scale), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function is called every forward, which means these 2 lines will result in 2 kernel launches for reciprocal: Can we add these 2 scales in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its not called in the forward pass. I recognize this is confusing, but the I am working on an ongoing refactor that makes the conversion
see https://vllm-dev.slack.com/archives/C08NFPURQ1F/p1767650816469009 for more details on my efforts |
||
| a2_gscale=(1.0 / layer.w2_input_scale), | ||
| g1_alphas=g1_alphas, | ||
| g2_alphas=g2_alphas, | ||
| ) | ||
| else: | ||
| assert self.flashinfer_moe_backend is None | ||
| return fp8_w8a8_moe_quant_config( | ||
| w1_scale=layer.w13_weight_scale, | ||
| w2_scale=layer.w2_weight_scale, | ||
| a1_scale=layer.w13_input_scale, | ||
| a2_scale=layer.w2_input_scale, | ||
| ) | ||
|
|
||
| def apply( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the a2_gscales are used for quantization of hidden states (
atensor) before FFN2 in MOE, hence the a2 (for second FFN) and gscale for quantization.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to reflect this