diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index ef629c75a6bc..7844c1ebb265 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -230,8 +230,8 @@ def getGlobalStrides(self, args): for ii in range(i): t *= t_globalDims[ii] # -2 means the sride in arguments is folded constant 1, we don't use 1 because it can not be distinguished from index 1 - elif t_globalStridesArgIdx[i] == -2: - t = 1 + elif t_globalStridesArgIdx[i] < 0: + t = -1 - t_globalStridesArgIdx[i] else: new_idx = self.getOriginArgIdx(t_globalStridesArgIdx[i], args) t = args[new_idx]