Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,14 +645,29 @@ TVM_DLL const Op& ptx_mma_sp();
TVM_DLL const Op& ptx_ldmatrix();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory
*
* void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
* bytes);
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async
*
* void ptx_cp_async(Var shared_ptr,
* Expr shared_offset,
* Var global_ptr,
* Expr global_offset,
* size_t bytes);
*/
TVM_DLL const Op& ptx_cp_async();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk
*
* void ptx_cp_async(Var shared_ptr,
* Expr shared_offset,
* Var global_ptr,
* Expr global_offset,
* size_t bytes,
* Var barrier_ptr,
* Expr barrier_offset);
*/
TVM_DLL const Op& ptx_cp_async_bulk();

/*!
* \brief tvm intrinsics for ptx async copy commit and wait.
*
Expand All @@ -666,31 +681,39 @@ TVM_DLL const Op& ptx_wait_group();
/*!
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
*
* ptx_cp_async_barrier(barrier_array, barrier_id)
* ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();

/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
*
* ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
* ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();

/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
* ptx_arrive_barrier(barrier_array, barrier_id)
* ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();

/*!
* \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
*
* ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count)
*
*/
TVM_DLL const Op& ptx_arrive_barrier_expect_tx();

/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
* ptx_wait_barrier(barrier_array, barrier_id)
* ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_wait_barrier();
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,7 @@ def wrapped(*args, **kwargs):
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -1876,6 +1877,7 @@ def wrapped(*args, **kwargs):
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
mma_store = _dtype_forward(_tir_op.mma_store)
mma_fill = _dtype_forward(_tir_op.mma_fill)
vectorlow = _dtype_forward(_tir_op.vectorlow)
Expand Down Expand Up @@ -2115,11 +2117,13 @@ def wrapped(*args, **kwargs):
"ptx_mma_sp",
"ptx_ldmatrix",
"ptx_cp_async",
"ptx_cp_async_bulk",
"ptx_wait_group",
"ptx_commit_group",
"ptx_cp_async_barrier",
"ptx_init_barrier_thread_count",
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
"mma_store",
"mma_fill",
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@
from .op import (
ptx_ldmatrix,
ptx_cp_async,
ptx_cp_async_bulk,
ptx_commit_group,
ptx_wait_group,
ptx_cp_async_barrier,
ptx_init_barrier_thread_count,
ptx_arrive_barrier,
ptx_arrive_barrier_expect_tx,
ptx_wait_barrier,
)
from .op import vectorlow, vectorhigh, vectorcombine
Expand Down
126 changes: 104 additions & 22 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme


def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Parameters
Expand Down Expand Up @@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
)


def ptx_cp_async_bulk(
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset
):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk

Parameters
----------
dtype : str
The data type of the result.

shared_ptr : Var
The shared memory pointer variable.

shared_offset : Expr
The offset of shared memory pointer.

global_ptr : Var
The global memory pointer variable.

global_offset : Expr
The offset of global memory pointer.

bytes : int
The data size to copy.

barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_cp_async_bulk",
shared_ptr,
shared_offset,
global_ptr,
global_offset,
bytes,
barrier_ptr,
barrier_offset,
)


def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Expand Down Expand Up @@ -1397,84 +1447,116 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)


def ptx_cp_async_barrier(barrier_arr, barrier_id):
def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset)


def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.

thread_count : int
Number of threads expected to arrive at the barrier
Number of threads expected to arrive at the barrier.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
"", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count
)


def ptx_arrive_barrier(barrier_arr, barrier_id):
def ptx_arrive_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset)


def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.

byte_count : int
Increases the tx count of the mbarrier object to track completion of
addtional async transactions.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
return call_intrin(
"", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count
)


def ptx_wait_barrier(barrier_arr, barrier_id):
def ptx_wait_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)


def vectorlow(dtype, vec):
Expand Down
Loading