Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,18 +2330,22 @@ def _resolve_quant_algo(self, prefix: str) -> str | None:
is not found.
"""
# 1. Direct lookup
if prefix in self.quantized_layers:
return self.quantized_layers[prefix]["quant_algo"].upper()
for candidate in self._quantized_layer_prefix_candidates(prefix):
if candidate in self.quantized_layers:
return self.quantized_layers[candidate]["quant_algo"].upper()

# 2. Packed / fused layer lookup
proj_name = prefix.rsplit(".", 1)[-1]
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
algos: set[str] = set()
base = prefix.rsplit(".", 1)[0]
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
for base_candidate in self._quantized_layer_prefix_candidates(base):
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base_candidate}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(
self.quantized_layers[shard_prefix]["quant_algo"].upper()
)
if len(algos) == 1:
return algos.pop()
if len(algos) > 1:
Expand All @@ -2351,13 +2355,32 @@ def _resolve_quant_algo(self, prefix: str) -> str | None:
)

# 3. Prefix-based lookup (for RoutedExperts / parent modules)
prefix_dot = prefix + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()
for candidate in self._quantized_layer_prefix_candidates(prefix):
prefix_dot = candidate + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()

return None

@staticmethod
def _quantized_layer_prefix_candidates(prefix: str) -> tuple[str, ...]:
candidates = [prefix]

if prefix.endswith(".lm_head"):
candidates.append("lm_head")

if prefix.startswith("language_model.model."):
candidates.append(
"model.language_model." + prefix[len("language_model.model.") :]
)
elif prefix.startswith("model.language_model."):
candidates.append(
"language_model.model." + prefix[len("model.language_model.") :]
)

return tuple(dict.fromkeys(candidates))
Comment thread
meenchen marked this conversation as resolved.

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
Expand Down
Loading