Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 33 additions & 178 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand Down Expand Up @@ -223,26 +222,6 @@ def get_cache_scale(self, name: str) -> str | None:
return None


class CopyNumelCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""

def __init__(self):
super().__init__()
self.copied_numel = 0

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*args, **kwargs)
if func == torch.ops.aten.copy_.default:
self.copied_numel += args[0].numel()
return out


def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs = set(dir(new))
Expand Down Expand Up @@ -515,75 +494,26 @@ def create_weights(
layer.weight_block_size = None

# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.

return res

weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
# materialized just-in-time with layerwise loading
device="meta",
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
weight_loader=weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
data=torch.empty_like(layer.weight, device=torch.get_default_device()),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
Expand Down Expand Up @@ -612,9 +542,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = qweight.t()
replace_parameter(layer, "weight", weight.data)

# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True


class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Expand Down Expand Up @@ -1012,86 +939,13 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time

w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)

w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device

# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
# materialized just-in-time with layerwise loading
device="meta",
dtype=params_dtype,
),
Expand All @@ -1105,7 +959,7 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
num_experts,
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
# materialized just-in-time with layerwise loading
device="meta",
dtype=params_dtype,
),
Expand All @@ -1118,52 +972,42 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):

# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
# Use the original weight_loader (not patched) for biases
orig_extra_weight_attrs = dict(extra_weight_attrs)
orig_extra_weight_attrs["weight_loader"] = weight_loader
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=layer.orig_dtype,
# materialized just-in-time with layerwise loading
# Note: this is currently broken for gpt-oss because it
# does not use weight loaders at all in the bf16 weights
# path
device="meta",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpt-oss bf16 is broken whether biases are initialized on gpu or on meta, going with meta to be consistent with layerwise loading infra

if we want gpt-oss to work with fp8.py we should refactor gpt_oss.py to use weight loaders

),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
set_weight_attrs(w13_bias, extra_weight_attrs)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to verify GPT-OSS 120B still works as this changes the code added by #34906 and there is no CI coverage

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following up on this, GPT-OSS bf16 is not expected to work with fp8.py online quant because:

  1. fp8.py online quant (and future online quant backends in vllm) require weight_loaders, because we use weight_loaders to inject the streaming weight loading functionality
  2. gpt_oss.py model definition for the bf16 weights case does not use weight loaders:
    param.copy_(narrow_weight)

I'm not exactly sure how #34906 worked given 1 and 2 ^. Going to skip this for now as gpt-oss + online quant seems low pri because the official weights are in mxfp4, and we can follow-up if needed.

for posterity, the easiest way to test this is using the 20b model from unsloth which goes through the same path as the 120b:

VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/basic/chat.py --model unsloth/gpt-oss-20b-BF16 --enforce-eager --dtype=bfloat16 --quantization=fp8

w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
requires_grad=False,
# materialized just-in-time with layerwise loading
# Note: this is currently broken for gpt-oss because it
# does not use weight loaders at all in the bf16 weights
# path
device="meta",
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, orig_extra_weight_attrs)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_bias, extra_weight_attrs)

layer.w13_input_scale = None
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
torch.empty_like(layer.w13_weight, device=torch.get_default_device()),
requires_grad=False,
)
set_weight_attrs(
Expand All @@ -1174,7 +1018,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
torch.empty_like(layer.w2_weight, device=torch.get_default_device()),
requires_grad=False,
)
set_weight_attrs(
Expand All @@ -1184,6 +1028,20 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
num_experts = layer.num_experts
with layer.w13_weight.device:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
Expand All @@ -1210,9 +1068,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_input_scale,
)

# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True


class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Expand Down
44 changes: 40 additions & 4 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
Expand Down Expand Up @@ -58,8 +62,27 @@ def load_model(
log_model_inspection(model)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)

use_layerwise_loading = _get_use_layerwise_loading(model, self)

if use_layerwise_loading:
# set up layer loading
initialize_layerwise_reload(
model, is_reload=False, target_device=load_device
)
# load weights, quantization via each layer's
# `process_weights_after_loading` will happen for each layer
# as soon as all of that layer's weights are loaded
self.load_weights(model, model_config)
# finalize layer reloading
finalize_layerwise_reload(model, model_config, is_reload=False)

else:
# Load weights to model format
self.load_weights(model, model_config)
# For layers with quantization, convert to kernel format
with target_device:
process_weights_after_loading(model, model_config, target_device)

# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
Expand All @@ -71,11 +94,24 @@ def load_model(
scope="local",
)

process_weights_after_loading(model, model_config, target_device)

return model.eval()


def _get_use_layerwise_loading(
model: torch.nn.Module,
model_loader: BaseModelLoader,
) -> bool:
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.utils import (
model_has_any_online_quant_with_device_meta,
)

has_online_quant = model_has_any_online_quant_with_device_meta(model)

is_dummy_loader = isinstance(model_loader, DummyModelLoader)
return has_online_quant and not is_dummy_loader


def log_model_inspection(model: nn.Module) -> None:
"""Log model structure if VLLM_LOG_MODEL_INSPECTION=1."""
if not envs.VLLM_LOG_MODEL_INSPECTION:
Expand Down
Loading
Loading