Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def rms_norm_dynamic_per_token_quant(
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
if envs.VLLM_USE_TRITON_AWQ or qweight.dtype != torch.float16:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros)
Expand Down
26 changes: 20 additions & 6 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
Expand Down Expand Up @@ -227,8 +229,9 @@ def is_layer_fp8(layer: LinearBase) -> bool:
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))

def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
return isinstance(layer.quant_method,
(UnquantizedLinearMethod,
AWQMarlinLinearMethod)) or is_layer_fp8(layer)

# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
Expand Down Expand Up @@ -289,19 +292,30 @@ def get_and_maybe_dequant_weights(layer: LinearBase):

return scaled_dequantize(weight, scales,
weight_scale_group_shape)
elif isinstance(layer.quant_method, AWQMarlinLinearMethod):
return layer.quant_method.decompress_weights(layer).T
else:
return layer.weight

if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
"Only FP8, AWQ, and Unquantized are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
def get_layer_dtype(layer):
if hasattr(layer, "weight"):
return layer.weight.dtype
elif hasattr(layer, "qweight"):
return layer.qweight.dtype
else:
raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight")

weight_dtype = get_layer_dtype(self.kv_b_proj)
assert get_layer_dtype(self.o_proj) == weight_dtype
assert get_layer_dtype(self.q_proj) == weight_dtype

kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def use_mla(self) -> bool:
return False

if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
"fp8", "compressed-tensors", "awq_marlin", "moe_wna16"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def create_weights(
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups

def decompress_weights(self, layer: torch.nn.Module) -> torch.Tensor:
"""
Decompress to recover the original unquantized weight.
NOTE: this is only to be used before process_weights_after_loading
"""
# We can use AWQ's dequant since the unprocessed weights
# are in AWQ format
return ops.awq_dequantize(layer.qweight, layer.scales, layer.qzeros, 0,
0, 0)

# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
Expand Down
90 changes: 34 additions & 56 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ def _initialize_model(
return model_class(**kwargs)


def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
# Currently only used by MLA.
# NOTE: This intentionally happens before other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)


class BaseModelLoader(ABC):
"""Base class for model loaders."""

Expand Down Expand Up @@ -376,7 +400,6 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
Expand All @@ -394,23 +417,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)

return model.eval()


Expand All @@ -429,29 +437,15 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()


Expand Down Expand Up @@ -632,6 +626,7 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
Expand All @@ -640,18 +635,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config.revision)

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
_process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
Expand Down Expand Up @@ -1401,16 +1388,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
self._get_weights_iterator(model_weights,
model_config.revision))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()


Expand Down