@@ -31,8 +31,8 @@ def main(
3131 C : T .Tensor ((M , N ), out_dtype ),
3232 ):
3333 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
34- A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
35- B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
34+ A_shared = T .alloc_shared (A_shared_shape , in_dtype )
35+ B_shared = T .alloc_shared (B_shared_shape , in_dtype )
3636 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
3737 T .clear (C_local )
3838 for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
@@ -162,8 +162,8 @@ def main(
162162 C : T .Tensor ((M , N ), out_dtype ),
163163 ):
164164 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
165- A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
166- B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
165+ A_shared = T .alloc_shared (A_shared_shape , in_dtype )
166+ B_shared = T .alloc_shared (B_shared_shape , in_dtype )
167167 A_frag = T .alloc_fragment (A_frag_shape , in_dtype )
168168 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
169169 T .clear (C_local )
@@ -296,8 +296,8 @@ def main(
296296 C : T .Tensor ((M , N ), out_dtype ),
297297 ):
298298 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
299- A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
300- B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
299+ A_shared = T .alloc_shared (A_shared_shape , in_dtype )
300+ B_shared = T .alloc_shared (B_shared_shape , in_dtype )
301301 B_frag = T .alloc_fragment (B_frag_shape , in_dtype )
302302 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
303303 T .clear (C_local )
@@ -431,8 +431,8 @@ def main(
431431 C : T .Tensor ((M , N ), out_dtype ),
432432 ):
433433 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = threads ) as (bx , by ):
434- A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = "shared" )
435- B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = "shared" )
434+ A_shared = T .alloc_shared (A_shared_shape , in_dtype )
435+ B_shared = T .alloc_shared (B_shared_shape , in_dtype )
436436 A_frag = T .alloc_fragment (A_frag_shape , in_dtype )
437437 B_frag = T .alloc_fragment (B_frag_shape , in_dtype )
438438 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
0 commit comments