Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 2 additions & 155 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,62 +544,16 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.

return res

weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
weight_loader=weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
Expand Down Expand Up @@ -1057,86 +1011,12 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time

w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)

w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
Expand All @@ -1150,20 +1030,18 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
num_experts,
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# TODO: technically should be meta
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
Expand All @@ -1179,34 +1057,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)

# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
Expand All @@ -1233,9 +1083,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_input_scale,
)

# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True


class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Expand Down
29 changes: 25 additions & 4 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.reload.layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
Expand Down Expand Up @@ -55,11 +59,17 @@ def load_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)

log_model_inspection(model)
log_model_inspection(model)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
logger.debug("Loading weights on %s ...", load_device)
if not _is_online_quant(vllm_config, model_config):
# load weights eagerly, which may lead to excess memory usage
self.load_weights(model, model_config)
else:
# load weights layerwise, which minimizes peak memory usage
initialize_layerwise_reload(model, is_reloading=False)
self.load_weights(model, model_config)
finalize_layerwise_reload(model, model_config)

# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
Expand All @@ -84,3 +94,14 @@ def log_model_inspection(model: nn.Module) -> None:
from vllm.model_inspection import format_model_inspection

logger.info("vLLM model structure:\n%s", format_model_inspection(model))


def _is_online_quant(vllm_config: VllmConfig, model_config: ModelConfig) -> bool:
quant_config = vllm_config.quant_config
return (
# TODO(future): add other online quant paths here
model_config.quantization == "fp8"
and quant_config is not None
and hasattr(quant_config, "is_checkpoint_fp8_serialized")
and not quant_config.is_checkpoint_fp8_serialized
)
54 changes: 33 additions & 21 deletions vllm/model_executor/model_loader/reload/layerwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):


@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
def initialize_layerwise_reload(model: torch.nn.Module, is_reloading: bool = True):
"""
Set up layerwise weight loading with deferred processing.

Expand All @@ -92,8 +92,8 @@ def initialize_layerwise_reload(model: torch.nn.Module):
if info.can_process():
continue

# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Save current tensors for later copying (only for reloading)
info.kernel_tensors = get_layer_params_buffers(layer) if is_reloading else None

# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
Expand Down Expand Up @@ -178,22 +178,38 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
info = get_layerwise_info(layer)

# Attention/MLA layers are processed after all other layers
# TODO(@kylesayrs): process attention in a separate for loop
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)

# Loading: initialize model tensors with empty values
elif info.kernel_tensors is None:
materialize_layer(layer)

# Reloading: place kernel tensors back (assumed to be empty)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)

# No weights were loaded, place kernel tensors back
layer.process_weights_after_loading(model_config.dtype)

# Non-attention: No weights were loaded
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
if info.load_numel_total is not None and info.load_numel_total > 0:
logger.warning("%s: Did not load weights", layer.__class__.__name__)

# Loading: initialize model tensors with empty values
if info.kernel_tensors is None:
materialize_layer(layer)

# Reloading: place kernel tensors back (assumed to be empty)
else:
_place_kernel_tensors(layer, info)

# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Non-attention: Some weights were loaded
# This can happen if the weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
Expand All @@ -216,11 +232,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device
materialize_layer(layer)

# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")

# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
Expand All @@ -237,15 +248,15 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)

# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
# Reloading: copy processed values into original tensors (preserves cudagraph refs)
if info.kernel_tensors is not None:
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))

_place_kernel_tensors(layer, info)
_place_kernel_tensors(layer, info)

info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
Expand All @@ -268,6 +279,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)

assert info.kernel_tensors is not None
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)
Expand Down
Loading