diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7e456d32598b..1fcc89cb5b26 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -99,9 +99,11 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, // Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single // driver call, amortizing per-copy submission overhead. - // int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms, - // so we reinterpret_cast the tensor data directly to avoid copies. + // int64_t, CUdeviceptr, void*, and size_t are all 8 bytes on 64-bit + // platforms, so we reinterpret_cast the tensor data directly to avoid + // copies. static_assert(sizeof(CUdeviceptr) == sizeof(int64_t)); + static_assert(sizeof(void*) == sizeof(int64_t)); static_assert(sizeof(size_t) == sizeof(int64_t)); #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080 // Resolve cuMemcpyBatchAsync at runtime via cuGetProcAddress so that @@ -134,18 +136,26 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, &fail_idx, static_cast(stream)); TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ", fail_idx, " with error ", result); - } else + return; + } +#elif defined(USE_ROCM) && defined(HIP_VERSION) && HIP_VERSION >= 70100000 + // ROCm 7.1 does not yet honor hipMemcpyAttributes; pass nullptr. + size_t fail_idx = 0; + hipError_t result = hipMemcpyBatchAsync( + reinterpret_cast(dst_data), reinterpret_cast(src_data), + reinterpret_cast(size_data), static_cast(n), nullptr, + nullptr, 0, &fail_idx, stream); + TORCH_CHECK(result == hipSuccess, "hipMemcpyBatchAsync failed at index ", + fail_idx, " with error ", result); + return; #endif - { - // Fallback for CUDA < 12.8, older drivers, and ROCm: - // individual async copies. - // cudaMemcpyDefault lets the driver infer direction from pointer types. - for (int64_t i = 0; i < n; i++) { - cudaMemcpyAsync(reinterpret_cast(dst_data[i]), - reinterpret_cast(src_data[i]), - static_cast(size_data[i]), cudaMemcpyDefault, - stream); - } + // Fallback for CUDA < 12.8, older drivers, and ROCm < 7.1. + // cudaMemcpyDefault lets the driver infer direction from pointer types. + for (int64_t i = 0; i < n; i++) { + cudaMemcpyAsync(reinterpret_cast(dst_data[i]), + reinterpret_cast(src_data[i]), + static_cast(size_data[i]), cudaMemcpyDefault, + stream); } }