Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)

# Initialize parameters
self.init_parameters(self.model)

# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)

Expand Down Expand Up @@ -298,6 +301,25 @@ def init_buffers(self, module: nn.Module):
for child in module.children():
self.init_buffers(child)

def init_parameters(self, module: nn.Module):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:

```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
"""
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
Comment on lines 319 to 320
Copy link

Copilot AI Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using module.register_parameter(name, new_param) instead of setattr(module, name, new_param) to ensure the parameter is correctly registered in the module's parameter dictionary.

Suggested change
setattr(module, name, new_param)
for child in module.children():
module.register_parameter(name, new_param)

Copilot uses AI. Check for mistakes.
self.init_parameters(child)

def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
Expand Down Expand Up @@ -342,6 +364,7 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())

loaded_params = set[str]()
for name, loaded_weight in weights:
# Use "model" instead of base_model_prefix because
Expand Down