diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py index b217743c4ab14..c2917be4510ec 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py @@ -61,9 +61,6 @@ def create_attention_node(self, gemm, gemm_qkv, input, output): self.node_name_to_graph_name[add_node.name] = self.this_graph_name def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - # (TODO) hasesh/tlwu: Investigate what fixes the following logic needs in order - # to fuse the Attention sub-graph. With some changes to other fusions, this stopped - # working. return_indice = [] is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization" @@ -187,20 +184,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]) if qk_nodes is not None: (softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes - mask_nodes = self.model.match_parent_path( + _, mask_nodes, _ = self.model.match_parent_paths( where_qk, [ - "Cast", - "Slice", - "Slice", - "Unsqueeze", - "Sub", - "Squeeze", - "Slice", - "Shape", - "Div", + ( + ["Cast", "Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape", "Div"], + [0, 0, 0, 1, 0, 0, 0, 0, 0], + ), + # For transformers >= 4.27, causal mask uses torch.bool instead of torch.uint8. + ( + ["Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape", "Div"], + [0, 0, 1, 0, 0, 0, 0, 0], + ), ], - [0, 0, 0, 1, 0, 0, 0, 0, 0], + output_name_to_node, ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py index 74136c2b8bc61..62e9c4a66f005 100644 --- a/onnxruntime/test/python/transformers/gpt2_model_generator.py +++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py @@ -928,6 +928,372 @@ def create_gpt2_fused_embedlayer( return helper.make_model(graph, opset_imports=(opsetid,)) +def create_gpt2_attention_no_past(hidden_size=64, num_heads=4, max_seq_len=32, switch_add_inputs=False, add_cast=True): + # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). + is_opset_13_or_newer = version.parse(onnx.__version__) >= version.parse("1.8.0") + + head_size = int(hidden_size // num_heads) + + # nodes in attention subgraph + nodes = [ + helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"), + helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + epsion=0.000009999999747378752, + ), + # reshape before gemm: [B, S, hidden] -> [B*S, hidden] + helper.make_node( + "Reshape", + ["layernorm_out", "reshape_before_gemm_shape"], + ["reshape_before_gemm_out"], + "reshape_before_gemm", + ), + # fully connected gemm: [B*S, hidden] -> [B*S, 3*hidden] + helper.make_node( + "Gemm", + ["reshape_before_gemm_out", "gemm_fc_weight", "gemm_fc_bias"], + ["gemm_fc_out"], + "gemm_fc", + alpha=1.0, + beta=1.0, + transA=0, + transB=0, + ), + # reshape after gemm: [B*S, 3*hidden] -> [B, S, 3*hidden] + helper.make_node( + "Reshape", + ["gemm_fc_out", "reshape_after_gemm_shape"], + ["reshape_after_gemm_out"], + "reshape_after_gemm", + ), + # split into q, k, v + ( + helper.make_node( + "Split", + ["reshape_after_gemm_out", "split_q_k_v"], + ["q", "k", "v"], + "split_qkv", + axis=2, + ) + if is_opset_13_or_newer + else helper.make_node( + "Split", + ["reshape_after_gemm_out"], + ["q", "k", "v"], + "split_qkv", + axis=2, + split=[hidden_size, hidden_size, hidden_size], + ) + ), + # q nodes: [B, S, hidden] -> [B, S, num_heads, head_size] -> [B, num_heads, S, head_size] + helper.make_node("Reshape", ["q", "reshape_x_shape"], ["reshape_q_out"], "reshape_q"), + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), + # k nodes: [B, S, hidden] -> [B, S, num_heads, head_size] -> [B, num_heads, head_size, S] + helper.make_node("Reshape", ["k", "reshape_x_shape"], ["reshape_k_out"], "reshape_k"), + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 3, 1], + ), + # v nodes: [B, S, hidden] -> [B, S, num_heads, head_size] -> [B, num_heads, S, head_size] + helper.make_node("Reshape", ["v", "reshape_x_shape"], ["reshape_v_out"], "reshape_v"), + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), + # qk matmul: [B, H, S, d] x [B, H, d, S] -> [B, H, S, S] + helper.make_node( + "MatMul", + ["transpose_q_out", "transpose_k_out"], + ["qk_out"], + "matmul_qk", + ), + # qk div (scaling) + helper.make_node("Div", ["qk_out", "div_weight"], ["qk_norm_out"], "qk_norm"), + # mask subgraph: Shape(div_output) -> extract total_seq_len + helper.make_node("Shape", ["qk_norm_out"], ["div_shape_out"], "div_shape"), + helper.make_node( + "Slice", + ["div_shape_out", "starts_n2", "ends_n1", "axes_0"], + ["div_shape_slice_out"], + "div_shape_slice", + ), + ( + helper.make_node( + "Squeeze", + ["div_shape_slice_out", "axes_0"], + ["total_seq_len"], + "squeeze_total_seq_len", + ) + if is_opset_13_or_newer + else helper.make_node( + "Squeeze", + ["div_shape_slice_out"], + ["total_seq_len"], + "squeeze_total_seq_len", + axes=[0], + ) + ), + # mask subgraph: Shape(transpose_q) -> extract q_seq_len + helper.make_node("Shape", ["transpose_q_out"], ["transpose_q_shape_out"], "transpose_q_shape"), + helper.make_node( + "Slice", + ["transpose_q_shape_out", "starts_n2", "ends_n1", "axes_0"], + ["transpose_q_shape_slice_out"], + "transpose_q_shape_slice", + ), + ( + helper.make_node( + "Squeeze", + ["transpose_q_shape_slice_out", "axes_0"], + ["q_seq_len"], + "squeeze_q_seq_len", + ) + if is_opset_13_or_newer + else helper.make_node( + "Squeeze", + ["transpose_q_shape_slice_out"], + ["q_seq_len"], + "squeeze_q_seq_len", + axes=[0], + ) + ), + # Sub(total_seq_len, q_seq_len) -> start_idx + helper.make_node("Sub", ["total_seq_len", "q_seq_len"], ["sub_out"], "sub"), + # Unsqueeze start_idx and total_seq_len for Slice inputs + ( + helper.make_node("Unsqueeze", ["sub_out", "axes_0"], ["sub_unsqueeze_out"], "sub_unsqueeze") + if is_opset_13_or_newer + else helper.make_node("Unsqueeze", ["sub_out"], ["sub_unsqueeze_out"], "sub_unsqueeze", axes=[0]) + ), + ( + helper.make_node( + "Unsqueeze", + ["total_seq_len", "axes_0"], + ["total_seq_len_unsqueeze_out"], + "total_seq_len_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["total_seq_len"], + ["total_seq_len_unsqueeze_out"], + "total_seq_len_unsqueeze", + axes=[0], + ) + ), + # Slice undir_mask on axis 2: [1,1,max,max] -> [1,1,S,max] + helper.make_node( + "Slice", + [ + "undir_mask", + "sub_unsqueeze_out", + "total_seq_len_unsqueeze_out", + "axes_2", + "steps_1", + ], + ["undir_mask_slice_out"], + "undir_mask_slice", + ), + # Slice on axis 3: [1,1,S,max] -> [1,1,S,S] + helper.make_node( + "Slice", + [ + "undir_mask_slice_out", + "starts_0", + "total_seq_len_unsqueeze_out", + "axes_3", + "steps_1", + ], + ["mask_slice_slice_out"], + "mask_slice_slice", + ), + ] + + # Optionally add Cast node for old transformers (uint8 mask -> bool) + if add_cast: + nodes.append( + helper.make_node( + "Cast", + ["mask_slice_slice_out"], + ["undir_mask_out"], + "undir_mask_cast", + to=9, + ), + ) + where_mask_input = "undir_mask_out" + else: + where_mask_input = "mask_slice_slice_out" + + nodes.extend( + [ + # Where(mask, div_output, -10000) + helper.make_node( + "Where", + [where_mask_input, "qk_norm_out", "where_weight"], + ["where_out"], + "where", + ), + helper.make_node("Softmax", ["where_out"], ["softmax_out"], "softmax", axis=3), + # qkv matmul: [B, H, S, S] x [B, H, S, d] -> [B, H, S, d] + helper.make_node( + "MatMul", + ["softmax_out", "transpose_v_out"], + ["matmul_qkv_out"], + "matmul_qk_v", + ), + # transpose qkv: [B, H, S, d] -> [B, S, H, d] + helper.make_node( + "Transpose", + ["matmul_qkv_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + # reshape: [B, S, H, d] -> [B, S, hidden] + helper.make_node( + "Reshape", + ["transpose_qkv_out", "reshape_weight_qkv"], + ["reshape_qkv_1_out"], + "reshape_qkv_1", + ), + # reshape: [B, S, hidden] -> [B*S, hidden] + helper.make_node( + "Reshape", + ["reshape_qkv_1_out", "reshape_before_output_gemm_shape"], + ["reshape_qkv_2_out"], + "reshape_qkv_2", + ), + # output projection gemm: [B*S, hidden] -> [B*S, hidden] + helper.make_node( + "Gemm", + ["reshape_qkv_2_out", "gemm_out_weight", "gemm_out_bias"], + ["gemm_out"], + "gemm_out", + alpha=1.0, + beta=1.0, + transA=0, + transB=0, + ), + # reshape: [B*S, hidden] -> [B, S, hidden] + helper.make_node( + "Reshape", + ["gemm_out", "reshape_after_output_gemm_shape"], + ["gemm_reshape_out"], + "gemm_reshape", + ), + # skip connection add + helper.make_node( + "Add", + reverse_if(["gemm_reshape_out", "layernorm_input"], switch_add_inputs), + ["skip_output"], + "add_skip", + ), + # final layernorm + helper.make_node( + "LayerNormalization", + ["skip_output", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm2", + epsion=0.000009999999747378752, + ), + ] + ) + + # Unidirectional mask + unidir_mask_data = numpy.tril(numpy.ones((max_seq_len, max_seq_len))).reshape([max_seq_len * max_seq_len]) + if add_cast: + unidir_mask_data = unidir_mask_data.astype(numpy.uint8) + unidir_mask_dtype = TensorProto.UINT8 + else: + unidir_mask_data = unidir_mask_data.astype(bool) + unidir_mask_dtype = TensorProto.BOOL + + initializers = [ + float_tensor("layer_norm_weight", [hidden_size]), + float_tensor("layer_norm_bias", [hidden_size]), + float_tensor("gemm_fc_weight", [hidden_size, 3 * hidden_size]), + float_tensor("gemm_fc_bias", [3 * hidden_size]), + float_tensor("gemm_out_weight", [hidden_size, hidden_size]), + float_tensor("gemm_out_bias", [hidden_size]), + helper.make_tensor( + "undir_mask", + unidir_mask_dtype, + [1, 1, max_seq_len, max_seq_len], + unidir_mask_data.tolist(), + ), + helper.make_tensor("div_weight", TensorProto.FLOAT, [], [math.sqrt(head_size)]), + helper.make_tensor("where_weight", TensorProto.FLOAT, [], [-10000.0]), + helper.make_tensor("starts_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("starts_n2", TensorProto.INT64, [1], [-2]), + helper.make_tensor("ends_n1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("axes_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("axes_2", TensorProto.INT64, [1], [2]), + helper.make_tensor("axes_3", TensorProto.INT64, [1], [3]), + helper.make_tensor("steps_1", TensorProto.INT64, [1], [1]), + helper.make_tensor("reshape_x_shape", TensorProto.INT64, [4], [0, 0, num_heads, head_size]), + helper.make_tensor("reshape_weight_qkv", TensorProto.INT64, [3], [0, 0, hidden_size]), + helper.make_tensor("reshape_before_gemm_shape", TensorProto.INT64, [2], [-1, hidden_size]), + helper.make_tensor("reshape_after_gemm_shape", TensorProto.INT64, [3], [0, 0, 3 * hidden_size]), + helper.make_tensor("reshape_before_output_gemm_shape", TensorProto.INT64, [2], [-1, hidden_size]), + helper.make_tensor("reshape_after_output_gemm_shape", TensorProto.INT64, [3], [0, 0, hidden_size]), + ] + + if is_opset_13_or_newer: + initializers.append( + helper.make_tensor( + "split_q_k_v", + TensorProto.INT64, + [3], + [hidden_size, hidden_size, hidden_size], + ) + ) + + graph = helper.make_graph( + [node for node in nodes if node], + "GPT2_no_past", # name + [ # inputs + helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + ], + [ # outputs + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + ], + initializers, + ) + + # Needed so that we don't see the new LayerNormalization function added in version 17. + # TODO(https://github.com/microsoft/onnxruntime/issues/11916): Remove once fixed. + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + if __name__ == "__main__": model = create_gpt2_attention() onnx.save(model, "gpt2_attention.onnx") @@ -952,3 +1318,15 @@ def create_gpt2_fused_embedlayer( model = create_gpt2_fused_embedlayer(one_attention_node=True, output_embedding_sum=True) onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx") + + model = create_gpt2_attention_no_past() + onnx.save(model, "gpt2_attention_no_past.onnx") + + model = create_gpt2_attention_no_past(switch_add_inputs=True) + onnx.save(model, "gpt2_attention_no_past_add.onnx") + + model = create_gpt2_attention_no_past(add_cast=False) + onnx.save(model, "gpt2_attention_no_past_no_cast.onnx") + + model = create_gpt2_attention_no_past(switch_add_inputs=True, add_cast=False) + onnx.save(model, "gpt2_attention_no_past_add_no_cast.onnx") diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 02cffd5e3a3aa..017c9b598046a 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -11,7 +11,7 @@ import onnx from bart_model_generator import create_bart_attention_sdpa from bert_model_generator import create_bert_attention, create_bert_attention_pre_ln, create_tf2onnx_attention_3d -from gpt2_model_generator import create_gpt2_attention +from gpt2_model_generator import create_gpt2_attention, create_gpt2_attention_no_past from model_loader import get_test_data_path from parity_utilities import find_transformers_source @@ -301,6 +301,37 @@ def test_bart_attention_sdpa_fusion(self): if unidirectional_attr is not None: self.assertEqual(unidirectional_attr.i, 0) + def test_gpt2_attention_no_past_fusion(self): + hidden_size = 64 + num_heads = 4 + for add_cast in [True, False]: + for switch_add_inputs in [False, True]: + model = create_gpt2_attention_no_past( + hidden_size=hidden_size, + num_heads=num_heads, + switch_add_inputs=switch_add_inputs, + add_cast=add_cast, + ) + dir = "." + model_path = os.path.join(dir, "gpt2_attention_no_past.onnx") + onnx.save(model, model_path) + + options = FusionOptions("gpt2") + + optimized_model = optimize_model( + model_path, + model_type="gpt2", + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=options, + ) + + os.remove(model_path) + + model_suffix = "add_opt" if switch_add_inputs else "opt" + model_name = f"gpt2_attention_no_past_{model_suffix}.onnx" + self.verify_fusion(optimized_model, model_name) + def test_megatron_gpt2_attention_fusion(self): for enable_skip_layer_norm_fusion in [False, True]: path = get_test_data_path("models", "gpt2_megatron.onnx") diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_add_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_add_opt.onnx new file mode 100644 index 0000000000000..c3af4d18565f3 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_add_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_opt.onnx new file mode 100644 index 0000000000000..5bab6a2a9c62e Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_no_past_opt.onnx differ