diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 9d4424dd0890..6f3cd2cde9c8 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -113,13 +113,17 @@ def __init__(self, ) scaling = self.head_size**-0.5 - rotary_dim = config.rotary_dim + + # https://huggingface.co/microsoft/phi-2/blob/cb2f4533604d8b67de604e7df03bfe6f3ca22869/modeling_phi.py#L278-L281 + assert config.rope_scaling is None, "rope_scaling is not supported" + head_dim = config.hidden_size // config.num_attention_heads + rotary_dim = int(config.partial_rotary_factor * head_dim) assert rotary_dim % 2 == 0 # pylint: disable=C0301 # Refer to: # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 - rope_theta = 10000 + rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "n_positions", 2048) self.rotary_emb = get_rope( self.head_size, @@ -166,7 +170,7 @@ def __init__(self, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) - self.act = get_act_fn(config.activation_function, quant_config, + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): @@ -175,6 +179,12 @@ def forward(self, hidden_states): hidden_states, _ = self.fc2(hidden_states) return hidden_states +def _get_layer_norm_eps(config: PretrainedConfig) -> float: + # check for layer_norm_eps in case of phi-1.5 and layer_norm_epsilon for phi-2 + layer_norm_eps = getattr(config, "layer_norm_eps", None) + if layer_norm_eps is None: + layer_norm_eps = getattr(config, "layer_norm_epsilon", 1e-6) + return layer_norm_eps class PhiLayer(nn.Module): @@ -183,7 +193,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=_get_layer_norm_eps(config)) self.mixer = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) @@ -245,7 +255,7 @@ class PhiCausalLMHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=_get_layer_norm_eps(config)) self.linear = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) @@ -304,4 +314,4 @@ def load_weights(self, param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader(param, loaded_weight) \ No newline at end of file