Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def make_expand(self, name, inputs, dtype, shape):
self.make_node("Expand", inputs=inputs, outputs=[output], name=name)
self.make_value(output, dtype, shape=shape)

def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=True):
def make_reduce_sum(self, name, inputs, dtype, shape, keepdims=False):
output = f"{name}/output_0"
self.make_node("ReduceSum", inputs=inputs, outputs=[output], name=name, keepdims=keepdims)
self.make_value(output, dtype, shape=shape)
Expand Down Expand Up @@ -4172,10 +4172,10 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen
# |
# Cast to int32
# |
# ReduceSum
# ReduceSum (keepdims=0)
# / \
# / \
# Sub Squeeze
# Sub ReduceMax
# | |
# seqlens_k total_seq_len
# (1D) (int)
Expand All @@ -4187,20 +4187,20 @@ def make_attention_mask_graph_capture_reformatting_for_gqa(self, attn_mask_basen
)
reduce_sum_name = f"{attn_mask_basename}/ReduceSum"
reduce_sum_inputs = [f"{cast_1_name}/output_0", "/model/constants/INT64/[1]"]
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1])
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT32, shape=["batch_size"])

# Left branch: Calculate seqlens_k = ReduceSum - 1
sub_name = f"{attn_mask_basename}/Sub"
sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT32/[1]"]
self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size", 1])
self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT32, shape=["batch_size"])

# Right branch: Squeeze to get int value for total_seq_len
squeeze_name = f"{attn_mask_basename}/Squeeze"
squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[0]"]
self.make_squeeze(squeeze_name, squeeze_inputs, dtype=ir.DataType.INT32, shape=[])
# Right branch: ReduceMax to get maximum int value for total_seq_len
reduce_max_name = f"{attn_mask_basename}/ReduceMax"
reduce_max_inputs = [f"{reduce_sum_name}/output_0"]
self.make_reduce_max(reduce_max_name, reduce_max_inputs, dtype=ir.DataType.INT32, shape=[])

self.mask_attrs["seqlens_k"] = sub_name
self.mask_attrs["total_seq_len"] = squeeze_name
self.mask_attrs["total_seq_len"] = reduce_max_name

def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename):
# Make nodes for the attention mask subgraph that calculates
Expand All @@ -4209,6 +4209,7 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename):
# attention_mask
# / \
# ReduceSum Shape
# (keepdims=0) |
# | |
# Sub Gather
# | |
Expand All @@ -4220,12 +4221,12 @@ def make_attention_mask_standard_reformatting_for_gqa(self, attn_mask_basename):
# Left path
reduce_sum_name = f"{attn_mask_basename}/ReduceSum"
reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"]
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1])
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"])
sub_name = f"{attn_mask_basename}/Sub"
sub_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[1]"]
self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1])
self.make_sub(sub_name, sub_inputs, dtype=ir.DataType.INT64, shape=["batch_size"])
cast_1_name = f"{attn_mask_basename}/Sub/Cast"
self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1])
self.make_cast(cast_1_name, f"{sub_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"])
Comment thread
kunal-vaishnavi marked this conversation as resolved.

# Right path
shape_name = f"{attn_mask_basename}/Shape"
Expand Down Expand Up @@ -4257,6 +4258,7 @@ def make_attention_mask_reformatting_for_sparse_attn(self):
# attention_mask
# / \
# ReduceSum Shape
# (keepdims=0) |
# | |
# Cast to int32 Gather
# | |
Expand All @@ -4271,9 +4273,9 @@ def make_attention_mask_reformatting_for_sparse_attn(self):
# Left path
reduce_sum_name = f"{attn_mask_basename}/ReduceSum"
reduce_sum_inputs = ["attention_mask", "/model/constants/INT64/[1]"]
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size", 1])
self.make_reduce_sum(reduce_sum_name, reduce_sum_inputs, dtype=ir.DataType.INT64, shape=["batch_size"])
cast_1_name = f"{attn_mask_basename}/ReduceSum/Cast"
self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size", 1])
self.make_cast(cast_1_name, f"{reduce_sum_name}/output_0", dtype=ir.DataType.INT32, shape=["batch_size"])

# Right path
shape_name = f"{attn_mask_basename}/Shape"
Expand Down
1 change: 0 additions & 1 deletion src/python/py/models/builders/gptoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ def make_moe_decomposed(self, layer_id, mlp, root_input):
reduce_sum_inputs,
dtype=ir.DataType.FLOAT,
shape=["batch_size", "sequence_length", self.intermediate_size, 1],
keepdims=False,
)
weighted_sum_squeeze_name = f"{basename}/weighted_sum/Squeeze"
weighted_sum_squeeze_inputs = [f"{reduce_sum_name}/output_0", "/model/constants/INT64/[-1]"]
Expand Down
Loading