diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 4632161af..005cd64d7 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -10,6 +10,28 @@ namespace tl { +TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar, + uint32_t size) { + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" + "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar) + :); +} + +TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, + uint64_t &smem_mbar, uint32_t size, + uint16_t mask) { + uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes." + "multicast::cluster [%0], [%1], %2, [%3], %4; \n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "r"(size), "r"(smem_int_mbar), "h"(mask) + :); +} + TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, void const *const smem_ptr, int32_t const &crd0) { uint64_t gmem_int_desc = reinterpret_cast(&descriptor); @@ -105,6 +127,15 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor, : "memory"); } +TL_DEVICE void tma_store(void *dst_gmem_ptr, void *smem_ptr, uint32_t size) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%1], [%0], %2; \n" ::"r"( + smem_int_ptr), + "l"(dst_gmem_ptr), "r"(size) + :); +} + TL_DEVICE void tma_store(const CUtensorMap &descriptor, void const *const smem_ptr, int32_t const &crd0) { uint64_t gmem_int_desc = reinterpret_cast(&descriptor);