Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def load_model(

# Process weights into kernel format. Note that when using online
# quantization, weights are (typically) quantized as they are loaded.
if _has_online_quant(model):
finalize_layerwise_processing(model, model_config)

finalize_layerwise_processing(model, model_config)
process_weights_after_loading(model, model_config, target_device)

return model.eval()
Expand All @@ -87,12 +85,3 @@ def log_model_inspection(model: nn.Module) -> None:
from vllm.model_inspection import format_model_inspection

logger.info("vLLM model structure:\n%s", format_model_inspection(model))


def _has_online_quant(model: nn.Module):
for module in model.modules():
quant_method = getattr(module, "quant_method", None)
if getattr(quant_method, "uses_meta_device", False):
return True

return False
78 changes: 39 additions & 39 deletions vllm/model_executor/model_loader/reload/layerwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from weakref import WeakKeyDictionary

import torch
from compressed_tensors import deprecated

from vllm.config import ModelConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -42,6 +41,7 @@
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
ATTENTION_LAYERS = (Attention, MLAAttention)


def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
Expand Down Expand Up @@ -162,7 +162,7 @@ def online_process_loader(*args, **kwargs):

# Process and copy when all weights are loaded
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
layer, (Attention, MLAAttention)
layer, ATTENTION_LAYERS
):
_layerwise_process(layer, info)

Expand All @@ -186,56 +186,47 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
if hasattr(model, "_original_do_torchao_reload"):
model._do_torchao_reload = model._original_do_torchao_reload

# Catch non-attention layers which did not process during loading
for layer in model.modules():
info = get_layerwise_info(layer)
if not info.can_load():
info.reset()
continue

# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)

elif info.kernel_tensors is None:
raise NotImplementedError(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
if info.can_load() and not isinstance(layer, ATTENTION_LAYERS):
# reloading: place kernel tensors back as a smart fallback
if info.load_numel <= 0 and info.kernel_tensors is not None:
logger.warning("%s: Failed to reload", layer.__class__.__name__)
_place_kernel_tensors(layer, info)

else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)

# No weights were loaded
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
if info.kernel_tensors is None:
materialize_layer(layer)
info.reset()

# reloading: place kernel tensors back as a fallback
else:
logger.warning("%s: Failed to load weights", layer.__class__.__name__)
# Intentionally delay processing for Attention/MLA layers
for layer in model.modules():
info = get_layerwise_info(layer)

if info.can_load() and isinstance(layer, ATTENTION_LAYERS):
# reloading: place kernel tensors back as a smart fallback
# unlike non-attention layers, attention scales are typically not loaded
if info.load_numel <= 0 and info.kernel_tensors is not None:
_place_kernel_tensors(layer, info)

# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to excess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
else:
_layerwise_process(layer, info)

info.reset()
info.reset()


@deprecated("finalize_layerwise_processing")
def finalize_layerwise_reload(*args, **kwargs):
finalize_layerwise_processing(*args, **kwargs)


def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
def _layerwise_process(
layer: torch.nn.Module,
info: LayerReloadingInfo,
model_config: ModelConfig | None = None,
):
"""
Finalize layer loading after all weights have been buffered.

Expand Down Expand Up @@ -265,9 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):

# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
_process_layer(layer, model_config)

# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
Expand All @@ -284,6 +273,17 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
logger.debug("%s: Processed", layer.__class__.__name__)


def _process_layer(layer: torch.nn.Module, model_config: ModelConfig | None = None):
if not isinstance(layer, ATTENTION_LAYERS):
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)

else:
assert model_config is not None, "Must pass model_config to process attention"
layer.process_weights_after_loading(model_config.dtype)


def _get_original_loader(tensor: torch.Tensor) -> Callable:
"""Return the weight loader with any layerwise wrappers removed"""
loader = _get_weight_loader(tensor)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/reload/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))

# kernel format (device)
# kernel format (device), used to copy into when reloading only
kernel_tensors: LayerTensors | None = None

# track how many restored elements are ready for loading
Expand Down
Loading