Fix bug in Squeeze for getting the value of total_seq_len#1886
Conversation
The Squeeze is used for removing single-dimensional entries from the shape of a tensor. In this node the axes is set to [0] which would only eliminate the first axis and lead to the output shape to be [1] if the batch_size is 1.This would cause ShapeInference error https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode. This PR just removes the axes input to ensure all the single dimensions be removed from the shape.
|
@qjia7, PTAL, thanks! |
|
Since the reduceSum output shape is [batch_size, 1]. Squeeze might not be the right way to get total_seq_len unless we assume that batch_size==1. A better way is to let ReduceSum output shape [batch_size] by adding If we do not use |
qjia7
left a comment
There was a problem hiding this comment.
A better way is to let ReduceSum output shape [batch_size] by adding
keepdims=0attribute. Then use ReduceMax to get the maximum one (assumption is the longest sequence does not have padding), or use Gather to get first item if we assume that batch_size==1 (if we only enable graph capture for one batch scenario)
Great idea! I did assume that batch_size was one when adding those code. Your suggestion looks great. @tianleiwu one more question: why expose a scalar total_seq_len since the batch can be larger than one?
# attention_mask
# |
# Cast to int32
# |
# ReduceSum (keepdims=0)
# / \
# / \
# Sub ReduceMax
# | |
# seqlens_k total_seq_len
# (1D) (scalar)
|
Thanks @tianleiwu @qjia7, I also had the concern about how to handle @tianleiwu's suggestion is really a great idea! I've addressed it in the new commit, PTAL again, thanks! |
|
@Honry You may also need to update the ReduceSum in make_attention_mask_standard_reformatting_for_gqa in the similar way to get the correct 1D seqlens_k. |
|
When the original logic to obtain |
The `Squeeze` op is used for removing single-dimensional entries from the shape of a tensor. In this node the `axes` input is set to `[0]` which would only eliminate the first axis and lead to the output shape to be `[1]` if the `batch_size` is 1.This would cause ShapeInference error at https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode. This PR fixes the issue by: - Changing the ReduceSum output shape to [batch_size] by adding keepdims=0 attribute - Using ReduceMax instead of Squeeze to get the value of total_seq_len and make it as a scalar, this would cover scenarios when batch_size > 1
The
Squeezeop is used for removing single-dimensional entries from the shape of a tensor. In this node theaxesinput is set to[0]which would only eliminate the first axis and lead to the output shape to be[1]if thebatch_sizeis 1.This would cause ShapeInference error at https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode.This PR fixes the issue by: