Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
63048e7
WIP: basic design
kylesayrs Jan 12, 2026
b7a1800
fix typos
kylesayrs Jan 12, 2026
469afaf
small simplify code
kylesayrs Jan 12, 2026
6beac38
WIP: delay materialization to better support out-of-order loading
kylesayrs Jan 13, 2026
197ac43
WIP: reloading works (tested 4x reload), code is messy and needs clea…
kylesayrs Jan 14, 2026
9dd5bd9
small cleanup
kylesayrs Jan 14, 2026
276d169
big ai cleanup
kylesayrs Jan 14, 2026
be7bc55
fix bug with second time reloading where shared tensors were only bei…
kylesayrs Jan 15, 2026
b79ab7c
break out LAYER_RELOADING_INFO
kylesayrs Jan 15, 2026
48635ef
cleanup
kylesayrs Jan 15, 2026
1fbc2b0
breakout files, apply style
kylesayrs Jan 16, 2026
34cdeb5
implement torchao decorator
kylesayrs Jan 16, 2026
ea502e6
docstrings
kylesayrs Jan 16, 2026
ec04411
rename, fix for quantized models
kylesayrs Jan 19, 2026
730b96c
fix non-attention modules with non-loaded weights
kylesayrs Jan 19, 2026
8510e52
add torchao decorator test
kylesayrs Jan 19, 2026
82f0e26
add unit tests
kylesayrs Jan 19, 2026
d53f4b4
add reloading tests
kylesayrs Jan 20, 2026
9fa4060
workaround
kylesayrs Jan 20, 2026
1a174b1
add debug logs, cleanup
kylesayrs Jan 20, 2026
cd4a061
docstring
kylesayrs Jan 20, 2026
a379f14
style
kylesayrs Jan 20, 2026
6847a10
style
kylesayrs Jan 20, 2026
eefcc25
support moe modules
kylesayrs Jan 21, 2026
c4737cb
skip hadamard layers, fix reload test
kylesayrs Jan 21, 2026
c40e065
return the return value of the weight loader
kylesayrs Jan 21, 2026
8410264
support ep
kylesayrs Jan 21, 2026
e145e2e
fix torchao test
kylesayrs Jan 22, 2026
bf52c89
fix perplexity bug
kylesayrs Jan 22, 2026
7350365
add eplb test
kylesayrs Jan 22, 2026
b9fc627
fix typo
kylesayrs Jan 22, 2026
af0a44f
fix perplexity bug, break out make_online_process_loader, fix docstring
kylesayrs Jan 22, 2026
e4cb84f
fix hadamard skip
kylesayrs Jan 23, 2026
fde119d
add reference sanitation to enable proper model cleanup
kylesayrs Jan 23, 2026
8b9bb36
add model cleanup test
kylesayrs Jan 23, 2026
893d3fc
cleanup
kylesayrs Jan 23, 2026
5a808f8
change api to pass model_config
kylesayrs Jan 23, 2026
80feeaf
address comments, add tests
kylesayrs Jan 24, 2026
363b92e
fix logic
kylesayrs Jan 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections.abc import Generator
from contextlib import nullcontext
from enum import Enum
from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING
from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING, Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -1023,7 +1023,9 @@ def generate_greedy_logprobs(
**kwargs,
)

def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
def generate_prompt_perplexity(
self, prompts: list[str], mask: Optional[list[str]] = None
) -> list[float]:
"""
Return the perplexity score associated with generating the prompts

Expand All @@ -1034,13 +1036,20 @@ def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
)

mask_prefix_lens = (
[len(self.llm.get_tokenizer()(prefix)["input_ids"]) for prefix in mask]
if mask is not None
else [0 for _ in range(len(prompts))]
)

perplexities = []
for output in outputs:
for output, mask_prefix_len in zip(outputs, mask_prefix_lens):
output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[dict[int, Logprob] | None], output[3])
assert token_datas[0] is None

token_log_probs = []
for token_data in token_datas[1:]:
for token_data in token_datas[mask_prefix_len + 1 :]:
assert token_data is not None
assert len(token_data) == 1
token_log_prob = list(token_data.values())[0].logprob
Expand Down Expand Up @@ -1121,6 +1130,9 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
def get_llm(self) -> LLM:
return self.llm

def collective_rpc(self, *args, **kwargs):
return self.llm.collective_rpc(*args, **kwargs)

def __enter__(self):
return self

Expand Down Expand Up @@ -1531,3 +1543,9 @@ def use_fresh_inductor_cache():
"""
with fresh_cache():
yield


@pytest.fixture(scope="function")
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
150 changes: 150 additions & 0 deletions tests/model_executor/model_loader/test_reload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import inspect
from weakref import WeakKeyDictionary, ref

import pytest
import torch

from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.model_loader.reload.meta import (
capture_layer_to_meta,
get_numel_loaded,
materialize_layer,
materialize_meta_tensor,
restore_layer_on_meta,
to_meta_tensor,
)
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless


def test_move_metatensors():
tensor = torch.empty((1, 2, 3))
meta_tensor = to_meta_tensor(tensor)
materialized_tensor = materialize_meta_tensor(meta_tensor)

assert meta_tensor.device.type == "meta"
assert tensor.device == materialized_tensor.device

assert tensor.dtype == meta_tensor.dtype == materialized_tensor.dtype
assert tensor.shape == meta_tensor.shape == materialized_tensor.shape
assert tensor.__class__ == meta_tensor.__class__ == materialized_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__ == materialized_tensor.__dict__


def test_reload_lifecycle():
layer = torch.nn.Linear(2, 3)
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))

restore_layer_on_meta(layer, info)
for name, tensor in get_layer_tensors(layer).items():
meta_tensor = getattr(layer, name)
assert tensor.dtype == meta_tensor.dtype
assert tensor.shape == meta_tensor.shape
assert tensor.__class__ == meta_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__

materialize_layer(layer)
for name, tensor in get_layer_tensors(layer).items():
materialized_tensor = getattr(layer, name)
assert tensor.dtype == materialized_tensor.dtype
assert tensor.shape == materialized_tensor.shape
assert tensor.__class__ == materialized_tensor.__class__
assert tensor.__dict__ == materialized_tensor.__dict__


def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))

mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
mock_info_dict[layer] = info
layer_ref = ref(layer)

del layer
gc.collect()

assert layer_ref() is None
assert len(mock_info_dict) == 0


def test_get_numel_loaded():
param = torch.empty(10, device="meta")
loaded_weight = torch.empty(10)

def complex_weight_loader(param, loaded_weight):
param[:3] = loaded_weight[:3]
param[5:8] = loaded_weight[5:8]
return "value"

args = inspect.signature(complex_weight_loader).bind(param, loaded_weight)
num_loaded, ret = get_numel_loaded(complex_weight_loader, args)
assert num_loaded == 6
assert ret == "value"


@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"base_model,mul_model,add_model",
[
(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
),
(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
),
(
"inference-optimization/Qwen3-0.6B-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
),
(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
),
(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
),
(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
),
],
)
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
if cuda_device_count_stateless() < tp_size:
pytest.skip(reason="Not enough CUDA devices")

if "FP8" in base_model and not current_platform.supports_fp8():
pytest.skip(reason="Requires FP8 support")

with vllm_runner(
model_name=base_model,
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
assert mul_perp < add_perp

llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
assert add_perp < mul_perp
19 changes: 15 additions & 4 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.metadata
import importlib.util

import pytest
import torch

from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform

DTYPE = ["bfloat16"]
Expand Down Expand Up @@ -105,8 +105,8 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_online_quant_config_dict_json(vllm_runner):
"""Testing on the fly quantization, load_weights integration point,
def test_online_quant_config_dict_json(vllm_runner, enable_pickle):
"""Testing online quantization, load_weights integration point,
with config dict serialized to json string
"""
torch._dynamo.reset()
Expand Down Expand Up @@ -135,7 +135,18 @@ def test_online_quant_config_dict_json(vllm_runner):
) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)

assert output
load_config = llm.llm.llm_engine.vllm_config.load_config
model_config = llm.llm.llm_engine.vllm_config.model_config

def load_weights(model):
model_loader = get_model_loader(load_config)
weights_iterator = model_loader.get_all_weights(model_config, model)
model.load_weights(weights_iterator)

llm.apply_model(load_weights)

reload_output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
assert output[0][0] == reload_output[0][0]


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):


def test_reload_weights_before_load_model(model_runner):
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
model_runner.reload_weights()


Expand Down
Loading