diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index f68405d05f87..499b1704eeb0 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -71,9 +71,7 @@ def load_model( # Process weights into kernel format. Note that when using online # quantization, weights are (typically) quantized as they are loaded. - if _has_online_quant(model): - finalize_layerwise_processing(model, model_config) - + finalize_layerwise_processing(model, model_config) process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -87,12 +85,3 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) - - -def _has_online_quant(model: nn.Module): - for module in model.modules(): - quant_method = getattr(module, "quant_method", None) - if getattr(quant_method, "uses_meta_device", False): - return True - - return False diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 32cb1b3d5156..f8b8159fd778 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -6,7 +6,6 @@ from weakref import WeakKeyDictionary import torch -from compressed_tensors import deprecated from vllm.config import ModelConfig from vllm.logger import init_logger @@ -42,6 +41,7 @@ LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( WeakKeyDictionary() ) +ATTENTION_LAYERS = (Attention, MLAAttention) def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo: @@ -162,7 +162,7 @@ def online_process_loader(*args, **kwargs): # Process and copy when all weights are loaded if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator] - layer, (Attention, MLAAttention) + layer, ATTENTION_LAYERS ): _layerwise_process(layer, info) @@ -186,56 +186,47 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon if hasattr(model, "_original_do_torchao_reload"): model._do_torchao_reload = model._original_do_torchao_reload + # Catch non-attention layers which did not process during loading for layer in model.modules(): info = get_layerwise_info(layer) - if not info.can_load(): - info.reset() - continue - # Attention/MLA layers are processed after all other layers - if isinstance(layer, (Attention, MLAAttention)): - if info.load_numel > 0: - raise NotImplementedError( - "Layerwise reloading of Q/K/V scale weights is not implemented yet" - ) - - elif info.kernel_tensors is None: - raise NotImplementedError( - "Layerwise loading of Q/K/V scale weights is not implemented yet" - ) + if info.can_load() and not isinstance(layer, ATTENTION_LAYERS): + # reloading: place kernel tensors back as a smart fallback + if info.load_numel <= 0 and info.kernel_tensors is not None: + logger.warning("%s: Failed to reload", layer.__class__.__name__) + _place_kernel_tensors(layer, info) else: - _place_kernel_tensors(layer, info) - layer.process_weights_after_loading(model_config.dtype) + logger.debug("%s: Delayed processing", layer.__class__.__name__) + _layerwise_process(layer, info) - # No weights were loaded - elif info.load_numel <= 0: - # first load but received no weights. This happens on dummy load - if info.kernel_tensors is None: - materialize_layer(layer) + info.reset() - # reloading: place kernel tensors back as a fallback - else: - logger.warning("%s: Failed to load weights", layer.__class__.__name__) + # Intentionally delay processing for Attention/MLA layers + for layer in model.modules(): + info = get_layerwise_info(layer) + + if info.can_load() and isinstance(layer, ATTENTION_LAYERS): + # reloading: place kernel tensors back as a smart fallback + # unlike non-attention layers, attention scales are typically not loaded + if info.load_numel <= 0 and info.kernel_tensors is not None: _place_kernel_tensors(layer, info) - # Process non-attention layers which did not load all elements. This can happen - # if the created weight has extra padding elements which are not loaded - # Having too many of these delayed layers can lead to excess memory usage - # see Limitations(4) - elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] - logger.debug("%s: Delayed processing", layer.__class__.__name__) - _layerwise_process(layer, info) + else: + _layerwise_process(layer, info) - info.reset() + info.reset() -@deprecated("finalize_layerwise_processing") def finalize_layerwise_reload(*args, **kwargs): finalize_layerwise_processing(*args, **kwargs) -def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): +def _layerwise_process( + layer: torch.nn.Module, + info: LayerReloadingInfo, + model_config: ModelConfig | None = None, +): """ Finalize layer loading after all weights have been buffered. @@ -265,9 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Process weights (quantization, repacking, etc.) # Attention/MLA are processed in `finalize_layerwise_reload` - quant_method = getattr(layer, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - quant_method.process_weights_after_loading(layer) + _process_layer(layer, model_config) # Copy processed values into original tensor storage (preserves cudagraph refs) # this code is a no-op if not reloading (because kernel tensors is empty) @@ -284,6 +273,17 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): logger.debug("%s: Processed", layer.__class__.__name__) +def _process_layer(layer: torch.nn.Module, model_config: ModelConfig | None = None): + if not isinstance(layer, ATTENTION_LAYERS): + quant_method = getattr(layer, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + quant_method.process_weights_after_loading(layer) + + else: + assert model_config is not None, "Must pass model_config to process attention" + layer.process_weights_after_loading(model_config.dtype) + + def _get_original_loader(tensor: torch.Tensor) -> Callable: """Return the weight loader with any layerwise wrappers removed""" loader = _get_weight_loader(tensor) diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index 20d42414504f..b1506fadcc71 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -16,7 +16,7 @@ class LayerReloadingInfo: # model format (meta), populated by `record_metadata_for_reloading` restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) - # kernel format (device) + # kernel format (device), used to copy into when reloading only kernel_tensors: LayerTensors | None = None # track how many restored elements are ready for loading