Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
Run `pytest tests/quantization/test_fp8.py --forked`.
"""

import logging

import pytest
import regex as re
import torch

from tests.quantization.utils import is_quant_method_supported
Expand Down Expand Up @@ -195,6 +198,99 @@ def check_model(model):
print(outputs[0][1])


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"

# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

with (
caplog_mp_spawn(logging.DEBUG) as log_holder,
vllm_runner(
model_name,
quantization="fp8",
enforce_eager=True,
) as llm,
):
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])

log_text = log_holder.text

# Parse memory usage from captured logs
model_memory_gib = None
peak_memory_gib = None
for line in log_text.splitlines():
if model_memory_gib is None:
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
if match:
model_memory_gib = float(match.group(1))
if peak_memory_gib is None:
match = re.search(
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
)
if match:
peak_memory_gib = float(match.group(1))

assert model_memory_gib is not None, "Could not find model loading memory log"
assert peak_memory_gib is not None, "Could not find peak memory log"
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")

# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7

# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4

assert model_memory_gib < expected_model_memory_gib, (
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
)
assert peak_memory_gib < expected_peak_memory_gib, (
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
)


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
vllm_runner,
monkeypatch,
caplog,
) -> None:
with vllm_runner(
"ibm-granite/granite-3.0-1b-a400m-base",
quantization="fp8",
enforce_eager=True,
load_format="dummy",
) as llm:
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
Expand Down
123 changes: 119 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
Expand Down Expand Up @@ -293,6 +294,16 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return out


def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs = set(dir(new))
attrs_to_set = {}
for attr in dir(old):
if attr not in new_attrs:
attrs_to_set[attr] = getattr(old, attr)
set_weight_attrs(new, attrs_to_set)


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -578,6 +589,22 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device
Comment thread
mgoin marked this conversation as resolved.

# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
Expand All @@ -590,30 +617,50 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.

return res

weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)

# TODO(future): support block_quant in online quant path
assert not self.block_quant

Expand Down Expand Up @@ -1069,6 +1116,39 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)

# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time

w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)

w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found del here will cause error in DP + EP case:

AttributeError: 'FusedMoE' object has no attribute '_load_device'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yma11 thanks! I will take a look directly after I fix the logging issue in CI. Just in case I don't repro right away, if you can share your repro command that would be great.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nm, I can repro, looking

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the issue by making sure we do not incorrectly reinitialize weights when EP is on, please let me know if there are any further issues


# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
Expand All @@ -1081,12 +1161,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
Expand All @@ -1098,6 +1182,8 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
requires_grad=False,
Expand All @@ -1110,12 +1196,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
num_experts,
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
Expand All @@ -1138,6 +1228,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)

# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
initialize_model,
process_weights_after_loading,
)
from vllm.platforms import current_platform
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype

logger = init_logger(__name__)
Expand Down Expand Up @@ -56,6 +58,17 @@ def load_model(
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)

# Log peak GPU memory after loading weights. This is needed
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the actual peak as logged here is not visible when just measuring peak memory after the llm object is initialized - seems like we need extra logging. Open on where to put this if there is a better place.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't add this log by default... could you make it a debug_once log and just set the logging level within the test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, fixed!

# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
peak_memory = torch.cuda.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
format_gib(peak_memory),
scope="local",
)

process_weights_after_loading(model, model_config, target_device)

return model.eval()
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/dummy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
initialize_dummy_weights(model, model_config)
Loading