diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d758edd9ca50..01571b946c81 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -5,7 +5,6 @@ import torch from torch.nn import Module -from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -223,26 +222,6 @@ def get_cache_scale(self, name: str) -> str | None: return None -class CopyNumelCounter(TorchDispatchMode): - """ - Tracks total number of elements modified with `copy_`. Useful for keeping - track of weight loading where underlying weights can be arbitrarily - transformed (such as with `narrow`) before calling copy. - """ - - def __init__(self): - super().__init__() - self.copied_numel = 0 - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - out = func(*args, **kwargs) - if func == torch.ops.aten.copy_.default: - self.copied_numel += args[0].numel() - return out - - def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: """Copies any attrs present in `old` but not in `new` to `new`""" new_attrs = set(dir(new)) @@ -515,75 +494,26 @@ def create_weights( layer.weight_block_size = None # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - param = layer.weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel` around just in case - # there is logic added to vllm in the future which calls a - # weight loader twice - we do not want to re-initialize in - # that case. - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, - # materialized just-in-time in `patched_weight_loader` + # materialized just-in-time with layerwise loading device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature if layer.weight.device == torch.device("meta"): weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), + data=torch.empty_like(layer.weight, device=torch.get_default_device()), input_dim=1, output_dim=0, weight_loader=layer.weight.weight_loader, @@ -612,9 +542,6 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = qweight.t() replace_parameter(layer, "weight", weight.data) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. @@ -1012,86 +939,13 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # save the ids of original w13 and w2 so that we can - # distinguish which one `param` should map to further - # down in this file - layer._w13_weight_orig_id = id(layer.w13_weight) - layer._w2_weight_orig_id = id(layer.w2_weight) - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w13_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w2_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - if id(param) == layer._w13_weight_orig_id: - param = layer.w13_weight - elif id(param) == layer._w2_weight_orig_id: - param = layer.w2_weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel`, - # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` - # around because if EP is on, weight loaders for non-local - # experts will run but not actually copy any elements, and we - # need to not re-initialize in that case. - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, - # materialized just-in-time in `patched_weight_loader` + # materialized just-in-time with layerwise loading device="meta", dtype=params_dtype, ), @@ -1105,7 +959,7 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, hidden_size, intermediate_size_per_partition, - # materialized just-in-time in `patched_weight_loader` + # materialized just-in-time with layerwise loading device="meta", dtype=params_dtype, ), @@ -1118,52 +972,42 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): # BIASES (for models like GPT-OSS that have biased MoE) if self.moe.has_bias: - # Use the original weight_loader (not patched) for biases - orig_extra_weight_attrs = dict(extra_weight_attrs) - orig_extra_weight_attrs["weight_loader"] = weight_loader w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, dtype=layer.orig_dtype, + # materialized just-in-time with layerwise loading + # Note: this is currently broken for gpt-oss because it + # does not use weight loaders at all in the bf16 weights + # path + device="meta", ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, orig_extra_weight_attrs) + set_weight_attrs(w13_bias, extra_weight_attrs) w2_bias = torch.nn.Parameter( torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype), requires_grad=False, + # materialized just-in-time with layerwise loading + # Note: this is currently broken for gpt-oss because it + # does not use weight loaders at all in the bf16 weights + # path + device="meta", ) layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, orig_extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_bias, extra_weight_attrs) layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature if layer.w13_weight.device == torch.device("meta"): w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), + torch.empty_like(layer.w13_weight, device=torch.get_default_device()), requires_grad=False, ) set_weight_attrs( @@ -1174,7 +1018,7 @@ def process_weights_after_loading(self, layer: Module) -> None: initialize_single_dummy_weight(layer.w13_weight) if layer.w2_weight.device == torch.device("meta"): w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), + torch.empty_like(layer.w2_weight, device=torch.get_default_device()), requires_grad=False, ) set_weight_attrs( @@ -1184,6 +1028,20 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.register_parameter("w2_weight", w2_weight) initialize_single_dummy_weight(layer.w2_weight) + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + num_experts = layer.num_experts + with layer.w13_weight.device: + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) @@ -1210,9 +1068,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale, ) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index e3b965db8aaf..a9bc37faae53 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,10 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -58,8 +62,27 @@ def load_model( log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model, model_config) + + use_layerwise_loading = _get_use_layerwise_loading(model, self) + + if use_layerwise_loading: + # set up layer loading + initialize_layerwise_reload( + model, is_reload=False, target_device=load_device + ) + # load weights, quantization via each layer's + # `process_weights_after_loading` will happen for each layer + # as soon as all of that layer's weights are loaded + self.load_weights(model, model_config) + # finalize layer reloading + finalize_layerwise_reload(model, model_config, is_reload=False) + + else: + # Load weights to model format + self.load_weights(model, model_config) + # For layers with quantization, convert to kernel format + with target_device: + process_weights_after_loading(model, model_config, target_device) # Log peak GPU memory after loading weights. This is needed # to have test coverage on peak memory for online quantization. @@ -71,11 +94,24 @@ def load_model( scope="local", ) - process_weights_after_loading(model, model_config, target_device) - return model.eval() +def _get_use_layerwise_loading( + model: torch.nn.Module, + model_loader: BaseModelLoader, +) -> bool: + from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader + from vllm.model_executor.model_loader.utils import ( + model_has_any_online_quant_with_device_meta, + ) + + has_online_quant = model_has_any_online_quant_with_device_meta(model) + + is_dummy_loader = isinstance(model_loader, DummyModelLoader) + return has_online_quant and not is_dummy_loader + + def log_model_inspection(model: nn.Module) -> None: """Log model structure if VLLM_LOG_MODEL_INSPECTION=1.""" if not envs.VLLM_LOG_MODEL_INSPECTION: diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 21795e63995e..b9f4139f1de1 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -17,10 +17,16 @@ capture_layer_to_meta, get_numel_loaded, materialize_layer, + materialize_layer_tensors_with_device_meta, restore_layer_on_meta, ) from .types import LayerReloadingInfo -from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors +from .utils import ( + CopyCounter, + get_layer_params_buffers, + get_layer_size, + get_layer_tensors, +) logger = init_logger(__name__) @@ -66,20 +72,31 @@ def record_metadata_for_reloading(model: torch.nn.Module): @torch.no_grad() -def initialize_layerwise_reload(model: torch.nn.Module): +def initialize_layerwise_reload( + model: torch.nn.Module, + is_reload: bool = True, + target_device: torch.device | None = None, +): """ - Set up layerwise weight loading with deferred processing. - - Must be called after `record_metadata_for_reloading`. This function: - 1. Saves current kernel tensors for later copying - 2. Restores layer parameters/buffers from metadata (on meta device) - 3. Wraps weight loaders to defer processing until all weights are loaded - - When all weights for a layer are loaded, the wrapped loaders will: - 1. Materialize the layer onto the target device - 2. Load all cached weights - 3. Run quantization processing if applicable - 4. Copy processed values back to original tensor storage + Set up layerwise weight reloading|initial loading with deferred processing. + + For weight reloading (is_reload = True): + + Must be called after `record_metadata_for_reloading`. This function: + 1. Saves current kernel tensors for later copying + 2. Restores layer parameters/buffers from metadata (on meta device) + 3. Wraps weight loaders to defer processing until all weights are loaded + + When all weights for a layer are loaded, the wrapped loaders will: + 1. Materialize the layer onto the target device + 2. Load all cached weights + 3. Run quantization processing if applicable + 4. Copy processed values back to original tensor storage + + For weight initial loading (is_reload = False): + + 1. saves `target_device` to be used during weight loading + 2. wraps all weight loaders to enable deferred processing """ # disable torchao reloading to avoid infinite recursion model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False) @@ -93,10 +110,16 @@ def initialize_layerwise_reload(model: torch.nn.Module): continue # Save current tensors for later copying - info.kernel_tensors = get_layer_params_buffers(layer) - - # Restore layer parameters/buffers onto meta device - restore_layer_on_meta(layer, info) + if is_reload: + # reload path, TODO document more + assert target_device is None + info.kernel_tensors = get_layer_params_buffers(layer) + # Restore layer parameters/buffers onto meta device + restore_layer_on_meta(layer, info) + else: + # initial load path + assert target_device is not None + info.initial_load_target_device = target_device # Track loading progress to determine when to process/copy info.load_numel = 0 @@ -106,11 +129,21 @@ def initialize_layerwise_reload(model: torch.nn.Module): # Note that nested wrapping will occur for shared tensors for name, tensor in get_layer_tensors(layer).items(): if _get_weight_loader(tensor).__name__ != "online_process_loader": - tensor.weight_loader = make_online_process_loader(layer, name) + tensor.weight_loader = make_online_process_loader( + layer, name, is_reload + ) -def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable: - """Create a wrapped weight loader that defers processing.""" +def make_online_process_loader( + layer: torch.nn.Module, param_name: str, is_reload: bool +) -> Callable: + """ + Create a wrapped weight loader that defers processing. + + * If `is_reload` is True, weights are cached on CPU and only actually + loaded when every chunk has been processed. + * If `is_reload` is False, weights are loaded directly (no CPU cache) + """ info = get_layerwise_info(layer) param = getattr(layer, param_name) original_loader = _get_original_loader(param) @@ -133,16 +166,42 @@ def online_process_loader(*args, **kwargs): # there's no way to reach `load_numel_total` without loading all # necessary weights. `weight_shape` is very small, so this is safe. # see Limitations(4) - logger.debug("%s: Excessive loading", layer.__class__.__name__) + logger.debug( + "%s: Excessive loading of param '%s'", + layer.__class__.__name__, + param_name, + ) return # Bind and normalize arguments bound_args = loader_signature.bind(*args, **kwargs) bound_args.apply_defaults() - # Cache loaded weights, track loading progress - info.loaded_weights.append((param_name, bound_args)) - num_loaded, ret = get_numel_loaded(original_loader, bound_args) + if is_reload: + # Cache loaded weights, track loading progress. Weights will be + # actually loaded later. + info.loaded_weights.append((param_name, bound_args)) + num_loaded, ret = get_numel_loaded(original_loader, bound_args) + + else: + if info.load_numel == 0: + # If this is the first weight chunk being loaded, materialize any + # non-materialized tensors in this layer + assert info.initial_load_target_device is not None + with info.initial_load_target_device: + materialize_layer_tensors_with_device_meta(layer) + + # Update param reference to point to the current (materialized) + # tensor instead of the old meta tensor that was captured when + # the loader was wrapped + current_param = getattr(layer, param_name) + bound_args.arguments["param"] = current_param + + # Load weights directly + with CopyCounter() as counter: + ret = original_loader(*bound_args.args, **bound_args.kwargs) + num_loaded = counter.copied_numel + info.load_numel += num_loaded logger.debug( @@ -156,14 +215,16 @@ def online_process_loader(*args, **kwargs): if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator] layer, (Attention, MLAAttention) ): - _layerwise_process(layer, info) + _layerwise_process(layer, info, is_reload) return ret return online_process_loader -def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig): +def finalize_layerwise_reload( + model: torch.nn.Module, model_config: ModelConfig, is_reload: bool = True +): """ Remove the outermost layer of weight loading wrappers. @@ -185,12 +246,14 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) ) else: - _place_kernel_tensors(layer, info) + if is_reload: + _place_kernel_tensors(layer, info) layer.process_weights_after_loading(model_config.dtype) # No weights were loaded, place kernel tensors back elif info.can_process() and info.load_numel <= 0: - _place_kernel_tensors(layer, info) + if is_reload: + _place_kernel_tensors(layer, info) # Process non-attention layers which did not load all elements. This can happen # if the created weight has extra padding elements which are not loaded @@ -198,38 +261,40 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) # see Limitations(4) elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] logger.debug("%s: Delayed processing", layer.__class__.__name__) - _layerwise_process(layer, info) + _layerwise_process(layer, info, is_reload) info.reset() -def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): +def _layerwise_process( + layer: torch.nn.Module, info: LayerReloadingInfo, is_reload: bool +): """ Finalize layer loading after all weights have been cached. - This function: - 1. Materializes the layer onto the target device - 2. Loads all cached weights - 3. Runs quantization processing if applicable - 4. Copies processed values back to original tensor storage - """ - # Materialize layer tensors onto device - materialize_layer(layer) + If `is_reload` is True, this function: + 1. Materializes the layer onto the target device + 2. Loads all cached weights + 3. Runs `process_weights_after_loading` if applicable + 4. Copies processed values back to original tensor storage - # Reset FP8 online quantization flag so process_weights_after_loading - # will run again during reload - if hasattr(layer, "_already_called_process_weights_after_loading"): - delattr(layer, "_already_called_process_weights_after_loading") + If `is_reload` is False, this function: + 1. Runs `process_weights_after_loading` if applicable + """ + if is_reload: + # Materialize layer tensors onto device + materialize_layer(layer) # Unwrap layerwise loading wrappers for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) - # Load all cached weights into materialized layer (using original loaders) - for name, args in info.loaded_weights: - param = getattr(layer, name) - args.arguments["param"] = param - param.weight_loader(*args.args, **args.kwargs) + if is_reload: + # Load all cached weights into materialized layer (using original loaders) + for name, args in info.loaded_weights: + param = getattr(layer, name) + args.arguments["param"] = param + param.weight_loader(*args.args, **args.kwargs) # Process weights (quantization, repacking, etc.) # Attention/MLA are processed in `finalize_layerwise_reload` @@ -237,15 +302,17 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): if isinstance(quant_method, QuantizeMethodBase): quant_method.process_weights_after_loading(layer) - # Copy processed values into original tensor storage (preserves cudagraph refs) - # this code is a no-op if not reloading (because kernel tensors is empty) - parameters, buffers = info.kernel_tensors - for name, param in parameters.items(): - param.data.copy_(getattr(layer, name)) - for name, buffer in buffers.items(): - buffer.data.copy_(getattr(layer, name)) + if is_reload: + # Reloading path + # Copy processed values into original tensor storage (preserves cudagraph refs) + # this code is a no-op if not reloading (because kernel tensors is empty) + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) - _place_kernel_tensors(layer, info) + _place_kernel_tensors(layer, info) info.reset() logger.debug("%s: Processed", layer.__class__.__name__) @@ -265,6 +332,7 @@ def _get_weight_loader(tensor: torch.Tensor): def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): + """Assign each parameter/buffer in `info.kernel_tensors` back onto `layer`""" for name in get_layer_tensors(layer): delattr(layer, name) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index af20236d1c9d..ca4b5f0bff5f 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -104,6 +104,17 @@ def materialize_layer(layer: torch.nn.Module) -> None: setattr(layer, name, materialize_meta_tensor(tensor)) +def materialize_layer_tensors_with_device_meta(layer: torch.nn.Module) -> None: + """Materialize all meta tensors in a layer to actual tensors.""" + if layer.__class__.__name__ in SKIP_MODULES: + return + + for name, tensor in get_layer_tensors(layer).items(): + if name in SKIP_TENSORS or tensor.device != torch.device("meta"): + continue + setattr(layer, name, materialize_meta_tensor(tensor)) + + class MetaCopyCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`. diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index a7edbe79a75e..a369f3a01958 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -26,6 +26,9 @@ class LayerReloadingInfo: # stores arguments and tensors ready for loading loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list) + # device to materialize meta tensors to for initial loading + initial_load_target_device: torch.device | None = None + def reset(self): self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] diff --git a/vllm/model_executor/model_loader/reload/utils.py b/vllm/model_executor/model_loader/reload/utils.py index 463ff6422213..16e87e0c857d 100644 --- a/vllm/model_executor/model_loader/reload/utils.py +++ b/vllm/model_executor/model_loader/reload/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from torch.utils._python_dispatch import TorchDispatchMode from .types import LayerTensors @@ -39,3 +40,28 @@ def get_layer_size(layer: torch.nn.Module) -> int: for name, tensor in get_layer_tensors(layer).items() if name not in SKIP_TENSORS ) + + +class CopyCounter(TorchDispatchMode): + """ + Tracks total number of elements modified with `copy_`. + + Useful for keeping track of weight loading where underlying weights can be + arbitrarily transformed (such as with `narrow`) before calling copy. + + Note: Assumes that copy kwargs are not used. + """ + + def __init__(self): + super().__init__() + self.copied_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.ops.aten.copy_.default: + assert args[0].numel() == args[1].numel() + self.copied_numel += args[0].numel() + + return func(*args, **kwargs) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index dc525c4541af..13f5f2c4aa93 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -284,3 +284,15 @@ def configure_quant_config( quant_config.apply_vllm_mapper(hf_to_vllm_mapper) if packed_mapping is not None: quant_config.packed_modules_mapping = packed_mapping + + +def model_has_any_online_quant_with_device_meta(model: nn.Module) -> bool: + """ + Returns True if any module uses online quantization with meta device weights. + """ + + def uses_meta_device(module: torch.nn.Module) -> bool: + quant_method = getattr(module, "quant_method", None) + return getattr(quant_method, "uses_meta_device", False) + + return any(uses_meta_device(m) for m in model.modules()) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index dd4bf636e0af..dd731345acce 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1260,11 +1260,11 @@ def initialize_dummy_weights( # Check if any module uses online quantization with meta device weights. # If so, we'll skip initializing params on meta device since they'll be # handled in `process_weights_after_loading`. - def uses_meta_device(module: torch.nn.Module) -> bool: - quant_method = getattr(module, "quant_method", None) - return getattr(quant_method, "uses_meta_device", False) + from vllm.model_executor.model_loader.utils import ( + model_has_any_online_quant_with_device_meta, + ) - has_online_quant = any(uses_meta_device(m) for m in model.modules()) + has_online_quant = model_has_any_online_quant_with_device_meta(model) for param in model.state_dict().values(): if has_online_quant and param.device == torch.device("meta"):