Skip to content

Commit b5f327c

Browse files
committed
Refactor shared memory allocation in GEMM tests
- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`. - This change simplifies the allocation process and aligns with the updated GEMM function signatures.
1 parent aa62efb commit b5f327c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def main(
3131
C: T.Tensor((M, N), out_dtype),
3232
):
3333
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
34-
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
35-
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
34+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
35+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
3636
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
3737
T.clear(C_local)
3838
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
@@ -162,8 +162,8 @@ def main(
162162
C: T.Tensor((M, N), out_dtype),
163163
):
164164
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
165-
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
166-
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
165+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
166+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
167167
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
168168
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
169169
T.clear(C_local)
@@ -296,8 +296,8 @@ def main(
296296
C: T.Tensor((M, N), out_dtype),
297297
):
298298
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
299-
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
300-
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
299+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
300+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
301301
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
302302
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
303303
T.clear(C_local)
@@ -431,8 +431,8 @@ def main(
431431
C: T.Tensor((M, N), out_dtype),
432432
):
433433
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
434-
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
435-
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
434+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
435+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
436436
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
437437
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
438438
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

0 commit comments

Comments
 (0)