Skip to content

Commit

Permalink
Fix bug for rope_scaling in config.josn is None (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai authored Mar 18, 2024
1 parent b445a6b commit 537c871
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
base = self.config.get("rope_theta", 10000.0)
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
if self.config.get("rope_scaling", {}) is None:
scaling_factor = 1.0
else:
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = 32 * max_position_embeddings # 64k
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
Expand All @@ -165,7 +168,10 @@ def _init_to_get_yarn_rotary(self):
dim = self.head_dim_
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
base = self.config.get("rope_theta", 10000.0)
scale = self.config.get("rope_scaling", {}).get("factor", 1.0)
if self.config.get("rope_scaling", {}) is None:
scale = 1.0
else:
scale = self.config.get("rope_scaling", {}).get("factor", 1.0)
original_max_position_embeddings = self.config.get("original_max_position_embeddings", 2048)
extrapolation_factor = 1.0
attn_factor = 1.0
Expand Down

0 comments on commit 537c871

Please sign in to comment.