Skip to content

Commit

Permalink
support bitsandbytes 8-bit and FP4 quantized models (vllm-project#7445)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqianfzh authored Aug 29, 2024
1 parent 257afc3 commit 4664cea
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 191 deletions.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,14 @@ class HfRunner:

def wrap_device(self, input: _T) -> _T:
if not is_cpu():
# Check if the input is already on the GPU
if hasattr(input, 'device') and input.device.type == "cuda":
return input # Already on GPU, no need to move
return input.to("cuda")
else:
# Check if the input is already on the CPU
if hasattr(input, 'device') and input.device.type == "cpu":
return input # Already on CPU, no need to move
return input.to("cpu")

def __init__(
Expand Down
166 changes: 98 additions & 68 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,115 @@
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''

import gc

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams

models_to_test = [
models_4bit_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]

models_pre_qaunt_4bit_to_test = [
('lllyasviel/omost-llama-3-8b-4bits',
'read pre-quantized 4-bit NF4 model'),
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
]

models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
]


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name, hf_model_kwargs)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
models_pre_qaunt_4bit_to_test)
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
@pytest.mark.parametrize("model_name, description",
models_pre_quant_8bit_to_test)
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
model_name)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
log_entry = {
"prompt": prompts[i],
"runner_name": runner_name,
"generated_text": generated_text,
}
logged_texts.append(log_entry)
return logged_texts


def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):

if hf_model_kwargs is None:
hf_model_kwargs = {}

# Run with HF runner
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
hf_outputs = llm.generate_greedy(prompts, 8)
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501

# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].mlp.down_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.o_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')

qweight = model.model.layers[0].self_attn.qkv_proj.qweight
assert qweight.dtype == torch.uint8, (
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')

# some weights should not be quantized
weight = model.lm_head.weight
assert weight.dtype != torch.uint8, (
'lm_head weight dtype should not be torch.uint8')

weight = model.model.embed_tokens.weight
assert weight.dtype != torch.uint8, (
'embed_tokens weight dtype should not be torch.uint8')

weight = model.model.layers[0].input_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

weight = model.model.layers[0].post_attention_layernorm.weight
assert weight.dtype != torch.uint8, (
'input_layernorm weight dtype should not be torch.uint8')

# check the output of the model is expected
sampling_params = SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=8)

prompts = ['That which does not kill us', 'To be or not to be,']
expected_outputs = [
'That which does not kill us makes us stronger.',
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)

for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]

assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]

assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
enforce_eager=True,
gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

# Compare the generated strings
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
hf_str = hf_log["generated_text"]
vllm_str = vllm_log["generated_text"]
prompt = hf_log["prompt"]
assert hf_str == vllm_str, (f"Model: {model_name}"
f"Mismatch between HF and vLLM outputs:\n"
f"Prompt: {prompt}\n"
f"HF Output: '{hf_str}'\n"
f"vLLM Output: '{vllm_str}'")
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def verify_with_parallel_config(
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode.")
Expand Down
18 changes: 10 additions & 8 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

total, _ = qkv_offsets["total"]
Expand Down Expand Up @@ -505,8 +505,9 @@ def weight_loader(self,
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
Expand Down Expand Up @@ -858,8 +859,9 @@ def weight_loader(self,
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
Expand All @@ -871,7 +873,7 @@ def weight_loader(self,
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)

if is_gguf_weight:
Expand Down
Loading

0 comments on commit 4664cea

Please sign in to comment.