diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 09ac99925..044ad63bd 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -144,7 +144,7 @@ def _init_to_get_dynamic_ntk_rotary(self): scaling_factor = 1.0 else: scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - max_seq_len = 32 * max_position_embeddings # 64k + max_seq_len = self.max_seq_length self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda") self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")