diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9b7d654335d5..cc6a0d96578b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -555,70 +555,24 @@ def create_weights( layer.weight_block_size = None # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - param = layer.weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel` around just in case - # there is logic added to vllm in the future which calls a - # weight loader twice - we do not want to re-initialize in - # that case. - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, # materialized just-in-time in `patched_weight_loader` + # TODO(before review): say where exactly this will be materialized device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) # stash the correct device for `patched_weight_loader` layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature if layer.weight.device == torch.device("meta"): @@ -1074,86 +1028,13 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # save the ids of original w13 and w2 so that we can - # distinguish which one `param` should map to further - # down in this file - layer._w13_weight_orig_id = id(layer.w13_weight) - layer._w2_weight_orig_id = id(layer.w2_weight) - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w13_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w2_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - if id(param) == layer._w13_weight_orig_id: - param = layer.w13_weight - elif id(param) == layer._w2_weight_orig_id: - param = layer.w2_weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel`, - # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` - # around because if EP is on, weight loaders for non-local - # experts will run but not actually copy any elements, and we - # need to not re-initialize in that case. - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, - # materialized just-in-time in `patched_weight_loader` + # materialized just-in-time in + # TODO(before review) document where device="meta", dtype=params_dtype, ), @@ -1168,6 +1049,7 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): hidden_size, intermediate_size_per_partition, # materialized just-in-time in `patched_weight_loader` + # TODO(before review) document where device="meta", dtype=params_dtype, ), @@ -1175,29 +1057,12 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return + # if getattr(layer, "_already_called_process_weights_after_loading", False): + # return # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature @@ -1224,7 +1089,21 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.register_parameter("w2_weight", w2_weight) initialize_single_dummy_weight(layer.w2_weight) - # If checkpoint is fp16, quantize in place. + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale later. + num_experts = layer.num_experts + with layer.w13_weight.device: + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Quantize the loaded high precision checkpoint to fp8 fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 2c55ee68e25a..288ca4d5e01f 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,10 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.reload.initial_load import ( + finalize_layerwise_initial_load, + initialize_layerwise_initial_load, +) from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -56,20 +60,43 @@ def load_model( log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model, model_config) - - # Log peak GPU memory after loading weights. This is needed - # to have test coverage on peak memory for online quantization. - if current_platform.is_cuda(): - peak_memory = torch.cuda.max_memory_allocated() - logger.debug_once( - "Peak GPU memory after loading weights: %s GiB", - format_gib(peak_memory), - scope="local", - ) - process_weights_after_loading(model, model_config, target_device) + is_online_quant = _is_online_quant(vllm_config, model_config) + if not is_online_quant: + # Regular path, `process_weights_after_loading` is called + # after all weights are loaded. + + # Quantization does not happen in `load_weights` but after it + self.load_weights(model, model_config) + process_weights_after_loading(model, model_config, target_device) + + else: + # Online quantization can take the layerwise loading path + # where `process_weights_after_loading` is done just-in-time + # after all of a layer's weights are loaded. + + # set up weight loaders for layerwise loading with + # streaming post-processing + initialize_layerwise_initial_load(model, target_device) + + # load the weights, layerwise loading infra will call + # each layer's `process_weights_after_loading` function + # as soon as every weight of that layer is loaded + self.load_weights(model, model_config) + + # finalize layerwise reloading (call any post-processing + # that did not happen in real time) + finalize_layerwise_initial_load(model, model_config) + + # Log peak GPU memory after loading weights. This is needed + # to have test coverage on peak memory for online quantization. + if current_platform.is_cuda(): + peak_memory = torch.cuda.max_memory_allocated() + logger.debug_once( + "Peak GPU memory after loading weights: %s GiB", + format_gib(peak_memory), + scope="local", + ) return model.eval() @@ -82,3 +109,14 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) + + +def _is_online_quant(vllm_config: VllmConfig, model_config: ModelConfig) -> bool: + quant_config = vllm_config.quant_config + return ( + # TODO(future): add other online quant paths here + model_config.quantization == "fp8" + and quant_config is not None + and hasattr(quant_config, "is_checkpoint_fp8_serialized") + and not quant_config.is_checkpoint_fp8_serialized + ) diff --git a/vllm/model_executor/model_loader/reload/initial_load.py b/vllm/model_executor/model_loader/reload/initial_load.py new file mode 100644 index 000000000000..52aa9425a2eb --- /dev/null +++ b/vllm/model_executor/model_loader/reload/initial_load.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections.abc import Callable +from functools import wraps +from weakref import WeakKeyDictionary + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention, MLAAttention +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.model_loader.reload.layerwise import ( + _get_original_loader, + _get_weight_loader, +) + +from .meta import ( + materialize_layer_tensors_with_device_meta, +) +from .types import LayerReloadingInfo +from .utils import get_layer_size, get_layer_tensors + +logger = init_logger(__name__) + +# Global dict storing information used for layerwise loading +INITIAL_LOAD_LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( + WeakKeyDictionary() +) + + +def get_initial_load_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo: + """ + Get information related to restoring and layerwise processing. If no previous + information existed, a new entry is constructed + """ + if layer not in INITIAL_LOAD_LAYERWISE_INFO: + INITIAL_LOAD_LAYERWISE_INFO[layer] = LayerReloadingInfo() + + return INITIAL_LOAD_LAYERWISE_INFO[layer] + + +# TODO(before review): move fp8's one to a common place and use that +class CopyCounter(TorchDispatchMode): + """ + Tracks total number of elements modified with `copy_`. + + Useful for keeping track of weight loading where underlying weights can be + arbitrarily transformed (such as with `narrow`) before calling copy. + + Note: Assumes that copy kwargs are not used. + """ + + def __init__(self): + super().__init__() + self.copied_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.ops.aten.copy_.default: + assert args[0].numel() == args[1].numel() + self.copied_numel += args[0].numel() + + return func(*args, **kwargs) + + +@torch.no_grad() +def initialize_layerwise_initial_load(model: torch.nn.Module, target_device): + """ + Initialize layerwise initial loading of model weights. In detail: + + 1. set up global state to track how many elements have been loaded + into each layer + 2. wrap original weight loaders to turn on layerwise post-processing. + Specifically, when all of a weight's chunks are loaded, the + `process_weights_after_loading` function will be called immediately. + For online quantiation this minimizes peak memory usage compared + to loading weights for the entire model first and then post-processing + weights. + """ + for layer in model.modules(): + info = get_initial_load_layerwise_info(layer) + + # Track loading progress to determine when to process/copy + info.load_numel = 0 + info.load_numel_total = get_layer_size(layer) + info.load_device = target_device + + # Wrap each parameter's weight loader + # Note that nested wrapping will occur for shared tensors + for name, tensor in get_layer_tensors(layer).items(): + if ( + _get_weight_loader(tensor).__name__ + != "online_initial_load_process_loader" + ): + tensor.weight_loader = make_online_initial_load_process_loader( + layer, name + ) + + +def make_online_initial_load_process_loader( + layer: torch.nn.Module, param_name: str +) -> Callable: + """Create a wrapped weight loader that defers processing.""" + info = get_initial_load_layerwise_info(layer) + param = getattr(layer, param_name) + original_loader = _get_original_loader(param) + loader_signature = inspect.signature(original_loader) + + @wraps(original_loader, assigned=("__doc__", "__annotations__")) + def online_initial_load_process_loader(*args, **kwargs): + if not info.can_process(): + # Unfortunately, some qconfigs are set up to load the same weight + # multiple times. For example, CT_WNA16 loads `weight_shape` for + # each of the qkv partitions. This results in layers loading extra + # weights (beyond load_numel_total) after it's already processed. + # + # Best solution is to ensure that `load_numel_total` reflects the + # actual number of weights loaded, either by modifying qconfigs to + # create as many weights as loaded (see padding issue as well) + # or maybe capturing how many weights are loaded on first pass + # + # For now, `load_numel_total` is still safe to use as long as + # there's no way to reach `load_numel_total` without loading all + # necessary weights. `weight_shape` is very small, so this is safe. + # see Limitations(4) + logger.debug("%s: Excessive loading", layer.__class__.__name__) + return + + if info.load_numel == 0: + # When the first weight chunk for a layer is seen, + # Materialize any layer tensors on device meta onto device. + # For most layers this is a no-op. For layers which initialize + # weights on device meta during `create_weights`, this is where + # the materialization happens. + with info.load_device: + materialize_layer_tensors_with_device_meta(layer) + + # Bind and normalize arguments + bound_args = loader_signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Update param reference to point to the current (materialized) tensor + # instead of the old meta tensor that was captured when the loader was wrapped + current_param = getattr(layer, param_name) + bound_args.arguments["param"] = current_param + + with CopyCounter() as counter: + ret = original_loader(*bound_args.args, **bound_args.kwargs) + + info.load_numel += counter.copied_numel + + logger.debug( + "%s: %d / %d", + layer.__class__.__name__, + info.load_numel, + info.load_numel_total, + ) + + # Process and copy when all weights are loaded + if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator] + layer, (Attention, MLAAttention) + ): + _layerwise_initial_load_process(layer, info) + + return ret + + return online_initial_load_process_loader + + +def finalize_layerwise_initial_load(model: torch.nn.Module, model_config: ModelConfig): + """ + Call `process_weights_after_loading` for any layers that did not participate + in layerwise loading: + 1. Attention (hardcoded out for now due to data dependencies) + 2. layers where not all elements were loaded during `model.load_weights()` + """ + + for layer in model.modules(): + info = get_initial_load_layerwise_info(layer) + + # Attention/MLA layers are processed after all other layers + if isinstance(layer, (Attention, MLAAttention)): + if info.load_numel > 0: + raise NotImplementedError( + "Layerwise loading of Q/K/V scale weights is not implemented yet" + ) + + else: + layer.process_weights_after_loading(model_config.dtype) + + # No weights were loaded, nothing to do + elif info.can_process() and info.load_numel <= 0: + pass + + # Process non-attention layers which did not load all elements. This can happen + # if the created weight has extra padding elements which are not loaded + elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] + logger.debug("%s: Delayed processing", layer.__class__.__name__) + _layerwise_initial_load_process(layer, info) + + +def _layerwise_initial_load_process(layer: torch.nn.Module, info: LayerReloadingInfo): + # Process weights (quantization, repacking, etc.) + # Attention/MLA are processed in `finalize_layerwise_initial_load` + quant_method = getattr(layer, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + quant_method.process_weights_after_loading(layer) + logger.debug("%s: Processed", layer.__class__.__name__) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index af20236d1c9d..ca4b5f0bff5f 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -104,6 +104,17 @@ def materialize_layer(layer: torch.nn.Module) -> None: setattr(layer, name, materialize_meta_tensor(tensor)) +def materialize_layer_tensors_with_device_meta(layer: torch.nn.Module) -> None: + """Materialize all meta tensors in a layer to actual tensors.""" + if layer.__class__.__name__ in SKIP_MODULES: + return + + for name, tensor in get_layer_tensors(layer).items(): + if name in SKIP_TENSORS or tensor.device != torch.device("meta"): + continue + setattr(layer, name, materialize_meta_tensor(tensor)) + + class MetaCopyCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`.