Skip to content

Fix bug in Squeeze for getting the value of total_seq_len#1886

Merged
kunal-vaishnavi merged 5 commits into
microsoft:mainfrom
Honry:fix-squeeze
Nov 26, 2025
Merged

Fix bug in Squeeze for getting the value of total_seq_len#1886
kunal-vaishnavi merged 5 commits into
microsoft:mainfrom
Honry:fix-squeeze

Conversation

@Honry
Copy link
Copy Markdown
Contributor

@Honry Honry commented Nov 21, 2025

The Squeeze op is used for removing single-dimensional entries from the shape of a tensor. In this node the axes input is set to [0] which would only eliminate the first axis and lead to the output shape to be [1] if the batch_size is 1.This would cause ShapeInference error at https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode.

This PR fixes the issue by:

  • Changing the ReduceSum output shape to [batch_size] by adding keepdims=0 attribute
  • Using ReduceMax instead of Squeeze to get the value of total_seq_len and make it as a scalar, this would cover scenarios when batch_size > 1

The Squeeze is used for removing single-dimensional entries from the shape of a tensor. In this node the axes is set to [0] which would only eliminate the first axis and lead to the output shape to be [1] if the batch_size is 1.This would cause ShapeInference error https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode.

This PR just removes the axes input to ensure all the single dimensions be removed from the shape.
@Honry
Copy link
Copy Markdown
Contributor Author

Honry commented Nov 21, 2025

@qjia7, PTAL, thanks!

qjia7
qjia7 previously approved these changes Nov 21, 2025
@tianleiwu
Copy link
Copy Markdown
Contributor

tianleiwu commented Nov 21, 2025

Since the reduceSum output shape is [batch_size, 1]. Squeeze might not be the right way to get total_seq_len unless we assume that batch_size==1.

A better way is to let ReduceSum output shape [batch_size] by adding keepdims=0 attribute. Then use ReduceMax to get the maximum one (assumption is the longest sequence does not have padding), or use Gather to get first item if we assume that batch_size==1 (if we only enable graph capture for one batch scenario)

If we do not use keepdims=0, the seqlen_k will have shape [batch_size, 1], while the expected shape is [batch_size]. That's not expected as well.

Copy link
Copy Markdown
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way is to let ReduceSum output shape [batch_size] by adding keepdims=0 attribute. Then use ReduceMax to get the maximum one (assumption is the longest sequence does not have padding), or use Gather to get first item if we assume that batch_size==1 (if we only enable graph capture for one batch scenario)

Great idea! I did assume that batch_size was one when adding those code. Your suggestion looks great. @tianleiwu one more question: why expose a scalar total_seq_len since the batch can be larger than one?

        #          attention_mask
        #               |
        #         Cast to int32
        #               |
        #    ReduceSum (keepdims=0)
        #              /    \
        #             /      \
        #           Sub    ReduceMax
        #            |        |
        #       seqlens_k  total_seq_len
        #         (1D)       (scalar)

@qjia7 qjia7 self-requested a review November 22, 2025 03:31
@Honry
Copy link
Copy Markdown
Contributor Author

Honry commented Nov 24, 2025

Thanks @tianleiwu @qjia7, I also had the concern about how to handle batch_size > 1.

@tianleiwu's suggestion is really a great idea! I've addressed it in the new commit, PTAL again, thanks!

@qjia7
Copy link
Copy Markdown
Contributor

qjia7 commented Nov 24, 2025

@Honry You may also need to update the ReduceSum in make_attention_mask_standard_reformatting_for_gqa in the similar way to get the correct 1D seqlens_k.

@Honry
Copy link
Copy Markdown
Contributor Author

Honry commented Nov 24, 2025

@Honry You may also need to update the ReduceSum in make_attention_mask_standard_reformatting_for_gqa in the similar way to get the correct 1D seqlens_k.

@qjia7 thanks! Done, pls. take another look.

@qjia7 qjia7 requested a review from tianleiwu November 24, 2025 03:06
@kunal-vaishnavi
Copy link
Copy Markdown
Contributor

When the original logic to obtain seqlens_k and total_seq_len from the attention mask for the GQA op was added, it was assumed that batch_size = 1 since most inference workloads with ORT GenAI are for batch size = 1.

Comment thread src/python/py/models/builders/base.py
@kunal-vaishnavi kunal-vaishnavi enabled auto-merge (squash) November 26, 2025 02:07
@kunal-vaishnavi kunal-vaishnavi merged commit 5492721 into microsoft:main Nov 26, 2025
15 checks passed
kunal-vaishnavi pushed a commit that referenced this pull request Dec 5, 2025
The `Squeeze` op is used for removing single-dimensional entries from
the shape of a tensor. In this node the `axes` input is set to `[0]`
which would only eliminate the first axis and lead to the output shape
to be `[1]` if the `batch_size` is 1.This would cause ShapeInference
error at
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132
if it is not a strict mode.

This PR fixes the issue by:
- Changing the ReduceSum output shape to [batch_size] by adding
keepdims=0 attribute
- Using ReduceMax instead of Squeeze to get the value of total_seq_len
and make it as a scalar, this would cover scenarios when batch_size > 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants