Skip to content

Commit a2b29c0

Browse files
committed
fix msc testcase
1 parent f6f6c1a commit a2b29c0

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

python/tvm/contrib/msc/core/transform/pattern.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,17 @@ def make_relax_attention_pattern() -> (
330330
q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
331331
k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
332332
v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
333-
out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
333+
attention = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
334+
out = relax_pattern.is_op("relax.permute_dims")(attention)
334335
annotations = {
335336
"weight_q": weight_q,
336337
"weight_k": weight_k,
337338
"weight_v": weight_v,
338339
"q_trans": q_trans,
339340
"k_trans": k_trans,
340341
"v_trans": v_trans,
341-
"attention": out,
342+
"attention": attention,
343+
"out": out,
342344
}
343345
return out, annotations
344346

@@ -378,7 +380,8 @@ def make_relax_mask_attention_pattern() -> (
378380
q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
379381
k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
380382
v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
381-
out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask)
383+
attention = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask)
384+
out = relax_pattern.is_op("relax.permute_dims")(attention)
382385
annotations = {
383386
"weight_q": weight_q,
384387
"weight_k": weight_k,
@@ -387,7 +390,8 @@ def make_relax_mask_attention_pattern() -> (
387390
"q_trans": q_trans,
388391
"k_trans": k_trans,
389392
"v_trans": v_trans,
390-
"attention": out,
393+
"attention": attention,
394+
"out": out,
391395
}
392396
return out, annotations
393397

src/contrib/msc/framework/tvm/relax_opcode.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class RelaxAttentionCodeGen : public RelaxOpCode {
107107
.op_list_arg<int>(axes_key, "axes");
108108
}
109109
stack_.op_call().op_inputs_arg(false).op_arg<float>("scale").op_str_arg("causal_mask");
110+
stack_.op_call("relax.op.permute_dims").op_output_arg().op_list_arg<int>("axes_3", "axes");
110111
}
111112
};
112113

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,12 +2362,7 @@ def forward(self, q_data, k_data, v_data):
23622362
{"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
23632363
],
23642364
"outputs": [
2365-
{
2366-
"name": "attention",
2367-
"shape": [1, seq, 8, 64],
2368-
"dtype": "float32",
2369-
"layout": "ABCD",
2370-
}
2365+
{"name": "attention", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD"}
23712366
],
23722367
"nodes": {"total": 4, "input": 3, "msc.attention": 1},
23732368
}
@@ -2396,7 +2391,7 @@ def forward(self, q_data, k_data, v_data, mask):
23962391
"outputs": [
23972392
{
23982393
"name": "attention_bias",
2399-
"shape": [1, seq, 8, 64],
2394+
"shape": [1, 8, seq, 64],
24002395
"dtype": "float32",
24012396
"layout": "ABCD",
24022397
}

0 commit comments

Comments
 (0)