Skip to content

Commit f19da64

Browse files
authored
[Core] Refactor GGUF parameters packing and forwarding (#8859)
1 parent 4f95ffe commit f19da64

File tree

4 files changed

+64
-62
lines changed

4 files changed

+64
-62
lines changed

tests/models/decoder_only/language/test_gguf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
# FIXME: Move this to confest
2121
MODELS = [
22-
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
23-
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
24-
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
25-
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26-
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
27-
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
22+
("meta-llama/Llama-3.2-1B-Instruct",
23+
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
24+
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
25+
("meta-llama/Llama-3.2-1B-Instruct",
26+
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
27+
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
2828
("Qwen/Qwen2-1.5B-Instruct",
2929
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
3030
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),

vllm/model_executor/layers/linear.py

+32-44
Original file line numberDiff line numberDiff line change
@@ -440,17 +440,23 @@ def weight_loader(self,
440440
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
441441
return
442442

443-
if is_gguf_weight and isinstance(param, UninitializedParameter):
444-
from gguf.constants import GGML_QUANT_SIZES
443+
if is_gguf_weight:
444+
tp_size = get_tensor_model_parallel_world_size()
445+
tp_rank = get_tensor_model_parallel_rank()
446+
447+
output_dim = getattr(param, "output_dim", None)
448+
shard_size = loaded_weight.size(output_dim) // tp_size
449+
start_idx = tp_rank * shard_size
450+
451+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
452+
shard_size)
445453

446-
ori_shape = param.tensor_shape
447-
weight_types = self.qweight_type.shard_weight_type.values()
448-
row_size = []
449-
for weight_type in weight_types:
450-
block_size, type_size = GGML_QUANT_SIZES[weight_type]
451-
row_size.append(ori_shape[1] // block_size * type_size)
452-
q_shape = (ori_shape[0], max(row_size))
453-
param.materialize(q_shape, dtype=loaded_weight.dtype)
454+
param.shard_id.append(loaded_shard_id)
455+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
456+
param.data_container.append(loaded_weight)
457+
if len(param.data_container) == 2:
458+
self.qweight = param.materialize_nested()
459+
return
454460

455461
param_data = param.data
456462
output_dim = getattr(param, "output_dim", None)
@@ -515,18 +521,6 @@ def weight_loader(self,
515521
shard_offset = loaded_weight.shape[output_dim] * \
516522
loaded_shard_id
517523

518-
if is_gguf_weight:
519-
tp_size = get_tensor_model_parallel_world_size()
520-
output_dim = getattr(param, "output_dim", None)
521-
shard_shape = list(loaded_weight.shape)
522-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
523-
param.shard_id.append(loaded_shard_id)
524-
param.shard_size[loaded_shard_id] = shard_shape
525-
526-
input_dim = getattr(param, "input_dim", None)
527-
input_size = loaded_weight.shape[input_dim]
528-
param_data = param_data.narrow(input_dim, 0, input_size)
529-
530524
param_data = param_data.narrow(output_dim, shard_offset,
531525
shard_size)
532526
start_idx = tp_rank * shard_size
@@ -783,17 +777,23 @@ def weight_loader(self,
783777
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
784778
return
785779

786-
if is_gguf_weight and isinstance(param, UninitializedParameter):
787-
from gguf.constants import GGML_QUANT_SIZES
780+
if is_gguf_weight:
781+
tp_size = get_tensor_model_parallel_world_size()
782+
tp_rank = get_tensor_model_parallel_rank()
788783

789-
ori_shape = param.tensor_shape
790-
weight_types = self.qweight_type.shard_weight_type.values()
791-
row_size = []
792-
for weight_type in weight_types:
793-
block_size, type_size = GGML_QUANT_SIZES[weight_type]
794-
row_size.append(ori_shape[1] // block_size * type_size)
795-
q_shape = (ori_shape[0], max(row_size))
796-
param.materialize(q_shape, dtype=loaded_weight.dtype)
784+
output_dim = getattr(param, "output_dim", None)
785+
shard_size = loaded_weight.size(output_dim) // tp_size
786+
start_idx = tp_rank * shard_size
787+
788+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
789+
shard_size)
790+
791+
param.shard_id.append(loaded_shard_id)
792+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
793+
param.data_container.append(loaded_weight)
794+
if len(param.data_container) == 3:
795+
self.qweight = param.materialize_nested()
796+
return
797797

798798
param_data = param.data
799799
output_dim = getattr(param, "output_dim", None)
@@ -883,18 +883,6 @@ def weight_loader(self,
883883
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
884884
param, orig_qkv_offsets, loaded_shard_id)
885885

886-
if is_gguf_weight:
887-
tp_size = get_tensor_model_parallel_world_size()
888-
output_dim = getattr(param, "output_dim", None)
889-
shard_shape = list(loaded_weight.shape)
890-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
891-
param.shard_id.append(loaded_shard_id)
892-
param.shard_size[loaded_shard_id] = shard_shape
893-
894-
input_dim = getattr(param, "input_dim", None)
895-
input_size = loaded_weight.shape[input_dim]
896-
param_data = param_data.narrow(input_dim, 0, input_size)
897-
898886
param_data = param_data.narrow(output_dim, shard_offset,
899887
shard_size)
900888
if loaded_shard_id == "q":

vllm/model_executor/layers/quantization/gguf.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,16 @@ def create_weights(self, layer: torch.nn.Module,
8686
output_size_per_partition = sum(output_partition_sizes)
8787

8888
tensor_shape = (output_size_per_partition, input_size_per_partition)
89-
qweight = UninitializedParameter(requires_grad=False)
89+
qweight = GGUFUninitializedParameter(requires_grad=False)
9090
set_weight_attrs(
9191
qweight, {
9292
"input_dim": 1,
9393
"output_dim": 0,
9494
"tensor_shape": tensor_shape,
9595
"is_gguf_weight": True,
96-
"shard_size": {},
96+
"data_container": [],
9797
"shard_id": [],
98+
"shard_id_map": {},
9899
})
99100
set_weight_attrs(qweight, extra_weight_attrs)
100101
layer.register_parameter("qweight", qweight)
@@ -116,21 +117,17 @@ def apply(self,
116117
layer: torch.nn.Module,
117118
x: torch.Tensor,
118119
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
119-
shard_size = getattr(layer.qweight, "shard_size", None)
120120
shard_id = getattr(layer.qweight, "shard_id", None)
121121

122-
if shard_id and shard_size:
123-
result = []
124-
offset = 0
122+
if shard_id:
125123
# dequantize shard weights respectively
126124
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
125+
qweight = layer.qweight.unbind(0)
126+
result = []
127127
for id in shard_id:
128-
shard_weight = layer.qweight[
129-
offset:offset +
130-
shard_size[id][0], :shard_size[id][1]].contiguous()
128+
q_idx = layer.qweight.shard_id_map[id]
131129
qweight_type = layer.qweight_type.shard_weight_type[id]
132-
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
133-
offset += shard_size[id][0]
130+
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
134131
out = torch.cat(result, axis=1)
135132
else:
136133
qweight = layer.qweight
@@ -162,3 +159,20 @@ def embedding(self, layer: torch.nn.Module,
162159
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
163160
x_flat.shape[0])
164161
return dequant.view(*x.shape, hidden_size)
162+
163+
164+
class GGUFUninitializedParameter(UninitializedParameter):
165+
cls_to_become = Parameter
166+
data_container: List[torch.Tensor]
167+
168+
def materialize_nested(self) -> Parameter:
169+
nested_data = torch.nested.nested_tensor(self.data_container,
170+
device=self.device,
171+
dtype=torch.uint8)
172+
self.data_container.clear()
173+
param = torch.Tensor._make_subclass(self.cls_to_become,
174+
nested_data,
175+
require_grad=False)
176+
for k, v in self.__dict__.items():
177+
setattr(param, k, v)
178+
return param

vllm/model_executor/models/llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def __init__(
512512
quant_config=quant_config,
513513
)
514514
if config.tie_word_embeddings:
515-
self.lm_head.weight = self.model.embed_tokens.weight
515+
self.lm_head = self.model.embed_tokens
516516

517517
logit_scale = getattr(config, "logit_scale", 1.0)
518518
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,

0 commit comments

Comments
 (0)