diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5353096ff7b7..c4036e2f18b0 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -2330,18 +2330,22 @@ def _resolve_quant_algo(self, prefix: str) -> str | None: is not found. """ # 1. Direct lookup - if prefix in self.quantized_layers: - return self.quantized_layers[prefix]["quant_algo"].upper() + for candidate in self._quantized_layer_prefix_candidates(prefix): + if candidate in self.quantized_layers: + return self.quantized_layers[candidate]["quant_algo"].upper() # 2. Packed / fused layer lookup proj_name = prefix.rsplit(".", 1)[-1] if self.packed_modules_mapping and proj_name in self.packed_modules_mapping: algos: set[str] = set() base = prefix.rsplit(".", 1)[0] - for shard_name in self.packed_modules_mapping[proj_name]: - shard_prefix = f"{base}.{shard_name}" - if shard_prefix in self.quantized_layers: - algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper()) + for base_candidate in self._quantized_layer_prefix_candidates(base): + for shard_name in self.packed_modules_mapping[proj_name]: + shard_prefix = f"{base_candidate}.{shard_name}" + if shard_prefix in self.quantized_layers: + algos.add( + self.quantized_layers[shard_prefix]["quant_algo"].upper() + ) if len(algos) == 1: return algos.pop() if len(algos) > 1: @@ -2351,13 +2355,32 @@ def _resolve_quant_algo(self, prefix: str) -> str | None: ) # 3. Prefix-based lookup (for RoutedExperts / parent modules) - prefix_dot = prefix + "." - for key, info in self.quantized_layers.items(): - if key.startswith(prefix_dot): - return info["quant_algo"].upper() + for candidate in self._quantized_layer_prefix_candidates(prefix): + prefix_dot = candidate + "." + for key, info in self.quantized_layers.items(): + if key.startswith(prefix_dot): + return info["quant_algo"].upper() return None + @staticmethod + def _quantized_layer_prefix_candidates(prefix: str) -> tuple[str, ...]: + candidates = [prefix] + + if prefix.endswith(".lm_head"): + candidates.append("lm_head") + + if prefix.startswith("language_model.model."): + candidates.append( + "model.language_model." + prefix[len("language_model.model.") :] + ) + elif prefix.startswith("model.language_model."): + candidates.append( + "language_model.model." + prefix[len("model.language_model.") :] + ) + + return tuple(dict.fromkeys(candidates)) + def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None":