Skip to content
Closed
34 changes: 33 additions & 1 deletion docs/diffusion/quantization.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Quantization

SGLang-Diffusion supports quantized transformer checkpoints. In most cases, keep
SGLang-Diffusion supports quantized transformer checkpoints and online quantization. In most cases, keep
the base model and the quantized transformer override separate.

## Quick Reference
Expand Down Expand Up @@ -47,6 +47,7 @@ backend.
| `modelopt-nvfp4` | Mixed transformer directory/repo with `config.json`, or raw NVFP4 safetensors export/repo | `--transformer-path` for mixed overrides; `--transformer-weights-path` for raw exports | FLUX.1, FLUX.2, Wan2.2 | None | Mixed override repos keep the base model separate; raw exports such as `black-forest-labs/FLUX.2-dev-NVFP4` still use the weights-path flow |
| `nunchaku-svdq` | Pre-quantized Nunchaku transformer weights, usually named `svdq-{int4\|fp4}_r{rank}-...` | `--transformer-weights-path` | Model-specific support such as Qwen-Image, FLUX, and Z-Image | `nunchaku` | SGLang can infer precision and rank from the filename and supports both `int4` and `nvfp4` |
| `msmodelslim` | Pre-quantized msmodelslim transformer weights | `--model-path` | Wan2.2 family | None | Currently only compatible with the Ascend NPU family and supports both `w8a8` and `w4a4` |
| `fp8 (online)` | Any unquantized BF16 checkpoint | `--quantization fp8` | ALL | None | Online FP8 weight quantization at load time with dynamic activations |

## Validated ModelOpt Checkpoints

Expand Down Expand Up @@ -282,6 +283,37 @@ sglang generate \
- Current runtime validation only allows Nunchaku on NVIDIA CUDA Ampere (SM8x)
or SM12x GPUs. Hopper (SM90) is currently rejected.

## Online Quantization

SGLang-Diffusion supports online quantization of model weights at load
time via the `--quantization` flag. This lets you run quantized inference
from any BF16 checkpoint without a pre-converted or calibrated checkpoint.

### Supported Methods

| value | precision |
|------------------------|-----------|
| `fp8` | FP8 (E4M3) |

### Usage Examples

Apply online FP8 quantization to a FLUX model:

```bash
sglang generate \
--model-path black-forest-labs/FLUX.1-dev \
--quantization fp8 \
--prompt "A small cat" \
--save-output
```

### Notes

- `--quantization` is independent of the checkpoint-based quantization.
It does not require `--transformer-path` or `--transformer-weights-path`.
- When `--quantization` is set, it takes precedence over any quantization
config found in the checkpoint metadata.

## [ModelSlim](https://gitcode.com/Ascend/msmodelslim)
MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression tool launched by MindStudio and optimized for Ascend hardware.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,19 @@ def _resolve_quant_config(
arch_config, "reverse_param_names_mapping", None
)

quant_config = get_quant_config(hf_config, component_model_path)
quant_config = get_quant_config(
hf_config, component_model_path, server_args=server_args
)
if quant_config is None and server_args.transformer_weights_path:
quant_config = _resolve_quant_config_from_transformer_override(
server_args.transformer_weights_path
)

quant_config_name = _get_quant_config_name(quant_config)

if quant_config is not None and quant_config_name != "modelopt_fp4":
return quant_config

inferred_nvfp4_config = None
if quant_config is None or quant_config_name == "modelopt_fp4":
Comment thread
avjves marked this conversation as resolved.
fallback_group_size = None
Expand All @@ -498,13 +509,7 @@ def _resolve_quant_config(
reverse_param_names_mapping_dict,
fallback_group_size,
)
quant_config = _merge_modelopt_fp4_configs(quant_config, inferred_nvfp4_config)
if quant_config is not None or not server_args.transformer_weights_path:
return quant_config

quant_config = _resolve_quant_config_from_transformer_override(
server_args.transformer_weights_path
)
quant_config = _merge_modelopt_fp4_configs(quant_config, inferred_nvfp4_config)
if quant_config is not None:
return quant_config
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ class ServerArgs(DisaggArgsMixin):

# path to pre-quantized transformer weights (single .safetensors or directory).
transformer_weights_path: str | None = None
# Online quantization method to apply
quantization: str | None = None
# can restrict layers to adapt, e.g. ["q_proj"]
# Will adapt only q, k, v, o by default.
lora_target_modules: list[str] | None = None
Expand Down Expand Up @@ -1004,6 +1006,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
)

parser.add_argument(
"--quantization",
type=str,
default=ServerArgs.quantization,
choices=["fp8"],
help="Apply online quantization to model weights. "
"Quantizes weights on-the-fly at load time, no pre-converted checkpoint needed.",
)

# Nunchaku SVDQuant quantization parameters
NunchakuSVDQuantArgs.add_cli_args(parser)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
QuantizationConfig,
get_quantization_config,
)
from sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -109,6 +111,16 @@ def find_quant_modelslim_config(model_config, component_model_path):
return quant_cfg


def resolve_online_quant_config(quantization_method: str) -> QuantizationConfig:
if quantization_method == "fp8":
logger.info("Online FP8 quantization enabled.")
return Fp8Config(
is_checkpoint_fp8_serialized=False,
activation_scheme="dynamic",
)
return None


def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
for prefix, new_prefix in prefix_mapping.items():
if key.startswith(prefix):
Expand All @@ -121,7 +133,12 @@ def get_quant_config(
component_model_path: str,
packed_modules_mapping: Dict[str, List[str]] = {},
remap_prefix: Dict[str, str] | None = None,
server_args: ServerArgs | None = None,
) -> QuantizationConfig:

if server_args and server_args.quantization:
return resolve_online_quant_config(server_args.quantization)

quant_cfg = find_quant_modelslim_config(model_config, component_model_path)
if quant_cfg is not None:
quant_cls = _load_quant_cls(quant_cfg)
Expand Down
Loading