Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set history_cross_kv_seqlens to 0 by default #2666

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class SchedulerSequence:
mrope_position_ids: Optional[Tensor] = None
mrope_position_delta: Optional[int] = None
cross_attention_states: Optional[Tensor] = None
history_cross_kv_seqlens: Optional[int] = None
history_cross_kv_seqlens: int = 0

def __post_init__(self):
"""post init."""
Expand Down Expand Up @@ -489,12 +489,11 @@ def num_all_tokens(self):

def num_all_cross_tokens(self):
"""num of all cross tokens."""
if self.history_cross_kv_seqlens is None:
if self.cross_attention_states is None:
self.history_cross_kv_seqlens = 0
else:
self.history_cross_kv_seqlens = self.cross_attention_states.shape[ # noqa
-2]
if self.cross_attention_states is None:
self.history_cross_kv_seqlens = 0
else:
self.history_cross_kv_seqlens = self.cross_attention_states.shape[
-2]
return self.history_cross_kv_seqlens

def update_token_ids(self,
Expand Down
Loading