diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py index 6fcb077c1c73..d031eafe8087 100644 --- a/tests/model_executor/model_loader/test_reload.py +++ b/tests/model_executor/model_loader/test_reload.py @@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] assert add_perp < mul_perp + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize( + "base_model,mul_model,add_model,quantization", + [ + ( + "Qwen/Qwen3-0.6B", + "inference-optimization/Qwen3-0.6B-debug-multiply", + "inference-optimization/Qwen3-0.6B-debug-add", + "fp8", + ), + ( + "inference-optimization/DeepSeek-V3-debug-empty", + "inference-optimization/DeepSeek-V3-debug-multiply", + "inference-optimization/DeepSeek-V3-debug-add", + "fp8", + ), + ( + "Qwen/Qwen3-0.6B", + "inference-optimization/Qwen3-0.6B-debug-multiply", + "inference-optimization/Qwen3-0.6B-debug-add", + "mxfp8", + ), + # ( TODO: support mxfp4 & mla + # "inference-optimization/DeepSeek-V3-debug-empty", + # "inference-optimization/DeepSeek-V3-debug-multiply", + # "inference-optimization/DeepSeek-V3-debug-add", + # "mxfp8", + # ), + ], +) +def test_online_quantize_reload( + base_model, mul_model, add_model, quantization, tp_size, vllm_runner +): + if cuda_device_count_stateless() < tp_size: + pytest.skip(reason="Not enough CUDA devices") + + if quantization == "fp8" and not current_platform.supports_fp8(): + pytest.skip(reason="Requires FP8 support") + + with vllm_runner( + model_name=base_model, + quantization=quantization, + tensor_parallel_size=tp_size, + enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), + enable_prefix_caching=False, + ) as llm: + llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert mul_perp < add_perp + + llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert add_perp < mul_perp diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d758edd9ca50..7d103fb11640 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -73,7 +73,9 @@ cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) -from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight +from vllm.model_executor.model_loader.reload.layerwise import ( + initialize_online_processing, +) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -491,8 +493,8 @@ def apply( class Fp8OnlineLinearMethod(Fp8LinearMethod): - """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint - and quantized the weights during loading.""" + """Online version of Fp8LinearMethod which loads a full precision checkpoint + and quantizes weights during loading.""" uses_meta_device: bool = True @@ -514,84 +516,25 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - param = layer.weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel` around just in case - # there is logic added to vllm in the future which calls a - # weight loader twice - we do not want to re-initialize in - # that case. - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, - # materialized just-in-time in `patched_weight_loader` - device="meta", + device="meta", # materialized and processed during loading dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) + initialize_online_processing(layer) + def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # deferred initialization of randomly initialized weights for the - # `--load_format dummy` feature - if layer.weight.device == torch.device("meta"): - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=layer.weight.weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - initialize_single_dummy_weight(layer.weight) - # TODO(future): support block_quant in online quant path assert not self.block_quant @@ -842,9 +785,6 @@ def _setup_kernel( ) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight @@ -889,9 +829,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1012,86 +949,12 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # save the ids of original w13 and w2 so that we can - # distinguish which one `param` should map to further - # down in this file - layer._w13_weight_orig_id = id(layer.w13_weight) - layer._w2_weight_orig_id = id(layer.w2_weight) - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w13_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w2_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - if id(param) == layer._w13_weight_orig_id: - param = layer.w13_weight - elif id(param) == layer._w2_weight_orig_id: - param = layer.w2_weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel`, - # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` - # around because if EP is on, weight loaders for non-local - # experts will run but not actually copy any elements, and we - # need to not re-initialize in that case. - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, - # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), @@ -1105,91 +968,53 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, hidden_size, intermediate_size_per_partition, - # materialized just-in-time in `patched_weight_loader` - device="meta", + device="meta", # materialized and processed during loading dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() # BIASES (for models like GPT-OSS that have biased MoE) if self.moe.has_bias: - # Use the original weight_loader (not patched) for biases - orig_extra_weight_attrs = dict(extra_weight_attrs) - orig_extra_weight_attrs["weight_loader"] = weight_loader w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, + device="meta", # materialized and processed during loading dtype=layer.orig_dtype, ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, orig_extra_weight_attrs) + set_weight_attrs(w13_bias, extra_weight_attrs) + w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype), + torch.zeros( + num_experts, + hidden_size, + device="meta", # materialized and processed during loading + dtype=layer.orig_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, orig_extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_bias, extra_weight_attrs) - layer.w13_input_scale = None - layer.w2_input_scale = None + initialize_online_processing(layer) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # deferred initialization of randomly initialized weights for the - # `--load_format dummy` feature - if layer.w13_weight.device == torch.device("meta"): - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w13_weight, {"weight_loader": layer.w13_weight.weight_loader} - ) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - initialize_single_dummy_weight(layer.w13_weight) - if layer.w2_weight.device == torch.device("meta"): - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w2_weight, {"weight_loader": layer.w2_weight.weight_loader} - ) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - initialize_single_dummy_weight(layer.w2_weight) - - # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale + w13_scale = torch.ones(layer.num_experts, dtype=torch.float32) + w2_scale = torch.ones(layer.num_experts, dtype=torch.float32) + layer.w13_input_scale = None + layer.w2_input_scale = None for expert in range(layer.local_num_experts): w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( @@ -1206,8 +1031,8 @@ def process_weights_after_loading(self, layer: Module) -> None: w2, w13_scale, w2_scale, - layer.w13_input_scale, - layer.w2_input_scale, + w13_input_scale=layer.w13_input_scale, + w2_input_scale=layer.w2_input_scale, ) # Prevent duplicate processing (e.g., during weight reload) diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py index 5b4564bea31c..bd29f272bd10 100644 --- a/vllm/model_executor/layers/quantization/mxfp8.py +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -337,6 +337,8 @@ def process_weights_after_loading(self, layer: Module) -> None: w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale + layer.w13_input_scale = None + layer.w2_input_scale = None w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight) w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight) diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index e3b965db8aaf..f68405d05f87 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.reload import finalize_layerwise_processing from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -49,16 +50,13 @@ def load_model( device_config.device if load_config.device is None else load_config.device ) target_device = torch.device(load_device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model( - vllm_config=vllm_config, model_config=model_config, prefix=prefix - ) - + with set_default_torch_dtype(model_config.dtype), target_device: + model = initialize_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) # Log peak GPU memory after loading weights. This is needed @@ -71,6 +69,11 @@ def load_model( scope="local", ) + # Process weights into kernel format. Note that when using online + # quantization, weights are (typically) quantized as they are loaded. + if _has_online_quant(model): + finalize_layerwise_processing(model, model_config) + process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) + + +def _has_online_quant(model: nn.Module): + for module in model.modules(): + quant_method = getattr(module, "quant_method", None) + if getattr(quant_method, "uses_meta_device", False): + return True + + return False diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 156071f1dae3..5a8b5de6f553 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch import torch.nn as nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor +from vllm.model_executor.model_loader.reload.utils import get_layer_tensors from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights @@ -23,6 +26,12 @@ def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + # materialize meta tensors as part of online quantization lifecycle + for layer in model.modules(): + for name, param in get_layer_tensors(layer).items(): + if param.device == torch.device("meta"): + setattr(layer, name, materialize_meta_tensor(param)) + # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model, model_config) diff --git a/vllm/model_executor/model_loader/reload/__init__.py b/vllm/model_executor/model_loader/reload/__init__.py index ea0b0bc06ad9..56a9d88ac4e4 100644 --- a/vllm/model_executor/model_loader/reload/__init__.py +++ b/vllm/model_executor/model_loader/reload/__init__.py @@ -21,12 +21,14 @@ __all__ = [ "record_metadata_for_reloading", "initialize_layerwise_reload", + "finalize_layerwise_processing", "finalize_layerwise_reload", "set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights", ] from .layerwise import ( + finalize_layerwise_processing, finalize_layerwise_reload, initialize_layerwise_reload, record_metadata_for_reloading, diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 21795e63995e..2a174673b91b 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -28,6 +28,7 @@ "get_layerwise_info", "record_metadata_for_reloading", "initialize_layerwise_reload", + "finalize_layerwise_processing", "finalize_layerwise_reload", ] @@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module): info = get_layerwise_info(layer) # Skip if the layer has already been initialized - if info.can_process(): + if info.can_load(): continue # Save current tensors for later copying @@ -98,15 +99,21 @@ def initialize_layerwise_reload(model: torch.nn.Module): # Restore layer parameters/buffers onto meta device restore_layer_on_meta(layer, info) - # Track loading progress to determine when to process/copy - info.load_numel = 0 - info.load_numel_total = get_layer_size(layer) + initialize_online_processing(layer) - # Wrap each parameter's weight loader - # Note that nested wrapping will occur for shared tensors - for name, tensor in get_layer_tensors(layer).items(): - if _get_weight_loader(tensor).__name__ != "online_process_loader": - tensor.weight_loader = make_online_process_loader(layer, name) + +def initialize_online_processing(layer: torch.nn.Module): + info = get_layerwise_info(layer) + + # Track loading progress to determine when to process/copy + info.load_numel = 0 + info.load_numel_total = get_layer_size(layer) + + # Wrap each parameter's weight loader + # Note that nested wrapping will occur for shared tensors + for name, tensor in get_layer_tensors(layer).items(): + if _get_weight_loader(tensor).__name__ != "online_process_loader": + tensor.weight_loader = make_online_process_loader(layer, name) def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable: @@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla @wraps(original_loader, assigned=("__doc__", "__annotations__")) def online_process_loader(*args, **kwargs): - if not info.can_process(): + if not info.can_load(): # Unfortunately, some qconfigs are set up to load the same weight # multiple times. For example, CT_WNA16 loads `weight_shape` for # each of the qkv partitions. This results in layers loading extra @@ -140,7 +147,7 @@ def online_process_loader(*args, **kwargs): bound_args = loader_signature.bind(*args, **kwargs) bound_args.apply_defaults() - # Cache loaded weights, track loading progress + # Buffer loaded weights, track loading progress info.loaded_weights.append((param_name, bound_args)) num_loaded, ret = get_numel_loaded(original_loader, bound_args) info.load_numel += num_loaded @@ -163,19 +170,26 @@ def online_process_loader(*args, **kwargs): return online_process_loader -def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig): +def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelConfig): """ - Remove the outermost layer of weight loading wrappers. + Apply processing to any layers which were not layerwise processed during loading. + This includes attention layers and layers which have weight elements which are not + loaded (due to padding). This function should be applied after `initialize_layerwise_reload` is applied unwrap the layerwise weight loaders. - Also processes Attention/MLA layers, which must be processed after all other layers + :param model: model to finalize processing for + :param model_config: config needed for applying processing to attention layers """ - model._do_torchao_reload = model._original_do_torchao_reload + if hasattr(model, "_original_do_torchao_reload"): + model._do_torchao_reload = model._original_do_torchao_reload for layer in model.modules(): info = get_layerwise_info(layer) + if not info.can_load(): + info.reset() + continue # Attention/MLA layers are processed after all other layers if isinstance(layer, (Attention, MLAAttention)): @@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) "Layerwise reloading of Q/K/V scale weights is not implemented yet" ) + elif info.kernel_tensors is None: + raise NotImplementedError( + "Layerwise loading of Q/K/V scale weights is not implemented yet" + ) + else: _place_kernel_tensors(layer, info) layer.process_weights_after_loading(model_config.dtype) - # No weights were loaded, place kernel tensors back - elif info.can_process() and info.load_numel <= 0: - _place_kernel_tensors(layer, info) + # No weights were loaded + elif info.load_numel <= 0: + # first load but received no weights. This happens on dummy load + if info.kernel_tensors is None: + materialize_layer(layer) + + # reloading: place kernel tensors back as a fallback + else: + logger.warning("%s: Failed to load weights", layer.__class__.__name__) + _place_kernel_tensors(layer, info) # Process non-attention layers which did not load all elements. This can happen # if the created weight has extra padding elements which are not loaded - # Having too many of these delayed layers can lead to execess memory usage + # Having too many of these delayed layers can lead to excess memory usage # see Limitations(4) elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] logger.debug("%s: Delayed processing", layer.__class__.__name__) @@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) info.reset() +def finalize_layerwise_reload(*args, **kwargs): + finalize_layerwise_processing(*args, **kwargs) + + def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): """ - Finalize layer loading after all weights have been cached. + Finalize layer loading after all weights have been buffered. This function: 1. Materializes the layer onto the target device - 2. Loads all cached weights + 2. Loads all buffered weights 3. Runs quantization processing if applicable 4. Copies processed values back to original tensor storage """ # Materialize layer tensors onto device materialize_layer(layer) - # Reset FP8 online quantization flag so process_weights_after_loading + # Reset online quantization flag so process_weights_after_loading # will run again during reload if hasattr(layer, "_already_called_process_weights_after_loading"): delattr(layer, "_already_called_process_weights_after_loading") @@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) - # Load all cached weights into materialized layer (using original loaders) + # Load all buffered weights into materialized layer (using original loaders) for name, args in info.loaded_weights: param = getattr(layer, name) args.arguments["param"] = param @@ -239,13 +269,14 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Copy processed values into original tensor storage (preserves cudagraph refs) # this code is a no-op if not reloading (because kernel tensors is empty) - parameters, buffers = info.kernel_tensors - for name, param in parameters.items(): - param.data.copy_(getattr(layer, name)) - for name, buffer in buffers.items(): - buffer.data.copy_(getattr(layer, name)) + if info.kernel_tensors is not None: + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) - _place_kernel_tensors(layer, info) + _place_kernel_tensors(layer, info) info.reset() logger.debug("%s: Processed", layer.__class__.__name__) @@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): for name in get_layer_tensors(layer): delattr(layer, name) + assert info.kernel_tensors is not None parameters, buffers = info.kernel_tensors for name, param in parameters.items(): layer.register_parameter(name, param) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index af20236d1c9d..138b9f01d69b 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None: setattr(layer, name, materialize_meta_tensor(tensor)) -class MetaCopyCounter(TorchDispatchMode): +class CopyCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`. @@ -122,7 +122,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func is torch.ops.aten.copy_.default and args[0].device.type == "meta": + if func is torch.ops.aten.copy_.default: assert args[0].numel() == args[1].numel() self.copied_numel += args[0].numel() @@ -140,7 +140,6 @@ def get_numel_loaded( :return: number of elements loaded by the weight loader, the return value of the weight loader """ - assert args.arguments["param"].device.type == "meta" - with MetaCopyCounter() as counter: + with CopyCounter() as counter: return_value = weight_loader(*args.args, **args.kwargs) return counter.copied_numel, return_value diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index a7edbe79a75e..b1506fadcc71 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -16,8 +16,8 @@ class LayerReloadingInfo: # model format (meta), populated by `record_metadata_for_reloading` restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) - # kernel format (device) - kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {})) + # kernel format (device), used to copy into when reloading only + kernel_tensors: LayerTensors | None = None # track how many restored elements are ready for loading load_numel: int = 0 @@ -29,5 +29,5 @@ class LayerReloadingInfo: def reset(self): self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] - def can_process(self) -> bool: + def can_load(self) -> bool: return self.load_numel_total is not None diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5986eb01b675..2350e17e740f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1323,25 +1323,11 @@ def initialize_dummy_weights( is fixed, the random values generated by this function only depends on the parameter's number of elements and its data type. """ - - # Check if any module uses online quantization with meta device weights. - # If so, we'll skip initializing params on meta device since they'll be - # handled in `process_weights_after_loading`. - def uses_meta_device(module: torch.nn.Module) -> bool: - quant_method = getattr(module, "quant_method", None) - return getattr(quant_method, "uses_meta_device", False) - - has_online_quant = any(uses_meta_device(m) for m in model.modules()) - for param in model.state_dict().values(): - if has_online_quant and param.device == torch.device("meta"): - # For online quantization, weights are created on meta device and - # dummy weight init will happen in `process_weights_after_loading`. - continue - initialize_single_dummy_weight(param, low, high, seed) +@torch.no_grad() def initialize_single_dummy_weight( param: torch.Tensor, low: float = -1e-3,