diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index aabb4aa27918..d32bfe6cabbd 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -16,8 +16,9 @@ # limitations under the License. """Transformers modeling backend base class.""" -from collections.abc import Iterable +from collections.abc import Callable, Iterable from itertools import chain +from operator import attrgetter from typing import TYPE_CHECKING import regex as re @@ -296,6 +297,15 @@ def _create_hf_to_vllm_mapper(self): # Apply mapping to quantization config if needed self._maybe_apply_model_mapping() + def _get_tie_word_embeddings(self): + """ + Check if the model has tied word embeddings. + """ + # Transformers v4 and v5 will store this in different places + tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False) + tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False) + return tie_word_embeddings_v4 or tie_word_embeddings_v5 + def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. @@ -311,11 +321,22 @@ def pipeline_parallel(self): f"{type(self.model)} does not support pipeline parallel. {tip}" ) + def attrsetter(attr: str) -> Callable[[object, object], None]: + """Set a possibly nested attribute, like the inverse of attrgetter.""" + parent, _, name = attr.rpartition(".") + + def setter(obj: object, value: object): + attr_parent = attrgetter(parent)(obj) if parent else obj + setattr(attr_parent, name, value) + + return setter + module_lists = [] module_list_idx = None pp_plan = list(self.model._pp_plan.keys()) for i, name in enumerate(pp_plan): - if isinstance(getattr(self.model, name), nn.ModuleList): + # attrgetter in case the module is nested (e.g. "text_model.layers") + if isinstance(attrgetter(name)(self.model), nn.ModuleList): module_lists.append(name) module_list_idx = i @@ -330,11 +351,11 @@ def pipeline_parallel(self): # Layers before module list for name in pp_plan[:module_list_idx]: if self.pp_group.is_first_rank or ( - getattr(self.text_config, "tie_word_embeddings", False) - and self.pp_group.is_last_rank + self._get_tie_word_embeddings() and self.pp_group.is_last_rank ): continue - setattr(self.model, name, PPMissingLayer()) + # attrsetter in case the module is nested (e.g. "text_model.embed_tokens") + attrsetter(name)(self.model, PPMissingLayer()) # Module list start_layer, end_layer = get_pp_indices( @@ -343,7 +364,8 @@ def pipeline_parallel(self): self.pp_group.world_size, ) layers_name = pp_plan[module_list_idx] - layers = getattr(self.model, layers_name) + # attrgetter in case the module is nested (e.g. "text_model.layers") + layers = attrgetter(layers_name)(self.model) for i in range(len(layers)): if start_layer <= i and i < end_layer: continue @@ -353,7 +375,8 @@ def pipeline_parallel(self): for name in pp_plan[module_list_idx + 1 :]: # Modules that should be on last rank if not self.pp_group.is_last_rank: - setattr(self.model, name, PPMissingLayer()) + # attrsetter in case the module is nested (e.g. "text_model.norm") + attrsetter(name)(self.model, PPMissingLayer()) def recursive_replace(self): """Recursively replace modules in the model as needed. diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py index d1efa6a11ee2..b6ceb2d67706 100644 --- a/vllm/model_executor/models/transformers/causal.py +++ b/vllm/model_executor/models/transformers/causal.py @@ -38,7 +38,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): # Tell `Base.load_weights` to skip # `lm_head` if the model has tied word embeddings - tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) + tie_word_embeddings = self._get_tie_word_embeddings() if tie_word_embeddings: self.skip_prefixes.append("lm_head.")