diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py index 6e3e2d63e144..cf3553bd57de 100644 --- a/tests/model_executor/model_loader/test_reload.py +++ b/tests/model_executor/model_loader/test_reload.py @@ -59,6 +59,34 @@ def test_reload_lifecycle(): assert tensor.__dict__ == materialized_tensor.__dict__ +def test_materialize_layer_preserves_non_meta_tensors(): + """Ensure that materialize_layer does not overwrite non meta tensors.""" + layer = torch.nn.Linear(2, 3, bias=True) + + # Create a non meta bias tensor and meta weight, which can happen with FP8 + bias_values = torch.ones(3) + layer.bias.data.copy_(bias_values) + layer.weight = torch.nn.Parameter(layer.weight.data.to("meta")) + + assert layer.weight.is_meta + assert not layer.bias.is_meta + + # materialize the layer weights after the bias is initialized + info = LayerReloadingInfo( + restore_metadata=({}, {}), + restore_device=torch.device("cpu"), + ) + materialize_layer(layer, info) + + # Ensure the weight materialized off meta + assert not layer.weight.is_meta + assert layer.weight.device.type == "cpu" + + # Ensure that the bias is (still) not meta and values are unchanged + assert not layer.bias.is_meta + assert torch.equal(layer.bias.data, bias_values) + + def test_model_cleanup(dist_init, default_vllm_config): layer = QKVParallelLinear(2, 3, 4) assert layer.weight.weight_loader.__self__ is layer diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index 91fce6f57b3e..baa2081d58b2 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -102,7 +102,7 @@ def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo): with info.restore_device: for name, tensor in get_layer_tensors(layer).items(): - if name not in SKIP_TENSORS: + if name not in SKIP_TENSORS and tensor.is_meta: setattr(layer, name, materialize_meta_tensor(tensor))