@@ -3825,7 +3825,7 @@ def main(
38253825 inp_0 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38263826 inp_1 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38273827 inp_2 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
3828- ) -> R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ):
3828+ ) -> R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ):
38293829 with R .dataflow ():
38303830 lv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .permute_dims (
38313831 inp_0 , axes = [0 , 2 , 1 , 3 ]
@@ -3839,7 +3839,10 @@ def main(
38393839 lv3 : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .nn .attention (
38403840 lv , lv1 , lv2 , scale = None
38413841 )
3842- gv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = lv3
3842+ lv4 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = R .permute_dims (
3843+ lv3 , axes = [0 , 2 , 1 , 3 ]
3844+ )
3845+ gv : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = lv4
38433846 R .output (gv )
38443847 return gv
38453848
@@ -3851,7 +3854,7 @@ def main(
38513854 inp_1 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38523855 inp_2 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38533856 inp_3 : R .Tensor ((32 , 8 , 128 , 128 ), dtype = "float32" ),
3854- ) -> R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ):
3857+ ) -> R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ):
38553858 with R .dataflow ():
38563859 lv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .permute_dims (
38573860 inp_0 , axes = [0 , 2 , 1 , 3 ]
@@ -3865,7 +3868,10 @@ def main(
38653868 lv3 : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .nn .attention (
38663869 lv , lv1 , lv2 , inp_3 , scale = None
38673870 )
3868- gv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = lv3
3871+ lv4 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = R .permute_dims (
3872+ lv3 , axes = [0 , 2 , 1 , 3 ]
3873+ )
3874+ gv : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = lv4
38693875 R .output (gv )
38703876 return gv
38713877
@@ -3876,7 +3882,7 @@ def main(
38763882 inp_0 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38773883 inp_1 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
38783884 inp_2 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ),
3879- ) -> R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ):
3885+ ) -> R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ):
38803886 with R .dataflow ():
38813887 lv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .permute_dims (
38823888 inp_0 , axes = [0 , 2 , 1 , 3 ]
@@ -3890,7 +3896,10 @@ def main(
38903896 lv3 : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = R .nn .attention (
38913897 lv , lv1 , lv2 , scale = None , causal_mask = "TopLeft"
38923898 )
3893- gv : R .Tensor ((32 , 128 , 8 , 64 ), dtype = "float32" ) = lv3
3899+ lv4 : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = R .permute_dims (
3900+ lv3 , axes = [0 , 2 , 1 , 3 ]
3901+ )
3902+ gv : R .Tensor ((32 , 8 , 128 , 64 ), dtype = "float32" ) = lv4
38943903 R .output (gv )
38953904 return gv
38963905
0 commit comments