@@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme
13351335
13361336
13371337def ptx_cp_async (dtype , shared_ptr , shared_offset , global_ptr , global_offset , bytes ):
1338- """TVM intrinsic for ptx async copy from global to shared memory
1338+ """TVM intrinsic for ptx async copy from global to shared memory using cp.async
13391339 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
13401340
13411341 Parameters
@@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
13681368 )
13691369
13701370
1371+ def ptx_cp_async_bulk (
1372+ dtype , shared_ptr , shared_offset , global_ptr , global_offset , bytes , barrier_ptr , barrier_offset
1373+ ):
1374+ """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
1375+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
1376+
1377+ Parameters
1378+ ----------
1379+ dtype : str
1380+ The data type of the result.
1381+
1382+ shared_ptr : Var
1383+ The shared memory pointer variable.
1384+
1385+ shared_offset : Expr
1386+ The offset of shared memory pointer.
1387+
1388+ global_ptr : Var
1389+ The global memory pointer variable.
1390+
1391+ global_offset : Expr
1392+ The offset of global memory pointer.
1393+
1394+ bytes : int
1395+ The data size to copy.
1396+
1397+ barrier_ptr : Var
1398+ The barrier shared memory pointer variable.
1399+
1400+ barrier_id : int
1401+ The offset of the barrier shared memory pointer.
1402+
1403+ Returns
1404+ -------
1405+ call : PrimExpr
1406+ The call expression.
1407+ """
1408+ return call_intrin (
1409+ dtype ,
1410+ "tir.ptx_cp_async_bulk" ,
1411+ shared_ptr ,
1412+ shared_offset ,
1413+ global_ptr ,
1414+ global_offset ,
1415+ bytes ,
1416+ barrier_ptr ,
1417+ barrier_offset ,
1418+ )
1419+
1420+
13711421def ptx_commit_group ():
13721422 """TVM intrinsic for ptx async copy commit
13731423 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
@@ -1397,84 +1447,116 @@ def ptx_wait_group(num):
13971447 return call_intrin ("" , "tir.ptx_wait_group" , num )
13981448
13991449
1400- def ptx_cp_async_barrier (barrier_arr , barrier_id ):
1450+ def ptx_cp_async_barrier (barrier_ptr , barrier_offset ):
14011451 """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
14021452 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
14031453
14041454 Parameters
14051455 ----------
1406- barrier_arr : string
1407- The name of the barrier array in shared memory
1456+ barrier_ptr : Var
1457+ The barrier shared memory pointer variable.
1458+
14081459 barrier_id : int
1409- Index into the barrier array
1460+ The offset of the barrier shared memory pointer.
14101461
14111462 Returns
14121463 -------
14131464 call : PrimExpr
14141465 The call expression.
14151466 """
1416- return call_intrin ("" , "tir.ptx_cp_async_barrier" , barrier_arr , barrier_id )
1467+ return call_intrin ("" , "tir.ptx_cp_async_barrier" , barrier_ptr , barrier_offset )
14171468
14181469
1419- def ptx_init_barrier_thread_count (barrier_arr , barrier_id , thread_count ):
1470+ def ptx_init_barrier_thread_count (barrier_ptr , barrier_offset , thread_count ):
14201471 """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
14211472 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
14221473
14231474 Parameters
14241475 ----------
1425- barrier_arr : string
1426- The name of the barrier array in shared memory
1476+ barrier_ptr : Var
1477+ The barrier shared memory pointer variable.
1478+
14271479 barrier_id : int
1428- Index into the barrier array
1480+ The offset of the barrier shared memory pointer.
1481+
14291482 thread_count : int
1430- Number of threads expected to arrive at the barrier
1483+ Number of threads expected to arrive at the barrier.
14311484
14321485 Returns
14331486 -------
14341487 call : PrimExpr
14351488 The call expression.
14361489 """
14371490 return call_intrin (
1438- "" , "tir.ptx_init_barrier_thread_count" , barrier_arr , barrier_id , thread_count
1491+ "" , "tir.ptx_init_barrier_thread_count" , barrier_ptr , barrier_offset , thread_count
14391492 )
14401493
14411494
1442- def ptx_arrive_barrier (barrier_arr , barrier_id ):
1495+ def ptx_arrive_barrier (barrier_ptr , barrier_offset ):
14431496 """TVM intrinsic for ptx barrier arrival using mbarrier.arrive
14441497 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
14451498
14461499 Parameters
14471500 ----------
1448- barrier_arr : string
1449- The name of the barrier array in shared memory
1501+ barrier_ptr : Var
1502+ The barrier shared memory pointer variable.
1503+
1504+ barrier_id : int
1505+ The offset of the barrier shared memory pointer.
1506+
1507+ Returns
1508+ -------
1509+ call : PrimExpr
1510+ The call expression.
1511+ """
1512+ return call_intrin ("" , "tir.ptx_arrive_barrier" , barrier_ptr , barrier_offset )
1513+
1514+
1515+ def ptx_arrive_barrier_expect_tx (barrier_ptr , barrier_offset , byte_count ):
1516+ """TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
1517+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
1518+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
1519+
1520+ Parameters
1521+ ----------
1522+ barrier_ptr : Var
1523+ The barrier shared memory pointer variable.
1524+
14501525 barrier_id : int
1451- Index into the barrier array
1526+ The offset of the barrier shared memory pointer.
1527+
1528+ byte_count : int
1529+ Increases the tx count of the mbarrier object to track completion of
1530+ addtional async transactions.
14521531
14531532 Returns
14541533 -------
14551534 call : PrimExpr
14561535 The call expression.
14571536 """
1458- return call_intrin ("" , "tir.ptx_arrive_barrier" , barrier_arr , barrier_id )
1537+ return call_intrin (
1538+ "" , "tir.ptx_arrive_barrier_expect_tx" , barrier_ptr , barrier_offset , byte_count
1539+ )
14591540
14601541
1461- def ptx_wait_barrier (barrier_arr , barrier_id ):
1542+ def ptx_wait_barrier (barrier_ptr , barrier_offset ):
14621543 """TVM intrinsic for ptx barrier wait using mbarrier.try_wait
14631544 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
14641545
14651546 Parameters
14661547 ----------
1467- barrier_arr : string
1468- The name of the barrier array in shared memory
1548+ barrier_ptr : Var
1549+ The barrier shared memory pointer variable.
1550+
14691551 barrier_id : int
1470- Index into the barrier array
1552+ The offset of the barrier shared memory pointer.
14711553
14721554 Returns
14731555 -------
14741556 call : PrimExpr
14751557 The call expression.
14761558 """
1477- return call_intrin ("" , "tir.ptx_wait_barrier" , barrier_arr , barrier_id )
1559+ return call_intrin ("" , "tir.ptx_wait_barrier" , barrier_ptr , barrier_offset )
14781560
14791561
14801562def vectorlow (dtype , vec ):
0 commit comments