Skip to content

Commit a783823

Browse files
committed
fix msc testcase
1 parent f6f6c1a commit a783823

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

python/tvm/contrib/msc/core/transform/pattern.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,17 @@ def make_relax_attention_pattern() -> (
330330
q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
331331
k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
332332
v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
333-
out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
333+
attention = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans)
334+
out = relax_pattern.is_op("relax.permute_dims")(attention)
334335
annotations = {
335336
"weight_q": weight_q,
336337
"weight_k": weight_k,
337338
"weight_v": weight_v,
338339
"q_trans": q_trans,
339340
"k_trans": k_trans,
340341
"v_trans": v_trans,
341-
"attention": out,
342+
"attention": attention,
343+
"out": out,
342344
}
343345
return out, annotations
344346

@@ -378,7 +380,8 @@ def make_relax_mask_attention_pattern() -> (
378380
q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q)
379381
k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k)
380382
v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v)
381-
out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask)
383+
attention = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask)
384+
out = relax_pattern.is_op("relax.permute_dims")(attention)
382385
annotations = {
383386
"weight_q": weight_q,
384387
"weight_k": weight_k,
@@ -387,7 +390,8 @@ def make_relax_mask_attention_pattern() -> (
387390
"q_trans": q_trans,
388391
"k_trans": k_trans,
389392
"v_trans": v_trans,
390-
"attention": out,
393+
"attention": attention,
394+
"out": out,
391395
}
392396
return out, annotations
393397

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,13 +2363,13 @@ def forward(self, q_data, k_data, v_data):
23632363
],
23642364
"outputs": [
23652365
{
2366-
"name": "attention",
2367-
"shape": [1, seq, 8, 64],
2366+
"name": "permute_dims_3",
2367+
"shape": [1, 8, seq, 64],
23682368
"dtype": "float32",
2369-
"layout": "ABCD",
2369+
"layout": "ACBD",
23702370
}
23712371
],
2372-
"nodes": {"total": 4, "input": 3, "msc.attention": 1},
2372+
"nodes": {"total": 5, "input": 3, "msc.attention": 1, "permute_dims": 1},
23732373
}
23742374
if dynamic:
23752375
expected1["prims"] = {"total": 1, "shape": 1}
@@ -2395,13 +2395,13 @@ def forward(self, q_data, k_data, v_data, mask):
23952395
],
23962396
"outputs": [
23972397
{
2398-
"name": "attention_bias",
2399-
"shape": [1, seq, 8, 64],
2398+
"name": "permute_dims_3",
2399+
"shape": [1, 8, seq, 64],
24002400
"dtype": "float32",
2401-
"layout": "ABCD",
2401+
"layout": "ACBD",
24022402
}
24032403
],
2404-
"nodes": {"total": 5, "input": 4, "msc.attention": 1},
2404+
"nodes": {"total": 6, "input": 4, "msc.attention": 1, "permute_dims": 1},
24052405
}
24062406
if dynamic:
24072407
expected2["prims"] = {"total": 1, "shape": 1}

0 commit comments

Comments
 (0)