3131from vllm .compilation .decorators import support_torch_compile
3232from vllm .config import (CacheConfig , DeviceConfig , ModelConfig ,
3333 ParallelConfig , VllmConfig )
34+ from vllm .config .utils import getattr_iter
3435from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
3536from vllm .distributed .utils import get_pp_indices
3637from vllm .logger import init_logger
@@ -486,10 +487,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
486487
487488 # Input embeddings
488489 if not isinstance (self .model .get_input_embeddings (), PPMissingLayer ):
490+ names = ("embedding_size" , "hidden_size" )
491+ embedding_dim = getattr_iter (self .text_config , names , None )
492+ assert embedding_dim is not None
489493 self .model .set_input_embeddings (
490494 VocabParallelEmbedding (
491495 self .text_config .vocab_size ,
492- self . text_config . hidden_size ,
496+ embedding_dim = embedding_dim ,
493497 org_num_embeddings = self .text_config .vocab_size ,
494498 quant_config = self .quant_config ,
495499 ))
@@ -645,7 +649,9 @@ def create_attention_instances(
645649 attn_type = attn_type )
646650 return attention_instances
647651
648- def init_parameters (self , module : nn .Module ):
652+ def init_parameters (self ,
653+ module : nn .Module ,
654+ dtype : Optional [torch .dtype ] = None ):
649655 """
650656 If a `parameter` is on the `meta` device, then its parent
651657 `module` is the original module created by:
@@ -659,11 +665,11 @@ def init_parameters(self, module: nn.Module):
659665 if param .device == torch .device ("meta" ):
660666 new_param = nn .Parameter (
661667 torch .empty_like (param .data ,
662- dtype = self .model_config .dtype ,
668+ dtype = dtype or self .model_config .dtype ,
663669 device = self .device_config .device ))
664670 setattr (module , name , new_param )
665671 for child in module .children ():
666- self .init_parameters (child )
672+ self .init_parameters (child , dtype )
667673
668674 def forward (
669675 self ,
@@ -712,73 +718,6 @@ def load_weights(self, weights: Iterable[tuple[str,
712718 return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
713719
714720
715- @support_torch_compile (enable_if = can_enable_torch_compile )
716- class TransformersModel (TransformersBase ):
717- hf_to_vllm_mapper = WeightsMapper (
718- orig_to_new_prefix = {
719- # Handle BERT-like models
720- "bert" : "model" ,
721- # Add `model.` prefix for base model checkpoints
722- "" : "model." ,
723- # Remove `model.` prefix if it was already there
724- "model.model." : "model." ,
725- # Pooling adapters will be adjacent to `model`
726- "model.pooler" : "pooler" ,
727- "model.score" : "score" ,
728- # Classifier adapter's classifier layer is renamed to score
729- "model.classifier" : "score" ,
730- },
731- orig_to_new_suffix = {
732- # Replace legacy suffixes used for norms
733- ".gamma" : ".weight" ,
734- ".beta" : ".bias" ,
735- })
736-
737- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
738- super ().__init__ (vllm_config = vllm_config , prefix = prefix )
739-
740- # After creating a pooling model, `pooler` will be duplicated.
741- # The one inside `model` comes from the Transformers modelling code.
742- # The one after `model` is an adapter from vLLM.
743- # We want to use the adapter so we nullify the original pooler.
744- if getattr (self .model , "pooler" , None ) is not None :
745- self .skip_prefixes .append ("pooler." )
746- self .model .pooler = torch .nn .Identity ()
747-
748- # Some encoder models have the position_ids buffer in the checkpoint.
749- # vLLM will always pass position_ids as an argument, so we skip loading
750- # the buffer if it exists
751- self .skip_substrs .append ("position_ids" )
752-
753- # Some encoder models have the bias of the final classifier layer
754- # in the checkpoint. vLLM does not use this bias, so we skip loading
755- # it if it exists
756- self .skip_substrs .append ("score.bias" )
757-
758- def create_attention_instances (
759- self , attn_type : AttentionType = AttentionType .DECODER ):
760- # TODO(hmellor): Better way to detect encoder models
761- # In encoder models, the attention layers will have `is_causal=False`
762- is_encoder = lambda m : not getattr (m , "is_causal" , True )
763- # vLLM does not support encoder-decoder models, so if any encoder layer
764- # is found, we assume the whole model is an encoder model
765- if any (is_encoder (m ) for m in self .model .modules ()):
766- attn_type = AttentionType .ENCODER_ONLY
767-
768- # Check minimum transformers version for encoder models support
769- if attn_type == AttentionType .ENCODER_ONLY :
770- import transformers
771- from packaging .version import Version
772- installed = Version (transformers .__version__ )
773- required = Version ("4.57.0.dev0" )
774- if installed < required :
775- raise ValueError (
776- "Encoder models with the Transformers backend require "
777- f"transformers>={ required } , but got { installed } " )
778-
779- return super ().create_attention_instances (attn_type )
780-
781-
782721@support_torch_compile (enable_if = can_enable_torch_compile )
783722class TransformersForCausalLM (TransformersBase ):
784723
0 commit comments