diff --git a/docs/diffusion/quantization.md b/docs/diffusion/quantization.md index 970b1ee5d743..4d876839ae35 100644 --- a/docs/diffusion/quantization.md +++ b/docs/diffusion/quantization.md @@ -1,6 +1,6 @@ # Quantization -SGLang-Diffusion supports quantized transformer checkpoints. In most cases, keep +SGLang-Diffusion supports quantized transformer checkpoints and online quantization. In most cases, keep the base model and the quantized transformer override separate. ## Quick Reference @@ -47,6 +47,7 @@ backend. | `modelopt-nvfp4` | Mixed transformer directory/repo with `config.json`, or raw NVFP4 safetensors export/repo | `--transformer-path` for mixed overrides; `--transformer-weights-path` for raw exports | FLUX.1, FLUX.2, Wan2.2 | None | Mixed override repos keep the base model separate; raw exports such as `black-forest-labs/FLUX.2-dev-NVFP4` still use the weights-path flow | | `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` | | `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` | +| `fp8 (online)` | Any unquantized BF16 checkpoint | `--quantization fp8` | ALL | None | Online FP8 weight quantization at load time with dynamic activations | ## Validated ModelOpt Checkpoints @@ -282,6 +283,37 @@ sglang generate \ - Current runtime validation only allows Nunchaku on NVIDIA CUDA Ampere (SM8x) or SM12x GPUs. Hopper (SM90) is currently rejected. +## Online Quantization + +SGLang-Diffusion supports online quantization of model weights at load +time via the `--quantization` flag. This lets you run quantized inference +from any BF16 checkpoint without a pre-converted or calibrated checkpoint. + +### Supported Methods + +| value | precision | +|------------------------|-----------| +| `fp8` | FP8 (E4M3) | + +### Usage Examples + +Apply online FP8 quantization to a FLUX model: + +```bash +sglang generate \ + --model-path black-forest-labs/FLUX.1-dev \ + --quantization fp8 \ + --prompt "A small cat" \ + --save-output +``` + +### Notes + +- `--quantization` is independent of the checkpoint-based quantization. + It does not require `--transformer-path` or `--transformer-weights-path`. +- When `--quantization` is set, it takes precedence over any quantization + config found in the checkpoint metadata. + ## [ModelSlim](https://gitcode.com/Ascend/msmodelslim) MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression tool launched by MindStudio and optimized for Ascend hardware. diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 2f2057b8dd65..328476ebbac6 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -485,8 +485,19 @@ def _resolve_quant_config( arch_config, "reverse_param_names_mapping", None ) - quant_config = get_quant_config(hf_config, component_model_path) + quant_config = get_quant_config( + hf_config, component_model_path, server_args=server_args + ) + if quant_config is None and server_args.transformer_weights_path: + quant_config = _resolve_quant_config_from_transformer_override( + server_args.transformer_weights_path + ) + quant_config_name = _get_quant_config_name(quant_config) + + if quant_config is not None and quant_config_name != "modelopt_fp4": + return quant_config + inferred_nvfp4_config = None if quant_config is None or quant_config_name == "modelopt_fp4": fallback_group_size = None @@ -498,13 +509,7 @@ def _resolve_quant_config( reverse_param_names_mapping_dict, fallback_group_size, ) - quant_config = _merge_modelopt_fp4_configs(quant_config, inferred_nvfp4_config) - if quant_config is not None or not server_args.transformer_weights_path: - return quant_config - quant_config = _resolve_quant_config_from_transformer_override( - server_args.transformer_weights_path - ) quant_config = _merge_modelopt_fp4_configs(quant_config, inferred_nvfp4_config) if quant_config is not None: return quant_config diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index cdc82c78f90d..4200a563781f 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -180,6 +180,8 @@ class ServerArgs(DisaggArgsMixin): # path to pre-quantized transformer weights (single .safetensors or directory). transformer_weights_path: str | None = None + # Online quantization method to apply + quantization: str | None = None # can restrict layers to adapt, e.g. ["q_proj"] # Will adapt only q, k, v, o by default. lora_target_modules: list[str] | None = None @@ -1004,6 +1006,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable autocast for denoising loop and vae decoding in pipeline sampling", ) + parser.add_argument( + "--quantization", + type=str, + default=ServerArgs.quantization, + choices=["fp8"], + help="Apply online quantization to model weights. " + "Quantizes weights on-the-fly at load time, no pre-converted checkpoint needed.", + ) + # Nunchaku SVDQuant quantization parameters NunchakuSVDQuantArgs.add_cli_args(parser) diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index 2ddf2968739a..6715a93edcce 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -11,6 +11,8 @@ QuantizationConfig, get_quantization_config, ) +from sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config +from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -109,6 +111,16 @@ def find_quant_modelslim_config(model_config, component_model_path): return quant_cfg +def resolve_online_quant_config(quantization_method: str) -> QuantizationConfig: + if quantization_method == "fp8": + logger.info("Online FP8 quantization enabled.") + return Fp8Config( + is_checkpoint_fp8_serialized=False, + activation_scheme="dynamic", + ) + return None + + def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str: for prefix, new_prefix in prefix_mapping.items(): if key.startswith(prefix): @@ -121,7 +133,12 @@ def get_quant_config( component_model_path: str, packed_modules_mapping: Dict[str, List[str]] = {}, remap_prefix: Dict[str, str] | None = None, + server_args: ServerArgs | None = None, ) -> QuantizationConfig: + + if server_args and server_args.quantization: + return resolve_online_quant_config(server_args.quantization) + quant_cfg = find_quant_modelslim_config(model_config, component_model_path) if quant_cfg is not None: quant_cls = _load_quant_cls(quant_cfg)