Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

import torch

from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
with vllm_runner(model_path, quantization="sparseml",
enforce_eager=True) as llm:
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

Expand All @@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.input_scale.dtype is torch.float32


def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path) as llm:
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output


def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
with vllm_runner(model_path,
quantization="sparseml",
enforce_eager=True,
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with vllm_runner(model_path, enforce_eager=True,
dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand Down
8 changes: 2 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,8 @@ def _verify_embedding_mode(self) -> None:
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# SparseML uses a "compression_config" with a "quantization_config".
compression_cfg = getattr(self.hf_config, "compression_config",
None)
if compression_cfg is not None:
quant_cfg = compression_cfg.get("quantization_config", None)

# compress-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg

def _verify_quantization(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"sparseml": CompressedTensorsConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
}

Expand Down
9 changes: 3 additions & 6 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig,
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is None:
compression_config = getattr(model_config.hf_config,
"compression_config", None)
if compression_config is not None:
hf_quant_config = compression_config.get("quantization_config",
None)

# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
Expand Down