From ad3dcdc3a676ca6c797e2326825448fed2b8be8c Mon Sep 17 00:00:00 2001 From: Aakash Kaushik Date: Tue, 9 Jan 2024 00:41:45 +0530 Subject: [PATCH 1/5] fix: update layer_norm_epsilon in phi_1_5 tp layer_norm_eps --- vllm/model_executor/models/phi_1_5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 9d4424dd0890..66707ddee52a 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -183,7 +183,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=config.layer_norm_eps) self.mixer = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) @@ -245,7 +245,7 @@ class PhiCausalLMHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=config.layer_norm_eps) self.linear = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) From cd9a6ef5be50330058e94432ba8a9125696ed245 Mon Sep 17 00:00:00 2001 From: Aakash Kaushik Date: Tue, 9 Jan 2024 03:14:32 +0530 Subject: [PATCH 2/5] fix(phi-1_5): layer_norm_eps and layer_norm_epsilon checking. --- vllm/model_executor/models/phi_1_5.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 66707ddee52a..3ab0f7baa1d8 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -61,6 +61,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] +# check for layer_norm_eps in case of phi-1.5 and layer_norm_epsilon for phi-2 +layer_norm_eps = getattr(PretrainedConfig, "layer_norm_eps", None) +if layer_norm_eps is None: + layer_norm_eps = PretrainedConfig.layer_norm_epsilon + class PhiEmbedding(nn.Module): @@ -183,7 +188,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + eps=layer_norm_eps) self.mixer = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) @@ -245,7 +250,7 @@ class PhiCausalLMHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + eps=layer_norm_eps) self.linear = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) From 2803c81fa0981772695051054055718a7a02fc4d Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 12 Jan 2024 00:19:03 +0000 Subject: [PATCH 3/5] fix config variable names --- vllm/model_executor/models/phi_1_5.py | 30 ++++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 3ab0f7baa1d8..4e4a5b046fea 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -61,11 +61,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -# check for layer_norm_eps in case of phi-1.5 and layer_norm_epsilon for phi-2 -layer_norm_eps = getattr(PretrainedConfig, "layer_norm_eps", None) -if layer_norm_eps is None: - layer_norm_eps = PretrainedConfig.layer_norm_epsilon - class PhiEmbedding(nn.Module): @@ -118,13 +113,17 @@ def __init__(self, ) scaling = self.head_size**-0.5 - rotary_dim = config.rotary_dim + + # https://huggingface.co/microsoft/phi-2/blob/cb2f4533604d8b67de604e7df03bfe6f3ca22869/modeling_phi.py#L278-L281 + assert config.rope_scaling is None, "rope_scaling is not supported" + head_dim = config.hidden_size // config.num_attention_heads + rotary_dim = int(config.partial_rotary_factor * head_dim) assert rotary_dim % 2 == 0 # pylint: disable=C0301 # Refer to: # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 - rope_theta = 10000 + rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "n_positions", 2048) self.rotary_emb = get_rope( self.head_size, @@ -171,7 +170,7 @@ def __init__(self, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) - self.act = get_act_fn(config.activation_function, quant_config, + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): @@ -180,6 +179,12 @@ def forward(self, hidden_states): hidden_states, _ = self.fc2(hidden_states) return hidden_states +def _get_layer_norm_eps(config: PretrainedConfig) -> float: + # check for layer_norm_eps in case of phi-1.5 and layer_norm_epsilon for phi-2 + layer_norm_eps = getattr(config, "layer_norm_eps", None) + if layer_norm_eps is None: + layer_norm_eps = getattr(config, "layer_norm_epsilon", 1e-6) + return layer_norm_eps class PhiLayer(nn.Module): @@ -188,7 +193,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=layer_norm_eps) + eps=_get_layer_norm_eps(config)) self.mixer = PhiAttention(config, linear_method) self.mlp = PhiMLP(config, linear_method) @@ -250,7 +255,7 @@ class PhiCausalLMHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, - eps=layer_norm_eps) + eps=_get_layer_norm_eps(config)) self.linear = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True) @@ -299,6 +304,9 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + print(name) + continue + if "rotary_emb.inv_freq" in name: continue @@ -306,7 +314,9 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue # pylint: disable=E1136 + print(params_dict.keys()) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + 1/0 \ No newline at end of file From 0b643db7f6a825916946f5c8c37685c3e68028a5 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 12 Jan 2024 00:19:24 +0000 Subject: [PATCH 4/5] rename phi --- vllm/model_executor/models/{phi_1_5.py => phi.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/model_executor/models/{phi_1_5.py => phi.py} (100%) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi.py similarity index 100% rename from vllm/model_executor/models/phi_1_5.py rename to vllm/model_executor/models/phi.py From 6d73b8244de3aea1f16a734875062b76fc26f768 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 12 Jan 2024 00:24:43 +0000 Subject: [PATCH 5/5] reset debugging --- vllm/model_executor/models/phi.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4e4a5b046fea..6f3cd2cde9c8 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -304,9 +304,6 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): - print(name) - continue - if "rotary_emb.inv_freq" in name: continue @@ -314,9 +311,7 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue # pylint: disable=E1136 - print(params_dict.keys()) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - 1/0 \ No newline at end of file + weight_loader(param, loaded_weight) \ No newline at end of file