From bd5d77f4071f01dc10d0da2c45fdc8929510bab0 Mon Sep 17 00:00:00 2001 From: dongdongl Date: Thu, 23 Nov 2023 09:01:31 +0000 Subject: [PATCH] fix bug when stride is constant --- python/triton/compiler/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index ef629c75a6bc..7844c1ebb265 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -230,8 +230,8 @@ def getGlobalStrides(self, args): for ii in range(i): t *= t_globalDims[ii] # -2 means the sride in arguments is folded constant 1, we don't use 1 because it can not be distinguished from index 1 - elif t_globalStridesArgIdx[i] == -2: - t = 1 + elif t_globalStridesArgIdx[i] < 0: + t = -1 - t_globalStridesArgIdx[i] else: new_idx = self.getOriginArgIdx(t_globalStridesArgIdx[i], args) t = args[new_idx]