Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Model Optimizer Changelog (Linux)
- Enabled native Modelopt quantization support for FP8 and NVFP4 formats in SGLang. See `SGLang quantization documentation <https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/quantization.md#using-nvidia-modelopt>`_ for more details.
- Added modelopt quantized checkpoints in vLLM/SGLang CI/CD pipelines (PRs are under review).
- Add support for exporting QLoRA checkpoint fintuned using ModelOpt.
- Update NVFP4 AWQ checkpoint export. It now fuses scaling factors of o_proj and down_proj layers into the model when possible to facilitate deployment.

**Documentation**

Expand Down
150 changes: 131 additions & 19 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames

if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_layernorm", False):
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
Expand Down Expand Up @@ -959,18 +959,145 @@ def all_items_same(item_list):
return all(x == item_list[0] for x in item_list)


def _update_pre_quant_scale(module, new_pre_quant_scale):
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
# do the processing in fp32 for numerical stability
dtype = module.weight.dtype
module.weight = nn.Parameter(
(
module.weight.to(torch.float32)
* old_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device)
/ new_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device)
).to(dtype)
)
module.input_quantizer.pre_quant_scale = new_pre_quant_scale

# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)


# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
PQS_FUSE_MODULE_MAPPING = [
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
# Mathematical equivalence:
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
]


def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False):
"""Fuse pre_quant_scale to the linear weights if possible.

Args:
model: The model to fuse pre_quant_scale to.
fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
and linear weights is not the same.

Returns:
fused_modules: A list of modules of which pre_quant_scale is fused to the previous linear layer.
"""
# Fuse pre_quant_scale to the linear weights
for _, module in model.named_modules():
for module_map in PQS_FUSE_MODULE_MAPPING:
target_module_list = module_map[0]
linear_pair = module_map[1]
if any(module_name in type(module).__name__ for module_name in target_module_list):
linear_fuse_into = module.get_submodule(linear_pair[0])
linear_pqs_from = module.get_submodule(linear_pair[1])
if hasattr(linear_pqs_from, "input_quantizer") and hasattr(
linear_pqs_from.input_quantizer, "_pre_quant_scale"
):
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale

# for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]:
if (
not fuse_grouped_heads
or "attention" not in type(module).__name__.lower()
):
warn(
f"Skipping pattern fuse prequant for {type(module).__name__}"
f"pre_quant_scale dim {pre_quant_scale.numel()} != "
f"out_channel dim {linear_fuse_into.weight.shape[-2]}"
)
continue
config = module.config
num_kv_heads = config.num_key_value_heads
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim

# Reshape:(num_kv_heads, n_rep, kv_head_dim)
# n_rep is the number of query group
averaged_scale = pre_quant_scale.view(
num_kv_heads, n_rep, kv_head_dim
).mean(dim=1)

# To update o_proj, we need to repeat back to original shape
repeated_scale = (
averaged_scale.unsqueeze(1)
.expand(num_kv_heads, n_rep, kv_head_dim)
.reshape(-1)
)
# Update o_proj's pre_quant_scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this update is regards to update o_proj's PQS so we can just take the first head and apply to v right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this updates the o_proj's PQS, so input channels of o_proj associated with the same query group (output channel) of v have the same prequant scale.

_update_pre_quant_scale(linear_pqs_from, repeated_scale)

# Use averaged scale (flattened) for v_proj fusion
pre_quant_scale = averaged_scale.reshape(-1)

# Fuse the pre_quant_scale to weight
linear_fuse_into.weight = torch.nn.Parameter(
linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
)
if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None:
linear_fuse_into.bias = torch.nn.Parameter(
linear_fuse_into.bias * pre_quant_scale
)

# Recalibrate the weight quantizer for linear_fuse_into
linear_fuse_into.weight_quantizer.reset_amax()
enable_stats_collection(linear_fuse_into.weight_quantizer)
linear_fuse_into.weight_quantizer(linear_fuse_into.weight)
finish_stats_collection(linear_fuse_into.weight_quantizer)

delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale")
setattr(linear_pqs_from, "fused_with_prequant", True)


def fuse_prequant_layernorm(
layernorm_module: torch.nn.Module,
modules: list[torch.Tensor],
):
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.

original:
layernorm_output = (normalization(input) * weight) + bias
layernorm_output_scaled = layernorm_output * pre_quant_scale

fused:
fused_weight = weight * avg_pre_quant_scale
fused_bias = bias * avg_pre_quant_scale
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
"""
layernorm_module.weight = torch.nn.Parameter(
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
if hasattr(layernorm_module, "bias") and layernorm_module.bias is not None:
layernorm_module.bias = torch.nn.Parameter(
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
for module in modules:
delattr(module.input_quantizer, "_pre_quant_scale")
setattr(module, "fused_with_layernorm", True)
setattr(module, "fused_with_prequant", True)


def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False):
Expand All @@ -992,22 +1119,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False

for module in modules:
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
module.weight = nn.Parameter(
module.weight
* module.input_quantizer.pre_quant_scale.to(
dtype=module.weight.dtype, device=module.weight.device
)
/ avg_prequant_scale.to(
dtype=module.weight.dtype, device=module.weight.device
)
)
module.input_quantizer.pre_quant_scale = avg_prequant_scale

# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)
_update_pre_quant_scale(module, avg_prequant_scale)

if resmooth_only:
return
Expand Down
5 changes: 5 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
fuse_prequant_to_linear,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use_prequant_to_linear and fuse_prequant_layernorm be combined or they are mutual exclusive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are quite different. use_prequant_to_linear is rule-based fusion and doesn't need graph tracing.

get_activation_scaling_factor,
get_quant_config,
get_quantization_format,
Expand Down Expand Up @@ -107,6 +108,10 @@ def _output_hook(module, input, output):
fused_linears = {}
module_names = set()

# Fuse pre_quant_scale to the linear weights if possible
if quantization_format is not None and "nvfp4_awq" in quantization_format.lower():
fuse_prequant_to_linear(model)

for name, module in model.named_modules():
module_names.add(name)

Expand Down
Loading