Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,6 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

torch.cuda.empty_cache()

param_dict = dict(model.named_parameters())
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
Expand Down Expand Up @@ -545,6 +543,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})

pack_ratio = getattr(param, "pack_factor", -1)
Expand All @@ -565,6 +565,28 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})

torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)


def dequantize_dq(quant_states: dict) -> None:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from bitsandbytes.functional import dequantize_blockwise
for _, quant_state in quant_states.items():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
Comment on lines +583 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment to explain the purpose of adding quant_state.offset to absmax.

            absmax = dequantize_blockwise(quant_state.absmax,
                                          quant_state.state2)
            absmax += quant_state.offset # Apply offset after dequantization

if absmax.dtype != torch.float32:
absmax = absmax.float()
Comment on lines +587 to +588
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment explaining why this conversion is necessary.

            if absmax.dtype != torch.float32:
                absmax = absmax.float() # Ensure absmax is float32 for consistency

quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None