Skip to content

Commit d26fdcf

Browse files
authored
[Hopper TMA] Add CUDA codegen support for bulk asynchronous copy (#15656)
* [Hopper TMA] Add CUDA codegen support for bulk asynchronous copy * fix typo in comments; use barrier ptr and offset rather than string
1 parent 04ee895 commit d26fdcf

File tree

11 files changed

+421
-75
lines changed

11 files changed

+421
-75
lines changed

include/tvm/tir/builtin.h

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,14 +645,29 @@ TVM_DLL const Op& ptx_mma_sp();
645645
TVM_DLL const Op& ptx_ldmatrix();
646646

647647
/*!
648-
* \brief tvm intrinsics for ptx async copy from global to shared memory
649-
*
650-
* void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
651-
* bytes);
648+
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async
652649
*
650+
* void ptx_cp_async(Var shared_ptr,
651+
* Expr shared_offset,
652+
* Var global_ptr,
653+
* Expr global_offset,
654+
* size_t bytes);
653655
*/
654656
TVM_DLL const Op& ptx_cp_async();
655657

658+
/*!
659+
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk
660+
*
661+
* void ptx_cp_async(Var shared_ptr,
662+
* Expr shared_offset,
663+
* Var global_ptr,
664+
* Expr global_offset,
665+
* size_t bytes,
666+
* Var barrier_ptr,
667+
* Expr barrier_offset);
668+
*/
669+
TVM_DLL const Op& ptx_cp_async_bulk();
670+
656671
/*!
657672
* \brief tvm intrinsics for ptx async copy commit and wait.
658673
*
@@ -666,31 +681,39 @@ TVM_DLL const Op& ptx_wait_group();
666681
/*!
667682
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
668683
*
669-
* ptx_cp_async_barrier(barrier_array, barrier_id)
684+
* ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
670685
*
671686
*/
672687
TVM_DLL const Op& ptx_cp_async_barrier();
673688

674689
/*!
675690
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
676691
*
677-
* ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
692+
* ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count)
678693
*
679694
*/
680695
TVM_DLL const Op& ptx_init_barrier_thread_count();
681696

682697
/*!
683698
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
684699
*
685-
* ptx_arrive_barrier(barrier_array, barrier_id)
700+
* ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
686701
*
687702
*/
688703
TVM_DLL const Op& ptx_arrive_barrier();
689704

705+
/*!
706+
* \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
707+
*
708+
* ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count)
709+
*
710+
*/
711+
TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
712+
690713
/*!
691714
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
692715
*
693-
* ptx_wait_barrier(barrier_array, barrier_id)
716+
* ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
694717
*
695718
*/
696719
TVM_DLL const Op& ptx_wait_barrier();

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,7 @@ def wrapped(*args, **kwargs):
18471847
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
18481848
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
18491849
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
1850+
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
18501851
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
18511852
assume = _op_wrapper(_tir_op.assume)
18521853
undef = _op_wrapper(_tir_op.undef)
@@ -1876,6 +1877,7 @@ def wrapped(*args, **kwargs):
18761877
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
18771878
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
18781879
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
1880+
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
18791881
mma_store = _dtype_forward(_tir_op.mma_store)
18801882
mma_fill = _dtype_forward(_tir_op.mma_fill)
18811883
vectorlow = _dtype_forward(_tir_op.vectorlow)
@@ -2115,11 +2117,13 @@ def wrapped(*args, **kwargs):
21152117
"ptx_mma_sp",
21162118
"ptx_ldmatrix",
21172119
"ptx_cp_async",
2120+
"ptx_cp_async_bulk",
21182121
"ptx_wait_group",
21192122
"ptx_commit_group",
21202123
"ptx_cp_async_barrier",
21212124
"ptx_init_barrier_thread_count",
21222125
"ptx_arrive_barrier",
2126+
"ptx_arrive_barrier_expect_tx",
21232127
"ptx_wait_barrier",
21242128
"mma_store",
21252129
"mma_fill",

python/tvm/tir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@
6363
from .op import (
6464
ptx_ldmatrix,
6565
ptx_cp_async,
66+
ptx_cp_async_bulk,
6667
ptx_commit_group,
6768
ptx_wait_group,
6869
ptx_cp_async_barrier,
6970
ptx_init_barrier_thread_count,
7071
ptx_arrive_barrier,
72+
ptx_arrive_barrier_expect_tx,
7173
ptx_wait_barrier,
7274
)
7375
from .op import vectorlow, vectorhigh, vectorcombine

python/tvm/tir/op.py

Lines changed: 104 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme
13351335

13361336

13371337
def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
1338-
"""TVM intrinsic for ptx async copy from global to shared memory
1338+
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async
13391339
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
13401340
13411341
Parameters
@@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
13681368
)
13691369

13701370

1371+
def ptx_cp_async_bulk(
1372+
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset
1373+
):
1374+
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
1375+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
1376+
1377+
Parameters
1378+
----------
1379+
dtype : str
1380+
The data type of the result.
1381+
1382+
shared_ptr : Var
1383+
The shared memory pointer variable.
1384+
1385+
shared_offset : Expr
1386+
The offset of shared memory pointer.
1387+
1388+
global_ptr : Var
1389+
The global memory pointer variable.
1390+
1391+
global_offset : Expr
1392+
The offset of global memory pointer.
1393+
1394+
bytes : int
1395+
The data size to copy.
1396+
1397+
barrier_ptr : Var
1398+
The barrier shared memory pointer variable.
1399+
1400+
barrier_id : int
1401+
The offset of the barrier shared memory pointer.
1402+
1403+
Returns
1404+
-------
1405+
call : PrimExpr
1406+
The call expression.
1407+
"""
1408+
return call_intrin(
1409+
dtype,
1410+
"tir.ptx_cp_async_bulk",
1411+
shared_ptr,
1412+
shared_offset,
1413+
global_ptr,
1414+
global_offset,
1415+
bytes,
1416+
barrier_ptr,
1417+
barrier_offset,
1418+
)
1419+
1420+
13711421
def ptx_commit_group():
13721422
"""TVM intrinsic for ptx async copy commit
13731423
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
@@ -1397,84 +1447,116 @@ def ptx_wait_group(num):
13971447
return call_intrin("", "tir.ptx_wait_group", num)
13981448

13991449

1400-
def ptx_cp_async_barrier(barrier_arr, barrier_id):
1450+
def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
14011451
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
14021452
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
14031453
14041454
Parameters
14051455
----------
1406-
barrier_arr : string
1407-
The name of the barrier array in shared memory
1456+
barrier_ptr : Var
1457+
The barrier shared memory pointer variable.
1458+
14081459
barrier_id : int
1409-
Index into the barrier array
1460+
The offset of the barrier shared memory pointer.
14101461
14111462
Returns
14121463
-------
14131464
call : PrimExpr
14141465
The call expression.
14151466
"""
1416-
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
1467+
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset)
14171468

14181469

1419-
def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
1470+
def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
14201471
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
14211472
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
14221473
14231474
Parameters
14241475
----------
1425-
barrier_arr : string
1426-
The name of the barrier array in shared memory
1476+
barrier_ptr : Var
1477+
The barrier shared memory pointer variable.
1478+
14271479
barrier_id : int
1428-
Index into the barrier array
1480+
The offset of the barrier shared memory pointer.
1481+
14291482
thread_count : int
1430-
Number of threads expected to arrive at the barrier
1483+
Number of threads expected to arrive at the barrier.
14311484
14321485
Returns
14331486
-------
14341487
call : PrimExpr
14351488
The call expression.
14361489
"""
14371490
return call_intrin(
1438-
"", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
1491+
"", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count
14391492
)
14401493

14411494

1442-
def ptx_arrive_barrier(barrier_arr, barrier_id):
1495+
def ptx_arrive_barrier(barrier_ptr, barrier_offset):
14431496
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
14441497
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
14451498
14461499
Parameters
14471500
----------
1448-
barrier_arr : string
1449-
The name of the barrier array in shared memory
1501+
barrier_ptr : Var
1502+
The barrier shared memory pointer variable.
1503+
1504+
barrier_id : int
1505+
The offset of the barrier shared memory pointer.
1506+
1507+
Returns
1508+
-------
1509+
call : PrimExpr
1510+
The call expression.
1511+
"""
1512+
return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset)
1513+
1514+
1515+
def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
1516+
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
1517+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
1518+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
1519+
1520+
Parameters
1521+
----------
1522+
barrier_ptr : Var
1523+
The barrier shared memory pointer variable.
1524+
14501525
barrier_id : int
1451-
Index into the barrier array
1526+
The offset of the barrier shared memory pointer.
1527+
1528+
byte_count : int
1529+
Increases the tx count of the mbarrier object to track completion of
1530+
addtional async transactions.
14521531
14531532
Returns
14541533
-------
14551534
call : PrimExpr
14561535
The call expression.
14571536
"""
1458-
return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
1537+
return call_intrin(
1538+
"", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count
1539+
)
14591540

14601541

1461-
def ptx_wait_barrier(barrier_arr, barrier_id):
1542+
def ptx_wait_barrier(barrier_ptr, barrier_offset):
14621543
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
14631544
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
14641545
14651546
Parameters
14661547
----------
1467-
barrier_arr : string
1468-
The name of the barrier array in shared memory
1548+
barrier_ptr : Var
1549+
The barrier shared memory pointer variable.
1550+
14691551
barrier_id : int
1470-
Index into the barrier array
1552+
The offset of the barrier shared memory pointer.
14711553
14721554
Returns
14731555
-------
14741556
call : PrimExpr
14751557
The call expression.
14761558
"""
1477-
return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
1559+
return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)
14781560

14791561

14801562
def vectorlow(dtype, vec):

0 commit comments

Comments
 (0)