Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,

// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
// int64_t, CUdeviceptr, void*, and size_t are all 8 bytes on 64-bit
// platforms, so we reinterpret_cast the tensor data directly to avoid
// copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(void*) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
// Resolve cuMemcpyBatchAsync at runtime via cuGetProcAddress so that
Expand Down Expand Up @@ -134,18 +136,26 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
&fail_idx, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
} else
return;
}
#elif defined(USE_ROCM) && defined(HIP_VERSION) && HIP_VERSION >= 70100000
// ROCm 7.1 does not yet honor hipMemcpyAttributes; pass nullptr.
size_t fail_idx = 0;
hipError_t result = hipMemcpyBatchAsync(
reinterpret_cast<void**>(dst_data), reinterpret_cast<void**>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), nullptr,
nullptr, 0, &fail_idx, stream);
Comment thread
Etelis marked this conversation as resolved.
TORCH_CHECK(result == hipSuccess, "hipMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
return;
#endif
{
// Fallback for CUDA < 12.8, older drivers, and ROCm:
// individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
// Fallback for CUDA < 12.8, older drivers, and ROCm < 7.1.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
}

Expand Down
Loading