Support Olive new uint8 quantization format#1916
Conversation
From AI:This is a comprehensive review of the Python script, with a specific deep dive into the Executive SummaryThe script is a weight converter designed to ingest quantized weights (from Safetensors) and restructure them into an internal intermediate representation (3D tensors for weights) suitable for converting to ONNX Runtime (ORT). The Critical Findings:
|
tianleiwu
left a comment
There was a problem hiding this comment.
(continue from AI)
There is a major logic error in the __init__ method regarding how Fused Layers (like qkv_proj or gate_up_proj) are split for olive.
The current code treats olive exactly like awq and gptq during tensor slicing, but olive uses a transposed data layout compared to those formats. This will cause the script to slice the wrong dimension, resulting in garbage weights for any fused layers.
1. The Fused Layer Splitting Bug
The Issue
In QuantizedModel.__init__, the code handles fused layers (e.g., Llama's gate_up_proj or qkv_proj) by slicing the tensor.
It groups olive with awq and gptq:
if quant_type in {"awq", "gptq", "olive", "quark"}:
# Calculates a "packed" dimension size
q_dim = q_size // (32 // local_bits)
else:
q_dim = q_size
# Slices the SECOND dimension (dim 1)
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]Why this fails for Olive:
- AWQ/GPTQ Layout: Weights are typically shaped
(In_Features, Out_Features). Since the output features (e.g., distinct neurons for Gate vs. Up) are ondim=1, slicingtensor[:, :split]is correct. - Olive Layout: Your
set_propertiesmethod confirms Olive uses(Out_Features, In_Features).- For a fused layer, the distinct neurons (Gate vs. Up) are stacked on
dim=0(Rows). - The current code slices
dim=1(Columns/Input). - Result: You are keeping all the output neurons (Gate + Up) but slicing the input features in half. This destroys the model.
- For a fused layer, the distinct neurons (Gate vs. Up) are stacked on
The Fix
You need to create a specific branch for olive that slices dim=0 using the unpacked sizes (since the Output dimension is typically not packed in Olive, only the Input dimension is).
Corrected Code Block for __init__:
# ... inside the loop over weights ...
# --- FIX START: Handling QKV Projection ---
elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.q?weight$", name)):
if self.quant_type == "olive":
# Olive: (Out, In). Split on dim 0. Use raw sizes (q_size), not packed sizes.
tensor_map["self_attn.q_proj.qweight"] = tensor[:q_size, :]
tensor_map["self_attn.k_proj.qweight"] = tensor[q_size : q_size + kv_size, :]
tensor_map["self_attn.v_proj.qweight"] = tensor[q_size + kv_size :, :]
else:
# AWQ/GPTQ: (In, Out). Split on dim 1. Use packed sizes for qweight.
q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]
# ... (Do the same for scales, qzeros, bias) ...
# --- FIX START: Handling MLP Gate/Up Projection ---
elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.q?weight$", name)):
if self.quant_type == "olive":
# Olive: (Out, In). Split on dim 0.
tensor_map["mlp.gate_proj.qweight"] = tensor[:intermediate_size, :]
tensor_map["mlp.up_proj.qweight"] = tensor[intermediate_size:, :]
else:
# AWQ/GPTQ: (In, Out). Split on dim 1.
intermediate_dim = (
intermediate_size // (32 // local_bits)
if quant_type in {"awq", "quark"}
else intermediate_size
)
tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]2. qzeros and scales Splitting
The same logic applies to scales and qzeros for fused layers.
- Scales: Usually
(Out, Groups).- For fused layers, Q-scales, K-scales, and V-scales are stacked on
dim=0(Out). - The existing code slices
dim=1(tensor[:, :q_size]). - This is wrong for everyone (even AWQ) if scales are
(Out, Groups). - However, if AWQ scales are
(Groups, Out), slicingdim=1is correct. You should verify the shape ofscalesin your input Safetensors. - For Olive: Scales are definitely
(Out, Groups). You must slicescalesondim=0.
- For fused layers, Q-scales, K-scales, and V-scales are stacked on
Corrected Logic for Olive Scales/Zeros:
elif bool(re.match(r"...(scales|weight_scale)$", name)):
if self.quant_type == "olive":
# Split on dim 0 (Out_Features)
tensor_map["self_attn.q_proj.scales"] = tensor[:q_size, :]
tensor_map["self_attn.k_proj.scales"] = tensor[q_size : q_size + kv_size, :]
tensor_map["self_attn.v_proj.scales"] = tensor[q_size + kv_size :, :]
else:
# Existing logic for AWQ/GPTQ (assuming they stack on dim 1)
tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
# ...3. set_properties Consistency Check
The set_properties method for olive looks mostly correct assuming the fix above is implemented:
elif self.quant_type == "olive":
# These rely on the tensors being correctly split in __init__
module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[0]
# Calculation of in_features from packed shape
module.self_attn.q_proj.in_features = (
module.self_attn.q_proj.qweight.shape[1] * 8 // module.self_attn.q_proj.bits
)This logic confirms that qweight must be (Out, Packed_In). If you don't apply the fix in __init__, the shape of qweight for fused layers will be wrong (it will be (Fused_Out, Split_Packed_In)), and this calculation will return an incorrect in_features count (half of what it should be), likely causing a crash later in the pipeline or silent data corruption.
Summary of Actions
- Modify
__init__: Introduce explicitif self.quant_type == "olive":blocks inside the fused layer matching regexes (qkv_projandgate_up_proj). - Change Slicing: For Olive, slice on
dim=0(using unpacked sizes) instead ofdim=1. - Verify Scales: Ensure
scalesandqzerosare also sliced ondim=0for Olive.
This pull request updates the handling of the Olive quantization format in `quantized_model.py` to match the latest specification and improve code clarity. The main changes include correcting how in/out features are computed for Olive quantized layers, documenting the Olive format, and updating repacking logic for compatibility with ONNX Runtime (ORT). **Olive quantization format support and documentation:** * Updated computation of `in_features` and `out_features` for Olive quantized layers to match the new format, which packs weights along the last dimension (`qweight` is now `(out_features, packed_in_features)`), and adjusted all relevant projections in self-attention and MLP modules. [[1]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L557-R560) [[2]](diffhunk://#diff-8c2caf775960974ce923934b24e069fae5b819a0fa972976363ab8689f996c23L658-R684) * Added a docstring to the `OliveModel` class explaining the Olive quantization format for weights, scales, and zero points. **Repacking and compatibility improvements:** * Implemented a new `repack` method for Olive quantized modules to reshape tensors for ONNX Runtime (ORT) compatibility, including reshaping `qweight`, flattening `scales`, and flattening `qzeros`. * Added placeholder methods `handle_qzeros` and `unpack` for Olive format to clarify that no offset or unpacking is required.
This pull request updates the handling of the Olive quantization format in
quantized_model.pyto match the latest specification and improve code clarity. The main changes include correcting how in/out features are computed for Olive quantized layers, documenting the Olive format, and updating repacking logic for compatibility with ONNX Runtime (ORT).Olive quantization format support and documentation:
in_featuresandout_featuresfor Olive quantized layers to match the new format, which packs weights along the last dimension (qweightis now(out_features, packed_in_features)), and adjusted all relevant projections in self-attention and MLP modules. [1] [2]OliveModelclass explaining the Olive quantization format for weights, scales, and zero points.Repacking and compatibility improvements:
repackmethod for Olive quantized modules to reshape tensors for ONNX Runtime (ORT) compatibility, including reshapingqweight, flatteningscales, and flatteningqzeros.handle_qzerosandunpackfor Olive format to clarify that no offset or unpacking is required.