[Draft] Migrate bitsandbytes support to OOT plugin#43529
Conversation
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
|
Documentation preview: https://vllm--43529.org.readthedocs.build/en/43529/ |
There was a problem hiding this comment.
Code Review
This pull request migrates BitsAndBytes support to an out-of-tree plugin, removing hardcoded BNB logic from the core vLLM codebase. It introduces generic hooks for quantization configurations, model loaders, and weight sharding to support this plugin architecture. Feedback focuses on performance optimizations in vllm/model_executor/layers/linear.py, specifically recommending that the calculation of shard offsets and indices be guarded by a check for the presence of a shard_indexer to avoid unnecessary overhead during standard model loading.
| index = list(itertools.accumulate([0] + self.output_sizes)) | ||
| orig_offsets = { | ||
| str(i): (index[i], size) for i, size in enumerate(self.output_sizes) | ||
| } | ||
| orig_offsets["total"] = (self.output_size, 0) | ||
| shard_size, shard_offset = adjust_shard_indexes( | ||
| param, orig_offsets, str(shard_id), shard_size, shard_offset | ||
| ) |
There was a problem hiding this comment.
The calculation of index and orig_offsets is performed inside the shard loop for every parameter. This introduces unnecessary overhead during model loading for all models using merged linear layers, even when no custom quantization plugin is used. These calculations should be guarded by a check for shard_indexer to avoid performance regressions in standard model loading.
if getattr(param, "shard_indexer", None) is not None:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_shard_indexes(
param, orig_offsets, str(shard_id), shard_size, shard_offset
)| index = list(itertools.accumulate([0] + self.output_sizes)) | ||
| orig_offsets = { | ||
| str(i): (index[i], size) for i, size in enumerate(self.output_sizes) | ||
| } | ||
| orig_offsets["total"] = (self.output_size, 0) | ||
| shard_size, shard_offset = adjust_shard_indexes( | ||
| param, orig_offsets, str(loaded_shard_id), shard_size, shard_offset | ||
| ) |
There was a problem hiding this comment.
The orig_offsets dictionary is constructed for every parameter in the weight_loader, which is unnecessary for standard models that do not utilize a custom shard_indexer. Adding a guard check for the indexer will prevent this overhead during model initialization.
| index = list(itertools.accumulate([0] + self.output_sizes)) | |
| orig_offsets = { | |
| str(i): (index[i], size) for i, size in enumerate(self.output_sizes) | |
| } | |
| orig_offsets["total"] = (self.output_size, 0) | |
| shard_size, shard_offset = adjust_shard_indexes( | |
| param, orig_offsets, str(loaded_shard_id), shard_size, shard_offset | |
| ) | |
| if getattr(param, "shard_indexer", None) is not None: | |
| index = list(itertools.accumulate([0] + self.output_sizes)) | |
| orig_offsets = { | |
| str(i): (index[i], size) for i, size in enumerate(self.output_sizes) | |
| } | |
| orig_offsets["total"] = (self.output_size, 0) | |
| shard_size, shard_offset = adjust_shard_indexes( | |
| param, orig_offsets, str(loaded_shard_id), shard_size, shard_offset | |
| ) |
| orig_qkv_offsets = { | ||
| "q": (0, self.total_num_heads * self.head_size), | ||
| "k": ( | ||
| self.total_num_heads * self.head_size, | ||
| self.total_num_kv_heads * self.head_size, | ||
| ), | ||
| "v": ( | ||
| (self.total_num_heads + self.total_num_kv_heads) | ||
| * self.head_size, | ||
| self.total_num_kv_heads * self.v_head_size, | ||
| ), | ||
| "total": ( | ||
| (self.total_num_heads + self.total_num_kv_heads) | ||
| * self.head_size | ||
| + self.total_num_kv_heads * self.v_head_size, | ||
| 0, | ||
| ), | ||
| } | ||
| shard_size, shard_offset = adjust_shard_indexes( | ||
| param, orig_qkv_offsets, shard_id, shard_size, shard_offset | ||
| ) |
There was a problem hiding this comment.
In QKVParallelLinear.weight_loader, the orig_qkv_offsets dictionary is recalculated inside the shard loop for every parameter. This results in redundant computations during model loading for all models. This block should be guarded to only execute when a shard_indexer is present on the parameter.
if getattr(param, "shard_indexer", None) is not None:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
"total": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0,
),
}
shard_size, shard_offset = adjust_shard_indexes(
param, orig_qkv_offsets, shard_id, shard_size, shard_offset
)| orig_qkv_offsets = { | ||
| "q": (0, self.num_heads * self.head_size), | ||
| "k": ( | ||
| self.num_heads * self.head_size, | ||
| self.num_kv_heads * self.head_size, | ||
| ), | ||
| "v": ( | ||
| (self.num_heads + self.num_kv_heads) * self.head_size, | ||
| self.num_kv_heads * self.v_head_size, | ||
| ), | ||
| "total": ( | ||
| (self.num_heads + self.num_kv_heads) * self.head_size | ||
| + self.num_kv_heads * self.v_head_size, | ||
| 0, | ||
| ), | ||
| } | ||
| shard_size, shard_offset = adjust_shard_indexes( | ||
| param, orig_qkv_offsets, loaded_shard_id, shard_size, shard_offset | ||
| ) |
There was a problem hiding this comment.
Constructing orig_qkv_offsets for every parameter in the QKV weight loader is wasteful for non-quantized models. A guard check for shard_indexer should be added to maintain optimal loading performance.
| orig_qkv_offsets = { | |
| "q": (0, self.num_heads * self.head_size), | |
| "k": ( | |
| self.num_heads * self.head_size, | |
| self.num_kv_heads * self.head_size, | |
| ), | |
| "v": ( | |
| (self.num_heads + self.num_kv_heads) * self.head_size, | |
| self.num_kv_heads * self.v_head_size, | |
| ), | |
| "total": ( | |
| (self.num_heads + self.num_kv_heads) * self.head_size | |
| + self.num_kv_heads * self.v_head_size, | |
| 0, | |
| ), | |
| } | |
| shard_size, shard_offset = adjust_shard_indexes( | |
| param, orig_qkv_offsets, loaded_shard_id, shard_size, shard_offset | |
| ) | |
| if getattr(param, "shard_indexer", None) is not None: | |
| orig_qkv_offsets = { | |
| "q": (0, self.num_heads * self.head_size), | |
| "k": ( | |
| self.num_heads * self.head_size, | |
| self.num_kv_heads * self.head_size, | |
| ), | |
| "v": ( | |
| (self.num_heads + self.num_kv_heads) * self.head_size, | |
| self.num_kv_heads * self.v_head_size, | |
| ), | |
| "total": ( | |
| (self.num_heads + self.num_kv_heads) * self.head_size | |
| + self.num_kv_heads * self.v_head_size, | |
| 0, | |
| ), | |
| } | |
| shard_size, shard_offset = adjust_shard_indexes( | |
| param, orig_qkv_offsets, loaded_shard_id, shard_size, shard_offset | |
| ) |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.