Skip to content

Commit 6649037

Browse files
committed
fix the testcase
1 parent ff8e416 commit 6649037

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tests/python/relax/test_frontend_from_fx.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3825,7 +3825,7 @@ def main(
38253825
inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
38263826
inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
38273827
inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
3828-
) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
3828+
) -> R.Tensor((32, 8, 128, 64), dtype="float32"):
38293829
with R.dataflow():
38303830
lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims(
38313831
inp_0, axes=[0, 2, 1, 3]
@@ -3839,7 +3839,10 @@ def main(
38393839
lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention(
38403840
lv, lv1, lv2, scale=None
38413841
)
3842-
gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
3842+
lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims(
3843+
lv3, axes=[0, 2, 1, 3]
3844+
)
3845+
gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4
38433846
R.output(gv)
38443847
return gv
38453848

@@ -3851,7 +3854,7 @@ def main(
38513854
inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
38523855
inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
38533856
inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
3854-
) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
3857+
) -> R.Tensor((32, 8, 128, 64), dtype="float32"):
38553858
with R.dataflow():
38563859
lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims(
38573860
inp_0, axes=[0, 2, 1, 3]
@@ -3865,7 +3868,10 @@ def main(
38653868
lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention(
38663869
lv, lv1, lv2, inp_3, scale=None
38673870
)
3868-
gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
3871+
lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims(
3872+
lv3, axes=[0, 2, 1, 3]
3873+
)
3874+
gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4
38693875
R.output(gv)
38703876
return gv
38713877

@@ -3876,7 +3882,7 @@ def main(
38763882
inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
38773883
inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
38783884
inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
3879-
) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
3885+
) -> R.Tensor((32, 8, 128, 64), dtype="float32"):
38803886
with R.dataflow():
38813887
lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims(
38823888
inp_0, axes=[0, 2, 1, 3]
@@ -3890,7 +3896,10 @@ def main(
38903896
lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention(
38913897
lv, lv1, lv2, scale=None, causal_mask="TopLeft"
38923898
)
3893-
gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
3899+
lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims(
3900+
lv3, axes=[0, 2, 1, 3]
3901+
)
3902+
gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4
38943903
R.output(gv)
38953904
return gv
38963905

0 commit comments

Comments
 (0)