diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f60ea640359b..732b68fdf8a7 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -33,7 +33,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), - "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "YiForCausalLM": ("yi", "YiForCausalLM"), diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 9d4424dd0890..d14326196828 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -62,20 +62,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class PhiEmbedding(nn.Module): - - def __init__(self, config: PretrainedConfig): - super().__init__() - - self.wte = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - - def forward(self, input_ids: torch.LongTensor): - return self.wte(input_ids) - - class PhiAttention(nn.Module): def __init__(self, @@ -93,27 +79,22 @@ def __init__(self, tensor_model_parallel_world_size) # pylint: disable=C0103 - self.Wqkv = QKVParallelLinear( - self.hidden_size, - self.head_size, - self.total_num_heads, - linear_method=linear_method, - ) self.qkv_proj = QKVParallelLinear( - config.hidden_size, + self.hidden_size, self.head_size, self.total_num_heads, - bias=False, + bias=True, linear_method=linear_method, ) - self.out_proj = RowParallelLinear( + self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, linear_method=linear_method, ) scaling = self.head_size**-0.5 - rotary_dim = config.rotary_dim + rotary_dim = int(config.partial_rotary_factor * + (config.hidden_size // config.num_attention_heads)) assert rotary_dim % 2 == 0 # pylint: disable=C0301 @@ -136,12 +117,12 @@ def forward( kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: - qkv, _ = self.Wqkv(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.out_proj(attn_output) + output, _ = self.dense(attn_output) return output @@ -166,8 +147,7 @@ def __init__(self, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) - self.act = get_act_fn(config.activation_function, quant_config, - n_inner) + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): hidden_states, _ = self.fc1(hidden_states) @@ -182,9 +162,9 @@ def __init__(self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() - self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.mixer = PhiAttention(config, linear_method) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.self_attn = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) def forward( @@ -195,8 +175,8 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: residual = hidden_states - hidden_states = self.ln(hidden_states) - attn_outputs = self.mixer( + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, @@ -215,11 +195,14 @@ def __init__(self, super().__init__() self.config = config self.linear_method = linear_method - self.embd = PhiEmbedding(config) - self.h = nn.ModuleList([ + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ PhiLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) def forward( self, @@ -228,27 +211,19 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.embd(input_ids) + hidden_states = self.embed_tokens(input_ids) for i in range(self.config.num_hidden_layers): - layer = self.h[i] + layer = self.layers[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, ) - return hidden_states - -class PhiCausalLMHead(nn.Module): + hidden_states = self.final_layernorm(hidden_states) - def __init__(self, config: PretrainedConfig): - super().__init__() - self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.linear = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=True) + return hidden_states class PhiForCausalLM(nn.Module): @@ -260,8 +235,11 @@ def __init__(self, self.config = config self.linear_method = linear_method - self.transformer = PhiModel(config, linear_method) - self.lm_head = PhiCausalLMHead(config) + self.model = PhiModel(config, linear_method) + + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True) self.sampler = Sampler(config.vocab_size) def forward( @@ -271,9 +249,9 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) - hidden_states = self.lm_head.ln(hidden_states) + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states def sample( @@ -281,7 +259,7 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - head = self.lm_head.linear + head = self.lm_head next_tokens = self.sampler(head.weight, hidden_states, sampling_metadata, head.bias) return next_tokens @@ -291,17 +269,37 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v") + ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # pylint: disable=E1136 - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # pylint: disable=E1136 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)