Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/model_executor/model_loader/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ def test_reload_lifecycle():
assert tensor.__dict__ == materialized_tensor.__dict__


def test_materialize_layer_preserves_non_meta_tensors():
"""Ensure that materialize_layer does not overwrite non meta tensors."""
layer = torch.nn.Linear(2, 3, bias=True)

# Create a non meta bias tensor and meta weight, which can happen with FP8
bias_values = torch.ones(3)
layer.bias.data.copy_(bias_values)
layer.weight = torch.nn.Parameter(layer.weight.data.to("meta"))

assert layer.weight.is_meta
assert not layer.bias.is_meta

# materialize the layer weights after the bias is initialized
info = LayerReloadingInfo(
restore_metadata=({}, {}),
restore_device=torch.device("cpu"),
)
materialize_layer(layer, info)

# Ensure the weight materialized off meta
assert not layer.weight.is_meta
assert layer.weight.device.type == "cpu"

# Ensure that the bias is (still) not meta and values are unchanged
assert not layer.bias.is_meta
assert torch.equal(layer.bias.data, bias_values)


def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/reload/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo):

with info.restore_device:
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
if name not in SKIP_TENSORS and tensor.is_meta:
setattr(layer, name, materialize_meta_tensor(tensor))


Expand Down
Loading