@@ -357,5 +357,211 @@ TVM_DLL int GetCudaDeviceCount() {
357357
358358TVM_FFI_REGISTER_GLOBAL (" runtime.GetCudaDeviceCount" ).set_body_typed(GetCudaDeviceCount);
359359
360+ /* *
361+ * \brief FFI wrapper for cuTensorMapEncodeTiled.
362+ *
363+ * This function registers a global function `runtime.cuTensorMapEncodeTiled` that can be
364+ * called from other parts of the TVM runtime (e.g., Python). It wraps the CUDA Driver API
365+ * function `cuTensorMapEncodeTiled`, which initializes a tensor map descriptor (CUtensorMap).
366+ *
367+ * \param tensor_map (handle): A `void*` pointer to the CUtensorMap object to be initialized.
368+ * \param tensor_dtype (DataType): The TVM data type of the tensor.
369+ * \param tensor_rank (int): The rank (number of dimensions) of the tensor.
370+ * \param tensor_ptr (handle): A `void*` pointer to the start of the tensor in global memory.
371+ * \param global_shape (int...): `tensor_rank` integer arguments for the global tensor dimensions.
372+ * \param global_strides (int...): `tensor_rank - 1` integer arguments for the global tensor
373+ * strides. The stride for the innermost dimension is not provided as it's assumed to be contiguous.
374+ * \param shared_shape (int...): `tensor_rank` integer arguments for the shape of the tile (box)
375+ * in shared memory.
376+ * \param shared_strides (int...): `tensor_rank` integer arguments for the strides of the tile (box)
377+ * in shared memory.
378+ * \param interleaved_kind (int): An integer corresponding to the CUtensorMapInterleave enum.
379+ * \param swizzle_kind (int): An integer corresponding to the CUtensorMapSwizzle enum.
380+ * \param l2_promotion_kind (int): An integer corresponding to the CUtensorMapL2promotion enum.
381+ * \param oob_fill_kind (int): An integer corresponding to the CUtensorMapFloatOOBfill enum.
382+ */
383+ TVM_FFI_REGISTER_GLOBAL (" runtime.cuTensorMapEncodeTiled" )
384+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
385+ CHECK_GE (args.size (), 4 ) << " init_cuTensorMap expects at least 4 arguments" ;
386+ size_t arg_cnt = 0 ;
387+ CUtensorMap* tensor_map = static_cast <CUtensorMap*>(args[arg_cnt++].cast <void *>());
388+ runtime::DataType tensor_dtype = args[arg_cnt++].cast <runtime::DataType>();
389+ uint32_t tensor_rank = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
390+ void * tensor_ptr = static_cast <void *>(args[arg_cnt++].cast <void *>());
391+
392+ CHECK_EQ (args.size (), 4 + tensor_rank * 4 + 3 )
393+ << " cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments"
394+ << " tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank
395+ << " ), global_strides(" << tensor_rank - 1 << " ), shared_shape(" << tensor_rank
396+ << " ), shared_strides(" << tensor_rank << " ), interleaved_kind, swizzle_kind"
397+ << " , l2_promotion_kind, oob_fill_kind" ;
398+
399+ std::vector<cuuint64_t > global_shape (tensor_rank);
400+ std::vector<cuuint64_t > global_strides (tensor_rank);
401+ std::vector<uint32_t > shared_shape (tensor_rank);
402+ std::vector<uint32_t > shared_strides (tensor_rank);
403+ for (size_t i = 0 ; i < tensor_rank; ++i) {
404+ global_shape[i] = static_cast <cuuint64_t >(args[arg_cnt++].cast <int64_t >());
405+ }
406+ for (size_t i = 0 ; i < tensor_rank - 1 ; ++i) {
407+ global_strides[i] = static_cast <cuuint64_t >(args[arg_cnt++].cast <int64_t >());
408+ CHECK_EQ (global_strides[i] % 16 , 0 ) << " global strides must be multiple of 16" ;
409+ }
410+ for (size_t i = 0 ; i < tensor_rank; ++i) {
411+ shared_shape[i] = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
412+ CHECK_GE (shared_shape[i], 0 ) << " boxDim must be non-negative" ;
413+ CHECK_LE (shared_shape[i], 256 ) << " boxDim must be less than or equal to 256" ;
414+ }
415+ for (size_t i = 0 ; i < tensor_rank; ++i) {
416+ shared_strides[i] = static_cast <uint32_t >(args[arg_cnt++].cast <int32_t >());
417+ }
418+ auto interleaved_kind = static_cast <CUtensorMapInterleave>(args[arg_cnt++].cast <int >());
419+ auto swizzle_kind = static_cast <CUtensorMapSwizzle>(args[arg_cnt++].cast <int >());
420+ auto l2_promotion_kind = static_cast <CUtensorMapL2promotion>(args[arg_cnt++].cast <int >());
421+ auto oob_fill_kind = static_cast <CUtensorMapFloatOOBfill>(args[arg_cnt++].cast <int >());
422+
423+ ICHECK_EQ (tensor_dtype.lanes (), 1 )
424+ << " Expect tensor_dtype to have lanes=1, but get " << tensor_dtype;
425+ CUtensorMapDataType cu_dtype;
426+ switch (tensor_dtype.code ()) {
427+ case DataType::kInt :
428+ // int
429+ switch (tensor_dtype.bits ()) {
430+ case 8 :
431+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
432+ break ;
433+ case 32 :
434+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT32;
435+ break ;
436+ case 64 :
437+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64;
438+ break ;
439+ default :
440+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
441+ }
442+ break ;
443+ case DataType::kUInt :
444+ // unsigned int
445+ switch (tensor_dtype.bits ()) {
446+ case 8 :
447+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
448+ break ;
449+ case 16 :
450+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT16;
451+ break ;
452+ case 32 :
453+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT32;
454+ break ;
455+ case 64 :
456+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64;
457+ break ;
458+ default :
459+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
460+ }
461+ break ;
462+ case DataType::kFloat :
463+ // float
464+ switch (tensor_dtype.bits ()) {
465+ case 16 :
466+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
467+ break ;
468+ case 32 :
469+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
470+ break ;
471+ case 64 :
472+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
473+ break ;
474+ default :
475+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
476+ }
477+ break ;
478+ case DataType::kBFloat :
479+ // bfloat
480+ switch (tensor_dtype.bits ()) {
481+ case 16 :
482+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
483+ break ;
484+ default :
485+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
486+ }
487+ break ;
488+ case DataType::kFloat8_e4m3fn :
489+ // NV float8 e4m3
490+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
491+ break ;
492+ case DataType::kFloat8_e5m2 :
493+ // NV float8 e5m2
494+ cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8;
495+ break ;
496+ default :
497+ LOG (FATAL) << " Unsupported data type " << runtime::DLDataTypeToString (tensor_dtype);
498+ }
499+
500+ // sanity checks per cuTensorMapEncodeTiled requirements
501+ // see
502+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
503+ CHECK_EQ ((reinterpret_cast <uint64_t >(tensor_ptr) & 0b1111 ), 0 ); // 16-byte alignment
504+ CHECK_EQ ((reinterpret_cast <uint64_t >(tensor_map) & 0b111111 ), 0 ); // 64-byte alignment
505+ CHECK_LE (tensor_rank, 5 ) << " cuTensorMapEncodeTiled only supports up to 5D tensors" ;
506+
507+ if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) {
508+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 32 )
509+ << " CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32." ;
510+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) {
511+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 64 )
512+ << " CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64." ;
513+ } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) {
514+ CHECK_LE (shared_shape[0 ] * tensor_dtype.bytes (), 128 )
515+ << " CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= "
516+ " 128." ;
517+ }
518+
519+ const cuuint64_t * global_shape_ptr = global_shape.data ();
520+ const cuuint64_t * global_strides_ptr = global_strides.data ();
521+ const uint32_t * shared_shape_ptr = shared_shape.data ();
522+ const uint32_t * shared_strides_ptr = shared_strides.data ();
523+
524+ CUresult res =
525+ cuTensorMapEncodeTiled (tensor_map, cu_dtype, tensor_rank, tensor_ptr, global_shape_ptr,
526+ global_strides_ptr, shared_shape_ptr, shared_strides_ptr,
527+ interleaved_kind, swizzle_kind, l2_promotion_kind, oob_fill_kind);
528+ const char * errstr;
529+ cuGetErrorString (res, &errstr);
530+ if (res != CUDA_SUCCESS) {
531+ // get error string
532+ const char * error_string = nullptr ;
533+ cuGetErrorString (res, &error_string);
534+ std::cerr << " Error in cuTensorMapEncodeTiled: " << error_string << std::endl;
535+ std::cout << " cu_dtype: " << cu_dtype << " \n " ;
536+ std::cout << " TMA Desc Addr: " << tensor_map << " \n " ;
537+ std::cout << " TMA Interleave: " << interleaved_kind << " \n " ;
538+ std::cout << " TMA L2Promotion: " << l2_promotion_kind << " \n " ;
539+ std::cout << " TMA OOBFill: " << oob_fill_kind << " \n " ;
540+ std::cout << " SMEM Swizzle: " << swizzle_kind << " \n " ;
541+ std::cout << " tensor rank: " << tensor_rank << " \n " ;
542+ std::cout << " global prob shape: " ;
543+ for (size_t i = 0 ; i < tensor_rank; i++) {
544+ std::cout << global_shape[i] << " " ;
545+ }
546+ std::cout << " \n " ;
547+ std::cout << " global prob stride: " ;
548+ for (size_t i = 0 ; i < tensor_rank; i++) {
549+ std::cout << global_strides[i] << " " ;
550+ }
551+ std::cout << " \n " ;
552+ std::cout << " smem box shape: " ;
553+ for (size_t i = 0 ; i < tensor_rank; i++) {
554+ std::cout << shared_shape[i] << " " ;
555+ }
556+ std::cout << " \n " ;
557+ std::cout << " smem box stride: " ;
558+ for (size_t i = 0 ; i < tensor_rank; i++) {
559+ std::cout << shared_strides[i] << " " ;
560+ }
561+ std::cout << " \n " ;
562+ CHECK_EQ (res, CUDA_SUCCESS) << " Error in cuTensorMapEncodeTiled: " << errstr;
563+ }
564+ });
565+
360566} // namespace runtime
361567} // namespace tvm
0 commit comments