Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions vllm/model_executor/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# limitations under the License.
"""Transformers modeling backend base class."""

from collections.abc import Iterable
from collections.abc import Callable, Iterable
from itertools import chain
from operator import attrgetter
from typing import TYPE_CHECKING

import regex as re
Expand Down Expand Up @@ -296,6 +297,15 @@ def _create_hf_to_vllm_mapper(self):
# Apply mapping to quantization config if needed
self._maybe_apply_model_mapping()

def _get_tie_word_embeddings(self):
"""
Check if the model has tied word embeddings.
"""
# Transformers v4 and v5 will store this in different places
tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False)
tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False)
return tie_word_embeddings_v4 or tie_word_embeddings_v5

def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
Expand All @@ -311,11 +321,22 @@ def pipeline_parallel(self):
f"{type(self.model)} does not support pipeline parallel. {tip}"
)

def attrsetter(attr: str) -> Callable[[object, object], None]:
"""Set a possibly nested attribute, like the inverse of attrgetter."""
parent, _, name = attr.rpartition(".")

def setter(obj: object, value: object):
attr_parent = attrgetter(parent)(obj) if parent else obj
setattr(attr_parent, name, value)

return setter

module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
# attrgetter in case the module is nested (e.g. "text_model.layers")
if isinstance(attrgetter(name)(self.model), nn.ModuleList):
module_lists.append(name)
module_list_idx = i

Expand All @@ -330,11 +351,11 @@ def pipeline_parallel(self):
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
getattr(self.text_config, "tie_word_embeddings", False)
and self.pp_group.is_last_rank
self._get_tie_word_embeddings() and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())
# attrsetter in case the module is nested (e.g. "text_model.embed_tokens")
attrsetter(name)(self.model, PPMissingLayer())

# Module list
start_layer, end_layer = get_pp_indices(
Expand All @@ -343,7 +364,8 @@ def pipeline_parallel(self):
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
# attrgetter in case the module is nested (e.g. "text_model.layers")
layers = attrgetter(layers_name)(self.model)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
Expand All @@ -353,7 +375,8 @@ def pipeline_parallel(self):
for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
# attrsetter in case the module is nested (e.g. "text_model.norm")
attrsetter(name)(self.model, PPMissingLayer())

def recursive_replace(self):
"""Recursively replace modules in the model as needed.
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/transformers/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):

# Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
tie_word_embeddings = self._get_tie_word_embeddings()
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")

Expand Down
Loading