@@ -386,13 +386,14 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
386386 T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
387387 for i in range (128 ):
388388 threadIdx_x = T .launch_thread ("threadIdx.x" , 128 )
389- red_buf0 = T .allocate ([1 ], "float32" , "local " )
390- red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local " )
389+ red_result = T .allocate ([1 ], "float32" , "shared " )
390+ red_result_1 = T .Buffer ((1 ,), data = red_result , scope = "shared " )
391391 with T .attr (
392392 T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
393393 "reduce_scope" ,
394394 T .reinterpret ("handle" , T .uint64 (0 )),
395395 ):
396+ red_buf0 = T .allocate ([1 ], "float32" , "local" )
396397 mask = T .allocate ([1 ], "uint32" , "local" )
397398 t0 = T .allocate ([1 ], "float32" , "local" )
398399 red_buf0_1 = T .allocate ([1 ], "float32" , "local" )
@@ -415,11 +416,11 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
415416 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
416417 t0_2 [0 ] = T .tvm_warp_shuffle_down (mask_2 [0 ], red_buf0_2 [0 ], 1 , 32 , 32 )
417418 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
418- red_buf0_2 [0 ] = T .tvm_warp_shuffle (mask_2 [0 ], red_buf0_2 [0 ], 0 , 32 , 32 )
419419 red_buf_staging_1 = T .Buffer ((4 ,), data = red_buf_staging , scope = "shared" )
420420 if threadIdx_x % 32 == 0 :
421421 red_buf_staging_1 [threadIdx_x // 32 ] = red_buf0_2 [0 ]
422422 T .tvm_storage_sync ("shared" )
423+ red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local" )
423424 if threadIdx_x < 4 :
424425 red_buf0_3 [0 ] = red_buf_staging_1 [threadIdx_x ]
425426 mask_3 = T .Buffer ((1 ,), "uint32" , data = mask , scope = "local" )
@@ -429,10 +430,12 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
429430 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
430431 t0_3 [0 ] = T .tvm_warp_shuffle_down (mask_3 [0 ], red_buf0_3 [0 ], 1 , 32 , 32 )
431432 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
432- red_buf0_3 [0 ] = T .tvm_warp_shuffle (mask_3 [0 ], red_buf0_3 [0 ], 0 , 32 , 32 )
433+ if threadIdx_x == 0 :
434+ red_result_1 [0 ] = red_buf0_3 [0 ]
435+ T .tvm_storage_sync ("shared" )
433436 if threadIdx_x == 0 :
434437 B_1 = T .Buffer ((128 ,), data = B .data )
435- B_1 [i ] = red_buf0_3 [0 ]
438+ B_1 [i ] = red_result_1 [0 ]
436439
437440
438441class TestMultiWarpReduce2 (BaseCompare ):
@@ -459,13 +462,14 @@ def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
459462 def expected (A : T .Buffer ((1 , 1024 ), "float32" ), B : T .Buffer ((1 ,), "float32" )):
460463 T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
461464 threadIdx_x = T .launch_thread ("threadIdx.x" , 1024 )
462- red_buf0 = T .allocate ([1 ], "float32" , "local " )
463- red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local " )
465+ red_result = T .allocate ([1 ], "float32" , "shared " )
466+ red_result_1 = T .Buffer ((1 ,), data = red_result , scope = "shared " )
464467 with T .attr (
465468 T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
466469 "reduce_scope" ,
467470 T .reinterpret ("handle" , T .uint64 (0 )),
468471 ):
472+ red_buf0 = T .allocate ([1 ], "float32" , "local" )
469473 mask = T .allocate ([1 ], "uint32" , "local" )
470474 t0 = T .allocate ([1 ], "float32" , "local" )
471475 red_buf0_1 = T .allocate ([1 ], "float32" , "local" )
@@ -488,11 +492,11 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
488492 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
489493 t0_2 [0 ] = T .tvm_warp_shuffle_down (mask_2 [0 ], red_buf0_2 [0 ], 1 , 32 , 32 )
490494 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
491- red_buf0_2 [0 ] = T .tvm_warp_shuffle (mask_2 [0 ], red_buf0_2 [0 ], 0 , 32 , 32 )
492495 red_buf_staging_1 = T .Buffer ((32 ,), data = red_buf_staging , scope = "shared" )
493496 if threadIdx_x % 32 == 0 :
494497 red_buf_staging_1 [threadIdx_x // 32 ] = red_buf0_2 [0 ]
495498 T .tvm_storage_sync ("shared" )
499+ red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local" )
496500 if threadIdx_x < 32 :
497501 red_buf0_3 [0 ] = red_buf_staging_1 [threadIdx_x ]
498502 mask_3 = T .Buffer ((1 ,), "uint32" , data = mask , scope = "local" )
@@ -508,10 +512,12 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
508512 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
509513 t0_3 [0 ] = T .tvm_warp_shuffle_down (mask_3 [0 ], red_buf0_3 [0 ], 1 , 32 , 32 )
510514 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
511- red_buf0_3 [0 ] = T .tvm_warp_shuffle (mask_3 [0 ], red_buf0_3 [0 ], 0 , 32 , 32 )
515+ if threadIdx_x == 0 :
516+ red_result_1 [0 ] = red_buf0_3 [0 ]
517+ T .tvm_storage_sync ("shared" )
512518 if threadIdx_x == 0 :
513519 B_1 = T .Buffer ((1 ,), data = B .data )
514- B_1 [0 ] = red_buf0_3 [0 ]
520+ B_1 [0 ] = red_result_1 [0 ]
515521
516522
517523class TestMultiGroupMultiWarpReduction (BaseCompare ):
@@ -543,14 +549,15 @@ def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
543549 def expected (A : T .Buffer ((4 , 128 ), "float32" ), B : T .Buffer ((4 ,), "float32" )):
544550 T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
545551 threadIdx_y = T .launch_thread ("threadIdx.y" , 4 )
546- red_buf0 = T .allocate ([1 ], "float32" , "local " )
552+ red_result = T .allocate ([4 ], "float32" , "shared " )
547553 threadIdx_x = T .launch_thread ("threadIdx.x" , 128 )
548- red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local " )
554+ red_result_1 = T .Buffer ((4 ,), data = red_result , scope = "shared " )
549555 with T .attr (
550556 T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
551557 "reduce_scope" ,
552558 T .reinterpret ("handle" , T .uint64 (0 )),
553559 ):
560+ red_buf0 = T .allocate ([1 ], "float32" , "local" )
554561 mask = T .allocate ([1 ], "uint32" , "local" )
555562 t0 = T .allocate ([1 ], "float32" , "local" )
556563 red_buf0_1 = T .allocate ([1 ], "float32" , "local" )
@@ -573,11 +580,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
573580 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
574581 t0_2 [0 ] = T .tvm_warp_shuffle_down (mask_2 [0 ], red_buf0_2 [0 ], 1 , 32 , 32 )
575582 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
576- red_buf0_2 [0 ] = T .tvm_warp_shuffle (mask_2 [0 ], red_buf0_2 [0 ], 32 * threadIdx_y , 32 , 32 )
577583 red_buf_staging_1 = T .Buffer ((16 ,), data = red_buf_staging , scope = "shared" )
578584 if threadIdx_x % 32 == 0 :
579585 red_buf_staging_1 [threadIdx_y * 4 + threadIdx_x // 32 ] = red_buf0_2 [0 ]
580586 T .tvm_storage_sync ("shared" )
587+ red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local" )
581588 if threadIdx_x < 16 :
582589 red_buf0_3 [0 ] = red_buf_staging_1 [threadIdx_x ]
583590 mask_3 = T .Buffer ((1 ,), "uint32" , data = mask , scope = "local" )
@@ -589,10 +596,12 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
589596 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
590597 t0_3 [0 ] = T .tvm_warp_shuffle_down (mask_3 [0 ], red_buf0_3 [0 ], 1 , 32 , 32 )
591598 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
592- red_buf0_3 [0 ] = T .tvm_warp_shuffle (mask_3 [0 ], red_buf0_3 [0 ], 4 * threadIdx_y , 32 , 32 )
599+ if threadIdx_x == 0 :
600+ red_result_1 [0 ] = red_buf0_3 [0 ]
601+ T .tvm_storage_sync ("shared" )
593602 if threadIdx_x == 0 :
594603 B_1 = T .Buffer ((4 ,), data = B .data )
595- B_1 [threadIdx_y ] = red_buf0_3 [0 ]
604+ B_1 [threadIdx_y ] = red_result_1 [0 ]
596605
597606
598607class TestMultiGroupMultiWarpPredicatedReduction (BaseCompare ):
@@ -626,19 +635,20 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
626635 T .func_attr ({"target" : T .target ("cuda" , host = "llvm" )})
627636 threadIdx_y = T .launch_thread ("threadIdx.y" , 2 )
628637 in_thread_B = T .allocate ([1 ], "float32" , "local" )
629- red_buf0 = T .allocate ([1 ], "float32" , "local " )
638+ red_result = T .allocate ([2 ], "float32" , "shared " )
630639 threadIdx_x = T .launch_thread ("threadIdx.x" , 512 )
631640 in_thread_B_1 = T .Buffer ((1 ,), data = in_thread_B , scope = "local" )
632641 in_thread_B_1 [0 ] = T .float32 (0 )
633642 if threadIdx_x < 70 :
634643 A_1 = T .Buffer ((140 ,), data = A .data )
635644 in_thread_B_1 [0 ] = in_thread_B_1 [0 ] + A_1 [threadIdx_y * 70 + threadIdx_x ]
636- red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local " )
645+ red_result_1 = T .Buffer ((2 ,), data = red_result , scope = "shared " )
637646 with T .attr (
638647 T .comm_reducer (lambda x0 , y0 : x0 + y0 , [T .float32 (0 )]),
639648 "reduce_scope" ,
640649 T .reinterpret ("handle" , T .uint64 (0 )),
641650 ):
651+ red_buf0 = T .allocate ([1 ], "float32" , "local" )
642652 mask = T .allocate ([1 ], "uint32" , "local" )
643653 t0 = T .allocate ([1 ], "float32" , "local" )
644654 red_buf0_1 = T .allocate ([1 ], "float32" , "local" )
@@ -660,11 +670,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
660670 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
661671 t0_2 [0 ] = T .tvm_warp_shuffle_down (mask_2 [0 ], red_buf0_2 [0 ], 1 , 32 , 32 )
662672 red_buf0_2 [0 ] = red_buf0_2 [0 ] + t0_2 [0 ]
663- red_buf0_2 [0 ] = T .tvm_warp_shuffle (mask_2 [0 ], red_buf0_2 [0 ], 32 * threadIdx_y , 32 , 32 )
664673 red_buf_staging_1 = T .Buffer ((32 ,), data = red_buf_staging , scope = "shared" )
665674 if threadIdx_x % 32 == 0 :
666675 red_buf_staging_1 [threadIdx_y * 16 + threadIdx_x // 32 ] = red_buf0_2 [0 ]
667676 T .tvm_storage_sync ("shared" )
677+ red_buf0_3 = T .Buffer ((1 ,), data = red_buf0 , scope = "local" )
668678 if threadIdx_x < 32 :
669679 red_buf0_3 [0 ] = red_buf_staging_1 [threadIdx_x ]
670680 mask_3 = T .Buffer ((1 ,), "uint32" , data = mask , scope = "local" )
@@ -680,10 +690,12 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
680690 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
681691 t0_3 [0 ] = T .tvm_warp_shuffle_down (mask_3 [0 ], red_buf0_3 [0 ], 1 , 32 , 32 )
682692 red_buf0_3 [0 ] = red_buf0_3 [0 ] + t0_3 [0 ]
683- red_buf0_3 [0 ] = T .tvm_warp_shuffle (mask_3 [0 ], red_buf0_3 [0 ], 16 * threadIdx_y , 32 , 32 )
693+ if threadIdx_x == 0 :
694+ red_result_1 [0 ] = red_buf0_3 [0 ]
695+ T .tvm_storage_sync ("shared" )
684696 if threadIdx_x == 0 :
685697 B_1 = T .Buffer ((2 ,), data = B .data )
686- B_1 [threadIdx_y ] = red_buf0_3 [0 ]
698+ B_1 [threadIdx_y ] = red_result_1 [0 ]
687699
688700
689701if __name__ == "__main__" :
0 commit comments