diff --git a/docs/diffusion/quantization.md b/docs/diffusion/quantization.md index e9340c54bba2..cf6e4d05d07a 100644 --- a/docs/diffusion/quantization.md +++ b/docs/diffusion/quantization.md @@ -43,21 +43,21 @@ backend. | quant_family | checkpoint form | canonical CLI | supported models | extra dependency | platform / notes | |-------------------|--------------------------------------------------------------------------------------------|------------------------------------------------------------------------|-----------------------------------------|---------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------| | `fp8` | Quantized transformer component folder, or safetensors with `quantization_config` metadata | `--transformer-path` or `--transformer-weights-path` | ALL | None | Component-folder and single-file flows are both supported | -| `modelopt-fp8` | Converted ModelOpt FP8 transformer directory or repo with `config.json` | `--transformer-path` | FLUX.1, FLUX.2, Wan2.2, Qwen Image, Qwen Image Edit | None | Serialized config stays `quant_method=modelopt` with `quant_algo=FP8`; `dit_layerwise_offload` is supported and `dit_cpu_offload` stays disabled | +| `modelopt-fp8` | Converted ModelOpt FP8 transformer directory or repo with `config.json` | `--transformer-path` | FLUX.1, FLUX.2, Wan2.2, HunyuanVideo, Qwen Image, Qwen Image Edit | None | Serialized config stays `quant_method=modelopt` with `quant_algo=FP8`; `dit_layerwise_offload` is supported and `dit_cpu_offload` stays disabled | | `modelopt-nvfp4` | Mixed transformer directory/repo with `config.json`, or raw NVFP4 safetensors export/repo | `--transformer-path` for mixed overrides; `--transformer-weights-path` for raw exports | FLUX.1, FLUX.2, Wan2.2 | None | Mixed override repos keep the base model separate; raw exports such as `black-forest-labs/FLUX.2-dev-NVFP4` still use the weights-path flow | | `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` | | `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` | ## Validated ModelOpt Checkpoints -This section is the canonical support matrix for the diffusion ModelOpt +This section is the canonical support matrix for the nine diffusion ModelOpt checkpoints currently wired up in SGLang docs and validation coverage. Published checkpoints keep the serialized quantization config as `quant_method=modelopt`; the FP8 vs NVFP4 split below is a documentation label derived from `quant_algo`. -Seven of the eight repos live under `lmsys/*`. The FLUX.2 NVFP4 entry keeps the +Eight of the nine repos live under `lmsys/*`. The FLUX.2 NVFP4 entry keeps the official `black-forest-labs/FLUX.2-dev-NVFP4` repo. | Quant Algo | Base Model | Preferred CLI | HF Repo | Current Scope | Notes | @@ -65,13 +65,14 @@ official `black-forest-labs/FLUX.2-dev-NVFP4` repo. | `FP8` | `black-forest-labs/FLUX.1-dev` | `--transformer-path` | `lmsys/flux1-dev-modelopt-fp8-sglang-transformer` | single-transformer override, deterministic latent/image comparison, H100 benchmark, torch-profiler trace | SGLang converter keeps a validated BF16 fallback set for modulation and FF projection layers; use `--model-id FLUX.1-dev` for local mirrors | | `FP8` | `black-forest-labs/FLUX.2-dev` | `--transformer-path` | `lmsys/flux2-dev-modelopt-fp8-sglang-transformer` | single-transformer override load and generation path | published SGLang-ready transformer override | | `FP8` | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `--transformer-path` | `lmsys/wan22-t2v-a14b-modelopt-fp8-sglang-transformer` | primary `transformer` quantized, `transformer_2` kept BF16 | primary-transformer-only path; keep `transformer_2` on the base checkpoint, and do not describe this as dual-transformer full-model FP8 unless that path is validated separately | +| `FP8` | `hunyuanvideo-community/HunyuanVideo` | `--transformer-path` | `lmsys/hunyuanvideo-modelopt-fp8-sglang-transformer` | single-transformer override, BF16-vs-FP8 video comparison, H100 benchmark, torch-profiler trace | HunyuanVideo uses different ModelOpt/diffusers and SGLang runtime module names; the converter maps those names before writing FP8 scale tensors and BF16 fallback ignores | | `FP8` | `Qwen/Qwen-Image` | `--transformer-path` | `lmsys/qwen-image-modelopt-fp8-sglang-transformer` | single-transformer override, BF16-vs-FP8 image comparison, H100 benchmark, torch-profiler trace | shares the Qwen Image FP8 fallback preset; keep `img_in`, `txt_in`, timestep embedder, `norm_out.linear`, `proj_out`, `img_mod`/`txt_mod`, and `img_mlp.net.2` in BF16 | -| `FP8` | `Qwen/Qwen-Image-Edit-2511` | `--transformer-path` | `lmsys/qwen-image-edit-modelopt-fp8-sglang-transformer` | TI2I edit smoke, BF16-vs-FP8 image comparison, H100 benchmark | shares `QwenImageTransformer2DModel` with Qwen Image and uses the same Qwen Image FP8 fallback preset | +| `FP8` | `Qwen/Qwen-Image-Edit-2511` | `--transformer-path` | `lmsys/qwen-image-edit-modelopt-fp8-sglang-transformer` | TI2I edit path, BF16-vs-FP8 image comparison, H100 benchmark | shares `QwenImageTransformer2DModel` with Qwen Image and uses the same Qwen Image FP8 fallback preset | | `NVFP4` | `black-forest-labs/FLUX.1-dev` | `--transformer-path` | `lmsys/flux1-dev-modelopt-nvfp4-sglang-transformer` | mixed BF16+NVFP4 transformer override, correctness validation, 4x RTX 5090 benchmark, torch-profiler trace | use `build_modelopt_nvfp4_transformer.py`; validated builder keeps selected FLUX.1 modules in BF16 and sets `swap_weight_nibbles=false` | | `NVFP4` | `black-forest-labs/FLUX.2-dev` | `--transformer-weights-path` | `black-forest-labs/FLUX.2-dev-NVFP4` | packed-QKV load path | official raw export repo; validated packed export detection and runtime layout handling | | `NVFP4` | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `--transformer-path` | `lmsys/wan22-t2v-a14b-modelopt-nvfp4-sglang-transformer` | primary `transformer` quantized with ModelOpt NVFP4, `transformer_2` kept BF16 | primary-transformer-only path; keep `transformer_2` on the base checkpoint, and current B200/Blackwell bring-up uses `SGLANG_DIFFUSION_FLASHINFER_FP4_GEMM_BACKEND=cudnn` | -These eight checkpoints are also the intended case set for the B200 diffusion +These nine checkpoints are also the intended case set for the B200 diffusion CI job (`multimodal-gen-test-1-b200`). ## ModelOpt FP8 @@ -98,6 +99,15 @@ sglang generate \ --save-output ``` +```bash +sglang generate \ + --model-path hunyuanvideo-community/HunyuanVideo \ + --transformer-path lmsys/hunyuanvideo-modelopt-fp8-sglang-transformer \ + --height 544 --width 960 --num-frames 17 \ + --prompt "A cinematic shot of a red sports car driving through rain at night" \ + --save-output +``` + ```bash sglang generate \ --model-path Qwen/Qwen-Image \ @@ -131,6 +141,17 @@ sglang generate \ - On disk, the quantization config stays `quant_method=modelopt` with `quant_algo=FP8`; the `modelopt-fp8` label in this document is a support family name, not a serialized config key. +- `hunyuanvideo-community/HunyuanVideo` uses the `hunyuan-video` converter + preset. Use `--model-type hunyuan-video` to force it, or rely on + auto-detection from `_class_name=HunyuanVideoTransformer3DModel`. +- The validated HunyuanVideo FP8 fallback preset keeps `context_embedder`, + `x_embedder.proj`, timestep/guidance/text embedder linear layers, + `norm_out.linear`, `proj_out`, double-block modulation linear layers, and + single-block modulation linear layers in BF16. +- HunyuanVideo ModelOpt exports use diffusers module names that do not match + SGLang runtime module names for fused QKV and fused QKV+MLP layers. The + converter maps the names before selecting scale tensors and before writing + the runtime ignore list. - `Qwen/Qwen-Image` and `Qwen/Qwen-Image-Edit-2511` share the `qwen-image` converter preset. Use `--model-type qwen-image` to force it, or rely on auto-detection from `_class_name=QwenImageTransformer2DModel`. diff --git a/docs_new/docs/sglang-diffusion/quantization.mdx b/docs_new/docs/sglang-diffusion/quantization.mdx index 621a99a442e5..392c0831b06b 100644 --- a/docs_new/docs/sglang-diffusion/quantization.mdx +++ b/docs_new/docs/sglang-diffusion/quantization.mdx @@ -76,7 +76,7 @@ backend. modelopt-fp8 Converted ModelOpt FP8 transformer directory or repo with config.json --transformer-path - FLUX.1, FLUX.2, Wan2.2 + FLUX.1, FLUX.2, Wan2.2, HunyuanVideo, Qwen Image, Qwen Image Edit None Serialized config stays quant_method=modelopt with quant_algo=FP8; dit_layerwise_offload is supported and dit_cpu_offload stays disabled @@ -109,14 +109,14 @@ backend. ## Validated ModelOpt Checkpoints -This section is the canonical support matrix for the eight diffusion ModelOpt +This section is the canonical support matrix for the nine diffusion ModelOpt checkpoints currently wired up in SGLang docs and B200 CI coverage. Published checkpoints keep the serialized quantization config as `quant_method=modelopt`; the FP8 vs NVFP4 split below is a documentation label derived from `quant_algo`. -Seven of the eight repos live under `lmsys/*`. The FLUX.2 NVFP4 entry keeps the +Eight of the nine repos live under `lmsys/*`. The FLUX.2 NVFP4 entry keeps the official `black-forest-labs/FLUX.2-dev-NVFP4` repo. @@ -163,6 +163,14 @@ official `black-forest-labs/FLUX.2-dev-NVFP4` repo. + + + + + + + + @@ -176,7 +184,7 @@ official `black-forest-labs/FLUX.2-dev-NVFP4` repo. - + @@ -206,7 +214,7 @@ official `black-forest-labs/FLUX.2-dev-NVFP4` repo.
primary transformer quantized, transformer_2 kept BF16 primary-transformer-only path; keep transformer_2 on the base checkpoint, and do not describe this as dual-transformer full-model FP8 unless that path is validated separately
FP8hunyuanvideo-community/HunyuanVideo--transformer-pathlmsys/hunyuanvideo-modelopt-fp8-sglang-transformersingle-transformer override, BF16-vs-FP8 video comparison, H100 benchmark, torch-profiler traceHunyuanVideo uses different ModelOpt/diffusers and SGLang runtime module names; the converter maps those names before writing FP8 scale tensors and BF16 fallback ignores
FP8 Qwen/Qwen-ImageQwen/Qwen-Image-Edit-2511 --transformer-path lmsys/qwen-image-edit-modelopt-fp8-sglang-transformerTI2I edit smoke, BF16-vs-FP8 image comparison, H100 benchmarkTI2I edit path, BF16-vs-FP8 image comparison, H100 benchmark shares QwenImageTransformer2DModel with Qwen Image and uses the same Qwen Image FP8 fallback preset
-These eight checkpoints are also the intended case set for the B200 diffusion CI +These nine checkpoints are also the intended case set for the B200 diffusion CI job (`multimodal-gen-test-1-b200`). ## ModelOpt FP8 @@ -233,6 +241,15 @@ sglang generate \ --save-output ``` +```bash +sglang generate \ + --model-path hunyuanvideo-community/HunyuanVideo \ + --transformer-path lmsys/hunyuanvideo-modelopt-fp8-sglang-transformer \ + --height 544 --width 960 --num-frames 17 \ + --prompt "A cinematic shot of a red sports car driving through rain at night" \ + --save-output +``` + ```bash sglang generate \ --model-path Qwen/Qwen-Image \ diff --git a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md index 7b6d90be15c2..12c10511d0b2 100644 --- a/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md +++ b/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.md @@ -63,7 +63,7 @@ This repo now contains: Validated documentation and CI coverage currently center on these ModelOpt diffusion transformer override families: -- FP8: FLUX.1-dev, FLUX.2-dev, Wan2.2, Qwen Image, Qwen Image Edit +- FP8: FLUX.1-dev, FLUX.2-dev, Wan2.2, HunyuanVideo, Qwen Image, Qwen Image Edit - NVFP4: FLUX.1-dev, FLUX.2-dev, Wan2.2 Treat a new family, a new precision, or a new checkpoint layout as unsupported until it has a documented matrix row and a matching validation story. @@ -71,26 +71,26 @@ Before writing CLI examples, re-read the active branch's `docs/diffusion/quantiz B200 CI coverage can include loose BF16-vs-quantized quality checks. Inspect the active branch's `run_suite.py` before assuming they are part of the suite; mainline and feature branches may differ. Those checks are intended to catch blank, corrupted, or obviously divergent images, not exact image parity. -Mainline documentation now uses `lmsys/*` for the five converted ModelOpt +Mainline documentation now uses `lmsys/*` for the eight converted ModelOpt checkpoint repos; the FLUX.2 NVFP4 raw export remains `black-forest-labs/FLUX.2-dev-NVFP4`. Do not use older `BBuf/*` examples unless you are explicitly testing a historical branch. -## Open PR Watchlist +## Related PR Watchlist -As of 2026-05-02, these related SGLang PRs were open. Treat them as future -support or migration work until they merge and the docs/CI matrix is updated. +As of 2026-05-04, these related SGLang PRs are relevant to ModelOpt diffusion +support. Treat unmerged items as future support or migration work until the +docs/CI matrix is updated. -- #23155 adds Qwen Image ModelOpt FP8 support. +- #23155 added Qwen Image ModelOpt FP8 support. - #23199 adds HunyuanVideo ModelOpt FP8 support. - #23373 adds a runtime quantization flag; keep PTQ/export workflows separate from runtime quant examples until the CLI behavior is merged. - #24024 adds transformer FP8-cast compatibility mode. - #24186 re-enables B200 multimodal CI with NVFP4 fixes for FLUX.2 and Wan2.2. -Do not expand the validated matrix beyond FLUX.1, FLUX.2, and Wan2.2 solely -because one of these PRs exists. Add a row only after the exact checkpoint, -loader path, accuracy check, and benchmark scope are validated on the active -branch. +Do not expand the validated matrix beyond the documented rows solely because a +related PR exists. Add a row only after the exact checkpoint, loader path, +accuracy check, and benchmark scope are validated on the active branch. ## Documentation Maintenance @@ -194,6 +194,28 @@ For `FLUX.1-dev`, the validated fallback set currently keeps these modules in BF Use `--model-type flux1` to force that profile, or rely on `--model-type auto` when the export config identifies `FluxTransformer2DModel`. +HunyuanVideo uses `HunyuanVideoTransformer3DModel`, so the validated +HunyuanVideo FP8 fallback preset keeps these modules in BF16: + +- `context_embedder.*` +- `x_embedder.proj` +- `time_text_embed.(timestep_embedder|guidance_embedder|text_embedder).linear_[12]` +- `norm_out.linear` +- `proj_out` +- `transformer_blocks.*.norm1.linear` +- `transformer_blocks.*.norm1_context.linear` +- `single_transformer_blocks.*.norm.linear` + +Use `--model-type hunyuan-video` to force that profile, or rely on +`--model-type auto` when the export config identifies +`HunyuanVideoTransformer3DModel`. + +HunyuanVideo ModelOpt exports use diffusers module names that differ from +SGLang runtime names for fused QKV and fused QKV+MLP layers. Keep the +diffusers-to-runtime mapping in `build_modelopt_fp8_transformer.py` in sync +with `runtime/models/dits/hunyuanvideo.py` before trusting converted scale +tensors. + Qwen Image and Qwen Image Edit share `QwenImageTransformer2DModel`, so one ModelOpt FP8 fallback preset covers both. The validated Qwen Image fallback set keeps these modules in BF16: diff --git a/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py b/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py index 2688f70ae41a..e489c2fd5713 100644 --- a/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py +++ b/python/sglang/multimodal_gen/benchmarks/bench_offline_throughput.py @@ -477,9 +477,9 @@ def main(): ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) - args = parser.parse_args() + args, unknown_args = parser.parse_known_args() - server_args = ServerArgs.from_cli_args(args) + server_args = ServerArgs.from_cli_args(args, unknown_args) bench_args = BenchArgs.from_cli_args(args) set_global_server_args(server_args) diff --git a/python/sglang/multimodal_gen/runtime/layers/linear.py b/python/sglang/multimodal_gen/runtime/layers/linear.py index c8c0a2598efd..f0cfc5408258 100644 --- a/python/sglang/multimodal_gen/runtime/layers/linear.py +++ b/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -232,6 +232,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: torch.dtype | None = None, quant_config: QuantizationConfig | None = None, + output_sizes: list[int] | None = None, prefix: str = "", ): super().__init__( @@ -245,10 +246,11 @@ def __init__( # All the linear layer supports quant method. assert self.quant_method is not None + output_partition_sizes = output_sizes or [self.output_size] self.quant_method.create_weights( self, self.input_size, - [self.output_size], + output_partition_sizes, self.input_size, self.output_size, self.params_dtype, @@ -497,7 +499,6 @@ def weight_loader( loaded_weight: torch.Tensor, loaded_shard_id: int | None = None, ) -> None: - param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. @@ -829,7 +830,6 @@ def weight_loader( loaded_weight: torch.Tensor, loaded_shard_id: str | None = None, ): - param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. @@ -866,7 +866,6 @@ def weight_loader( ] for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size ) @@ -1037,7 +1036,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): - # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py index 9decdf49e435..1750f7e7b361 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -95,6 +95,7 @@ def __init__( params_dtype=dtype, prefix=f"{prefix}.img_attn_qkv", quant_config=quant_config, + output_sizes=[hidden_size] * 3, ) self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) @@ -142,7 +143,9 @@ def __init__( hidden_size * 3, bias=True, params_dtype=dtype, + prefix=f"{prefix}.txt_attn_qkv", quant_config=quant_config, + output_sizes=[hidden_size] * 3, ) # QK norm layers for text @@ -154,6 +157,7 @@ def __init__( hidden_size, bias=True, params_dtype=dtype, + prefix=f"{prefix}.txt_attn_proj", quant_config=quant_config, ) @@ -162,6 +166,7 @@ def __init__( mlp_hidden_dim, bias=True, dtype=dtype, + prefix=f"{prefix}.txt_mlp", quant_config=quant_config, ) @@ -220,9 +225,10 @@ def forward( img_k = self.img_attn_k_norm(img_k.contiguous()).to(img_v) # Apply rotary embeddings cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, is_neox_style=False - ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + img_q, img_k = ( + _apply_rotary_emb(img_q, cos, sin, is_neox_style=False), + _apply_rotary_emb(img_k, cos, sin, is_neox_style=False), + ) # Prepare text for attention using fused operation txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale) @@ -304,6 +310,7 @@ def __init__( params_dtype=dtype, prefix=f"{prefix}.linear1", quant_config=quant_config, + output_sizes=[hidden_size] * 3 + [mlp_hidden_dim], ) # Combined projection and MLP output @@ -386,9 +393,10 @@ def forward( img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:] # Apply rotary embeddings to image parts cos, sin = freqs_cis - img_q, img_k = _apply_rotary_emb( - img_q, cos, sin, is_neox_style=False - ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + img_q, img_k = ( + _apply_rotary_emb(img_q, cos, sin, is_neox_style=False), + _apply_rotary_emb(img_k, cos, sin, is_neox_style=False), + ) # Run distributed attention img_attn_output, txt_attn_output = self.attn( @@ -682,7 +690,6 @@ def maybe_cache_states( self.previous_residual = hidden_states - original_hidden_states def should_skip_forward_for_cached_states(self, **kwargs) -> bool: - forward_context = get_forward_context() forward_batch = forward_context.forward_batch if forward_batch is None: diff --git a/python/sglang/multimodal_gen/test/server/gpu_cases.py b/python/sglang/multimodal_gen/test/server/gpu_cases.py index 47ded1b499fe..8abc50665ab6 100644 --- a/python/sglang/multimodal_gen/test/server/gpu_cases.py +++ b/python/sglang/multimodal_gen/test/server/gpu_cases.py @@ -4,6 +4,7 @@ MODELOPT_FLUX1_NVFP4_TRANSFORMER, MODELOPT_FLUX2_FP8_TRANSFORMER, MODELOPT_FLUX2_NVFP4_WEIGHTS, + MODELOPT_HUNYUANVIDEO_FP8_TRANSFORMER, MODELOPT_NVFP4_B200_ENV_VARS, MODELOPT_QWEN_IMAGE_EDIT_FP8_TRANSFORMER, MODELOPT_QWEN_IMAGE_FP8_TRANSFORMER, @@ -403,6 +404,18 @@ sampling_params=MODELOPT_T2V_CI_sampling_params, extras=["--transformer-path", MODELOPT_WAN22_FP8_TRANSFORMER], ), + _make_modelopt_ci_case( + "hunyuanvideo_modelopt_fp8_t2v", + model_path="hunyuanvideo-community/HunyuanVideo", + modality="video", + sampling_params=MODELOPT_T2V_CI_sampling_params, + extras=[ + "--transformer-path", + MODELOPT_HUNYUANVIDEO_FP8_TRANSFORMER, + "--text-encoder-cpu-offload", + "--pin-cpu-memory", + ], + ), _make_modelopt_ci_case( "qwen_image_modelopt_fp8_t2i", model_path=DEFAULT_QWEN_IMAGE_MODEL_NAME_FOR_TEST, diff --git a/python/sglang/multimodal_gen/test/server/testcase_configs.py b/python/sglang/multimodal_gen/test/server/testcase_configs.py index 3a03c8329b00..122f5f8cb2c8 100644 --- a/python/sglang/multimodal_gen/test/server/testcase_configs.py +++ b/python/sglang/multimodal_gen/test/server/testcase_configs.py @@ -442,6 +442,9 @@ def from_req_perf_record( MODELOPT_FLUX1_FP8_TRANSFORMER = "lmsys/flux1-dev-modelopt-fp8-sglang-transformer" MODELOPT_FLUX2_FP8_TRANSFORMER = "lmsys/flux2-dev-modelopt-fp8-sglang-transformer" MODELOPT_WAN22_FP8_TRANSFORMER = "lmsys/wan22-t2v-a14b-modelopt-fp8-sglang-transformer" +MODELOPT_HUNYUANVIDEO_FP8_TRANSFORMER = ( + "lmsys/hunyuanvideo-modelopt-fp8-sglang-transformer" +) MODELOPT_QWEN_IMAGE_FP8_TRANSFORMER = "lmsys/qwen-image-modelopt-fp8-sglang-transformer" MODELOPT_QWEN_IMAGE_EDIT_FP8_TRANSFORMER = ( "lmsys/qwen-image-edit-modelopt-fp8-sglang-transformer" diff --git a/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py b/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py index 923ee5f3275b..a579e2035da8 100644 --- a/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py +++ b/python/sglang/multimodal_gen/tools/build_modelopt_fp8_transformer.py @@ -29,7 +29,7 @@ import shutil from collections import defaultdict from pathlib import Path -from typing import Iterable, Mapping, Sequence +from typing import Callable, Iterable, Mapping, Sequence import torch from safetensors import safe_open @@ -76,6 +76,128 @@ r"^transformer_blocks\.(0|43|44|45|46|47)\.(attn1|attn2|audio_attn1|audio_attn2|audio_to_video_attn|video_to_audio_attn)\.to_out\.0$", r"^transformer_blocks\.(0|43|44|45|46|47)\.(ff|audio_ff)\.proj_(in|out)$", ] +DEFAULT_HUNYUANVIDEO_KEEP_BF16_PATTERNS = [ + r"^context_embedder\.", + r"^x_embedder\.proj$", + r"^time_text_embed\.(timestep_embedder|guidance_embedder|text_embedder)\.linear_[12]$", + r"^norm_out\.linear$", + r"^proj_out$", + r"^transformer_blocks\.\d+\.norm1\.linear$", + r"^transformer_blocks\.\d+\.norm1_context\.linear$", + r"^single_transformer_blocks\.\d+\.norm\.linear$", +] +HUNYUANVIDEO_RUNTIME_NAME_REPLACEMENTS = [ + ( + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1$", + r"txt_in.t_embedder.mlp.fc_in", + ), + ( + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_2$", + r"txt_in.t_embedder.mlp.fc_out", + ), + (r"^context_embedder\.proj_in$", r"txt_in.input_embedder"), + ( + r"^context_embedder\.time_text_embed\.text_embedder\.linear_1$", + r"txt_in.c_embedder.fc_in", + ), + ( + r"^context_embedder\.time_text_embed\.text_embedder\.linear_2$", + r"txt_in.c_embedder.fc_out", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm1$", + r"txt_in.refiner_blocks.\1.norm1", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm2$", + r"txt_in.refiner_blocks.\1.norm2", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_[qkv]$", + r"txt_in.refiner_blocks.\1.self_attn_qkv", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_out\.0$", + r"txt_in.refiner_blocks.\1.self_attn_proj", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?$", + r"txt_in.refiner_blocks.\1.mlp.fc_in", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?$", + r"txt_in.refiner_blocks.\1.mlp.fc_out", + ), + ( + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm_out\.linear$", + r"txt_in.refiner_blocks.\1.adaLN_modulation.linear", + ), + (r"^x_embedder\.proj$", r"img_in.proj"), + (r"^time_text_embed\.timestep_embedder\.linear_1$", r"time_in.mlp.fc_in"), + (r"^time_text_embed\.timestep_embedder\.linear_2$", r"time_in.mlp.fc_out"), + (r"^time_text_embed\.guidance_embedder\.linear_1$", r"guidance_in.mlp.fc_in"), + (r"^time_text_embed\.guidance_embedder\.linear_2$", r"guidance_in.mlp.fc_out"), + (r"^time_text_embed\.text_embedder\.linear_1$", r"vector_in.fc_in"), + (r"^time_text_embed\.text_embedder\.linear_2$", r"vector_in.fc_out"), + (r"^transformer_blocks\.(\d+)\.norm1\.linear$", r"double_blocks.\1.img_mod.linear"), + ( + r"^transformer_blocks\.(\d+)\.norm1_context\.linear$", + r"double_blocks.\1.txt_mod.linear", + ), + (r"^transformer_blocks\.(\d+)\.attn\.norm_q$", r"double_blocks.\1.img_attn_q_norm"), + (r"^transformer_blocks\.(\d+)\.attn\.norm_k$", r"double_blocks.\1.img_attn_k_norm"), + (r"^transformer_blocks\.(\d+)\.attn\.to_[qkv]$", r"double_blocks.\1.img_attn_qkv"), + ( + r"^transformer_blocks\.(\d+)\.attn\.add_[qkv]_proj$", + r"double_blocks.\1.txt_attn_qkv", + ), + ( + r"^transformer_blocks\.(\d+)\.attn\.to_out\.0$", + r"double_blocks.\1.img_attn_proj", + ), + ( + r"^transformer_blocks\.(\d+)\.attn\.to_add_out$", + r"double_blocks.\1.txt_attn_proj", + ), + ( + r"^transformer_blocks\.(\d+)\.attn\.norm_added_q$", + r"double_blocks.\1.txt_attn_q_norm", + ), + ( + r"^transformer_blocks\.(\d+)\.attn\.norm_added_k$", + r"double_blocks.\1.txt_attn_k_norm", + ), + ( + r"^transformer_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?$", + r"double_blocks.\1.img_mlp.fc_in", + ), + ( + r"^transformer_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?$", + r"double_blocks.\1.img_mlp.fc_out", + ), + ( + r"^transformer_blocks\.(\d+)\.ff_context\.net\.0(?:\.proj)?$", + r"double_blocks.\1.txt_mlp.fc_in", + ), + ( + r"^transformer_blocks\.(\d+)\.ff_context\.net\.2(?:\.proj)?$", + r"double_blocks.\1.txt_mlp.fc_out", + ), + (r"^single_transformer_blocks\.(\d+)\.attn\.norm_q$", r"single_blocks.\1.q_norm"), + (r"^single_transformer_blocks\.(\d+)\.attn\.norm_k$", r"single_blocks.\1.k_norm"), + ( + r"^single_transformer_blocks\.(\d+)\.attn\.to_[qkv]$", + r"single_blocks.\1.linear1", + ), + (r"^single_transformer_blocks\.(\d+)\.proj_mlp$", r"single_blocks.\1.linear1"), + (r"^single_transformer_blocks\.(\d+)\.proj_out$", r"single_blocks.\1.linear2"), + ( + r"^single_transformer_blocks\.(\d+)\.norm\.linear$", + r"single_blocks.\1.modulation.linear", + ), + (r"^norm_out\.linear$", r"final_layer.adaLN_modulation.linear"), + (r"^proj_out$", r"final_layer.linear"), +] DEFAULT_QWEN_IMAGE_KEEP_BF16_PATTERNS = [ r"^img_in$", r"^txt_in$", @@ -166,7 +288,27 @@ def _load_first_shard_metadata( return dict(f.metadata() or {}) -def _module_name_variants(weight_name: str) -> list[str]: +def _map_hunyuanvideo_runtime_module_name(module_name: str) -> list[str]: + mapped_names: list[str] = [] + for pattern, replacement in HUNYUANVIDEO_RUNTIME_NAME_REPLACEMENTS: + mapped = re.sub(pattern, replacement, module_name) + if mapped != module_name: + mapped_names.append(mapped) + return mapped_names + + +def _get_runtime_module_name_mapper( + *, model_type: str, class_name: str | None +) -> Callable[[str], list[str]] | None: + if model_type == "hunyuan-video" or class_name == "HunyuanVideoTransformer3DModel": + return _map_hunyuanvideo_runtime_module_name + return None + + +def _module_name_variants( + weight_name: str, + runtime_name_mapper: Callable[[str], list[str]] | None = None, +) -> list[str]: module_name = weight_name[:-7] if weight_name.endswith(".weight") else weight_name variants = [module_name] @@ -184,6 +326,11 @@ def _module_name_variants(weight_name: str) -> list[str]: ) canonicalized.append(re.sub(r"(\.(img_mod|txt_mod))\.1$", r"\1", variant)) variants.extend(canonicalized) + if runtime_name_mapper is not None: + runtime_variants: list[str] = [] + for variant in variants: + runtime_variants.extend(runtime_name_mapper(variant)) + variants.extend(runtime_variants) deduped: list[str] = [] for variant in variants: @@ -192,8 +339,11 @@ def _module_name_variants(weight_name: str) -> list[str]: return deduped -def _preferred_module_name(weight_name: str) -> str: - return _module_name_variants(weight_name)[-1] +def _preferred_module_name( + weight_name: str, + runtime_name_mapper: Callable[[str], list[str]] | None = None, +) -> str: + return _module_name_variants(weight_name, runtime_name_mapper)[-1] def _scale_key_candidates(weight_name: str) -> list[str]: @@ -269,6 +419,8 @@ def get_default_keep_bf16_patterns( return list(DEFAULT_FLUX1_KEEP_BF16_PATTERNS) if model_type == "flux2": return list(DEFAULT_FLUX2_KEEP_BF16_PATTERNS) + if model_type == "hunyuan-video": + return list(DEFAULT_HUNYUANVIDEO_KEEP_BF16_PATTERNS) if model_type == "qwen-image": return list(DEFAULT_QWEN_IMAGE_KEEP_BF16_PATTERNS) if model_type == "none": @@ -277,6 +429,8 @@ def get_default_keep_bf16_patterns( return list(DEFAULT_FLUX1_KEEP_BF16_PATTERNS) if class_name == "Flux2Transformer2DModel": return list(DEFAULT_FLUX2_KEEP_BF16_PATTERNS) + if class_name == "HunyuanVideoTransformer3DModel": + return list(DEFAULT_HUNYUANVIDEO_KEEP_BF16_PATTERNS) if class_name == "QwenImageTransformer2DModel": return list(DEFAULT_QWEN_IMAGE_KEEP_BF16_PATTERNS) return [] @@ -285,6 +439,7 @@ def get_default_keep_bf16_patterns( def should_keep_bf16( weight_name: str, keep_bf16_patterns: Sequence[str], + runtime_name_mapper: Callable[[str], list[str]] | None = None, ) -> bool: if not keep_bf16_patterns: return False @@ -292,13 +447,14 @@ def should_keep_bf16( return any( re.search(pattern, module_name) for pattern in keep_bf16_patterns - for module_name in _module_name_variants(weight_name) + for module_name in _module_name_variants(weight_name, runtime_name_mapper) ) def is_ignored_by_modelopt( weight_name: str, ignore_patterns: Sequence[str], + runtime_name_mapper: Callable[[str], list[str]] | None = None, ) -> bool: if not ignore_patterns: return False @@ -307,7 +463,7 @@ def is_ignored_by_modelopt( regex_str = pattern.replace(".", r"\.").replace("*", r".*") if any( re.fullmatch(regex_str, module_name) - for module_name in _module_name_variants(weight_name) + for module_name in _module_name_variants(weight_name, runtime_name_mapper) ): return True return False @@ -424,6 +580,9 @@ def build_modelopt_fp8_transformer( source_weight_map=source_weight_map_all, ) class_name = config.get("_class_name") + runtime_name_mapper = _get_runtime_module_name_mapper( + model_type=model_type, class_name=class_name + ) ignore_patterns = list(quant_config.get("ignore", []) or []) patterns = list( get_default_keep_bf16_patterns(model_type=model_type, class_name=class_name) @@ -463,7 +622,8 @@ def build_modelopt_fp8_transformer( fallback_weight_names = sorted( weight_name for weight_name in source_weight_map - if weight_name.endswith(".weight") and should_keep_bf16(weight_name, patterns) + if weight_name.endswith(".weight") + and should_keep_bf16(weight_name, patterns, runtime_name_mapper) ) fallback_weight_names_set = set(fallback_weight_names) @@ -492,14 +652,17 @@ def build_modelopt_fp8_transformer( auto_ignore_modules = sorted( { - _preferred_module_name(weight_name) + _preferred_module_name(weight_name, runtime_name_mapper) for weight_name in source_weight_map if weight_name.endswith(".weight") and _resolve_scale_key(weight_name, fp8_scale_map) is None } ) fallback_ignore_modules = sorted( - {_preferred_module_name(weight_name) for weight_name in fallback_weight_names} + { + _preferred_module_name(weight_name, runtime_name_mapper) + for weight_name in fallback_weight_names + } ) ignore_patterns = sorted( { @@ -570,7 +733,7 @@ def build_modelopt_fp8_transformer( shard_tensors[name] = fallback_tensors[name] continue if name.endswith(".weight") and is_ignored_by_modelopt( - name, ignore_patterns + name, ignore_patterns, runtime_name_mapper ): preserved_ignored_weight_count += 1 continue @@ -620,7 +783,7 @@ def build_modelopt_fp8_transformer( for name in source_weight_map if name.endswith(".weight") and _resolve_scale_key(name, fp8_scale_map) is not None - and not is_ignored_by_modelopt(name, ignore_patterns) + and not is_ignored_by_modelopt(name, ignore_patterns, runtime_name_mapper) ), "bf16_fallback_weights": len(fallback_weight_names), "preserved_ignored_weights": preserved_ignored_weight_count, @@ -660,13 +823,21 @@ def _parse_args() -> argparse.Namespace: ) parser.add_argument( "--model-type", - choices=["auto", "flux1", "flux2", "ltx2", "qwen-image", "none"], + choices=[ + "auto", + "flux1", + "flux2", + "ltx2", + "hunyuan-video", + "qwen-image", + "none", + ], default="auto", help=( "Optional model-family BF16 fallback profile. 'none' uses the generic " "conversion path. 'auto' enables the validated FLUX.1 / FLUX.2 / LTX-2 / " - "Qwen Image fallback set when the export config matches those transformer " - "classes." + "HunyuanVideo / Qwen Image fallback sets when the export config matches " + "those transformer classes." ), ) parser.add_argument(