Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

import pytest
import torch
from transformers import BitsAndBytesConfig

from tests.quantization.utils import is_quant_method_supported

from ..models.utils import check_embeddings_close
from ..utils import compare_two_settings, create_new_process_for_each_test

models_4bit_to_test = [
Expand All @@ -19,6 +21,10 @@
"quantize inflight model with both HF and Mistral format weights")
]

models_4bit_to_embedding_test = [
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]

models_pre_qaunt_4bit_to_test = [
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
Expand All @@ -31,6 +37,12 @@
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]

models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8',
'read pre-quantized llama 8-bit model'),
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
Expand All @@ -39,7 +51,8 @@
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to avoid the warning below

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead

validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, False, hf_model_kwargs)

Expand Down Expand Up @@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
validate_generated_texts(hf_runner,
vllm_runner,
example_prompts[:1],
Expand Down Expand Up @@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings(model_name, common_args, pp_args)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_4bit_to_embedding_test)
@pytest.mark.parametrize("dtype", ["half"])
@create_new_process_for_each_test()
def test_4bit_bnb_embedding_model(
model_name,
description,
hf_runner,
vllm_runner,
example_prompts,
dtype: str,
) -> None:

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]

# Inflight 4bit quantization
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
load_in_4bit=True))
with hf_runner(
model_name,
dtype=dtype,
model_kwargs=hf_model_kwargs,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

with vllm_runner(model_name,
task="embed",
dtype=dtype,
quantization="bitsandbytes") as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=5e-2,
)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
Expand Down
16 changes: 15 additions & 1 deletion vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -133,6 +134,16 @@ def _prepare_weights(self, model_name_or_path: str,
return hf_weights_files, use_safetensors

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name:str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
return "model."+module_name

return module_name

if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
Expand All @@ -148,6 +159,9 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name=_maybe_pool_model(mapped_name)


yield org_name, mapped_name, param

def _get_quantized_weights_iterator(
Expand Down Expand Up @@ -405,7 +419,7 @@ def _load_weights(self, model_config: ModelConfig,
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")

self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))

Expand Down