From d89a1e8320f2662915efd30024aef9b2e2b415ad Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 21 Nov 2025 19:33:16 +0800 Subject: [PATCH 1/5] Fix bug in Squeeze for getting the value of total_seq_len The Squeeze is used for removing single-dimensional entries from the shape of a tensor. In this node the axes is set to [0] which would only eliminate the first axis and lead to the output shape to be [1] if the batch_size is 1.This would cause ShapeInference error https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph.cc#L111-L132 if it is not a strict mode. This PR just removes the axes input to ensure all the single dimensions be removed from the shape. --- src/python/py/models/builders/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index be2daa2d8e..57ea43c8eb 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -4196,7 +4196,7 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen # Right branch: Squeeze to get int value for total_seq_len squeeze_name = f"{attn_mask_basename}/Squeeze" - squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[0]"] + squeeze_inputs = [f"{reduce_sum_name}/output_0"] self.make_squeeze(squeeze_name, squeeze_inputs, dtype=ir.DataType.INT32, shape=[]) self.mask_attrs["seqlens_k"] = sub_name From 60caa6e1ee887679049cdad70bcc00d8ade70c00 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 24 Nov 2025 10:35:23 +0800 Subject: [PATCH 2/5] Address comments --- src/python/py/models/builders/base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 57ea43c8eb..69c7c77ebb 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -4172,10 +4172,10 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen # | # Cast to int32 # | - # ReduceSum + # ReduceSum (keepdims=0) # / \ # / \ - # Sub Squeeze + # Sub ReduceMax # | | # seqlens_k total_seq_len # (1D) (int) @@ -4187,20 +4187,20 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen ) reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = [f"{cast_1_name}/output_0", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size"], keepdims=False) # Left branch: Calculate seqlens_k = ReduceSum - 1 sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT32/[1]"] - self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size"]) - # Right branch: Squeeze to get int value for total_seq_len - squeeze_name = f"{attn_mask_basename}/Squeeze" - squeeze_inputs = [f"{reduce_sum_name}/output_0"] - self.make_squeeze(squeeze_name, squeeze_inputs, dtype=ir.DataType.INT32, shape=[]) + # Right branch: ReduceMax to get maximum int value for total_seq_len + reduce_max_name = f"{attn_mask_basename}/ReduceMax" + reduce_max_inputs = [f"{reduce_sum_name}/output_0"] + self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=ir.DataType.INT32, shape=[]) self.mask_attrs["seqlens_k"] = sub_name - self.mask_attrs["total_seq_len"] = squeeze_name + self.mask_attrs["total_seq_len"] = reduce_max_name def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Make nodes for the attention mask subgraph that calculates From 2d3838211c6eb01151f84a73b0d51e1cecb81bca Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 24 Nov 2025 10:59:42 +0800 Subject: [PATCH 3/5] update the ReduceSum in make_attention_mask_standard_reformatting_for_gqa --- src/python/py/models/builders/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 69c7c77ebb..f928baef0a 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -4209,6 +4209,7 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # attention_mask # / \ # ReduceSum Shape + # (keepdims=0) | # | | # Sub Gather # | | @@ -4220,12 +4221,12 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"], keepdims=False) sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[1]"] - self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) cast_1_name = f"{attn_mask_basename}/Sub/Cast" - self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"]) # Right path shape_name = f"{attn_mask_basename}/Shape" From 435172cb14fd350c68cf6884a7f12e36fd20f946 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 24 Nov 2025 13:03:53 +0800 Subject: [PATCH 4/5] update ReduceSum for the attention mask subgraph as well --- src/python/py/models/builders/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index f928baef0a..913df61815 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -4258,6 +4258,7 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # attention_mask # / \ # ReduceSum Shape + # (keepdims=0) | # | | # Cast to int32 Gather # | | @@ -4272,9 +4273,9 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"], keepdims=False) cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast" - self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"]) # Right path shape_name = f"{attn_mask_basename}/Shape" From 245a0e250d521d19865bfea5f44b628c8a730f8d Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 24 Nov 2025 13:13:11 +0800 Subject: [PATCH 5/5] Simplify make_reduce_sum by using keepdims=False by default --- src/python/py/models/builders/base.py | 8 ++++---- src/python/py/models/builders/gptoss.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 913df61815..b05d7090b5 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -911,7 +911,7 @@ def make_expand(self, name, inputs, dtype, shape): self.make_node("Expand", inputs=inputs, outputs=[output], name=name) self.make_value(output, dtype, shape=shape) - def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=True): + def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=False): output = f"{name}/output_0" self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name, keepdims=keepdims) self.make_value(output, dtype, shape=shape) @@ -4187,7 +4187,7 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen ) reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = [f"{cast_1_name}/output_0", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size"], keepdims=False) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size"]) # Left branch: Calculate seqlens_k = ReduceSum - 1 sub_name = f"{attn_mask_basename}/Sub" @@ -4221,7 +4221,7 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"], keepdims=False) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[1]"] self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) @@ -4273,7 +4273,7 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"], keepdims=False) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast" self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"]) diff --git a/src/python/py/models/builders/gptoss.py b/src/python/py/models/builders/gptoss.py index 71cf69b825..2812a1fcbe 100644 --- a/src/python/py/models/builders/gptoss.py +++ b/src/python/py/models/builders/gptoss.py @@ -519,7 +519,6 @@ def make_moe_decomposed(self, layer_id, mlp, root_input): reduce_sum_inputs, dtype=ir.DataType.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size, 1], - keepdims=False, ) weighted_sum_squeeze_name = f"{basename}/weighted_sum/Squeeze" weighted_sum_squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[-1]"]