Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 22 additions & 143 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,70 +555,24 @@ def create_weights(
layer.weight_block_size = None

# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.

return res

weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
# TODO(before review): say where exactly this will be materialized
Comment on lines 562 to +563
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment and TODO are outdated since patched_weight_loader has been removed in this refactoring. The materialization now happens in make_online_initial_load_process_loader within vllm/model_executor/model_loader/reload/layerwise.py. Please remove these lines.

device="meta",
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
weight_loader=weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
Expand Down Expand Up @@ -1074,86 +1028,13 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time

w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)

w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
# materialized just-in-time in
# TODO(before review) document where
device="meta",
dtype=params_dtype,
),
Expand All @@ -1168,36 +1049,20 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
# TODO(before review) document where
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

layer.w13_input_scale = None
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# if getattr(layer, "_already_called_process_weights_after_loading", False):
# return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
Expand All @@ -1224,7 +1089,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)

# If checkpoint is fp16, quantize in place.
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale later.
num_experts = layer.num_experts
with layer.w13_weight.device:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not make more sense to just call process_after_weight_loading within the with target_device context?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be initialized with ones? Why not empty?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR

Copy link
Copy Markdown
Contributor

@kylesayrs kylesayrs Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think scale dtype should theoretically be the model dtype, not necessarily float32, but it's been a while since I looked at this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR

)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

# Quantize the loaded high precision checkpoint to fp8
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
Expand Down
64 changes: 51 additions & 13 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.reload.initial_load import (
finalize_layerwise_initial_load,
initialize_layerwise_initial_load,
)
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
Expand Down Expand Up @@ -56,20 +60,43 @@ def load_model(
log_model_inspection(model)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)

# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
peak_memory = torch.cuda.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
format_gib(peak_memory),
scope="local",
)

process_weights_after_loading(model, model_config, target_device)
is_online_quant = _is_online_quant(vllm_config, model_config)
if not is_online_quant:
# Regular path, `process_weights_after_loading` is called
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably don't need this many comments

# after all weights are loaded.

# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)

else:
# Online quantization can take the layerwise loading path
# where `process_weights_after_loading` is done just-in-time
# after all of a layer's weights are loaded.

# set up weight loaders for layerwise loading with
# streaming post-processing
initialize_layerwise_initial_load(model, target_device)

# load the weights, layerwise loading infra will call
# each layer's `process_weights_after_loading` function
# as soon as every weight of that layer is loaded
self.load_weights(model, model_config)

# finalize layerwise reloading (call any post-processing
# that did not happen in real time)
finalize_layerwise_initial_load(model, model_config)

# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be unindented by one?

peak_memory = torch.cuda.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
format_gib(peak_memory),
scope="local",
)

return model.eval()

Expand All @@ -82,3 +109,14 @@ def log_model_inspection(model: nn.Module) -> None:
from vllm.model_inspection import format_model_inspection

logger.info("vLLM model structure:\n%s", format_model_inspection(model))


def _is_online_quant(vllm_config: VllmConfig, model_config: ModelConfig) -> bool:
quant_config = vllm_config.quant_config
return (
# TODO(future): add other online quant paths here
model_config.quantization == "fp8"
and quant_config is not None
and hasattr(quant_config, "is_checkpoint_fp8_serialized")
and not quant_config.is_checkpoint_fp8_serialized
)
Loading
Loading