Skip to content

Commit e1c430c

Browse files
authored
[Relay][Frontend][Torch] fix pytorch frontend linspace op (#16417)
fix pytorch frontend linspace op
1 parent a5e883e commit e1c430c

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def linspace(self, inputs, input_types):
918918
# Find the spacing between values as step
919919
if step != 1:
920920
step = (stop - start) / (step - 1)
921-
stop = stop + step
921+
stop = stop + (step / 2)
922922
else:
923923
stop = start + step
924924

tests/python/frontend/pytorch/test_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3632,6 +3632,10 @@ class Linspace8(Module):
36323632
def forward(self, *args):
36333633
return torch.linspace(1, 2, 1, dtype=torch.int16)
36343634

3635+
class Linspace9(Module):
3636+
def forward(self, *args):
3637+
return torch.linspace(0, 8, 10)
3638+
36353639
verify_model(Linspace1().float().eval())
36363640
verify_model(Linspace2().float().eval())
36373641
verify_model(Linspace3().float().eval())
@@ -3640,6 +3644,7 @@ def forward(self, *args):
36403644
verify_model(Linspace6().float().eval())
36413645
verify_model(Linspace7().float().eval())
36423646
verify_model(Linspace8().float().eval())
3647+
verify_model(Linspace9().float().eval())
36433648

36443649

36453650
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)