diff --git a/testing/python/language/test_tilelang_language_cooperative.py b/testing/python/language/test_tilelang_language_cooperative.py index 5b21172e3..0a4d7a6df 100644 --- a/testing/python/language/test_tilelang_language_cooperative.py +++ b/testing/python/language/test_tilelang_language_cooperative.py @@ -6,19 +6,20 @@ @tilelang.jit def grid_sync(N=1024): - block = 128 + block = 64 @T.prim_func def kernel(A: T.Tensor((N), T.float32)): with T.Kernel(T.ceildiv(N, block), threads=128) as bx: + A_local = T.alloc_fragment((block), dtype=T.float32) n_idx = bx * block for i in T.Parallel(block): - if n_idx + i < N: - A[n_idx + i] = n_idx + i + A[n_idx + i] = n_idx + i T.sync_grid() for i in T.Parallel(block): - if n_idx + i < N: - A[n_idx + i] = A[n_idx + i] + A[N - n_idx - i - 1] + A_local[i] = A[N - n_idx - i - 1] + T.sync_grid() + A[n_idx + i] = A[n_idx + i] + A_local[i] return kernel