-
Notifications
You must be signed in to change notification settings - Fork 6.1k
✨ [diffusion][npu][quant] Add MXFP8 quantization support for Wan2.2 Diffusion on Ascend NPU #20922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
ef874c0
✨ feat(npu): add online MXFP8 quantization support for Ascend NPU (Pa…
TallMessiWu d2d19c6
✨ feat(diffusion): add online MXFP8 quantization support for Wan2.2 o…
TallMessiWu c838ade
:bug: fix(diffusion): fix npu method call error
TallMessiWu be3b684
:bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch…
TallMessiWu fd79b23
:recycle: refactor(mxfp8): split linear method into config and NPU la…
TallMessiWu df61b29
:twisted_rightwards_arrows: merge: sync from upstream
TallMessiWu 490ad0b
:sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight su…
TallMessiWu cc80690
:bug: fix(diffusion): correct MXFP8 weight dtype and scale shape
TallMessiWu b9aa785
✨ feat(wan22): Redesigned the wan_repack tool. Now support one-click …
TallMessiWu 22bee9e
:recycle: refactor(mxfp8): hoist imports and replace print with logger
TallMessiWu a29bb3d
:pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimM…
TallMessiWu 3bbf703
:twisted_rightwards_arrows: chore(merge): sync upstream/main, keep MX…
TallMessiWu 250fe65
:adhesive_bandage: fix(diffusion): register --quantization CLI arg to…
TallMessiWu e146b03
:bug: fix(mxfp8_npu): move weight to current NPU device before quanti…
TallMessiWu 711bb8b
:rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) f…
TallMessiWu 1604d4e
:twisted_rightwards_arrows: chore(merge): sync upstream/main into junlin
TallMessiWu 1101cf5
:adhesive_bandage: fix(loader): preserve --quantization flag priority…
TallMessiWu 553a82c
Merge branch 'main' into junlin
ping1jing2 6bc42f9
:bug: fix(diffusion/mxfp8): fix torch_npu import error in non-npu env…
TallMessiWu 4dc2135
Merge branch 'main' into junlin
ping1jing2 92e2939
Merge branch 'main' into junlin
ping1jing2 0617ae8
:memo: docs(quant/mxfp8): update docs
TallMessiWu 8f54ac3
:twisted_rightwards_arrows: merge(junlin): sync upstream/main, preser…
TallMessiWu a0ee0be
:bug: fix(ci): fix Windows encoding and path separator bugs in pre-co…
TallMessiWu f92f820
:twisted_rightwards_arrows: Merge remote-tracking branch 'upstream/ma…
TallMessiWu e1dd1d3
:green_heart: fix(docs): fix broken relative path to ascend_npu_quant…
TallMessiWu 1787605
Merge branch 'main' into junlin
ping1jing2 b3cd75b
:green_heart: fix(test): add missing quantization attr to _make_serve…
TallMessiWu 4f79522
Merge branch 'main' into junlin
ping1jing2 d87905a
Merge branch 'main' into junlin
ping1jing2 762f21f
docs: sync LMSYS SGLang blog cards
github-actions[bot] ba014f2
Merge branch 'main' into junlin
ping1jing2 9be9672
Merge branch 'main' into junlin
ping1jing2 d44c6d2
:rewind: revert: drop auto-generated LMSYS blog sync from PR diff
TallMessiWu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
124 changes: 124 additions & 0 deletions
124
python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| """ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. | ||
|
|
||
| Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, | ||
| uint8 scales) and runs MXFP8 matmul at inference. | ||
| """ | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.multimodal_gen.runtime.platforms import current_platform | ||
|
|
||
| _is_npu = current_platform.is_npu() | ||
|
|
||
| if _is_npu: | ||
| import torch_npu | ||
|
|
||
| from sglang.multimodal_gen.runtime.models.parameter import ( | ||
| GroupQuantScaleParameter, | ||
| ModelWeightParameter, | ||
| ) | ||
| from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme | ||
|
|
||
| MXFP8_BLOCK_SIZE = 32 | ||
|
|
||
|
|
||
| class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: List[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| weight_loader = extra_weight_attrs.get("weight_loader") | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
|
|
||
| # msmodelslim exports weight as float8_e4m3fn, shape [out, in] | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty( | ||
| (output_size_per_partition, input_size_per_partition), | ||
| dtype=torch.float8_e4m3fn, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| # msmodelslim exports weight_scale as uint8, shape [out, in/32]. | ||
| # NOTE: This parameter is intentionally named "weight_scale" (not | ||
| # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader | ||
| # matches parameter names to checkpoint keys, and msmodelslim checkpoints | ||
| # store this tensor under the key "<layer>.weight_scale". | ||
| scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE | ||
| weight_scale = GroupQuantScaleParameter( | ||
| data=torch.empty( | ||
| (output_size_per_partition, scale_dim), | ||
| dtype=torch.uint8, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_scale", weight_scale) | ||
|
TallMessiWu marked this conversation as resolved.
|
||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module): | ||
| # weight is already float8_e4m3fn, no cast needed | ||
| weight = layer.weight.data | ||
| layer.weight = torch.nn.Parameter(weight, requires_grad=False) | ||
|
|
||
| # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] | ||
| weight_scale = layer.weight_scale.data | ||
| weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) | ||
| layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) | ||
|
|
||
| def apply_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
|
|
||
| original_dtype = x.dtype | ||
| if original_dtype not in (torch.float16, torch.bfloat16): | ||
| # npu_dynamic_mx_quant only accepts fp16/bf16 activations | ||
| x = x.to(torch.bfloat16) | ||
| original_dtype = torch.bfloat16 | ||
|
|
||
| # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. | ||
| # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or | ||
| # higher. Flattening to 2D merges all leading dimensions into a single | ||
| # token axis so the NPU kernel can compute per-token MXFP8 scales, then | ||
| # we restore the original shape from the output. | ||
| input_shape = x.shape | ||
| x_2d = x.reshape(-1, x.shape[-1]) | ||
|
|
||
| # Dynamic MXFP8 activation quantisation | ||
| qx, input_scale = torch_npu.npu_dynamic_mx_quant( | ||
| x_2d, dst_type=torch_npu.float8_e4m3fn | ||
| ) | ||
|
|
||
| # MXFP8 matmul | ||
| output = torch_npu.npu_quant_matmul( | ||
| qx, | ||
| layer.weight.transpose(0, 1), | ||
| layer.weight_scale.transpose(0, 1), | ||
| scale_dtype=torch_npu.float8_e8m0fnu, | ||
| pertoken_scale=input_scale, | ||
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | ||
| bias=bias.to(torch.float32) if bias is not None else None, | ||
| output_dtype=original_dtype, | ||
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | ||
| ) | ||
|
|
||
| # Restore original shape | ||
| output_shape = list(input_shape[:-1]) + [output.shape[-1]] | ||
| output = output.reshape(output_shape) | ||
|
|
||
| return output | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.