Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""

# Whether this method creates weights on meta device for online quantization.
# When True, weights are created on meta device and quantized layer-wise
# in process_weights_after_loading, reducing peak memory during loading.
uses_meta_device: bool = False

@abstractmethod
def create_weights(
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""

uses_meta_device: bool = True

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -1039,6 +1041,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
quant_config: The quantization config.
"""

uses_meta_device: bool = True

def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized
Expand Down
20 changes: 12 additions & 8 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,16 +1092,20 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
# TODO(future PR): make the check below more generic as more online
# quant backends are added
is_fp8_py_quant = model_config.quantization == "fp8"

# Check if any module uses online quantization with meta device weights.
# If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def uses_meta_device(module: torch.nn.Module) -> bool:
quant_method = getattr(module, "quant_method", None)
return getattr(quant_method, "uses_meta_device", False)

has_online_quant = any(uses_meta_device(m) for m in model.modules())

for param in model.state_dict().values():
if is_fp8_py_quant and param.device == torch.device("meta"):
# for fp8.py's online quantization, dummy weight init will happen
# in `process_weights_after_loading`.
# TODO(future PR): consider refactoring dummy model init to compose
# better with online quantization
if has_online_quant and param.device == torch.device("meta"):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue

initialize_single_dummy_weight(param, low, high, seed)
Expand Down