diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a68235016767..73f41c3f317e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -215,7 +215,7 @@ def rms_norm_dynamic_per_token_quant( def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: + if envs.VLLM_USE_TRITON_AWQ or qweight.dtype != torch.float16: from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton) return awq_dequantize_triton(qweight, scales, zeros) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a41140ec8378..5a9b4b03b5b2 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -227,8 +229,9 @@ def is_layer_fp8(layer: LinearBase) -> bool: and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) def quantization_scheme_supported(layer: LinearBase) -> bool: - return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ - is_layer_fp8(layer) + return isinstance(layer.quant_method, + (UnquantizedLinearMethod, + AWQMarlinLinearMethod)) or is_layer_fp8(layer) # TODO(lucas) This is very gross, we need a more wide scale refactor of # all the FP8 code with a more standard way of @@ -289,6 +292,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase): return scaled_dequantize(weight, scales, weight_scale_group_shape) + elif isinstance(layer.quant_method, AWQMarlinLinearMethod): + return layer.quant_method.decompress_weights(layer).T else: return layer.weight @@ -296,12 +301,21 @@ def get_and_maybe_dequant_weights(layer: LinearBase): quantization_scheme_supported(self.q_proj) and\ quantization_scheme_supported(self.o_proj)): raise NotImplementedError( - "Only FP8 and UnquantizedLinearMethod are supported for MLA" + "Only FP8, AWQ, and Unquantized are supported for MLA" ", please run with VLLM_MLA_DISABLE=1") - weight_dtype = self.kv_b_proj.weight.dtype - assert self.o_proj.weight.dtype == weight_dtype - assert self.q_proj.weight.dtype == weight_dtype + def get_layer_dtype(layer): + if hasattr(layer, "weight"): + return layer.weight.dtype + elif hasattr(layer, "qweight"): + return layer.qweight.dtype + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") + + weight_dtype = get_layer_dtype(self.kv_b_proj) + assert get_layer_dtype(self.o_proj) == weight_dtype + assert get_layer_dtype(self.q_proj) == weight_dtype kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( diff --git a/vllm/config.py b/vllm/config.py index 1740871e7c10..29a5b2014953 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -990,7 +990,7 @@ def use_mla(self) -> bool: return False if self.quantization is not None and self.quantization not in [\ - "fp8", "compressed-tensors"]: + "fp8", "compressed-tensors", "awq_marlin", "moe_wna16"]: logger.warning( "MLA is not supported with %s quantization. " "Disabling MLA.", self.quantization) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index a43b2e597c1e..d687765313ab 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -242,6 +242,16 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition layer.num_groups = num_groups + def decompress_weights(self, layer: torch.nn.Module) -> torch.Tensor: + """ + Decompress to recover the original unquantized weight. + NOTE: this is only to be used before process_weights_after_loading + """ + # We can use AWQ's dequant since the unprocessed weights + # are in AWQ format + return ops.awq_dequantize(layer.qweight, layer.scales, layer.qzeros, 0, + 0, 0) + # TODO: Update this docs # Checkpoints are serialized in AutoAWQ format, which is different from the # marlin format. This function is called after the weights are loaded. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2a2c2523b725..47aa35ed25b0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -153,6 +153,30 @@ def _initialize_model( return model_class(**kwargs) +def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + # Currently only used by MLA. + # NOTE: This intentionally happens before other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -376,7 +400,6 @@ def download_model(self, model_config: ModelConfig) -> None: def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: @@ -394,23 +417,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: "Following weights were not initialized from " f"checkpoint: {weights_not_loaded}") - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - # TODO(lucas): see if there is a way to unify the signatures - # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) + return model.eval() @@ -429,29 +437,15 @@ def download_model(self, model_config: ModelConfig) -> None: def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context( - module, torch.device(device_config.device)): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -632,6 +626,7 @@ def download_model(self, model_config: ModelConfig) -> None: def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank @@ -640,18 +635,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: model_config.revision) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading( - model_config.dtype) + _process_weights_after_loading(model, model_config, + target_device) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, @@ -1401,16 +1388,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: self._get_weights_iterator(model_weights, model_config.revision)) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) return model.eval()