diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 80348edcc350..d0b71ebf69c7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -544,62 +544,16 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - param = layer.weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel` around just in case - # there is logic added to vllm in the future which calls a - # weight loader twice - we do not want to re-initialize in - # that case. - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, - # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) # stash the correct device for `patched_weight_loader` layer._load_device = torch.get_default_device() @@ -1057,86 +1011,12 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # save the ids of original w13 and w2 so that we can - # distinguish which one `param` should map to further - # down in this file - layer._w13_weight_orig_id = id(layer.w13_weight) - layer._w2_weight_orig_id = id(layer.w2_weight) - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w13_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w2_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - if id(param) == layer._w13_weight_orig_id: - param = layer.w13_weight - elif id(param) == layer._w2_weight_orig_id: - param = layer.w2_weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel`, - # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` - # around because if EP is on, weight loaders for non-local - # experts will run but not actually copy any elements, and we - # need to not re-initialize in that case. - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, - # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), @@ -1150,7 +1030,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, hidden_size, intermediate_size_per_partition, - # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), @@ -1158,12 +1037,11 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. + # TODO: technically should be meta w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -1179,34 +1057,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - - # deferred initialization of randomly initialized weights for the - # `--load_format dummy` feature - if layer.w13_weight.device == torch.device("meta"): - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w13_weight, {"weight_loader": layer.w13_weight.weight_loader} - ) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - initialize_single_dummy_weight(layer.w13_weight) - if layer.w2_weight.device == torch.device("meta"): - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w2_weight, {"weight_loader": layer.w2_weight.weight_loader} - ) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - initialize_single_dummy_weight(layer.w2_weight) - # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) @@ -1233,9 +1083,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale, ) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 77fbb41f0371..f213f30cfbf0 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,10 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.reload.layerwise import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -55,11 +59,17 @@ def load_model( vllm_config=vllm_config, model_config=model_config, prefix=prefix ) - log_model_inspection(model) + log_model_inspection(model) - logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model, model_config) + logger.debug("Loading weights on %s ...", load_device) + if not _is_online_quant(vllm_config, model_config): + # load weights eagerly, which may lead to excess memory usage + self.load_weights(model, model_config) + else: + # load weights layerwise, which minimizes peak memory usage + initialize_layerwise_reload(model, is_reloading=False) + self.load_weights(model, model_config) + finalize_layerwise_reload(model, model_config) # Log peak GPU memory after loading weights. This is needed # to have test coverage on peak memory for online quantization. @@ -84,3 +94,14 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) + + +def _is_online_quant(vllm_config: VllmConfig, model_config: ModelConfig) -> bool: + quant_config = vllm_config.quant_config + return ( + # TODO(future): add other online quant paths here + model_config.quantization == "fp8" + and quant_config is not None + and hasattr(quant_config, "is_checkpoint_fp8_serialized") + and not quant_config.is_checkpoint_fp8_serialized + ) diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 21795e63995e..81b671f238f3 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -66,7 +66,7 @@ def record_metadata_for_reloading(model: torch.nn.Module): @torch.no_grad() -def initialize_layerwise_reload(model: torch.nn.Module): +def initialize_layerwise_reload(model: torch.nn.Module, is_reloading: bool = True): """ Set up layerwise weight loading with deferred processing. @@ -92,8 +92,8 @@ def initialize_layerwise_reload(model: torch.nn.Module): if info.can_process(): continue - # Save current tensors for later copying - info.kernel_tensors = get_layer_params_buffers(layer) + # Save current tensors for later copying (only for reloading) + info.kernel_tensors = get_layer_params_buffers(layer) if is_reloading else None # Restore layer parameters/buffers onto meta device restore_layer_on_meta(layer, info) @@ -178,22 +178,38 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) info = get_layerwise_info(layer) # Attention/MLA layers are processed after all other layers + # TODO(@kylesayrs): process attention in a separate for loop if isinstance(layer, (Attention, MLAAttention)): if info.load_numel > 0: raise NotImplementedError( "Layerwise reloading of Q/K/V scale weights is not implemented yet" ) + # Loading: initialize model tensors with empty values + elif info.kernel_tensors is None: + materialize_layer(layer) + + # Reloading: place kernel tensors back (assumed to be empty) else: _place_kernel_tensors(layer, info) - layer.process_weights_after_loading(model_config.dtype) - # No weights were loaded, place kernel tensors back + layer.process_weights_after_loading(model_config.dtype) + + # Non-attention: No weights were loaded elif info.can_process() and info.load_numel <= 0: - _place_kernel_tensors(layer, info) + if info.load_numel_total is not None and info.load_numel_total > 0: + logger.warning("%s: Did not load weights", layer.__class__.__name__) + + # Loading: initialize model tensors with empty values + if info.kernel_tensors is None: + materialize_layer(layer) + + # Reloading: place kernel tensors back (assumed to be empty) + else: + _place_kernel_tensors(layer, info) - # Process non-attention layers which did not load all elements. This can happen - # if the created weight has extra padding elements which are not loaded + # Non-attention: Some weights were loaded + # This can happen if the weight has extra padding elements which are not loaded # Having too many of these delayed layers can lead to execess memory usage # see Limitations(4) elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] @@ -216,11 +232,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Materialize layer tensors onto device materialize_layer(layer) - # Reset FP8 online quantization flag so process_weights_after_loading - # will run again during reload - if hasattr(layer, "_already_called_process_weights_after_loading"): - delattr(layer, "_already_called_process_weights_after_loading") - # Unwrap layerwise loading wrappers for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) @@ -237,15 +248,15 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): if isinstance(quant_method, QuantizeMethodBase): quant_method.process_weights_after_loading(layer) - # Copy processed values into original tensor storage (preserves cudagraph refs) - # this code is a no-op if not reloading (because kernel tensors is empty) - parameters, buffers = info.kernel_tensors - for name, param in parameters.items(): - param.data.copy_(getattr(layer, name)) - for name, buffer in buffers.items(): - buffer.data.copy_(getattr(layer, name)) + # Reloading: copy processed values into original tensors (preserves cudagraph refs) + if info.kernel_tensors is not None: + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) - _place_kernel_tensors(layer, info) + _place_kernel_tensors(layer, info) info.reset() logger.debug("%s: Processed", layer.__class__.__name__) @@ -268,6 +279,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): for name in get_layer_tensors(layer): delattr(layer, name) + assert info.kernel_tensors is not None parameters, buffers = info.kernel_tensors for name, param in parameters.items(): layer.register_parameter(name, param) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index af20236d1c9d..07c23aa88c02 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -19,7 +19,19 @@ "get_numel_loaded", ] -SKIP_MODULES: set[str] = {"HadamardTransform"} + +def _is_skip_module(module: torch.nn.Module): + """ + `HadamardTransform`: uses `SharedWeightParameter` which does not have `.data` attr + `RotaryEmbedding`: is not expected to load, instead use values it was inited with + """ + from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 + HadamardTransform, + ) + from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding + + return isinstance(module, (RotaryEmbedding, HadamardTransform)) + SKIP_TENSORS: set[str] = { "_expert_map", @@ -55,7 +67,7 @@ def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors: - if layer.__class__.__name__ in SKIP_MODULES: + if _is_skip_module(layer): return ({}, {}) params, buffers = get_layer_params_buffers(layer) @@ -75,7 +87,7 @@ def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors: def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo): """Restore a layer to model format with tensors on the meta device""" - if layer.__class__.__name__ in SKIP_MODULES: + if _is_skip_module(layer): return for name in get_layer_tensors(layer): @@ -96,7 +108,7 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo): def materialize_layer(layer: torch.nn.Module) -> None: """Materialize all meta tensors in a layer to actual tensors.""" - if layer.__class__.__name__ in SKIP_MODULES: + if _is_skip_module(layer): return for name, tensor in get_layer_tensors(layer).items(): diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index a7edbe79a75e..c0134ab9551f 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -17,7 +17,7 @@ class LayerReloadingInfo: restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) # kernel format (device) - kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {})) + kernel_tensors: LayerTensors | None = None # track how many restored elements are ready for loading load_numel: int = 0