diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index bb0003b697ca..20ffbc1df450 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -217,9 +217,36 @@ class CUDAWrappedFunc { } } CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); - CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), - wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); + CUresult result; + + if (launch_param_config_.use_programtic_dependent_launch()) { + CUlaunchConfig config{}; + CUlaunchAttribute attribute[1]{}; + attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attribute[0].value.programmaticStreamSerializationAllowed = 1; + + config.attrs = attribute; + config.numAttrs = 1; + config.hStream = strm; + config.gridDimX = wl.grid_dim(0); + config.gridDimY = wl.grid_dim(1); + config.gridDimZ = wl.grid_dim(2); + config.blockDimX = wl.block_dim(0); + config.blockDimY = wl.block_dim(1); + config.blockDimZ = wl.block_dim(2); + config.sharedMemBytes = wl.dyn_shmem_size; + + result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else if (launch_param_config_.use_cooperative_launch()) { + result = cuLaunchCooperativeKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args); + } else { + result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, + strm, void_args, nullptr); + } + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -257,6 +284,8 @@ class CUDAWrappedFunc { // Cached last dynamic shared memory size per device and whether it's initialized mutable std::array dyn_smem_last_; mutable std::array dyn_smem_initialized_; + // have pdl setting + bool has_programmatic_dependent_launch_; }; class CUDAPrepGlobalBarrier { diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 85b83289f4d3..aceb97b58374 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -48,6 +48,10 @@ namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +/*! \brief A tag to specify whether or not use programatic dependent launch */ +constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; +/*! \brief A tag to specify whether or not use cooperative launch */ +constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; } // namespace launch_param diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 914fe67819de..c2cd792220f5 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -247,6 +247,10 @@ class LaunchParamConfig { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; + } else if (tag == launch_param::kUseProgramaticDependentLaunch) { + use_programmatic_dependent_launch_ = true; + } else if (tag == launch_param::kUseCooperativeLaunch) { + use_cooperative_launch_ = true; } else { ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); @@ -281,6 +285,10 @@ class LaunchParamConfig { // return the work dim size_t work_dim() const { return work_dim_; } + bool use_programtic_dependent_launch() const { return use_programmatic_dependent_launch_; } + + bool use_cooperative_launch() const { return use_cooperative_launch_; } + private: /*! \brief base axis */ size_t base_; @@ -290,6 +298,10 @@ class LaunchParamConfig { std::vector arg_index_map_; /*! \brief Whether or not use dynamic shared memory. */ bool use_dyn_shared_memory_{false}; + /*! \brief Whether or not use programmatic dependent launch. */ + bool use_programmatic_dependent_launch_{false}; + /*! \brief Whether or not use cooperative launch. */ + bool use_cooperative_launch_{false}; }; } // namespace runtime