@@ -2363,13 +2363,13 @@ def forward(self, q_data, k_data, v_data):
23632363 ],
23642364 "outputs" : [
23652365 {
2366- "name" : "attention " ,
2367- "shape" : [1 , seq , 8 , 64 ],
2366+ "name" : "permute_dims_3 " ,
2367+ "shape" : [1 , 8 , seq , 64 ],
23682368 "dtype" : "float32" ,
2369- "layout" : "ABCD " ,
2369+ "layout" : "ACBD " ,
23702370 }
23712371 ],
2372- "nodes" : {"total" : 4 , "input" : 3 , "msc.attention" : 1 },
2372+ "nodes" : {"total" : 5 , "input" : 3 , "msc.attention" : 1 , "permute_dims " : 1 },
23732373 }
23742374 if dynamic :
23752375 expected1 ["prims" ] = {"total" : 1 , "shape" : 1 }
@@ -2395,13 +2395,13 @@ def forward(self, q_data, k_data, v_data, mask):
23952395 ],
23962396 "outputs" : [
23972397 {
2398- "name" : "attention_bias " ,
2399- "shape" : [1 , seq , 8 , 64 ],
2398+ "name" : "permute_dims_3 " ,
2399+ "shape" : [1 , 8 , seq , 64 ],
24002400 "dtype" : "float32" ,
2401- "layout" : "ABCD " ,
2401+ "layout" : "ACBD " ,
24022402 }
24032403 ],
2404- "nodes" : {"total" : 5 , "input" : 4 , "msc.attention" : 1 },
2404+ "nodes" : {"total" : 6 , "input" : 4 , "msc.attention" : 1 , "permute_dims " : 1 },
24052405 }
24062406 if dynamic :
24072407 expected2 ["prims" ] = {"total" : 1 , "shape" : 1 }
0 commit comments