diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 9226f2a55c3..dddbd119392 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -5,7 +5,10 @@ Run `pytest tests/quantization/test_fp8.py --forked`. """ +import logging + import pytest +import regex as re import torch from tests.quantization.utils import is_quant_method_supported @@ -195,6 +198,99 @@ def check_model(model): print(outputs[0][1]) +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +def test_online_quant_peak_mem( + vllm_runner, + caplog_mp_spawn, + monkeypatch, +) -> None: + # Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because: + # 1. it covers both Linear and MoE paths + # 2. it is already used by other tests in CI, so adding it here + # does not increase disk space for CI runners + # I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base` + # which I think is the smallest MoE model in vLLM (2.5 GiB bf16, + # 1.3 GiB fp8), but could not as adding one more model makes CI + # run out of disk space. + model_name = "allenai/OLMoE-1B-7B-0125-Instruct" + + # Force spawn to ensure caplog_mp_spawn works consistently + # (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores) + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + with ( + caplog_mp_spawn(logging.DEBUG) as log_holder, + vllm_runner( + model_name, + quantization="fp8", + enforce_eager=True, + ) as llm, + ): + outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4) + print(outputs[0][1]) + + log_text = log_holder.text + + # Parse memory usage from captured logs + model_memory_gib = None + peak_memory_gib = None + for line in log_text.splitlines(): + if model_memory_gib is None: + match = re.search(r"Model loading took ([\d.]+) GiB memory", line) + if match: + model_memory_gib = float(match.group(1)) + if peak_memory_gib is None: + match = re.search( + r"Peak GPU memory after loading weights: ([\d.]+) GiB", line + ) + if match: + peak_memory_gib = float(match.group(1)) + + assert model_memory_gib is not None, "Could not find model loading memory log" + assert peak_memory_gib is not None, "Could not find peak memory log" + print(f"GPU memory used after loading weights: {model_memory_gib} GiB") + print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB") + + # model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant + # uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB) + expected_model_memory_gib = 6.7 + + # for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06 + # GiB, which is 1.36x above model_memory_gib. A slightly higher number is + # expected as when we load and quantize weights in a streaming fashion we + # need to have individual weights in bf16 + fp8 alive at the same time. + expected_peak_memory_gib = expected_model_memory_gib * 1.4 + + assert model_memory_gib < expected_model_memory_gib, ( + f"{model_memory_gib=} higher than {expected_model_memory_gib}" + ) + assert peak_memory_gib < expected_peak_memory_gib, ( + f"{peak_memory_gib=} higher than {expected_peak_memory_gib}" + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +def test_online_quant_load_format_dummy( + vllm_runner, + monkeypatch, + caplog, +) -> None: + with vllm_runner( + "ibm-granite/granite-3.0-1b-a400m-base", + quantization="fp8", + enforce_eager=True, + load_format="dummy", + ) as llm: + outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4) + print(outputs[0][1]) + + @pytest.mark.skipif( not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.", diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6436a9ae0ab..a8467b5f07f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,6 +86,7 @@ cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) +from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -293,6 +294,16 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return out +def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: + """Copies any attrs present in `old` but not in `new` to `new`""" + new_attrs = set(dir(new)) + attrs_to_set = {} + for attr in dir(old): + if attr not in new_attrs: + attrs_to_set[attr] = getattr(old, attr) + set_weight_attrs(new, attrs_to_set) + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -578,6 +589,22 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 + # when the first `loaded_weight` is about to be + # loaded to `param`, materialize `param` just-in-time + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=patched_weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + del layer._load_device + + # refresh the reference to `param` to reflect just-in-time + # materialization + param = layer.weight + # load the current weight chunk copy_numel_counter = CopyNumelCounter() with copy_numel_counter: @@ -590,30 +617,50 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): if layer._loaded_numel == target_loaded_numel: self.process_weights_after_loading(layer) - # Delete the bookkeeping - del layer._loaded_numel # Prevent the usual `process_weights_after_loading` call from doing # anything layer._already_called_process_weights_after_loading = True + # Note that we keep `layer._loaded_numel` around just in case + # there is logic added to vllm in the future which calls a + # weight loader twice - we do not want to re-initialize in + # that case. + return res weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, + # materialized just-in-time in `patched_weight_loader` + device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=patched_weight_loader, ) + # stash the correct device for `patched_weight_loader` + layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return + # deferred initialization of randomly initialized weights for the + # `--load_format dummy` feature + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + # TODO(future): support block_quant in online quant path assert not self.block_quant @@ -1069,6 +1116,39 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 + # save the ids of original w13 and w2 so that we can + # distinguish which one `param` should map to further + # down in this file + layer._w13_weight_orig_id = id(layer.w13_weight) + layer._w2_weight_orig_id = id(layer.w2_weight) + + # when the first `loaded_weight` is about to be + # loaded to `param`, materialize `param` just-in-time + + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs(w13_weight, extra_weight_attrs) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs(w2_weight, extra_weight_attrs) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + del layer._load_device + + # refresh the reference to `param` to reflect just-in-time + # materialization + if id(param) == layer._w13_weight_orig_id: + param = layer.w13_weight + elif id(param) == layer._w2_weight_orig_id: + param = layer.w2_weight + # load the current weight chunk copy_numel_counter = CopyNumelCounter() with copy_numel_counter: @@ -1081,12 +1161,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): if layer._loaded_numel == target_loaded_numel: self.process_weights_after_loading(layer) - # Delete the bookkeeping - del layer._loaded_numel # Prevent the usual `process_weights_after_loading` call # from doing anything layer._already_called_process_weights_after_loading = True + # Note that we keep `layer._loaded_numel`, + # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` + # around because if EP is on, weight loaders for non-local + # experts will run but not actually copy any elements, and we + # need to not re-initialize in that case. + return res new_extra_weight_attrs["weight_loader"] = patched_weight_loader @@ -1098,6 +1182,8 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, 2 * intermediate_size_per_partition, hidden_size, + # materialized just-in-time in `patched_weight_loader` + device="meta", dtype=params_dtype, ), requires_grad=False, @@ -1110,12 +1196,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, hidden_size, intermediate_size_per_partition, + # materialized just-in-time in `patched_weight_loader` + device="meta", dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + # stash the correct device for `patched_weight_loader` + layer._load_device = torch.get_default_device() # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. @@ -1138,6 +1228,31 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return + # deferred initialization of randomly initialized weights for the + # `--load_format dummy` feature + if layer.w13_weight.device == torch.device("meta"): + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w13_weight, {"weight_loader": layer.w13_weight.weight_loader} + ) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + initialize_single_dummy_weight(layer.w13_weight) + if layer.w2_weight.device == torch.device("meta"): + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w2_weight, {"weight_loader": layer.w2_weight.weight_loader} + ) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + initialize_single_dummy_weight(layer.w2_weight) + # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4b89b34813e..2c55ee68e25 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -13,6 +13,8 @@ initialize_model, process_weights_after_loading, ) +from vllm.platforms import current_platform +from vllm.utils.mem_utils import format_gib from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) @@ -56,6 +58,17 @@ def load_model( logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) + + # Log peak GPU memory after loading weights. This is needed + # to have test coverage on peak memory for online quantization. + if current_platform.is_cuda(): + peak_memory = torch.cuda.max_memory_allocated() + logger.debug_once( + "Peak GPU memory after loading weights: %s GiB", + format_gib(peak_memory), + scope="local", + ) + process_weights_after_loading(model, model_config, target_device) return model.eval() diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index b2a934ce594..156071f1dae 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -25,4 +25,4 @@ def download_model(self, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. - initialize_dummy_weights(model) + initialize_dummy_weights(model, model_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 7ea3bb2ebd1..15fd4423943 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1059,6 +1059,7 @@ def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def initialize_dummy_weights( model: torch.nn.Module, + model_config: ModelConfig, low: float = -1e-3, high: float = 1e-3, seed: int = 1234, @@ -1075,41 +1076,61 @@ def initialize_dummy_weights( is fixed, the random values generated by this function only depends on the parameter's number of elements and its data type. """ + # TODO(future PR): make the check below more generic as more online + # quant backends are added + is_fp8_py_quant = model_config.quantization == "fp8" + for param in model.state_dict().values(): - if torch.is_floating_point(param): - if current_platform.is_tpu(): - generator = torch.Generator(device="cpu") - generator.manual_seed(seed) - # Note: The param.uniform_ function cannot be used in this - # context because it demands more TPU HBM than directly copying - # from a CPU tensor. - # Note: We avoid using torch.rank_like as it doesn't currently - # support the generator argument. - param.copy_( - (high - low) - * torch.rand( - param.shape, - generator=generator, - dtype=param.dtype, - layout=param.layout, - requires_grad=param.requires_grad, - device="cpu", - ) - + low - ) - torch._sync(param) - continue + if is_fp8_py_quant and param.device == torch.device("meta"): + # for fp8.py's online quantization, dummy weight init will happen + # in `process_weights_after_loading`. + # TODO(future PR): consider refactoring dummy model init to compose + # better with online quantization + continue - generator = torch.Generator(device=param.data.device) + initialize_single_dummy_weight(param, low, high, seed) + + +def initialize_single_dummy_weight( + param: torch.Tensor, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + if torch.is_floating_point(param): + if current_platform.is_tpu(): + generator = torch.Generator(device="cpu") generator.manual_seed(seed) - if torch.finfo(param.data.dtype).bits < 16: - # uniform_ doesn't support < 16-bit datatypes (FP8) - dtype = param.data.dtype - tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) - param.data.copy_(tmp_param) - else: - param.uniform_(low, high, generator=generator) + # Note: The param.uniform_ function cannot be used in this + # context because it demands more TPU HBM than directly copying + # from a CPU tensor. + # Note: We avoid using torch.rank_like as it doesn't currently + # support the generator argument. + param.copy_( + (high - low) + * torch.rand( + param.shape, + generator=generator, + dtype=param.dtype, + layout=param.layout, + requires_grad=param.requires_grad, + device="cpu", + ) + + low + ) + torch._sync(param) + return + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: