Skip to content

Commit

Permalink
reset the ntk length range (#374)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <[email protected]>
  • Loading branch information
shihaobai and baishihao authored Mar 25, 2024
1 parent 755c4fd commit f5dc783
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _init_to_get_dynamic_ntk_rotary(self):
scaling_factor = 1.0
else:
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = 32 * max_position_embeddings # 64k
max_seq_len = self.max_seq_length
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")

Expand Down

0 comments on commit f5dc783

Please sign in to comment.