Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 138 additions & 56 deletions src/python/py/models/quantized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,18 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.self_attention.query_key_value.qweight
# model.layers.layer_id.self_attn.qkv_proj.weight
# model.layers.layer_id.self_attention.query_key_value.weight
q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]
if quant_type == "olive":
# Olive: (out_features, in_features), split on dim=0
tensor_map["self_attn.q_proj.qweight"] = tensor[:q_size, :]
tensor_map["self_attn.k_proj.qweight"] = tensor[q_size : q_size + kv_size, :]
tensor_map["self_attn.v_proj.qweight"] = tensor[q_size + kv_size :, :]
else:
# AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size
kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size
tensor_map["self_attn.q_proj.qweight"] = tensor[:, :q_dim]
tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim]
tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :]
elif bool(
re.match(
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(scales|weight_scale)$",
Expand All @@ -381,9 +388,16 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.self_attention.query_key_value.scales
# model.layers.layer_id.self_attn.qkv_proj.weight_scale
# model.layers.layer_id.self_attention.query_key_value.weight_scale
tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
tensor_map["self_attn.k_proj.scales"] = tensor[:, q_size : q_size + kv_size]
tensor_map["self_attn.v_proj.scales"] = tensor[:, q_size + kv_size :]
if quant_type == "olive":
# Olive: (out_features, num_groups), split on dim=0
tensor_map["self_attn.q_proj.scales"] = tensor[:q_size, :]
tensor_map["self_attn.k_proj.scales"] = tensor[q_size : q_size + kv_size, :]
tensor_map["self_attn.v_proj.scales"] = tensor[q_size + kv_size :, :]
else:
# AWQ/GPTQ/Quark: split on dim=1
tensor_map["self_attn.q_proj.scales"] = tensor[:, :q_size]
tensor_map["self_attn.k_proj.scales"] = tensor[:, q_size : q_size + kv_size]
tensor_map["self_attn.v_proj.scales"] = tensor[:, q_size + kv_size :]
elif bool(
re.match(
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(qzeros|weight_zero_point)$",
Expand All @@ -394,19 +408,28 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.self_attention.query_key_value.qzeros
# model.layers.layer_id.self_attn.qkv_proj.weight_zero_point
# model.layers.layer_id.self_attention.query_key_value.weight_zero_point
q_dim = (
q_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "olive", "quark"}
else q_size
)
kv_dim = (
kv_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "olive", "quark"}
else kv_size
)
tensor_map["self_attn.q_proj.qzeros"] = tensor[:, :q_dim]
tensor_map["self_attn.k_proj.qzeros"] = tensor[:, q_dim : q_dim + kv_dim]
tensor_map["self_attn.v_proj.qzeros"] = tensor[:, q_dim + kv_dim :]
if quant_type == "olive":
# Olive: (out_features, packed_num_groups) uint8, split on dim=0
q_dim = q_size // (8 // local_bits)
kv_dim = kv_size // (8 // local_bits)
tensor_map["self_attn.q_proj.qzeros"] = tensor[:q_dim, :]
tensor_map["self_attn.k_proj.qzeros"] = tensor[q_dim : q_dim + kv_dim, :]
tensor_map["self_attn.v_proj.qzeros"] = tensor[q_dim + kv_dim :, :]
else:
# AWQ/GPTQ/Quark: int32 packing, split on dim=1
q_dim = (
q_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "quark"}
else q_size
)
kv_dim = (
kv_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "quark"}
else kv_size
)
tensor_map["self_attn.q_proj.qzeros"] = tensor[:, :q_dim]
tensor_map["self_attn.k_proj.qzeros"] = tensor[:, q_dim : q_dim + kv_dim]
tensor_map["self_attn.v_proj.qzeros"] = tensor[:, q_dim + kv_dim :]
elif bool(
re.match(
r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.g_idx$", name
Expand Down Expand Up @@ -434,13 +457,19 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.mlp.dense_h_to_4h.qweight
# model.layers.layer_id.mlp.gate_up_proj.weight
# model.layers.layer_id.mlp.dense_h_to_4h.weight
intermediate_dim = (
intermediate_size // (32 // local_bits)
if quant_type in {"awq", "quark"}
else intermediate_size
)
tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]
if quant_type == "olive":
# Olive: (out_features, in_features), split on dim=0
tensor_map["mlp.gate_proj.qweight"] = tensor[:intermediate_size, :]
tensor_map["mlp.up_proj.qweight"] = tensor[intermediate_size:, :]
else:
# AWQ/GPTQ/Quark: (in_features, out_features), split on dim=1
intermediate_dim = (
intermediate_size // (32 // local_bits)
if quant_type in {"awq", "quark"}
else intermediate_size
)
tensor_map["mlp.gate_proj.qweight"] = tensor[:, :intermediate_dim]
tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim:]
elif bool(
re.match(
r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(scales|weight_scale)$",
Expand All @@ -451,8 +480,14 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.mlp.dense_h_to_4h.scales
# model.layers.layer_id.mlp.gate_up_proj.weight_scale
# model.layers.layer_id.mlp.dense_h_to_4h.weight_scale
tensor_map["mlp.gate_proj.scales"] = tensor[:, :intermediate_size]
tensor_map["mlp.up_proj.scales"] = tensor[:, intermediate_size:]
if quant_type == "olive":
# Olive: (out_features, num_groups), split on dim=0
tensor_map["mlp.gate_proj.scales"] = tensor[:intermediate_size, :]
tensor_map["mlp.up_proj.scales"] = tensor[intermediate_size:, :]
else:
# AWQ/GPTQ/Quark: split on dim=1
tensor_map["mlp.gate_proj.scales"] = tensor[:, :intermediate_size]
tensor_map["mlp.up_proj.scales"] = tensor[:, intermediate_size:]
elif bool(
re.match(
r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(qzeros|weight_zero_point)$",
Expand All @@ -463,13 +498,20 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# model.layers.layer_id.mlp.dense_h_to_4h.qzeros
# model.layers.layer_id.mlp.gate_up_proj.weight_zero_point
# model.layers.layer_id.mlp.dense_h_to_4h.weight_zero_point
intermediate_dim = (
intermediate_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "quark", "olive"}
else intermediate_size
)
tensor_map["mlp.gate_proj.qzeros"] = tensor[:, :intermediate_dim]
tensor_map["mlp.up_proj.qzeros"] = tensor[:, intermediate_dim:]
if quant_type == "olive":
# Olive: (out_features, packed_num_groups) uint8, split on dim=0
intermediate_dim = intermediate_size // (8 // local_bits)
tensor_map["mlp.gate_proj.qzeros"] = tensor[:intermediate_dim, :]
tensor_map["mlp.up_proj.qzeros"] = tensor[intermediate_dim:, :]
else:
# AWQ/GPTQ/Quark: int32 packing, split on dim=1
intermediate_dim = (
intermediate_size // (32 // local_bits)
if quant_type in {"awq", "gptq", "quark"}
else intermediate_size
)
tensor_map["mlp.gate_proj.qzeros"] = tensor[:, :intermediate_dim]
tensor_map["mlp.up_proj.qzeros"] = tensor[:, intermediate_dim:]
elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.g_idx$", name)):
# model.layers.layer_id.mlp.gate_up_proj.g_idx
# model.layers.layer_id.mlp.dense_h_to_4h.g_idx
Expand Down Expand Up @@ -554,10 +596,10 @@ def set_properties(self):
self.lm_head.out_features = self.lm_head.qweight.shape[1]
self.lm_head.in_features = self.lm_head.g_idx.shape[0]
elif self.quant_type == "olive":
self.lm_head.out_features = self.lm_head.qweight.shape[1]
# expects in_features to be divisible by the packing factor (32 // bits)
# not a new assumption since no code here accounts for padded packed weights
self.lm_head.in_features = self.lm_head.qweight.shape[0] * 32 // self.lm_head.bits
# Olive format: qweight is (out_features, packed_in_features) uint8
# packed_in_features = in_features * bits / 8
self.lm_head.out_features = self.lm_head.qweight.shape[0]
self.lm_head.in_features = self.lm_head.qweight.shape[1] * 8 // self.lm_head.bits
else:
raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.")
for module in self.layers:
Expand Down Expand Up @@ -654,32 +696,31 @@ def set_properties(self):
module.mlp.down_proj.in_features = module.mlp.down_proj.g_idx.shape[0]

elif self.quant_type == "olive":
# Set in_features and out_features
module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[1]
module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[0]
module.self_attn.q_proj.in_features = (
module.self_attn.q_proj.qweight.shape[0] * 32 // module.self_attn.q_proj.bits
module.self_attn.q_proj.qweight.shape[1] * 8 // module.self_attn.q_proj.bits
)
module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[1]
module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[0]
module.self_attn.k_proj.in_features = (
module.self_attn.k_proj.qweight.shape[0] * 32 // module.self_attn.k_proj.bits
module.self_attn.k_proj.qweight.shape[1] * 8 // module.self_attn.k_proj.bits
)
module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[1]
module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[0]
module.self_attn.v_proj.in_features = (
module.self_attn.v_proj.qweight.shape[0] * 32 // module.self_attn.v_proj.bits
module.self_attn.v_proj.qweight.shape[1] * 8 // module.self_attn.v_proj.bits
)
module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[1]
module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[0]
module.self_attn.o_proj.in_features = (
module.self_attn.o_proj.qweight.shape[0] * 32 // module.self_attn.o_proj.bits
module.self_attn.o_proj.qweight.shape[1] * 8 // module.self_attn.o_proj.bits
)
module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[1]
module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[0]
module.mlp.gate_proj.in_features = (
module.mlp.gate_proj.qweight.shape[0] * 32 // module.mlp.gate_proj.bits
module.mlp.gate_proj.qweight.shape[1] * 8 // module.mlp.gate_proj.bits
)
module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[1]
module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[0] * 32 // module.mlp.up_proj.bits
module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1]
module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[0]
module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[1] * 8 // module.mlp.up_proj.bits
module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[0]
module.mlp.down_proj.in_features = (
module.mlp.down_proj.qweight.shape[0] * 32 // module.mlp.down_proj.bits
module.mlp.down_proj.qweight.shape[1] * 8 // module.mlp.down_proj.bits
)

else:
Expand Down Expand Up @@ -1138,6 +1179,13 @@ def reverse_reorder_tensor(self, tensor, bits):


class OliveModel(GPTQModel):
"""
Olive quantization format:
- qweight: (out_features, packed_in_features) uint8, packed along last dim
- scales: (out_features, num_groups) float
- qzeros: (out_features, packed_num_groups) uint8, packed along last dim
"""

def _load_quant_config(self, quant_attrs):
super()._load_quant_config(quant_attrs)
self.overrides = quant_attrs["config"]["overrides"] or {}
Expand All @@ -1150,6 +1198,40 @@ def get_layer_group_size(self, layer_name):
name = ".".join(layer_name.split(".")[:-1])
return self.overrides.get(name, {}).get("group_size", self.global_group_size)

def handle_qzeros(self, module):
"""Olive uses unsigned quantization, no offset needed."""
pass

def unpack(self, module):
"""Skip unpack for Olive format."""
pass

def repack(self, module):
"""
Olive format:
- qweight: (out_features, packed_in_features) uint8
- scales: (out_features, num_groups) float
- qzeros: (out_features, packed_num_groups) uint8

ORT format:
- qweight: (out_features, k_blocks, blob_size) uint8
- scales: (out_features * num_groups,) float, flattened
- qzeros: (out_features * packed_num_groups,) uint8, flattened
"""
kpack = 8 // module.bits
k_blocks = module.in_features // module.group_size
blob_size = module.group_size // kpack

# qweight: (out_features, packed_in_features) -> (out_features, k_blocks, blob_size)
module.qweight = module.qweight.reshape(module.out_features, k_blocks, blob_size).contiguous()

# scales: (out_features, num_groups) -> flatten to 1D
module.scales = module.scales.reshape(-1).contiguous()

# qzeros: (out_features, packed_num_groups) -> flatten to 1D
if module.qzeros is not None and module.qzeros.numel() > 0:
module.qzeros = module.qzeros.reshape(-1).contiguous()


class QuantModel:
@staticmethod
Expand Down
Loading