Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/generation/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,8 @@ def get_model_kwargs(self) -> PagedAttentionArgs:
"input_ids": self.input_ids,
"position_ids": self.position_ids,
"attention_mask": self.attention_mask,
"cumulative_seqlens_q": self.cumulative_seqlens_q,
"cumulative_seqlens_k": self.cumulative_seqlens_k,
"cu_seq_lens_q": self.cumulative_seqlens_q,
"cu_seq_lens_k": self.cumulative_seqlens_k,
"write_index": self.write_index,
"read_index": self.read_index,
"logits_indices": self.logits_indices,
Expand Down Expand Up @@ -1238,7 +1238,7 @@ def _process_logit(self, batch_data, logits):
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
if hasattr(self.logit_processor, "set_continuous_batching_context"):
self.logit_processor.set_continuous_batching_context(
batch_data["logits_indices"], batch_data["cumulative_seqlens_q"]
batch_data["logits_indices"], batch_data["cu_seq_lens_q"]
)
return self.logit_processor(batch_data["input_ids"], logits)

Expand Down
10 changes: 5 additions & 5 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,26 +356,26 @@ def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None):
self.penalty = penalty
self.prompt_ignore_length = prompt_ignore_length
self.logits_indices = None
self.cumulative_seqlens_q = None
self.cu_seq_lens_q = None

def set_continuous_batching_context(self, logits_indices: torch.Tensor, cumulative_seqlens_q: torch.Tensor):
def set_continuous_batching_context(self, logits_indices: torch.Tensor, cu_seq_lens_q: torch.Tensor):
self.logits_indices = logits_indices
self.cumulative_seqlens_q = cumulative_seqlens_q
self.cu_seq_lens_q = cu_seq_lens_q

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]

if scores.dim() == 3:
if self.logits_indices is not None and self.cumulative_seqlens_q is not None:
if self.logits_indices is not None and self.cu_seq_lens_q is not None:
batch_size, seq_len, vocab_size = scores.shape
last_positions = self.logits_indices
last_scores = scores[0, last_positions, :]

# Prepare token mask
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
cu_seq_lens = self.cumulative_seqlens_q
cu_seq_lens = self.cu_seq_lens_q
lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
token_mask[seq_indices, input_ids] = True
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/integrations/flash_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def paged_attention_forward(
v: torch.Tensor,
attention_mask: torch.Tensor = None,
cache: PagedAttentionCache = None,
cumulative_seqlens_q=None,
cumulative_seqlens_k=None,
cu_seq_lens_q=None,
cu_seq_lens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
Expand All @@ -35,9 +35,9 @@ def paged_attention_forward(
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Comment on lines +38 to 39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my main issue with this naming is that is is no helpful for newbies, cu does not mean anything!

cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
Expand All @@ -48,7 +48,7 @@ def paged_attention_forward(
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
k, v = cache.update(k, v, module.layer_idx, **kwargs)

sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
if implementation is not None:
Copy link
Contributor Author

@Kurt232 Kurt232 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think paged_attention_forward should use cu_seq_lens_q/k instead of cumulative_seqlens_q/k to keep coherency with flash_attn_varlen_func.

https://github.com/huggingface/transformers/blob/29ddcacea3ad9d3cdf6c5d8e51d1d39cbc5e7dfa/src/transformers/modeling_flash_attention_utils.py#L557C1-L578C3

Expand All @@ -58,8 +58,8 @@ def paged_attention_forward(
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
v.transpose(1, 2).squeeze(0).contiguous(),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32).clone(),
cu_seq_lens_q.to(torch.int32),
cu_seq_lens_k.to(torch.int32).clone(),
max_seqlen_q,
max_seqlen_k,
softmax_scale=module.scaling,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,18 @@ class FlashAttentionKwargs(TypedDict, total=False):
Keyword arguments for Flash Attention with Compile.

Attributes:
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""

cumulative_seqlens_q: Optional[torch.LongTensor]
cumulative_seqlens_k: Optional[torch.LongTensor]
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ class TransformersKwargs(TypedDict, total=False):
Turn this on to return the intermediary attention scores.
output_router_logits (`Optional[bool]`, *optional*):
For MoE models, this allows returning the router logits to compute the loss.
cumulative_seqlens_q (`torch.LongTensor`, *optional*)
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cumulative_seqlens_k (`torch.LongTensor`, *optional*)
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
Expand All @@ -880,8 +880,8 @@ class TransformersKwargs(TypedDict, total=False):
output_hidden_states: Optional[bool]
output_attentions: Optional[bool]
output_router_logits: Optional[bool]
cumulative_seqlens_q: Optional["torch.LongTensor"]
cumulative_seqlens_k: Optional["torch.LongTensor"]
cu_seq_lens_q: Optional["torch.LongTensor"]
cu_seq_lens_k: Optional["torch.LongTensor"]
max_length_q: Optional[int]
max_length_k: Optional[int]

Expand Down