-
Notifications
You must be signed in to change notification settings - Fork 241
[OMNIML-2932] Fusing pre_quant_scale for NVFP4 AWQ #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b5153fa
711796d
98b3e5f
d4c73ad
63b8b37
8713e3b
7864f56
1b40581
6a704d4
2eeb5bb
0b14f0b
f74c041
d646fa9
756a4ed
0781465
fbd3ab3
dd7f8af
e49137e
506ba83
0573c7b
6854b80
25ff362
f55baad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -489,7 +489,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames | |
|
|
||
| if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): | ||
| return QUANTIZATION_NVFP4_AWQ | ||
| if getattr(layer, "fused_with_layernorm", False): | ||
| if getattr(layer, "fused_with_prequant", False): | ||
| return QUANTIZATION_NVFP4_AWQ | ||
| assert input_quantizer is not None, ( | ||
| f"input_quantizer is None for {quantizer_attr_names}" | ||
|
|
@@ -959,18 +959,145 @@ def all_items_same(item_list): | |
| return all(x == item_list[0] for x in item_list) | ||
|
|
||
|
|
||
| def _update_pre_quant_scale(module, new_pre_quant_scale): | ||
| old_pre_quant_scale = module.input_quantizer._pre_quant_scale | ||
| # do the processing in fp32 for numerical stability | ||
| dtype = module.weight.dtype | ||
| module.weight = nn.Parameter( | ||
| ( | ||
| module.weight.to(torch.float32) | ||
| * old_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device) | ||
| / new_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device) | ||
| ).to(dtype) | ||
| ) | ||
| module.input_quantizer.pre_quant_scale = new_pre_quant_scale | ||
|
|
||
| # Redo weights collection | ||
| module.weight_quantizer.reset_amax() | ||
| enable_stats_collection(module.weight_quantizer) | ||
| module.weight_quantizer(module.weight) | ||
| finish_stats_collection(module.weight_quantizer) | ||
|
|
||
|
|
||
| # Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) | ||
| PQS_FUSE_MODULE_MAPPING = [ | ||
| # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension | ||
| # Mathematical equivalence: | ||
| # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T | ||
| # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T | ||
| (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), | ||
| # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension | ||
| # Mathematical equivalence: | ||
| # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T | ||
| # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T | ||
| (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), | ||
| ] | ||
|
|
||
|
|
||
| def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False): | ||
| """Fuse pre_quant_scale to the linear weights if possible. | ||
|
|
||
| Args: | ||
| model: The model to fuse pre_quant_scale to. | ||
| fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale | ||
| and linear weights is not the same. | ||
|
|
||
| Returns: | ||
| fused_modules: A list of modules of which pre_quant_scale is fused to the previous linear layer. | ||
| """ | ||
| # Fuse pre_quant_scale to the linear weights | ||
| for _, module in model.named_modules(): | ||
| for module_map in PQS_FUSE_MODULE_MAPPING: | ||
| target_module_list = module_map[0] | ||
| linear_pair = module_map[1] | ||
| if any(module_name in type(module).__name__ for module_name in target_module_list): | ||
| linear_fuse_into = module.get_submodule(linear_pair[0]) | ||
| linear_pqs_from = module.get_submodule(linear_pair[1]) | ||
| if hasattr(linear_pqs_from, "input_quantizer") and hasattr( | ||
| linear_pqs_from.input_quantizer, "_pre_quant_scale" | ||
| ): | ||
| pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale | ||
|
|
||
| # for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups | ||
| if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]: | ||
| if ( | ||
| not fuse_grouped_heads | ||
| or "attention" not in type(module).__name__.lower() | ||
| ): | ||
| warn( | ||
| f"Skipping pattern fuse prequant for {type(module).__name__}" | ||
| f"pre_quant_scale dim {pre_quant_scale.numel()} != " | ||
| f"out_channel dim {linear_fuse_into.weight.shape[-2]}" | ||
| ) | ||
| continue | ||
| config = module.config | ||
| num_kv_heads = config.num_key_value_heads | ||
| kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads | ||
| n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim | ||
|
|
||
| # Reshape:(num_kv_heads, n_rep, kv_head_dim) | ||
| # n_rep is the number of query group | ||
| averaged_scale = pre_quant_scale.view( | ||
| num_kv_heads, n_rep, kv_head_dim | ||
| ).mean(dim=1) | ||
|
|
||
| # To update o_proj, we need to repeat back to original shape | ||
| repeated_scale = ( | ||
| averaged_scale.unsqueeze(1) | ||
| .expand(num_kv_heads, n_rep, kv_head_dim) | ||
| .reshape(-1) | ||
| ) | ||
| # Update o_proj's pre_quant_scale | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this update is regards to update o_proj's PQS so we can just take the first head and apply to v right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this updates the o_proj's PQS, so input channels of o_proj associated with the same query group (output channel) of v have the same prequant scale. |
||
| _update_pre_quant_scale(linear_pqs_from, repeated_scale) | ||
|
|
||
| # Use averaged scale (flattened) for v_proj fusion | ||
| pre_quant_scale = averaged_scale.reshape(-1) | ||
|
|
||
| # Fuse the pre_quant_scale to weight | ||
| linear_fuse_into.weight = torch.nn.Parameter( | ||
| linear_fuse_into.weight * pre_quant_scale.view(-1, 1) | ||
| ) | ||
| if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None: | ||
| linear_fuse_into.bias = torch.nn.Parameter( | ||
| linear_fuse_into.bias * pre_quant_scale | ||
| ) | ||
|
|
||
| # Recalibrate the weight quantizer for linear_fuse_into | ||
| linear_fuse_into.weight_quantizer.reset_amax() | ||
| enable_stats_collection(linear_fuse_into.weight_quantizer) | ||
| linear_fuse_into.weight_quantizer(linear_fuse_into.weight) | ||
| finish_stats_collection(linear_fuse_into.weight_quantizer) | ||
|
|
||
| delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale") | ||
| setattr(linear_pqs_from, "fused_with_prequant", True) | ||
|
|
||
|
|
||
| def fuse_prequant_layernorm( | ||
| layernorm_module: torch.nn.Module, | ||
| modules: list[torch.Tensor], | ||
| ): | ||
| """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.""" | ||
| """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted. | ||
|
|
||
| original: | ||
| layernorm_output = (normalization(input) * weight) + bias | ||
| layernorm_output_scaled = layernorm_output * pre_quant_scale | ||
|
|
||
| fused: | ||
| fused_weight = weight * avg_pre_quant_scale | ||
| fused_bias = bias * avg_pre_quant_scale | ||
| layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias | ||
| """ | ||
| layernorm_module.weight = torch.nn.Parameter( | ||
| layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") | ||
| ) | ||
| if hasattr(layernorm_module, "bias") and layernorm_module.bias is not None: | ||
| layernorm_module.bias = torch.nn.Parameter( | ||
| layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") | ||
| ) | ||
| # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm | ||
| for module in modules: | ||
| delattr(module.input_quantizer, "_pre_quant_scale") | ||
| setattr(module, "fused_with_layernorm", True) | ||
| setattr(module, "fused_with_prequant", True) | ||
|
|
||
|
|
||
| def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False): | ||
|
|
@@ -992,22 +1119,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False | |
|
|
||
| for module in modules: | ||
| if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale): | ||
| module.weight = nn.Parameter( | ||
| module.weight | ||
| * module.input_quantizer.pre_quant_scale.to( | ||
| dtype=module.weight.dtype, device=module.weight.device | ||
| ) | ||
| / avg_prequant_scale.to( | ||
| dtype=module.weight.dtype, device=module.weight.device | ||
| ) | ||
| ) | ||
| module.input_quantizer.pre_quant_scale = avg_prequant_scale | ||
|
|
||
| # Redo weights collection | ||
| module.weight_quantizer.reset_amax() | ||
| enable_stats_collection(module.weight_quantizer) | ||
| module.weight_quantizer(module.weight) | ||
| finish_stats_collection(module.weight_quantizer) | ||
| _update_pre_quant_scale(module, avg_prequant_scale) | ||
|
|
||
| if resmooth_only: | ||
| return | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,7 @@ | |
| from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only | ||
| from .quant_utils import ( | ||
| fuse_prequant_layernorm, | ||
| fuse_prequant_to_linear, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are quite different. |
||
| get_activation_scaling_factor, | ||
| get_quant_config, | ||
| get_quantization_format, | ||
|
|
@@ -107,6 +108,10 @@ def _output_hook(module, input, output): | |
| fused_linears = {} | ||
| module_names = set() | ||
|
|
||
| # Fuse pre_quant_scale to the linear weights if possible | ||
| if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): | ||
| fuse_prequant_to_linear(model) | ||
|
|
||
| for name, module in model.named_modules(): | ||
| module_names.add(name) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.