diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py index d031eafe8087..6fcb077c1c73 100644 --- a/tests/model_executor/model_loader/test_reload.py +++ b/tests/model_executor/model_loader/test_reload.py @@ -148,60 +148,3 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] assert add_perp < mul_perp - - -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize( - "base_model,mul_model,add_model,quantization", - [ - ( - "Qwen/Qwen3-0.6B", - "inference-optimization/Qwen3-0.6B-debug-multiply", - "inference-optimization/Qwen3-0.6B-debug-add", - "fp8", - ), - ( - "inference-optimization/DeepSeek-V3-debug-empty", - "inference-optimization/DeepSeek-V3-debug-multiply", - "inference-optimization/DeepSeek-V3-debug-add", - "fp8", - ), - ( - "Qwen/Qwen3-0.6B", - "inference-optimization/Qwen3-0.6B-debug-multiply", - "inference-optimization/Qwen3-0.6B-debug-add", - "mxfp8", - ), - # ( TODO: support mxfp4 & mla - # "inference-optimization/DeepSeek-V3-debug-empty", - # "inference-optimization/DeepSeek-V3-debug-multiply", - # "inference-optimization/DeepSeek-V3-debug-add", - # "mxfp8", - # ), - ], -) -def test_online_quantize_reload( - base_model, mul_model, add_model, quantization, tp_size, vllm_runner -): - if cuda_device_count_stateless() < tp_size: - pytest.skip(reason="Not enough CUDA devices") - - if quantization == "fp8" and not current_platform.supports_fp8(): - pytest.skip(reason="Requires FP8 support") - - with vllm_runner( - model_name=base_model, - quantization=quantization, - tensor_parallel_size=tp_size, - enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), - enable_prefix_caching=False, - ) as llm: - llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) - mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] - add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] - assert mul_perp < add_perp - - llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model}) - mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] - add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] - assert add_perp < mul_perp diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9e717da43fbd..e01148313eb7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -73,9 +73,7 @@ cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) -from vllm.model_executor.model_loader.reload.layerwise import ( - initialize_online_processing, -) +from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -498,8 +496,8 @@ def apply( class Fp8OnlineLinearMethod(Fp8LinearMethod): - """Online version of Fp8LinearMethod which loads a full precision checkpoint - and quantizes weights during loading.""" + """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint + and quantized the weights during loading.""" uses_meta_device: bool = True @@ -521,25 +519,84 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None + # WEIGHT + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + + # when the first `loaded_weight` is about to be + # loaded to `param`, materialize `param` just-in-time + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=patched_weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + del layer._load_device + + # refresh the reference to `param` to reflect just-in-time + # materialization + param = layer.weight + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + # Note that we keep `layer._loaded_numel` around just in case + # there is logic added to vllm in the future which calls a + # weight loader twice - we do not want to re-initialize in + # that case. + + return res + weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, - device="meta", # materialized and processed during loading + # materialized just-in-time in `patched_weight_loader` + device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=weight_loader, + weight_loader=patched_weight_loader, ) + # stash the correct device for `patched_weight_loader` + layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) - initialize_online_processing(layer) - def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return + # deferred initialization of randomly initialized weights for the + # `--load_format dummy` feature + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + # TODO(future): support block_quant in online quant path assert not self.block_quant @@ -788,6 +845,9 @@ def _setup_kernel( ) def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight @@ -832,6 +892,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -950,12 +1013,86 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None + # We are doing online quantization, patch the weight loaded + # to call `process_weights_after_loading` in a streaming fashion + # as soon as the last weight chunk is loaded. + weight_loader = extra_weight_attrs["weight_loader"] + # create a new holder to prevent modifying behavior of any other + # objects which might depend on the old one + new_extra_weight_attrs = extra_weight_attrs + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # add a counter to track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + + # save the ids of original w13 and w2 so that we can + # distinguish which one `param` should map to further + # down in this file + layer._w13_weight_orig_id = id(layer.w13_weight) + layer._w2_weight_orig_id = id(layer.w2_weight) + + # when the first `loaded_weight` is about to be + # loaded to `param`, materialize `param` just-in-time + + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs(w13_weight, extra_weight_attrs) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs(w2_weight, extra_weight_attrs) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + del layer._load_device + + # refresh the reference to `param` to reflect just-in-time + # materialization + if id(param) == layer._w13_weight_orig_id: + param = layer.w13_weight + elif id(param) == layer._w2_weight_orig_id: + param = layer.w2_weight + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Prevent the usual `process_weights_after_loading` call + # from doing anything + layer._already_called_process_weights_after_loading = True + + # Note that we keep `layer._loaded_numel`, + # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` + # around because if EP is on, weight loaders for non-local + # experts will run but not actually copy any elements, and we + # need to not re-initialize in that case. + + return res + + new_extra_weight_attrs["weight_loader"] = patched_weight_loader + extra_weight_attrs = new_extra_weight_attrs + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, + # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), @@ -969,53 +1106,91 @@ def create_weights( num_experts, hidden_size, intermediate_size_per_partition, - device="meta", # materialized and processed during loading + # materialized just-in-time in `patched_weight_loader` + device="meta", dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + # stash the correct device for `patched_weight_loader` + layer._load_device = torch.get_default_device() # BIASES (for models like GPT-OSS that have biased MoE) if self.moe.has_bias: + # Use the original weight_loader (not patched) for biases + orig_extra_weight_attrs = dict(extra_weight_attrs) + orig_extra_weight_attrs["weight_loader"] = weight_loader w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, - device="meta", # materialized and processed during loading dtype=layer.orig_dtype, ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - + set_weight_attrs(w13_bias, orig_extra_weight_attrs) w2_bias = torch.nn.Parameter( - torch.zeros( - num_experts, - hidden_size, - device="meta", # materialized and processed during loading - dtype=layer.orig_dtype, - ), + torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) + set_weight_attrs(w2_bias, orig_extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) - initialize_online_processing(layer) + layer.w13_input_scale = None + layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return + # deferred initialization of randomly initialized weights for the + # `--load_format dummy` feature + if layer.w13_weight.device == torch.device("meta"): + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w13_weight, {"weight_loader": layer.w13_weight.weight_loader} + ) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + initialize_single_dummy_weight(layer.w13_weight) + if layer.w2_weight.device == torch.device("meta"): + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w2_weight, {"weight_loader": layer.w2_weight.weight_loader} + ) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + initialize_single_dummy_weight(layer.w2_weight) + + # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) - w13_scale = torch.ones(layer.num_experts, dtype=torch.float32) - w2_scale = torch.ones(layer.num_experts, dtype=torch.float32) - layer.w13_input_scale = None - layer.w2_input_scale = None + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale for expert in range(layer.local_num_experts): w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( @@ -1032,8 +1207,8 @@ def process_weights_after_loading(self, layer: Module) -> None: w2, w13_scale, w2_scale, - w13_input_scale=layer.w13_input_scale, - w2_input_scale=layer.w2_input_scale, + layer.w13_input_scale, + layer.w2_input_scale, ) # Prevent duplicate processing (e.g., during weight reload) diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py index bd29f272bd10..5b4564bea31c 100644 --- a/vllm/model_executor/layers/quantization/mxfp8.py +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -337,8 +337,6 @@ def process_weights_after_loading(self, layer: Module) -> None: w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale - layer.w13_input_scale = None - layer.w2_input_scale = None w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight) w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight) diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index d6c38664fde6..0e222139abc8 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,7 +9,6 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.reload import finalize_layerwise_processing from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -61,6 +60,7 @@ def load_model( log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device) + # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) # Log peak GPU memory after loading weights. This is needed @@ -73,11 +73,6 @@ def load_model( scope="local", ) - # Process weights into kernel format. Note that when using online - # quantization, weights are (typically) quantized as they are loaded. - if _has_online_quant(model): - finalize_layerwise_processing(model, model_config) - process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -91,12 +86,3 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) - - -def _has_online_quant(model: nn.Module): - for module in model.modules(): - quant_method = getattr(module, "quant_method", None) - if getattr(quant_method, "uses_meta_device", False): - return True - - return False diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 5a8b5de6f553..156071f1dae3 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch import torch.nn as nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor -from vllm.model_executor.model_loader.reload.utils import get_layer_tensors from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights @@ -26,12 +23,6 @@ def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - # materialize meta tensors as part of online quantization lifecycle - for layer in model.modules(): - for name, param in get_layer_tensors(layer).items(): - if param.device == torch.device("meta"): - setattr(layer, name, materialize_meta_tensor(param)) - # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model, model_config) diff --git a/vllm/model_executor/model_loader/reload/__init__.py b/vllm/model_executor/model_loader/reload/__init__.py index 56a9d88ac4e4..ea0b0bc06ad9 100644 --- a/vllm/model_executor/model_loader/reload/__init__.py +++ b/vllm/model_executor/model_loader/reload/__init__.py @@ -21,14 +21,12 @@ __all__ = [ "record_metadata_for_reloading", "initialize_layerwise_reload", - "finalize_layerwise_processing", "finalize_layerwise_reload", "set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights", ] from .layerwise import ( - finalize_layerwise_processing, finalize_layerwise_reload, initialize_layerwise_reload, record_metadata_for_reloading, diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 2a174673b91b..21795e63995e 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -28,7 +28,6 @@ "get_layerwise_info", "record_metadata_for_reloading", "initialize_layerwise_reload", - "finalize_layerwise_processing", "finalize_layerwise_reload", ] @@ -90,7 +89,7 @@ def initialize_layerwise_reload(model: torch.nn.Module): info = get_layerwise_info(layer) # Skip if the layer has already been initialized - if info.can_load(): + if info.can_process(): continue # Save current tensors for later copying @@ -99,21 +98,15 @@ def initialize_layerwise_reload(model: torch.nn.Module): # Restore layer parameters/buffers onto meta device restore_layer_on_meta(layer, info) - initialize_online_processing(layer) + # Track loading progress to determine when to process/copy + info.load_numel = 0 + info.load_numel_total = get_layer_size(layer) - -def initialize_online_processing(layer: torch.nn.Module): - info = get_layerwise_info(layer) - - # Track loading progress to determine when to process/copy - info.load_numel = 0 - info.load_numel_total = get_layer_size(layer) - - # Wrap each parameter's weight loader - # Note that nested wrapping will occur for shared tensors - for name, tensor in get_layer_tensors(layer).items(): - if _get_weight_loader(tensor).__name__ != "online_process_loader": - tensor.weight_loader = make_online_process_loader(layer, name) + # Wrap each parameter's weight loader + # Note that nested wrapping will occur for shared tensors + for name, tensor in get_layer_tensors(layer).items(): + if _get_weight_loader(tensor).__name__ != "online_process_loader": + tensor.weight_loader = make_online_process_loader(layer, name) def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable: @@ -125,7 +118,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla @wraps(original_loader, assigned=("__doc__", "__annotations__")) def online_process_loader(*args, **kwargs): - if not info.can_load(): + if not info.can_process(): # Unfortunately, some qconfigs are set up to load the same weight # multiple times. For example, CT_WNA16 loads `weight_shape` for # each of the qkv partitions. This results in layers loading extra @@ -147,7 +140,7 @@ def online_process_loader(*args, **kwargs): bound_args = loader_signature.bind(*args, **kwargs) bound_args.apply_defaults() - # Buffer loaded weights, track loading progress + # Cache loaded weights, track loading progress info.loaded_weights.append((param_name, bound_args)) num_loaded, ret = get_numel_loaded(original_loader, bound_args) info.load_numel += num_loaded @@ -170,26 +163,19 @@ def online_process_loader(*args, **kwargs): return online_process_loader -def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelConfig): +def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig): """ - Apply processing to any layers which were not layerwise processed during loading. - This includes attention layers and layers which have weight elements which are not - loaded (due to padding). + Remove the outermost layer of weight loading wrappers. This function should be applied after `initialize_layerwise_reload` is applied unwrap the layerwise weight loaders. - :param model: model to finalize processing for - :param model_config: config needed for applying processing to attention layers + Also processes Attention/MLA layers, which must be processed after all other layers """ - if hasattr(model, "_original_do_torchao_reload"): - model._do_torchao_reload = model._original_do_torchao_reload + model._do_torchao_reload = model._original_do_torchao_reload for layer in model.modules(): info = get_layerwise_info(layer) - if not info.can_load(): - info.reset() - continue # Attention/MLA layers are processed after all other layers if isinstance(layer, (Attention, MLAAttention)): @@ -198,29 +184,17 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon "Layerwise reloading of Q/K/V scale weights is not implemented yet" ) - elif info.kernel_tensors is None: - raise NotImplementedError( - "Layerwise loading of Q/K/V scale weights is not implemented yet" - ) - else: _place_kernel_tensors(layer, info) layer.process_weights_after_loading(model_config.dtype) - # No weights were loaded - elif info.load_numel <= 0: - # first load but received no weights. This happens on dummy load - if info.kernel_tensors is None: - materialize_layer(layer) - - # reloading: place kernel tensors back as a fallback - else: - logger.warning("%s: Failed to load weights", layer.__class__.__name__) - _place_kernel_tensors(layer, info) + # No weights were loaded, place kernel tensors back + elif info.can_process() and info.load_numel <= 0: + _place_kernel_tensors(layer, info) # Process non-attention layers which did not load all elements. This can happen # if the created weight has extra padding elements which are not loaded - # Having too many of these delayed layers can lead to excess memory usage + # Having too many of these delayed layers can lead to execess memory usage # see Limitations(4) elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] logger.debug("%s: Delayed processing", layer.__class__.__name__) @@ -229,24 +203,20 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon info.reset() -def finalize_layerwise_reload(*args, **kwargs): - finalize_layerwise_processing(*args, **kwargs) - - def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): """ - Finalize layer loading after all weights have been buffered. + Finalize layer loading after all weights have been cached. This function: 1. Materializes the layer onto the target device - 2. Loads all buffered weights + 2. Loads all cached weights 3. Runs quantization processing if applicable 4. Copies processed values back to original tensor storage """ # Materialize layer tensors onto device materialize_layer(layer) - # Reset online quantization flag so process_weights_after_loading + # Reset FP8 online quantization flag so process_weights_after_loading # will run again during reload if hasattr(layer, "_already_called_process_weights_after_loading"): delattr(layer, "_already_called_process_weights_after_loading") @@ -255,7 +225,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) - # Load all buffered weights into materialized layer (using original loaders) + # Load all cached weights into materialized layer (using original loaders) for name, args in info.loaded_weights: param = getattr(layer, name) args.arguments["param"] = param @@ -269,14 +239,13 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Copy processed values into original tensor storage (preserves cudagraph refs) # this code is a no-op if not reloading (because kernel tensors is empty) - if info.kernel_tensors is not None: - parameters, buffers = info.kernel_tensors - for name, param in parameters.items(): - param.data.copy_(getattr(layer, name)) - for name, buffer in buffers.items(): - buffer.data.copy_(getattr(layer, name)) + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) - _place_kernel_tensors(layer, info) + _place_kernel_tensors(layer, info) info.reset() logger.debug("%s: Processed", layer.__class__.__name__) @@ -299,7 +268,6 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): for name in get_layer_tensors(layer): delattr(layer, name) - assert info.kernel_tensors is not None parameters, buffers = info.kernel_tensors for name, param in parameters.items(): layer.register_parameter(name, param) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index 138b9f01d69b..af20236d1c9d 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None: setattr(layer, name, materialize_meta_tensor(tensor)) -class CopyCounter(TorchDispatchMode): +class MetaCopyCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`. @@ -122,7 +122,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func is torch.ops.aten.copy_.default: + if func is torch.ops.aten.copy_.default and args[0].device.type == "meta": assert args[0].numel() == args[1].numel() self.copied_numel += args[0].numel() @@ -140,6 +140,7 @@ def get_numel_loaded( :return: number of elements loaded by the weight loader, the return value of the weight loader """ - with CopyCounter() as counter: + assert args.arguments["param"].device.type == "meta" + with MetaCopyCounter() as counter: return_value = weight_loader(*args.args, **args.kwargs) return counter.copied_numel, return_value diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index b1506fadcc71..a7edbe79a75e 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -16,8 +16,8 @@ class LayerReloadingInfo: # model format (meta), populated by `record_metadata_for_reloading` restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) - # kernel format (device), used to copy into when reloading only - kernel_tensors: LayerTensors | None = None + # kernel format (device) + kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {})) # track how many restored elements are ready for loading load_numel: int = 0 @@ -29,5 +29,5 @@ class LayerReloadingInfo: def reset(self): self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] - def can_load(self) -> bool: + def can_process(self) -> bool: return self.load_numel_total is not None diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 37023d3f1f5c..fbaaef59de0b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1323,11 +1323,25 @@ def initialize_dummy_weights( is fixed, the random values generated by this function only depends on the parameter's number of elements and its data type. """ + + # Check if any module uses online quantization with meta device weights. + # If so, we'll skip initializing params on meta device since they'll be + # handled in `process_weights_after_loading`. + def uses_meta_device(module: torch.nn.Module) -> bool: + quant_method = getattr(module, "quant_method", None) + return getattr(quant_method, "uses_meta_device", False) + + has_online_quant = any(uses_meta_device(m) for m in model.modules()) + for param in model.state_dict().values(): + if has_online_quant and param.device == torch.device("meta"): + # For online quantization, weights are created on meta device and + # dummy weight init will happen in `process_weights_after_loading`. + continue + initialize_single_dummy_weight(param, low, high, seed) -@torch.no_grad() def initialize_single_dummy_weight( param: torch.Tensor, low: float = -1e-3,