Skip to content

Commit 45079bb

Browse files
committed
fix strided_slice
1 parent 188e3af commit 45079bb

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tester/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,9 @@ def torch_assert_accuracy(self, paddle_tensor, torch_tensor, atol, rtol):
893893

894894
paddle_tensor = paddle_tensor.cpu().detach()
895895
torch_tensor = torch_tensor.cpu().detach()
896+
if self.api_config.api_name == "paddle.strided_slice" and any(s < 0 for s in paddle_tensor.strides):
897+
# torch's from_dlpack now don't support negative strides
898+
paddle_tensor = paddle_tensor.contiguous()
896899

897900
paddle_dlpack = paddle.utils.dlpack.to_dlpack(paddle_tensor)
898901
converted_paddle_tensor = torch.utils.dlpack.from_dlpack(paddle_dlpack)

0 commit comments

Comments
 (0)