Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ steps:
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py


Expand Down Expand Up @@ -1242,7 +1242,7 @@ steps:
- vllm/platforms/rocm.py
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
Expand Down Expand Up @@ -2501,7 +2501,7 @@ steps:
- tests/models/
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/test_areas/model_executor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ steps:
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
2 changes: 1 addition & 1 deletion .buildkite/test_areas/models_distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
- tests/models/
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
Expand Down
65 changes: 45 additions & 20 deletions tests/model_executor/model_loader/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def test_move_metatensors():

def test_reload_lifecycle():
layer = torch.nn.Linear(2, 3)
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)

restore_layer_on_meta(layer, info)
for name, tensor in get_layer_tensors(layer).items():
Expand All @@ -48,7 +51,7 @@ def test_reload_lifecycle():
assert tensor.__class__ == meta_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__

materialize_layer(layer)
materialize_layer(layer, info)
for name, tensor in get_layer_tensors(layer).items():
materialized_tensor = getattr(layer, name)
assert tensor.dtype == materialized_tensor.dtype
Expand All @@ -60,7 +63,10 @@ def test_reload_lifecycle():
def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)

mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
Expand Down Expand Up @@ -90,39 +96,46 @@ def complex_weight_loader(param, loaded_weight):
assert ret == "value"


@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
marks=[pytest.mark.slow_test],
),
],
)
Expand All @@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
Expand All @@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert add_perp < mul_perp


@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model,quantization",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"fp8",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"fp8",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"mxfp8",
marks=[pytest.mark.slow_test],
),
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"mxfp8",
marks=[
pytest.mark.slow_test,
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
],
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def test_online_quantize_reload(
Expand All @@ -195,6 +218,8 @@ def test_online_quantize_reload(
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,14 +1006,17 @@ def create_weights(
initialize_online_processing(layer)

def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
w13_scale = torch.ones(
layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None

Expand Down
18 changes: 15 additions & 3 deletions vllm/model_executor/model_loader/reload/layerwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
LAYERWISE_INFO[layer] = LayerReloadingInfo(
restore_metadata=({}, {}),
restore_device=torch.get_default_device(),
)

return LAYERWISE_INFO[layer]

Expand All @@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
info.restore_device = torch.get_default_device()


@torch.no_grad()
Expand Down Expand Up @@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)

# Wrap weight loaders to buffer loading
initialize_online_processing(layer)


def initialize_online_processing(layer: torch.nn.Module):
"""
Wrap a layer's weight loaders with online processing loaders.
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading

:param layer: layer whose parameter weight loaders will be wrapped
"""
info = get_layerwise_info(layer)

# Track loading progress to determine when to process/copy
Expand Down Expand Up @@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
if info.kernel_tensors is None:
materialize_layer(layer)
materialize_layer(layer, info)

# reloading: place kernel tensors back as a fallback
else:
Expand Down Expand Up @@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
materialize_layer(layer, info)

# Reset online quantization flag so process_weights_after_loading
# will run again during reload
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/model_loader/reload/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
layer.register_buffer(name, buffer)


def materialize_layer(layer: torch.nn.Module) -> None:
def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES:
return

for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
with info.restore_device:
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))


class CopyCounter(TorchDispatchMode):
Expand Down
19 changes: 12 additions & 7 deletions vllm/model_executor/model_loader/reload/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,26 @@

@dataclass
class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# model format metadata, recorded by `record_metadata_for_reloading`
restore_metadata: LayerTensors

# kernel format (device), used to copy into when reloading only
kernel_tensors: LayerTensors | None = None
# device to materialize layers with, recorded by `record_metadata_for_reloading`
restore_device: torch.device

# track how many restored elements are ready for loading
# track how many elements are ready for loading, used by `online_process_loader`
load_numel: int = 0
load_numel_total: int | None = None

# stores arguments and tensors ready for loading
# used by `online_process_loader` to buffer args and tensors until ready to load
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)

# kernel formatted tensors, copied into by `_layerwise_process` when reloading
kernel_tensors: LayerTensors | None = None

def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
self.__init__( # type: ignore[misc]
restore_metadata=self.restore_metadata, restore_device=self.restore_device
)

def can_load(self) -> bool:
return self.load_numel_total is not None
38 changes: 17 additions & 21 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4943,28 +4943,24 @@ def reload_weights(

# begin loading weights
logger.info_once("Reloading weights inplace...", scope="local")
load_device = (
self.vllm_config.load_config.device or self.vllm_config.device_config.device
)
with torch.device(load_device):
if is_checkpoint_format:
# load weights from checkpoint/ original model format
initialize_layerwise_reload(model)
loaded_weights = model.load_weights(weights_iterator)
finalize_layerwise_reload(model, self.model_config)
if is_checkpoint_format:
# load weights from checkpoint/ original model format
initialize_layerwise_reload(model)
loaded_weights = model.load_weights(weights_iterator)
finalize_layerwise_reload(model, self.model_config)

else:
# load weights from kernel format
logger.warning_once(
"Reloading with `is_checkpoint_format=True` requires that "
"weights be in kernel format and already sharded",
scope="local",
)
loaded_weights = set()
for name, loaded_weight in weights_iterator:
param = model.get_parameter(name) # TODO: buffers?
param.copy_(loaded_weight)
loaded_weights.add(name)
else:
# load weights from kernel format
logger.warning_once(
"Reloading with `is_checkpoint_format=True` requires that "
"weights be in kernel format and already sharded",
scope="local",
)
loaded_weights = set()
for name, loaded_weight in weights_iterator:
param = model.get_parameter(name) # TODO: buffers?
param.copy_(loaded_weight)
loaded_weights.add(name)

# logging and validation
counter_after_reloading = time.perf_counter()
Expand Down
Loading