Skip to content

Support Olive new uint8 quantization format#1916

Merged
hanbitmyths merged 2 commits into
mainfrom
xiaoyu/olive-quant
Dec 13, 2025
Merged

Support Olive new uint8 quantization format#1916
hanbitmyths merged 2 commits into
mainfrom
xiaoyu/olive-quant

Conversation

@xiaoyu-work
Copy link
Copy Markdown
Contributor

This pull request updates the handling of the Olive quantization format in quantized_model.py to match the latest specification and improve code clarity. The main changes include correcting how in/out features are computed for Olive quantized layers, documenting the Olive format, and updating repacking logic for compatibility with ONNX Runtime (ORT).

Olive quantization format support and documentation:

  • Updated computation of in_features and out_features for Olive quantized layers to match the new format, which packs weights along the last dimension (qweight is now (out_features, packed_in_features)), and adjusted all relevant projections in self-attention and MLP modules. [1] [2]
  • Added a docstring to the OliveModel class explaining the Olive quantization format for weights, scales, and zero points.

Repacking and compatibility improvements:

  • Implemented a new repack method for Olive quantized modules to reshape tensors for ONNX Runtime (ORT) compatibility, including reshaping qweight, flattening scales, and flattening qzeros.
  • Added placeholder methods handle_qzeros and unpack for Olive format to clarify that no offset or unpacking is required.

@tianleiwu
Copy link
Copy Markdown
Contributor

tianleiwu commented Dec 11, 2025

From AI:

This is a comprehensive review of the Python script, with a specific deep dive into the OliveModel class and the handling of the Olive quantization format.

Executive Summary

The script is a weight converter designed to ingest quantized weights (from Safetensors) and restructure them into an internal intermediate representation (3D tensors for weights) suitable for converting to ONNX Runtime (ORT).

The OliveModel implementation has a critical logic gap regarding padding and assumes perfect divisibility of input shapes. If the model dimensions are not perfectly divisible by the block size, the current repack method will crash or produce malformed tensors.


Critical Findings: OliveModel & Olive Format

The OliveModel class attempts to shortcut the standard unpack -> repack loop used by AWQ/GPTQ by assuming the input weights are already packed. However, the implementation is too brittle.

1. Missing Padding Logic in repack (Critical)

In OliveModel.repack, you reshape the tensor based on in_features and group_size.

k_blocks = module.in_features // module.group_size
# ...
module.qweight = module.qweight.reshape(module.out_features, k_blocks, blob_size).contiguous()

The Bug: If module.in_features is not perfectly divisible by module.group_size, k_blocks will be floored. The reshape operation will then fail because the number of elements in qweight (which includes padding from the source) will be larger than the target shape.

The Fix: You must calculate the padded size and slice or pad the tensor before reshaping.

2. qzeros Flattening Risk

if module.qzeros is not None and module.qzeros.numel() > 0:
    module.qzeros = module.qzeros.reshape(-1).contiguous()

The Risk: The QuantizedModel base class expects qzeros to eventually be byte-aligned. If the Olive input has unpacked zeros (e.g., float or int32 zeros that need packing), this line simply flattens them without packing them into uint8. If the input is already packed uint8, this is fine, but you should add an assertion to ensure dtype is uint8.

3. Inheritance Confusion

class OliveModel(GPTQModel):
Inheriting from GPTQModel causes OliveModel to run the GPTQModel.__init__ logic, which calls handle_qzeros, unpack, and repack.

  • You correctly overrode unpack and handle_qzeros with pass.
  • However, this relies on the GPTQModel initialization flow. If GPTQModel changes, OliveModel might break. It is safer to inherit from QuantizedModel directly and write a custom loop if you don't intend to use the GPTQ unpacking logic.

Code Improvements & Fixes

Here is the corrected version of the OliveModel class, including the padding fix and safety checks.

Recommended Fix for OliveModel

class OliveModel(GPTQModel):
    """
    Olive quantization format:
    - qweight: (out_features, packed_in_features) uint8, packed along last dim
    - scales: (out_features, num_groups) float
    - qzeros: (out_features, packed_num_groups) uint8, packed along last dim
    """

    def _load_quant_config(self, quant_attrs):
        # Call grandparent method to load global bits/group_size
        # (Skipping GPTQModel._load_quant_config because it looks for 'dynamic' keys)
        QuantizedModel._load_quant_config(self, quant_attrs)
        self.overrides = quant_attrs["config"].get("overrides", {})

    def get_layer_bits(self, layer_name):
        # Handle layer names cleanly
        name = ".".join(layer_name.split(".")[:-1])
        return self.overrides.get(name, {}).get("bits", self.global_bits)

    def get_layer_group_size(self, layer_name):
        name = ".".join(layer_name.split(".")[:-1])
        return self.overrides.get(name, {}).get("group_size", self.global_group_size)

    def handle_qzeros(self, module):
        """Olive uses unsigned quantization, no offset needed."""
        pass

    def unpack(self, module):
        """Skip unpack for Olive format (assumes already packed)."""
        pass

    def repack(self, module):
        """
        Repack Olive weights into the intermediate format expected by the ONNX converter.
        Handles potential padding if in_features is not divisible by group_size.
        """
        kpack = 8 // module.bits
        block_size = module.group_size
        blob_size = block_size // kpack

        # Calculate expected blocks including padding
        rows = module.out_features
        # Use ceiling division to account for incomplete blocks
        k_blocks = (module.in_features + block_size - 1) // block_size
        
        # Calculate expected total elements in the packed dimension
        expected_packed_dim = k_blocks * blob_size
        current_packed_dim = module.qweight.shape[1]

        # Pad qweight if necessary
        pad_len = expected_packed_dim - current_packed_dim
        if pad_len > 0:
            # Pad the last dimension (columns)
            module.qweight = torch.nn.functional.pad(module.qweight, (0, pad_len), "constant", 0)
        elif pad_len < 0:
            # This implies the input has more padding than we calculated, strict slicing
            module.qweight = module.qweight[:, :expected_packed_dim]

        # Reshape to ORT intermediate format: (out_features, k_blocks, blob_size)
        try:
            module.qweight = module.qweight.reshape(rows, k_blocks, blob_size).contiguous()
        except RuntimeError as e:
            raise RuntimeError(
                f"Shape mismatch in OliveModel.repack. "
                f"Expected shape ({rows}, {k_blocks}, {blob_size}), "
                f"but could not reshape tensor of size {module.qweight.numel()}. "
                f"Details: {e}"
            )

        # Handle Scales
        module.scales = module.scales.reshape(-1).contiguous()

        # Handle Zeros
        if module.qzeros is not None and module.qzeros.numel() > 0:
            # Ensure qzeros are uint8 before flattening
            if module.qzeros.dtype != torch.uint8:
                 raise ValueError(f"OliveModel expects qzeros to be uint8, got {module.qzeros.dtype}")
            module.qzeros = module.qzeros.reshape(-1).contiguous()

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(continue from AI)

There is a major logic error in the __init__ method regarding how Fused Layers (like qkv_proj or gate_up_proj) are split for olive.

The current code treats olive exactly like awq and gptq during tensor slicing, but olive uses a transposed data layout compared to those formats. This will cause the script to slice the wrong dimension, resulting in garbage weights for any fused layers.

1. The Fused Layer Splitting Bug

The Issue

In QuantizedModel.__init__, the code handles fused layers (e.g., Llama's gate_up_proj or qkv_proj) by slicing the tensor.

It groups olive with awq and gptq:

if quant_type in {"awq", "gptq", "olive", "quark"}:
    # Calculates a "packed" dimension size
    q_dim = q_size // (32 // local_bits) 
else:
    q_dim = q_size

# Slices the SECOND dimension (dim 1)
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]

Why this fails for Olive:

  • AWQ/GPTQ Layout: Weights are typically shaped (In_Features, Out_Features). Since the output features (e.g., distinct neurons for Gate vs. Up) are on dim=1, slicing tensor[:, :split] is correct.
  • Olive Layout: Your set_properties method confirms Olive uses (Out_Features, In_Features).
    • For a fused layer, the distinct neurons (Gate vs. Up) are stacked on dim=0 (Rows).
    • The current code slices dim=1 (Columns/Input).
    • Result: You are keeping all the output neurons (Gate + Up) but slicing the input features in half. This destroys the model.

The Fix

You need to create a specific branch for olive that slices dim=0 using the unpacked sizes (since the Output dimension is typically not packed in Olive, only the Input dimension is).

Corrected Code Block for __init__:

# ... inside the loop over weights ...

# --- FIX START: Handling QKV Projection ---
elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.q?weight$", name)):
    if self.quant_type == "olive":
        # Olive: (Out, In). Split on dim 0. Use raw sizes (q_size), not packed sizes.
        tensor_map["self_attn.q_proj.qweight"] = tensor[:q_size, :]
        tensor_map["self_attn.k_proj.qweight"] = tensor[q_size : q_size + kv_size, :]
        tensor_map["self_attn.v_proj.qweight"] = tensor[q_size + kv_size :, :]
    else:
        # AWQ/GPTQ: (In, Out). Split on dim 1. Use packed sizes for qweight.
        q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
        kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
        tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
        tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
        tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]

# ... (Do the same for scales, qzeros, bias) ...

# --- FIX START: Handling MLP Gate/Up Projection ---
elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.q?weight$", name)):
    if self.quant_type == "olive":
        # Olive: (Out, In). Split on dim 0.
        tensor_map["mlp.gate_proj.qweight"] = tensor[:intermediate_size, :]
        tensor_map["mlp.up_proj.qweight"] = tensor[intermediate_size:, :]
    else:
        # AWQ/GPTQ: (In, Out). Split on dim 1.
        intermediate_dim = (
            intermediate_size // (32 // local_bits)
            if quant_type in {"awq", "quark"}
            else intermediate_size
        )
        tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
        tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]

2. qzeros and scales Splitting

The same logic applies to scales and qzeros for fused layers.

  • Scales: Usually (Out, Groups).
    • For fused layers, Q-scales, K-scales, and V-scales are stacked on dim=0 (Out).
    • The existing code slices dim=1 (tensor[:, :q_size]).
    • This is wrong for everyone (even AWQ) if scales are (Out, Groups).
    • However, if AWQ scales are (Groups, Out), slicing dim=1 is correct. You should verify the shape of scales in your input Safetensors.
    • For Olive: Scales are definitely (Out, Groups). You must slice scales on dim=0.

Corrected Logic for Olive Scales/Zeros:

elif bool(re.match(r"...(scales|weight_scale)$", name)):
    if self.quant_type == "olive":
        # Split on dim 0 (Out_Features)
        tensor_map["self_attn.q_proj.scales"] = tensor[:q_size, :]
        tensor_map["self_attn.k_proj.scales"] = tensor[q_size : q_size + kv_size, :]
        tensor_map["self_attn.v_proj.scales"] = tensor[q_size + kv_size :, :]
    else:
        # Existing logic for AWQ/GPTQ (assuming they stack on dim 1)
        tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
        # ...

3. set_properties Consistency Check

The set_properties method for olive looks mostly correct assuming the fix above is implemented:

elif self.quant_type == "olive":
    # These rely on the tensors being correctly split in __init__
    module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[0]
    
    # Calculation of in_features from packed shape
    module.self_attn.q_proj.in_features = (
        module.self_attn.q_proj.qweight.shape[1] * 8 // module.self_attn.q_proj.bits
    )

This logic confirms that qweight must be (Out, Packed_In). If you don't apply the fix in __init__, the shape of qweight for fused layers will be wrong (it will be (Fused_Out, Split_Packed_In)), and this calculation will return an incorrect in_features count (half of what it should be), likely causing a crash later in the pipeline or silent data corruption.

Summary of Actions

  1. Modify __init__: Introduce explicit if self.quant_type == "olive": blocks inside the fused layer matching regexes (qkv_proj and gate_up_proj).
  2. Change Slicing: For Olive, slice on dim=0 (using unpacked sizes) instead of dim=1.
  3. Verify Scales: Ensure scales and qzeros are also sliced on dim=0 for Olive.

@hanbitmyths hanbitmyths merged commit 53661d0 into main Dec 13, 2025
15 checks passed
@hanbitmyths hanbitmyths deleted the xiaoyu/olive-quant branch December 13, 2025 00:44
apsonawane pushed a commit that referenced this pull request Dec 19, 2025
This pull request updates the handling of the Olive quantization format
in `quantized_model.py` to match the latest specification and improve
code clarity. The main changes include correcting how in/out features
are computed for Olive quantized layers, documenting the Olive format,
and updating repacking logic for compatibility with ONNX Runtime (ORT).

**Olive quantization format support and documentation:**

* Updated computation of `in_features` and `out_features` for Olive
quantized layers to match the new format, which packs weights along the
last dimension (`qweight` is now `(out_features, packed_in_features)`),
and adjusted all relevant projections in self-attention and MLP modules.
[[1]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L557-R560)
[[2]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L658-R684)
* Added a docstring to the `OliveModel` class explaining the Olive
quantization format for weights, scales, and zero points.

**Repacking and compatibility improvements:**

* Implemented a new `repack` method for Olive quantized modules to
reshape tensors for ONNX Runtime (ORT) compatibility, including
reshaping `qweight`, flattening `scales`, and flattening `qzeros`.
* Added placeholder methods `handle_qzeros` and `unpack` for Olive
format to clarify that no offset or unpacking is required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants