Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
Expand Down Expand Up @@ -635,6 +638,24 @@ def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head():
)


def test_find_matched_target_returns_none_on_no_match():
result = find_matched_target(
layer_name="model.layers.0.self_attn.qkv_proj",
module=Mock(spec=torch.nn.Linear),
targets=["no_match_target"],
)
assert result is None


def test_get_scheme_dict_returns_none_on_no_match():
config = _make_ct_config(target="matched_layer")
result = config.get_scheme_dict(
layer=Mock(spec=torch.nn.Linear),
layer_name="model.layers.0.unmatched_layer",
)
assert result is None


@pytest.mark.skipif(
not current_platform.is_cuda() or not current_platform.has_device_capability(75),
reason="MXFP8 requires Turing (sm_75+) or newer.",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, cast

Expand Down Expand Up @@ -747,13 +746,13 @@ def get_scheme(
self.sparsity_ignore_list
)
sparsity_scheme: SparsityCompressionConfig | None = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping,
)
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping,
)
if matched_target is not None:
sparsity_scheme = self.sparsity_scheme_map[matched_target]

if self.supports_cutlass_24(
Expand Down Expand Up @@ -821,10 +820,11 @@ def get_scheme_dict(
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
scheme_dict = self.target_scheme_map[matched_target]
if scheme_dict.get("format") is None:
scheme_dict["format"] = self.quant_format
return scheme_dict
if matched_target is not None:
scheme_dict = self.target_scheme_map[matched_target]
if scheme_dict.get("format") is None:
scheme_dict["format"] = self.quant_format
return scheme_dict
Comment thread
kylesayrs marked this conversation as resolved.

return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def find_matched_target(
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> str:
) -> str | None:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Expand Down Expand Up @@ -150,12 +150,6 @@ def find_matched_target(
or _match_fused_layer(layer_name, targets, fused_mapping)
)

if matched_target is None:
raise ValueError(
f"Unable to find matching target for {layer_name} in the "
"compressed-tensors config."
)

return matched_target


Expand Down
Loading