diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index be2daa2d8e..b05d7090b5 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -911,7 +911,7 @@ def make_expand(self, name, inputs, dtype, shape): self.make_node("Expand", inputs=inputs, outputs=[output], name=name) self.make_value(output, dtype, shape=shape) - def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=True): + def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=False): output = f"{name}/output_0" self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name, keepdims=keepdims) self.make_value(output, dtype, shape=shape) @@ -4172,10 +4172,10 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen # | # Cast to int32 # | - # ReduceSum + # ReduceSum (keepdims=0) # / \ # / \ - # Sub Squeeze + # Sub ReduceMax # | | # seqlens_k total_seq_len # (1D) (int) @@ -4187,20 +4187,20 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen ) reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = [f"{cast_1_name}/output_0", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size"]) # Left branch: Calculate seqlens_k = ReduceSum - 1 sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT32/[1]"] - self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size"]) - # Right branch: Squeeze to get int value for total_seq_len - squeeze_name = f"{attn_mask_basename}/Squeeze" - squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[0]"] - self.make_squeeze(squeeze_name, squeeze_inputs, dtype=ir.DataType.INT32, shape=[]) + # Right branch: ReduceMax to get maximum int value for total_seq_len + reduce_max_name = f"{attn_mask_basename}/ReduceMax" + reduce_max_inputs = [f"{reduce_sum_name}/output_0"] + self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=ir.DataType.INT32, shape=[]) self.mask_attrs["seqlens_k"] = sub_name - self.mask_attrs["total_seq_len"] = squeeze_name + self.mask_attrs["total_seq_len"] = reduce_max_name def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Make nodes for the attention mask subgraph that calculates @@ -4209,6 +4209,7 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # attention_mask # / \ # ReduceSum Shape + # (keepdims=0) | # | | # Sub Gather # | | @@ -4220,12 +4221,12 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) sub_name = f"{attn_mask_basename}/Sub" sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[1]"] - self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) cast_1_name = f"{attn_mask_basename}/Sub/Cast" - self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"]) # Right path shape_name = f"{attn_mask_basename}/Shape" @@ -4257,6 +4258,7 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # attention_mask # / \ # ReduceSum Shape + # (keepdims=0) | # | | # Cast to int32 Gather # | | @@ -4271,9 +4273,9 @@ def make_attention_mask_reformatting_for_sparse_attn(self): # Left path reduce_sum_name = f"{attn_mask_basename}/ReduceSum" reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"] - self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1]) + self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"]) cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast" - self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1]) + self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"]) # Right path shape_name = f"{attn_mask_basename}/Shape" diff --git a/src/python/py/models/builders/gptoss.py b/src/python/py/models/builders/gptoss.py index 71cf69b825..2812a1fcbe 100644 --- a/src/python/py/models/builders/gptoss.py +++ b/src/python/py/models/builders/gptoss.py @@ -519,7 +519,6 @@ def make_moe_decomposed(self, layer_id, mlp, root_input): reduce_sum_inputs, dtype=ir.DataType.FLOAT, shape=["batch_size", "sequence_length", self.intermediate_size, 1], - keepdims=False, ) weighted_sum_squeeze_name = f"{basename}/weighted_sum/Squeeze" weighted_sum_squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[-1]"]