diff --git a/tests/conftest.py b/tests/conftest.py index 1352cdeeaa81..ed4e0aa9cd74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ from collections.abc import Generator from contextlib import nullcontext from enum import Enum -from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING +from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING, Optional import numpy as np import pytest @@ -1023,7 +1023,9 @@ def generate_greedy_logprobs( **kwargs, ) - def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + def generate_prompt_perplexity( + self, prompts: list[str], mask: Optional[list[str]] = None + ) -> list[float]: """ Return the perplexity score associated with generating the prompts @@ -1034,13 +1036,20 @@ def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0 ) + mask_prefix_lens = ( + [len(self.llm.get_tokenizer()(prefix)["input_ids"]) for prefix in mask] + if mask is not None + else [0 for _ in range(len(prompts))] + ) + perplexities = [] - for output in outputs: + for output, mask_prefix_len in zip(outputs, mask_prefix_lens): output = cast(TokensTextLogprobsPromptLogprobs, output) token_datas = cast(list[dict[int, Logprob] | None], output[3]) assert token_datas[0] is None + token_log_probs = [] - for token_data in token_datas[1:]: + for token_data in token_datas[mask_prefix_len + 1 :]: assert token_data is not None assert len(token_data) == 1 token_log_prob = list(token_data.values())[0].logprob @@ -1121,6 +1130,9 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def get_llm(self) -> LLM: return self.llm + def collective_rpc(self, *args, **kwargs): + return self.llm.collective_rpc(*args, **kwargs) + def __enter__(self): return self @@ -1531,3 +1543,9 @@ def use_fresh_inductor_cache(): """ with fresh_cache(): yield + + +@pytest.fixture(scope="function") +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py new file mode 100644 index 000000000000..6fcb077c1c73 --- /dev/null +++ b/tests/model_executor/model_loader/test_reload.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import inspect +from weakref import WeakKeyDictionary, ref + +import pytest +import torch + +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.model_loader.reload.meta import ( + capture_layer_to_meta, + get_numel_loaded, + materialize_layer, + materialize_meta_tensor, + restore_layer_on_meta, + to_meta_tensor, +) +from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo +from vllm.model_executor.model_loader.reload.utils import get_layer_tensors +from vllm.platforms import current_platform +from vllm.utils.torch_utils import cuda_device_count_stateless + + +def test_move_metatensors(): + tensor = torch.empty((1, 2, 3)) + meta_tensor = to_meta_tensor(tensor) + materialized_tensor = materialize_meta_tensor(meta_tensor) + + assert meta_tensor.device.type == "meta" + assert tensor.device == materialized_tensor.device + + assert tensor.dtype == meta_tensor.dtype == materialized_tensor.dtype + assert tensor.shape == meta_tensor.shape == materialized_tensor.shape + assert tensor.__class__ == meta_tensor.__class__ == materialized_tensor.__class__ + assert tensor.__dict__ == meta_tensor.__dict__ == materialized_tensor.__dict__ + + +def test_reload_lifecycle(): + layer = torch.nn.Linear(2, 3) + info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer)) + + restore_layer_on_meta(layer, info) + for name, tensor in get_layer_tensors(layer).items(): + meta_tensor = getattr(layer, name) + assert tensor.dtype == meta_tensor.dtype + assert tensor.shape == meta_tensor.shape + assert tensor.__class__ == meta_tensor.__class__ + assert tensor.__dict__ == meta_tensor.__dict__ + + materialize_layer(layer) + for name, tensor in get_layer_tensors(layer).items(): + materialized_tensor = getattr(layer, name) + assert tensor.dtype == materialized_tensor.dtype + assert tensor.shape == materialized_tensor.shape + assert tensor.__class__ == materialized_tensor.__class__ + assert tensor.__dict__ == materialized_tensor.__dict__ + + +def test_model_cleanup(dist_init, default_vllm_config): + layer = QKVParallelLinear(2, 3, 4) + assert layer.weight.weight_loader.__self__ is layer + info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer)) + + mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( + WeakKeyDictionary() + ) + mock_info_dict[layer] = info + layer_ref = ref(layer) + + del layer + gc.collect() + + assert layer_ref() is None + assert len(mock_info_dict) == 0 + + +def test_get_numel_loaded(): + param = torch.empty(10, device="meta") + loaded_weight = torch.empty(10) + + def complex_weight_loader(param, loaded_weight): + param[:3] = loaded_weight[:3] + param[5:8] = loaded_weight[5:8] + return "value" + + args = inspect.signature(complex_weight_loader).bind(param, loaded_weight) + num_loaded, ret = get_numel_loaded(complex_weight_loader, args) + assert num_loaded == 6 + assert ret == "value" + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize( + "base_model,mul_model,add_model", + [ + ( + "Qwen/Qwen3-0.6B", + "inference-optimization/Qwen3-0.6B-debug-multiply", + "inference-optimization/Qwen3-0.6B-debug-add", + ), + ( + "inference-optimization/Qwen3-0.6B-FP8_BLOCK", + "inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK", + "inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK", + ), + ( + "inference-optimization/Qwen3-0.6B-W4A16-G128", + "inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128", + "inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128", + ), + ( + "inference-optimization/DeepSeek-V3-debug-empty", + "inference-optimization/DeepSeek-V3-debug-multiply", + "inference-optimization/DeepSeek-V3-debug-add", + ), + ( + "inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC", + "inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC", + "inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC", + ), + ( + "inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16", + "inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16", + "inference-optimization/DeepSeek-V3-debug-add-NVFP4A16", + ), + ], +) +def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): + if cuda_device_count_stateless() < tp_size: + pytest.skip(reason="Not enough CUDA devices") + + if "FP8" in base_model and not current_platform.supports_fp8(): + pytest.skip(reason="Requires FP8 support") + + with vllm_runner( + model_name=base_model, + tensor_parallel_size=tp_size, + enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), + enable_prefix_caching=False, + ) as llm: + llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert mul_perp < add_perp + + llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert add_perp < mul_perp diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index da4f6a028709..c859f890bddf 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib.metadata import importlib.util import pytest import torch +from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform DTYPE = ["bfloat16"] @@ -105,8 +105,8 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") -def test_online_quant_config_dict_json(vllm_runner): - """Testing on the fly quantization, load_weights integration point, +def test_online_quant_config_dict_json(vllm_runner, enable_pickle): + """Testing online quantization, load_weights integration point, with config dict serialized to json string """ torch._dynamo.reset() @@ -135,7 +135,18 @@ def test_online_quant_config_dict_json(vllm_runner): ) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=4) - assert output + load_config = llm.llm.llm_engine.vllm_config.load_config + model_config = llm.llm.llm_engine.vllm_config.model_config + + def load_weights(model): + model_loader = get_model_loader(load_config) + weights_iterator = model_loader.get_all_weights(model_config, model) + model.load_weights(weights_iterator) + + llm.apply_model(load_weights) + + reload_output = llm.generate_greedy(["The capital of France is"], max_tokens=4) + assert output[0][0] == reload_output[0][0] @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index badbd3e9adff..00c2bfa406ab 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -543,7 +543,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_reload_weights_before_load_model(model_runner): - with pytest.raises(AssertionError): + with pytest.raises(ValueError): model_runner.reload_weights() diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py deleted file mode 100644 index f330af85bbe8..000000000000 --- a/vllm/model_executor/model_loader/online_quantization.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import types -from collections.abc import Iterable - -import torch -from torch import nn - -from vllm.config import ModelConfig -from vllm.logger import init_logger -from vllm.model_executor.model_loader.utils import process_weights_after_loading - -logger = init_logger(__name__) - -# Notes for Online Quantization -# In terms of state of checkpoints, quantization config and their -# correspondance to online quantization: -# | Use Case | Checkpoints | model_config.quantization | -# | no quant | high precision | None | -# | offline quant | quantized | fp8, torchao etc. | -# | online quant | high precision | torchao etc. | -# -# The process for loading non-quantized checkpoint -# 1. load non-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) -# -# The process for loading offline quantized checkpoint -# 1. load offline-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) - -# The process for unquantized model reloading -# (repeated run in RL training loop) -# first run -# UI1. load_weights: load bfloat16 weights -# UI2. process_weights_after_loading: any additional post processing -# subsequent run -# UC1: load_weights: load bfloat16 weights -# (shouldn't be any issues since we didn't change any attributes -# of the weights) -# UC2: process_weights_after_loading: any additional post processing - -# The process for weight reloading with online quantization -# (repeated run in RL training loop) -# first run -# I1. load_weights: load bfloat16 weights -# I2. process_weights_after_loading: -# record weight metadata and attributes for R1 and R2 -# quantize weights to fp8 -# subsequent run -# (beginning model weight is in fp8) -# load_weights: -# R1. restore bfloat16 model weight metadata -# R2. restore the model weight attributes -# R3. reload bfloat16 weights -# R4. quantize weights (by calling process_weights_after_loading), -# also set `process_weights_after_loading_already_called` to -# True to stop it from running again -# R5. (workaround for cudagraph), we restore the weight params to original quantized -# weights params, and use original_weight_param.copy_(updated_weight_param) so that -# the weight update work well with cudagraph -# process_weights_after_loading (if called): -# this will be skipped since it's already ran in -# load_weights - - -def maybe_save_metadata_and_attributes_for_weight_reloading( - model: nn.Module, model_config: ModelConfig -): - # following is to support on the fly quantization, currently only supported - # for torchao - if model_config.quantization != "torchao": - return - - from vllm.model_executor.model_loader.weight_utils import get_quant_config - - quant_config = get_quant_config(model_config, None) - - # If checkpoint is already torchao serialized, this means it's - # pre-quantized quantization case, we'll skip saving the metadata - # Otherwise, this is Step I2 of initialization steps of - # online quantization - # This step record the weights metadata and weight attributes so we can - # restore the bfloat16 model weights during the relad step (R1 and R2) - # see Notes in online_quantization.py for more details - if not ( - hasattr(quant_config, "is_checkpoint_torchao_serialized") - and not quant_config.is_checkpoint_torchao_serialized - ): - return - - # This is the I2 step of online quantiztion that saves - # metadata and attributes of weights so they can be used in R1 and - # R2 step, note that we only save these during initialization - - # Includes two things - # 1. save floating point metadata (shape, dtype, device) for init - # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init - - if getattr(model, "weight_metadata_and_attr_saved", False): - return - - # save the dtype, shape and device for model parameter, used for - # restoring the model high precision parameters before - # reloading the weights - assert not hasattr(model, "original_weights_rebuild_keys") - model.original_weights_rebuild_keys = {} - for name, p in model.named_parameters(): - model.original_weights_rebuild_keys[name] = { - "shape": p.shape, - "dtype": p.dtype, - "device": p.device, - } - - # record the weight attributes (loader functions etc.) - # so these can be recovered later when we reload the weights - # structure: {"weight_name": {"weight_attr_key": attr}} - assert not hasattr(model, "recorded_weight_attr") - model.recorded_weight_attr = {} - for name, param in model.named_parameters(): - model.recorded_weight_attr[name] = {} - for key in param.__dict__: - if hasattr(param, key): - attr = getattr(param, key) - if not callable(attr): - model.recorded_weight_attr[name][key] = attr - elif hasattr(attr, "__self__") and param is attr.__self__: - # if attr is a bonded method for an instance, and - # attr.__self__ points to the instance (param) - # we'll record the underlying function object - model.recorded_weight_attr[name][key] = attr.__func__ - else: - model.recorded_weight_attr[name][key] = attr - # mark the metadata and attributes saved so we don't run it again - model._model_config = model_config - model.weight_metadata_and_attr_saved = True - - -def _bond_method_to_cls(func, obj): - if hasattr(func, "__self__") or not callable(func): - # If the function is already bound to an instance, return it as is - return func - else: - return types.MethodType(func, obj) - - -def support_quantized_model_reload_from_hp_weights(original_load_weights): - """Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support - reloading high precision (bfloat16/float16/float32) weight for an already quantized - model, this involves restoring the weights to a high precision weights and - then online quantize the weights - """ - # online quantization, right now only enabled for - # torchao - # R1, R2, R3, R4, R5 in the Notes - - def patched_model_load_weights( - auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None - ) -> set[str]: - model = auto_weight_loader.module - offline_quantization_or_first_run_of_online_quantization = not getattr( - model, "weight_metadata_and_attr_saved", False - ) - - # if we don't have `model.weight_metadata_and_attr_saved` defined and - # set to True, it means that this is either offline quantization case - # or the first run of online quantization - # see Notes in this file for more details - if offline_quantization_or_first_run_of_online_quantization: - # case 1: offline quantized checkpoint - # case 2: Step I1 first run of weight loading with - # online quantization - return original_load_weights(auto_weight_loader, weights, mapper=mapper) - - model_config = model._model_config - - # TODO: Add fp8 support - assert model_config.quantization == "torchao", ( - "online quantization is only enabled for torchao currently" - ) - # TODO: use create_weights to restore the weights to original state - - # Step R1: First restore the quantized weights to original bfloat16 - # weights, with original metadata (shape, dtype, device) - # and attributes, so that bfloat16 weights can be loaded properly - # TODO: maybe set remove_duplicate to True? - original_quantized_weight_dict = dict( - model.named_parameters(remove_duplicate=False) - ) - named_modules = dict(model.named_modules(remove_duplicate=False)) - model_device = None - - for name, d in model.original_weights_rebuild_keys.items(): - _shape = d["shape"] - _dtype = d["dtype"] - _device = d["device"] - if model_device is not None: - assert model_device == _device, ( - "Expecting all weights " - "to be in the same device for now, got both: " - f"{model_device} and {_device}" - ) - else: - model_device = _device - - if name in original_quantized_weight_dict: - module_name, weight_name = name.rsplit(".", 1) - module = named_modules[module_name] - setattr( - module, - weight_name, - torch.nn.Parameter( - torch.empty(_shape, dtype=_dtype, device=_device), - requires_grad=False, - ), - ) - - # Step R2: recover the weight attributes to the state before first loading - # recorded_weight_attr is - # {"weight_name": {"weight_attr_key": attr}} - # e.g. - # { - # { - # "layer.0.weight": { - # "weight_loader": weight_loader_function_object, - # "input_dim": 0, ... - # }, - # "layer.1.weight": ..., - # } - # } - for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): - for attr_name, attr in weight_attr_dict.items(): - module_name, weight_name = full_weight_name.rsplit(".", 1) - module = named_modules[module_name] - weight = getattr(module, weight_name) - if not hasattr(weight, attr_name): - setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) - - # Step R3: reload bfloat16 / high precision weights - updated_params = original_load_weights( - auto_weight_loader, weights, mapper=mapper - ) - - # Step R4: online quantize the weights - # manually process weights after loading - model.process_weights_after_loading_already_called = False - if model_device is not None: - process_weights_after_loading(model, model_config, model_device) - else: - logger.warning_once( - "model_device is None, skip calling process_weights_after_loading" - ) - - # Step R5 (workaround for cudagraph): restore the original quantized weights - # and do a copy_ of the currents weights to the original weights - updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False)) - for name in model.original_weights_rebuild_keys: - if name in original_quantized_weight_dict: - original_quantized_weight = original_quantized_weight_dict[name] - updated_quantized_weight = updated_quantized_weights[name] - - module_name, weight_name = name.rsplit(".", 1) - module = named_modules[module_name] - setattr(module, weight_name, original_quantized_weight) - with torch.no_grad(): - original_quantized_weight.copy_(updated_quantized_weight) - - del original_quantized_weight_dict - del named_modules - del updated_quantized_weight - - model.process_weights_after_loading_already_called = True - return updated_params - - return patched_model_load_weights diff --git a/vllm/model_executor/model_loader/reload/__init__.py b/vllm/model_executor/model_loader/reload/__init__.py new file mode 100644 index 000000000000..ea0b0bc06ad9 --- /dev/null +++ b/vllm/model_executor/model_loader/reload/__init__.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Layerwise weight reloading utilities for vLLM. + +This module provides functionality to reload model weights layer-by-layer, +which is useful for weight updates without full model reconstruction. + +Limitations: +1. Composition with CPU offloading has not been implemented +2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented +3. Tied parameters will only reflect processing from one of the parent layers (for + example, only processing from embed_tokens will have an effect) +4. This design assumes that the number of weights loaded from disk is the same as the + number of weights created at model init time. This is not true for quant methods + which (1) pad weights or (2) load qkv weights into the same parameter. Both of these + cases are non-issues for today's quant methods, but future quantizations may cause + reloading to fail +""" + +__all__ = [ + "record_metadata_for_reloading", + "initialize_layerwise_reload", + "finalize_layerwise_reload", + "set_torchao_reload_attrs", + "support_quantized_model_reload_from_hp_weights", +] + +from .layerwise import ( + finalize_layerwise_reload, + initialize_layerwise_reload, + record_metadata_for_reloading, +) +from .torchao_decorator import ( + set_torchao_reload_attrs, + support_quantized_model_reload_from_hp_weights, +) diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py new file mode 100644 index 000000000000..6629f5c5f40e --- /dev/null +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections.abc import Callable +from functools import wraps +from weakref import WeakKeyDictionary + +import torch + +from vllm.attention.layer import Attention, MLAAttention +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .meta import ( + capture_layer_to_meta, + get_numel_loaded, + materialize_layer, + restore_layer_on_meta, +) +from .types import LayerReloadingInfo +from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors + +logger = init_logger(__name__) + +__all__ = [ + "get_layerwise_info", + "record_metadata_for_reloading", + "initialize_layerwise_reload", + "finalize_layerwise_reload", +] + + +# Global dict storing information used for layerwise restoring, loading, and processing. +# For more information regarding what info is stored when, see `LayerReloadingInfo` +# +# Use a weak ref dictionary so that modules can be freed when the model is freed. +# Values are sanitized from references to the layer key in order to avoid circular refs +LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( + WeakKeyDictionary() +) + + +def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo: + """ + Get information related to restoring and layerwise processing. If no previous + information existed, a new entry is constructed + """ + if layer not in LAYERWISE_INFO: + LAYERWISE_INFO[layer] = LayerReloadingInfo() + + return LAYERWISE_INFO[layer] + + +def record_metadata_for_reloading(model: torch.nn.Module): + """ + Record layer metadata needed for later reloading. + + Stores parameter and buffer metadata as meta tensors for restoration. + Must be called before `initialize_layerwise_reload`. + """ + for layer in model.modules(): + info = get_layerwise_info(layer) + info.restore_metadata = capture_layer_to_meta(layer) + + +@torch.no_grad() +def initialize_layerwise_reload(model: torch.nn.Module): + """ + Set up layerwise weight loading with deferred processing. + + Must be called after `record_metadata_for_reloading`. This function: + 1. Saves current kernel tensors for later copying + 2. Restores layer parameters/buffers from metadata (on meta device) + 3. Wraps weight loaders to defer processing until all weights are loaded + + When all weights for a layer are loaded, the wrapped loaders will: + 1. Materialize the layer onto the target device + 2. Load all cached weights + 3. Run quantization processing if applicable + 4. Copy processed values back to original tensor storage + """ + # disable torchao reloading to avoid infinite recursion + model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False) + model._do_torchao_reload = False + + for layer in model.modules(): + info = get_layerwise_info(layer) + + # Skip if the layer has already been initialized + if info.can_process(): + continue + + # Save current tensors for later copying + info.kernel_tensors = get_layer_params_buffers(layer) + + # Restore layer parameters/buffers onto meta device + restore_layer_on_meta(layer, info) + + # Track loading progress to determine when to process/copy + info.load_numel = 0 + info.load_numel_total = get_layer_size(layer) + + # Wrap each parameter's weight loader + # Note that nested wrapping will occur for shared tensors + for name, tensor in get_layer_tensors(layer).items(): + if _get_weight_loader(tensor).__name__ != "online_process_loader": + tensor.weight_loader = make_online_process_loader(layer, name) + + +def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable: + """Create a wrapped weight loader that defers processing.""" + info = get_layerwise_info(layer) + param = getattr(layer, param_name) + original_loader = _get_original_loader(param) + loader_signature = inspect.signature(original_loader) + + @wraps(original_loader, assigned=("__doc__", "__annotations__")) + def online_process_loader(*args, **kwargs): + if not info.can_process(): + # Unfortunately, some qconfigs are set up to load the same weight + # multiple times. For example, CT_WNA16 loads `weight_shape` for + # each of the qkv partitions. This results in layers loading extra + # weights (beyond load_numel_total) after it's already processed. + # + # Best solution is to ensure that `load_numel_total` reflects the + # actual number of weights loaded, either by modifying qconfigs to + # create as many weights as loaded (see padding issue as well) + # or maybe capturing how many weights are loaded on first pass + # + # For now, `load_numel_total` is still safe to use as long as + # there's no way to reach `load_numel_total` without loading all + # necessary weights. `weight_shape` is very small, so this is safe. + # see Limitations(4) + logger.debug("%s: Excessive loading", layer.__class__.__name__) + return + + # Bind and normalize arguments + bound_args = loader_signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Cache loaded weights, track loading progress + info.loaded_weights.append((param_name, bound_args)) + num_loaded, ret = get_numel_loaded(original_loader, bound_args) + info.load_numel += num_loaded + + logger.debug( + "%s: %d / %d", + layer.__class__.__name__, + info.load_numel, + info.load_numel_total, + ) + + # Process and copy when all weights are loaded + if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator] + layer, (Attention, MLAAttention) + ): + _layerwise_process(layer, info) + + return ret + + return online_process_loader + + +def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig): + """ + Remove the outermost layer of weight loading wrappers. + + This function should be applied after `initialize_layerwise_reload` is applied + unwrap the layerwise weight loaders. + + Also processes Attention/MLA layers, which must be processed after all other layers + """ + model._do_torchao_reload = model._original_do_torchao_reload + + for layer in model.modules(): + info = get_layerwise_info(layer) + + # Attention/MLA layers are processed after all other layers + if isinstance(layer, (Attention, MLAAttention)): + if info.load_numel > 0: + raise NotImplementedError( + "Layerwise reloading of Q/K/V scale weights is not implemented yet" + ) + + else: + _place_kernel_tensors(layer, info) + layer.process_weights_after_loading(model_config.dtype) + + # No weights were loaded, place kernel tensors back + elif info.can_process() and info.load_numel <= 0: + _place_kernel_tensors(layer, info) + + # Process non-attention layers which did not load all elements. This can happen + # if the created weight has extra padding elements which are not loaded + # Having too many of these delayed layers can lead to execess memory usage + # see Limitations(4) + elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] + logger.debug("%s: Delayed processing", layer.__class__.__name__) + _layerwise_process(layer, info) + + info.reset() + + +def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): + """ + Finalize layer loading after all weights have been cached. + + This function: + 1. Materializes the layer onto the target device + 2. Loads all cached weights + 3. Runs quantization processing if applicable + 4. Copies processed values back to original tensor storage + """ + # Materialize layer tensors onto device + materialize_layer(layer) + + # Unwrap layerwise loading wrappers + for param in get_layer_tensors(layer).values(): + param.weight_loader = _get_original_loader(param) + + # Load all cached weights into materialized layer (using original loaders) + for name, args in info.loaded_weights: + param = getattr(layer, name) + args.arguments["param"] = param + param.weight_loader(*args.args, **args.kwargs) + + # Process weights (quantization, repacking, etc.) + # Attention/MLA are processed in `finalize_layerwise_reload` + quant_method = getattr(layer, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + quant_method.process_weights_after_loading(layer) + + # Copy processed values into original tensor storage (preserves cudagraph refs) + # this code is a no-op if not reloading (because kernel tensors is empty) + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) + + _place_kernel_tensors(layer, info) + + info.reset() + logger.debug("%s: Processed", layer.__class__.__name__) + + +def _get_original_loader(tensor: torch.Tensor) -> Callable: + """Return the weight loader with any layerwise wrappers removed""" + loader = _get_weight_loader(tensor) + while loader.__name__ == "online_process_loader": + loader = loader.__wrapped__ # type: ignore[union-attr] + + return loader + + +def _get_weight_loader(tensor: torch.Tensor): + return getattr(tensor, "weight_loader", default_weight_loader) + + +def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): + for name in get_layer_tensors(layer): + delattr(layer, name) + + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + layer.register_parameter(name, param) + for name, buffer in buffers.items(): + layer.register_buffer(name, buffer) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py new file mode 100644 index 000000000000..af20236d1c9d --- /dev/null +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +from collections.abc import Callable + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + +from .sanitize import restore_layer_refs, sanitize_layer_refs +from .types import LayerReloadingInfo, LayerTensors +from .utils import get_layer_params_buffers, get_layer_tensors + +__all__ = [ + "to_meta_tensor", + "materialize_meta_tensor", + "capture_layer_to_meta", + "restore_layer_on_meta", + "materialize_layer", + "get_numel_loaded", +] + +SKIP_MODULES: set[str] = {"HadamardTransform"} + +SKIP_TENSORS: set[str] = { + "_expert_map", + "expert_mask", + "expert_global_to_physical", + "expert_physical_to_global", + "expert_local_to_global", +} + + +def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Convert a tensor to a meta tensor while preserving class and attributes.""" + meta_tensor = tensor.data.to("meta") + meta_tensor.__class__ = tensor.__class__ + meta_tensor.__dict__ = tensor.__dict__.copy() + return meta_tensor + + +def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: + """ + Materialize a meta tensor into an actual tensor on the current device. + Should be called within the torch device context for the given rank. + """ + tensor = torch.empty_strided( + size=tuple(meta_tensor.size()), + stride=tuple(meta_tensor.stride()), + dtype=meta_tensor.dtype, + requires_grad=False, + ) + tensor.__class__ = meta_tensor.__class__ + tensor.__dict__ = meta_tensor.__dict__.copy() + return tensor + + +def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors: + if layer.__class__.__name__ in SKIP_MODULES: + return ({}, {}) + + params, buffers = get_layer_params_buffers(layer) + return ( + { + name: sanitize_layer_refs(to_meta_tensor(param), layer) + for name, param in params.items() + if name not in SKIP_TENSORS + }, + { + name: sanitize_layer_refs(to_meta_tensor(buffer), layer) + for name, buffer in buffers.items() + if name not in SKIP_TENSORS + }, + ) + + +def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo): + """Restore a layer to model format with tensors on the meta device""" + if layer.__class__.__name__ in SKIP_MODULES: + return + + for name in get_layer_tensors(layer): + if name not in SKIP_TENSORS: + delattr(layer, name) + + restore_params, restore_buffers = info.restore_metadata + for name, param in restore_params.items(): + if name not in SKIP_TENSORS: + param = restore_layer_refs(param, layer) + layer.register_parameter(name, param) + + for name, buffer in restore_buffers.items(): + if name not in SKIP_TENSORS: + buffer = restore_layer_refs(buffer, layer) + layer.register_buffer(name, buffer) + + +def materialize_layer(layer: torch.nn.Module) -> None: + """Materialize all meta tensors in a layer to actual tensors.""" + if layer.__class__.__name__ in SKIP_MODULES: + return + + for name, tensor in get_layer_tensors(layer).items(): + if name not in SKIP_TENSORS: + setattr(layer, name, materialize_meta_tensor(tensor)) + + +class MetaCopyCounter(TorchDispatchMode): + """ + Tracks total number of elements modified with `copy_`. + + Useful for keeping track of weight loading where underlying weights can be + arbitrarily transformed (such as with `narrow`) before calling copy. + + Note: Assumes that copy kwargs are not used. + """ + + def __init__(self): + super().__init__() + self.copied_numel = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.ops.aten.copy_.default and args[0].device.type == "meta": + assert args[0].numel() == args[1].numel() + self.copied_numel += args[0].numel() + + return func(*args, **kwargs) + + +def get_numel_loaded( + weight_loader: Callable, args: inspect.BoundArguments +) -> tuple[int, object]: + """ + Determine how many elements would be loaded by a weight loader call. + + :param weight loader: used to load weights + :param args: bound arguments to weight loader + :return: number of elements loaded by the weight loader, the return value of the + weight loader + """ + assert args.arguments["param"].device.type == "meta" + with MetaCopyCounter() as counter: + return_value = weight_loader(*args.args, **args.kwargs) + return counter.copied_numel, return_value diff --git a/vllm/model_executor/model_loader/reload/sanitize.py b/vllm/model_executor/model_loader/reload/sanitize.py new file mode 100644 index 000000000000..2a6dc7182d02 --- /dev/null +++ b/vllm/model_executor/model_loader/reload/sanitize.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import MethodType + +import torch + +__all__ = ["sanitize_layer_refs", "restore_layer_refs"] + + +layer_ref_sentinel = object() + + +def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: + """ + Removes references to layer held by tensor attributes. Specifically, removes the + `__self__` attribute of weight loader methods attached to the tensor. + + Used by `capture_layer_to_meta` to avoid circular references to layers in + `LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation, + tensors will reference layers, and the WeakKeyDictionary will never evict entries, + even when the model is deleted. + + :param tensor: tensor to be sanitized + :param layer: layer whose references should be removed + :return: sanitized tensor + """ + for key, value in tensor.__dict__.items(): + if isinstance(value, MethodType) and value.__self__ is layer: + tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel) + + return tensor + + +def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: + """ + Restores references to layer held by tensor attributes. + + Used by `restore_layer_on_meta` to add back layer references, allowing for proper + weight loading. + + :param tensor: tensor to be sanitized + :param layer: layer whose references should be removed + :return: sanitized tensor + + """ + for key, value in tensor.__dict__.items(): + if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel: + tensor.__dict__[key] = value.__func__.__get__(layer) + + return tensor diff --git a/vllm/model_executor/model_loader/reload/torchao_decorator.py b/vllm/model_executor/model_loader/reload/torchao_decorator.py new file mode 100644 index 000000000000..7fbc1c32944a --- /dev/null +++ b/vllm/model_executor/model_loader/reload/torchao_decorator.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from functools import wraps +from types import FunctionType +from typing import TYPE_CHECKING + +import torch + +from vllm.config import ModelConfig + +from .layerwise import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import AutoWeightsLoader + +__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"] + + +def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig): + model._do_torchao_reload = True + model._model_config = model_config + + +def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType): + """ + Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support + reloading high precision (bfloat16/float16/float32) weight for an already quantized + model, this involves restoring the weights to a high precision weights and + then online quantize the weights. + + Only applies to torchao quantized models. Assumes that all model weights are + loaded within a single weights iterator (cannot perform batched updates) + """ + + @wraps(original_load_weights) + def patched_model_load_weights( + self: "AutoWeightsLoader", + weights: Iterable[tuple[str, torch.Tensor]], + *args, + **kwargs, + ): + model = self.module + + if not getattr(model, "_do_torchao_reload", False): + return original_load_weights(self, weights, *args, **kwargs) + + initialize_layerwise_reload(model) + loaded_weights = original_load_weights(self, weights, *args, **kwargs) + finalize_layerwise_reload(model, model._model_config) + + return loaded_weights + + return patched_model_load_weights diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py new file mode 100644 index 000000000000..a7edbe79a75e --- /dev/null +++ b/vllm/model_executor/model_loader/reload/types.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from inspect import BoundArguments + +import torch + +__all__ = ["LayerTensors", "LayerReloadingInfo"] + +# encodes both parameters and buffers separately +LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] + + +@dataclass +class LayerReloadingInfo: + # model format (meta), populated by `record_metadata_for_reloading` + restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) + + # kernel format (device) + kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {})) + + # track how many restored elements are ready for loading + load_numel: int = 0 + load_numel_total: int | None = None + + # stores arguments and tensors ready for loading + loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list) + + def reset(self): + self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] + + def can_process(self) -> bool: + return self.load_numel_total is not None diff --git a/vllm/model_executor/model_loader/reload/utils.py b/vllm/model_executor/model_loader/reload/utils.py new file mode 100644 index 000000000000..1e5d42ba7515 --- /dev/null +++ b/vllm/model_executor/model_loader/reload/utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from .types import LayerTensors + +__all__ = [ + "get_layer_tensors", + "get_layer_params_buffers", + "get_layer_size", +] + + +def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]: + """Get all parameters and buffers from a module as a dict.""" + params, buffers = get_layer_params_buffers(layer) + return params | buffers + + +def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors: + """Get all parameters and buffers of a module as a tuple of dicts.""" + return ( + {name: param for name, param in layer._parameters.items() if param is not None}, + {name: buffer for name, buffer in layer._buffers.items() if buffer is not None}, + ) + + +def get_layer_size(layer: torch.nn.Module) -> int: + """Calculate total number of elements across all tensors in a layer.""" + return sum(tensor.numel() for tensor in get_layer_tensors(layer).values()) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 1d67cb835e93..755c92298fb4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -18,6 +18,10 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.model_loader.reload import ( + record_metadata_for_reloading, + set_torchao_reload_attrs, +) from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils.platform_utils import is_pin_memory_available @@ -45,7 +49,9 @@ def initialize_model( if "vllm_config" in all_params and "prefix" in all_params: # new-style model class with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): - return model_class(vllm_config=vllm_config, prefix=prefix) + model = model_class(vllm_config=vllm_config, prefix=prefix) + record_metadata_for_reloading(model) + return model msg = ( "vLLM model class should accept `vllm_config` and `prefix` as " @@ -75,27 +81,15 @@ def initialize_model( if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): - return model_class(**kwargs) + model = model_class(**kwargs) + record_metadata_for_reloading(model) + + return model def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: - if getattr(model, "process_weights_after_loading_already_called", False): - # In case `process_weights_after_loading` is called multiple times - # we'll skip it at later times - logger.debug_once( - "process_weights_after_loading already called for model %s", model - ) - return - - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - maybe_save_metadata_and_attributes_for_weight_reloading, - ) - - maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) - for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): @@ -117,6 +111,11 @@ def process_weights_after_loading( # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) + # Needed for torchao model reloading via model.reload_weights + # @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights` + if model_config.quantization == "torchao": + set_torchao_reload_attrs(model, model_config) + @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index d3e1434b75f8..594f975aacde 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) -from vllm.model_executor.model_loader.online_quantization import ( +from vllm.model_executor.model_loader.reload import ( support_quantized_model_reload_from_hp_weights, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index d3a91feab64d..2bfa7a5a7d87 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -536,8 +536,7 @@ def process_weights_after_loading(self): @property def data(self): raise ValueError( - "Accessing `data` of a " - "`PartitionedModelWeightParameter` is not allowed. " + "Accessing `data` of a `SharedWeightParameter` is not allowed. " "Instead, use `get_partition` to get the weight of " "the particular partition you want to access" ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 64f6263cc038..e0b297d10b83 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,7 +7,7 @@ import threading import time from collections import defaultdict -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -59,6 +59,10 @@ XDRotaryEmbedding, ) from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsMRoPE, @@ -2515,8 +2519,10 @@ def _gather_mm_embeddings( return mm_embeds, is_mm_embed def get_model(self) -> nn.Module: - # get raw model out of the cudagraph wrapper. + if not hasattr(self, "model"): + raise ValueError("Cannot get model before model has been initialized") if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + # get raw model out of the cudagraph wrapper. return self.model.unwrap() return self.model @@ -4209,13 +4215,89 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." + def reload_weights( + self, + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + weights_path: str | None = None, + is_checkpoint_format: bool = True, + ) -> None: + """ + Reload weights from a weights iterator or from disk + + :param weights_iterator: weights to load into model + :param weights_path: path to load weights from if weights_iterator is not + provided. Use path of original model if neither is provided. + :param is_checkpoint_format: set to False if weights have already been processed + into kernel format (repacking, renaming, ect.) + """ + # TODO(@kylesayrs): generalize to all runners and loaders + # argument validation + if weights_iterator is None and not is_checkpoint_format: + logger.warning( + "Reloading from disk means that weights will be in checkpoint format. " + "Please use `is_checkpoint_format=True` " + "to avoid weight reloading errors" + ) + + model = self.get_model() + weights_to_load = {name for name, _ in model.named_parameters()} + counter_before_reloading = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + if not hasattr(model_loader, "get_all_weights"): + raise NotImplementedError( + f"Model reloading with `{self.load_config.load_format}` format" + ) + + if weights_path is not None: + self.model_config.model = weights_path + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + # begin loading weights + logger.info_once("Reloading weights inplace...", scope="local") + load_device = ( + self.vllm_config.load_config.device or self.vllm_config.device_config.device ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), model_config=self.model_config) + with torch.device(load_device): + if is_checkpoint_format: + # load weights from checkpoint/ original model format + initialize_layerwise_reload(model) + loaded_weights = model.load_weights(weights_iterator) + finalize_layerwise_reload(model, self.model_config) + + else: + # load weights from kernel format + logger.warning_once( + "Reloading with `is_checkpoint_format=True` requires that " + "weights be in kernel format and already sharded", + scope="local", + ) + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) # TODO: buffers? + param.copy_(loaded_weight) + loaded_weights.add(name) + + # logging and validation + counter_after_reloading = time.perf_counter() + diff_seconds = counter_after_reloading - counter_before_reloading + logger.info_once( + "Reloading and processing weights took %.2f seconds", + diff_seconds, + scope="local", + ) + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + logger.warning( + "Following weights were not loaded from checkpoint: %s", + weights_not_loaded, + ) def save_tensorized_model( self, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4c02c0598545..c57b6e762fe5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -277,8 +277,8 @@ def load_model(self) -> None: def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) - def reload_weights(self) -> None: - self.model_runner.reload_weights() + def reload_weights(self, *args, **kwargs) -> None: + self.model_runner.reload_weights(*args, **kwargs) @torch.inference_mode() def determine_available_memory(self) -> int: