diff --git a/tests/quantization/test_fp8_scale_parameter.py b/tests/quantization/test_fp8_scale_parameter.py new file mode 100644 index 000000000000..c95cdbb98be2 --- /dev/null +++ b/tests/quantization/test_fp8_scale_parameter.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.model_executor.parameter as parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + create_fp8_scale_parameter, +) +from vllm.model_executor.parameter import BlockQuantScaleParameter + + +@pytest.mark.skipif( + not hasattr(torch, "float8_e8m0fnu"), + reason="torch does not expose float8_e8m0fnu", +) +def test_create_fp8_scale_parameter_initializes_e8m0(monkeypatch): + monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 0) + monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 1) + + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes=[128], + input_size_per_partition=128, + block_size=[128, 128], + weight_loader=None, + scale_dtype=torch.float8_e8m0fnu, + ) + + assert scale.dtype == torch.float8_e8m0fnu + raw_scale = scale.data.view(torch.uint8) + assert torch.equal(raw_scale, torch.zeros_like(raw_scale)) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c6473c406c92..6826c3eba891 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1231,6 +1231,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale