Skip to content

Commit 36f2502

Browse files
Add op support for roll op (#17839)
* add op support for roll op * lint fix * fixed unity check * add unit test in fx_graph * lint issues * lint check * confilct resolved
1 parent 40a16db commit 36f2502

File tree

5 files changed

+327
-1
lines changed

5 files changed

+327
-1
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import math
2424
from typing import Callable, Dict, Optional, Tuple, Union, List
2525

26-
from tvm import relax
26+
from tvm import relax, tir
2727

2828

2929
class BaseFXGraphImporter(metaclass=abc.ABCMeta):
@@ -1164,6 +1164,85 @@ def _repeat(self, node: fx.Node) -> relax.Var:
11641164
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
11651165
return self.block_builder.emit(relax.op.tile(x, dims))
11661166

1167+
def _roll(self, node: fx.Node) -> relax.Var:
1168+
args = self.retrieve_args(node)
1169+
input_tensor = args[0]
1170+
shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts", None)
1171+
dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None)
1172+
1173+
# Get original shape
1174+
original_shape = self.shape_of(input_tensor)
1175+
1176+
def to_int(val):
1177+
if isinstance(val, tir.IntImm):
1178+
return int(val.value)
1179+
elif isinstance(val, int):
1180+
return val
1181+
elif hasattr(val, "__int__"):
1182+
return int(val)
1183+
raise TypeError(f"Unsupported type for shift/dim: {type(val)}")
1184+
1185+
def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var:
1186+
shape = self.shape_of(tensor)
1187+
1188+
dim_size = shape.values[dim]
1189+
shift_val = to_int(shift)
1190+
dim_size_val = to_int(dim_size)
1191+
shift_mod = shift_val % dim_size_val
1192+
if shift_mod == 0:
1193+
return tensor
1194+
1195+
split_pos = dim_size_val - shift_mod
1196+
part1 = self.block_builder.emit(
1197+
relax.op.strided_slice(
1198+
tensor,
1199+
axes=[dim],
1200+
begin=[0],
1201+
end=[split_pos],
1202+
strides=[1],
1203+
)
1204+
)
1205+
part2 = self.block_builder.emit(
1206+
relax.op.strided_slice(
1207+
tensor,
1208+
axes=[dim],
1209+
begin=[split_pos],
1210+
end=[dim_size_val],
1211+
strides=[1],
1212+
)
1213+
)
1214+
return self.block_builder.emit(relax.op.concat([part2, part1], axis=dim))
1215+
1216+
# Handle dims=None (flatten -> roll -> reshape)
1217+
if dims is None:
1218+
flattened = self.block_builder.emit(relax.op.reshape(input_tensor, (-1,)))
1219+
shift_scalar = to_int(shifts[0] if isinstance(shifts, (list, tuple)) else shifts)
1220+
rolled = roll_single_dim(flattened, shift_scalar, 0)
1221+
return self.block_builder.emit(relax.op.reshape(rolled, original_shape))
1222+
1223+
# Normalize shifts and dims
1224+
if isinstance(shifts, (list, tuple)):
1225+
shifts = [to_int(s) for s in shifts]
1226+
else:
1227+
shifts = [to_int(shifts)]
1228+
1229+
if isinstance(dims, (list, tuple)):
1230+
dims = [to_int(d) for d in dims]
1231+
else:
1232+
dims = [to_int(dims)]
1233+
1234+
if len(shifts) != len(dims):
1235+
raise ValueError("shifts and dims must have the same length")
1236+
1237+
result = input_tensor
1238+
rank = len(original_shape.values)
1239+
for shift, dim in zip(shifts, dims):
1240+
if dim < 0:
1241+
dim += rank
1242+
result = roll_single_dim(result, shift, dim)
1243+
1244+
return result
1245+
11671246
def _reshape(self, node: fx.Node) -> relax.Var:
11681247
import torch # type: ignore
11691248

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def create_convert_map(
423423
"narrow.default": self._narrow,
424424
"permute.default": self._permute,
425425
"repeat.default": self._repeat,
426+
"roll.default": self._roll,
426427
"select.int": self._select,
427428
"slice.Tensor": self._slice,
428429
"split.Tensor": self._split,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ def create_convert_map(
750750
"numel": self._numel,
751751
"permute": self._permute,
752752
"repeat": self._repeat,
753+
"roll": self._roll,
753754
"reshape": self._reshape,
754755
"scatter": self._scatter,
755756
"select": self._select,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,6 +2968,131 @@ def main(
29682968
verify_model(ReshapeAs(), example_args, {}, expected1)
29692969

29702970

2971+
def test_roll():
2972+
class Roll1(Module):
2973+
def forward(self, x):
2974+
return torch.roll(x, 1)
2975+
2976+
class Roll2(Module):
2977+
def forward(self, x):
2978+
return torch.roll(x, -1, 0)
2979+
2980+
class Roll3(Module):
2981+
def forward(self, x):
2982+
return torch.roll(x, shifts=(2, 1), dims=(0, 1))
2983+
2984+
# Test case 1: torch.roll(x, 1)
2985+
@I.ir_module
2986+
class Expected1:
2987+
@R.function
2988+
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
2989+
with R.dataflow():
2990+
lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
2991+
lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
2992+
lv,
2993+
axes=[0],
2994+
begin=[R.prim_value(0)],
2995+
end=[R.prim_value(7)],
2996+
strides=[R.prim_value(1)],
2997+
assume_inbound=False,
2998+
)
2999+
lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
3000+
lv,
3001+
axes=[0],
3002+
begin=[R.prim_value(7)],
3003+
end=[R.prim_value(8)],
3004+
strides=[R.prim_value(1)],
3005+
assume_inbound=False,
3006+
)
3007+
lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0)
3008+
lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2]))
3009+
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
3010+
R.output(gv)
3011+
return gv
3012+
3013+
# Test case 2: torch.roll(x, -1, 0)
3014+
@I.ir_module
3015+
class Expected2:
3016+
@R.function
3017+
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
3018+
with R.dataflow():
3019+
lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
3020+
x,
3021+
axes=[0],
3022+
begin=[R.prim_value(0)],
3023+
end=[R.prim_value(1)],
3024+
strides=[R.prim_value(1)],
3025+
assume_inbound=False,
3026+
)
3027+
lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
3028+
x,
3029+
axes=[0],
3030+
begin=[R.prim_value(1)],
3031+
end=[R.prim_value(4)],
3032+
strides=[R.prim_value(1)],
3033+
assume_inbound=False,
3034+
)
3035+
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
3036+
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
3037+
R.output(gv)
3038+
return gv
3039+
3040+
# Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1))
3041+
@I.ir_module
3042+
class Expected3:
3043+
@R.function
3044+
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
3045+
with R.dataflow():
3046+
# First roll along dim=0 with shift=2
3047+
lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
3048+
x,
3049+
axes=[0],
3050+
begin=[R.prim_value(0)],
3051+
end=[R.prim_value(2)],
3052+
strides=[R.prim_value(1)],
3053+
assume_inbound=False,
3054+
)
3055+
lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
3056+
x,
3057+
axes=[0],
3058+
begin=[R.prim_value(2)],
3059+
end=[R.prim_value(4)],
3060+
strides=[R.prim_value(1)],
3061+
assume_inbound=False,
3062+
)
3063+
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
3064+
3065+
# Second roll along dim=1 with shift=1
3066+
lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
3067+
lv2,
3068+
axes=[1],
3069+
begin=[R.prim_value(0)],
3070+
end=[R.prim_value(1)],
3071+
strides=[R.prim_value(1)],
3072+
assume_inbound=False,
3073+
)
3074+
lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
3075+
lv2,
3076+
axes=[1],
3077+
begin=[R.prim_value(1)],
3078+
end=[R.prim_value(2)],
3079+
strides=[R.prim_value(1)],
3080+
assume_inbound=False,
3081+
)
3082+
lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1)
3083+
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
3084+
R.output(gv)
3085+
return gv
3086+
3087+
# Test inputs
3088+
example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
3089+
3090+
# Run verification for each case
3091+
verify_model(Roll1(), (example_input,), {}, Expected1)
3092+
verify_model(Roll2(), (example_input,), {}, Expected2)
3093+
verify_model(Roll3(), (example_input,), {}, Expected3)
3094+
3095+
29713096
def test_select_slice():
29723097
class Slice1(Module):
29733098
def forward(self, x):

tests/python/relax/test_frontend_from_fx.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3560,6 +3560,126 @@ def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float3
35603560
verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)
35613561

35623562

3563+
def test_roll():
3564+
class Roll1(Module):
3565+
def forward(self, x):
3566+
return torch.roll(x, 1)
3567+
3568+
class Roll2(Module):
3569+
def forward(self, x):
3570+
return torch.roll(x, -1, 0)
3571+
3572+
class Roll3(Module):
3573+
def forward(self, x):
3574+
return torch.roll(x, shifts=(2, 1), dims=(0, 1))
3575+
3576+
# Test case 1: torch.roll(x, 1)
3577+
@I.ir_module
3578+
class Expected1:
3579+
@R.function
3580+
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
3581+
with R.dataflow():
3582+
lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, R.shape([8]))
3583+
lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
3584+
lv,
3585+
axes=[0],
3586+
begin=[R.prim_value(0)],
3587+
end=[R.prim_value(7)],
3588+
strides=[R.prim_value(1)],
3589+
assume_inbound=False,
3590+
)
3591+
lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
3592+
lv,
3593+
axes=[0],
3594+
begin=[R.prim_value(7)],
3595+
end=[R.prim_value(8)],
3596+
strides=[R.prim_value(1)],
3597+
assume_inbound=False,
3598+
)
3599+
lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0)
3600+
lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2]))
3601+
gv: R.Tensor((4, 2), dtype="int64") = lv4
3602+
R.output(gv)
3603+
return gv
3604+
3605+
# Test case 2: torch.roll(x, -1, 0)
3606+
@I.ir_module
3607+
class Expected2:
3608+
@R.function
3609+
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
3610+
with R.dataflow():
3611+
lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
3612+
inp_0,
3613+
axes=[0],
3614+
begin=[R.prim_value(0)],
3615+
end=[R.prim_value(1)],
3616+
strides=[R.prim_value(1)],
3617+
assume_inbound=False,
3618+
)
3619+
lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
3620+
inp_0,
3621+
axes=[0],
3622+
begin=[R.prim_value(1)],
3623+
end=[R.prim_value(4)],
3624+
strides=[R.prim_value(1)],
3625+
assume_inbound=False,
3626+
)
3627+
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
3628+
gv: R.Tensor((4, 2), dtype="int64") = lv2
3629+
R.output(gv)
3630+
return gv
3631+
3632+
# Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1))
3633+
@I.ir_module
3634+
class Expected3:
3635+
@R.function
3636+
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
3637+
with R.dataflow():
3638+
lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
3639+
inp_0,
3640+
axes=[0],
3641+
begin=[R.prim_value(0)],
3642+
end=[R.prim_value(2)],
3643+
strides=[R.prim_value(1)],
3644+
assume_inbound=False,
3645+
)
3646+
lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
3647+
inp_0,
3648+
axes=[0],
3649+
begin=[R.prim_value(2)],
3650+
end=[R.prim_value(4)],
3651+
strides=[R.prim_value(1)],
3652+
assume_inbound=False,
3653+
)
3654+
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
3655+
lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
3656+
lv2,
3657+
axes=[1],
3658+
begin=[R.prim_value(0)],
3659+
end=[R.prim_value(1)],
3660+
strides=[R.prim_value(1)],
3661+
assume_inbound=False,
3662+
)
3663+
lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
3664+
lv2,
3665+
axes=[1],
3666+
begin=[R.prim_value(1)],
3667+
end=[R.prim_value(2)],
3668+
strides=[R.prim_value(1)],
3669+
assume_inbound=False,
3670+
)
3671+
lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1)
3672+
gv: R.Tensor((4, 2), dtype="int64") = lv5
3673+
R.output(gv)
3674+
return gv
3675+
3676+
input_info = [([4, 2], "int64")]
3677+
3678+
verify_model(Roll1(), input_info, {}, Expected1)
3679+
verify_model(Roll2(), input_info, {}, Expected2)
3680+
verify_model(Roll3(), input_info, {}, Expected3)
3681+
3682+
35633683
def test_view():
35643684
input_info = [([1, 2, 3, 4], "float32")]
35653685

0 commit comments

Comments
 (0)