Skip to content

Commit 43268e1

Browse files
committed
fix msc testcase
1 parent f6f6c1a commit 43268e1

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,13 +2363,13 @@ def forward(self, q_data, k_data, v_data):
23632363
],
23642364
"outputs": [
23652365
{
2366-
"name": "attention",
2367-
"shape": [1, seq, 8, 64],
2366+
"name": "permute_dims_3",
2367+
"shape": [1, 8, seq, 64],
23682368
"dtype": "float32",
2369-
"layout": "ABCD",
2369+
"layout": "ACBD",
23702370
}
23712371
],
2372-
"nodes": {"total": 4, "input": 3, "msc.attention": 1},
2372+
"nodes": {"total": 5, "input": 3, "msc.attention": 1, "permute_dims": 1},
23732373
}
23742374
if dynamic:
23752375
expected1["prims"] = {"total": 1, "shape": 1}
@@ -2395,13 +2395,13 @@ def forward(self, q_data, k_data, v_data, mask):
23952395
],
23962396
"outputs": [
23972397
{
2398-
"name": "attention_bias",
2399-
"shape": [1, seq, 8, 64],
2398+
"name": "permute_dims_3",
2399+
"shape": [1, 8, seq, 64],
24002400
"dtype": "float32",
2401-
"layout": "ABCD",
2401+
"layout": "ACBD",
24022402
}
24032403
],
2404-
"nodes": {"total": 5, "input": 4, "msc.attention": 1},
2404+
"nodes": {"total": 6, "input": 4, "msc.attention": 1, "permute_dims": 1},
24052405
}
24062406
if dynamic:
24072407
expected2["prims"] = {"total": 1, "shape": 1}

0 commit comments

Comments
 (0)