Skip to content

Commit

Permalink
fix bug for qwen without dynamic ntk (#303)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <[email protected]>
  • Loading branch information
shihaobai and baishihao authored Jan 17, 2024
1 parent 9659c53 commit 8ecd84d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions lightllm/models/qwen/infer_struct.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo
from lightllm.models.llama.infer_struct import LlamaInferStateInfo

class QwenInferStateInfo(InferStateInfo):
class QwenInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.position_cos = None
Expand All @@ -11,6 +11,11 @@ def __init__(self):
self.logn_values = None

def init_some_extra_state(self, model, input_ids : torch.Tensor):
use_dynamic_ntk = model.config.get("use_dynamic_ntk", False)
if not use_dynamic_ntk:
super().init_some_extra_state(model, input_ids)
return

if self.is_prefill:
b_start_loc_numpy = self.b_start_loc.cpu().numpy()
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
Expand Down

0 comments on commit 8ecd84d

Please sign in to comment.