From 9d9451767f6a0ed572a16c5e6212db8719ce7b13 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 21 Mar 2025 15:18:00 -0700 Subject: [PATCH 1/4] Update attention_qk.cu --- onnxruntime/contrib_ops/cuda/bert/attention_qk.cu | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu index 78c407fd3bb3b..1d0bb8d7a58d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu @@ -29,6 +29,10 @@ Status CopyQK(cudaStream_t stream, const int qk_size, const T* input, QK* output) { + if constexpr (std::is_same::value) { + cudaMemcpyAsync(output, input, qk_size * sizeof(QK), cudaMemcpyDeviceToDevice, stream); + return Status::OK(); + } const bool half2float = std::is_same::value && std::is_same::value; const bool float2half = std::is_same::value && std::is_same::value; ORT_ENFORCE(half2float || float2half); @@ -40,6 +44,11 @@ Status CopyQK(cudaStream_t stream, return CUDA_CALL(cudaGetLastError()); } +template Status CopyQK(cudaStream_t stream, + const int qk_size, + const float* input, + float* output); + template Status CopyQK(cudaStream_t stream, const int qk_size, const float* input, @@ -50,6 +59,11 @@ template Status CopyQK(cudaStream_t stream, const half* input, float* output); +template Status CopyQK(cudaStream_t stream, + const int qk_size, + const half* input, + half* output); + } // namespace cuda } // namespace contrib } // namespace onnxruntime From dca5c5bcb6d910be29f2a0f4f18e213db3d5a0d8 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 21 Mar 2025 15:18:53 -0700 Subject: [PATCH 2/4] Update attention_impl.cu --- onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 84a7cc19f1576..51311715d3b2a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -762,12 +762,8 @@ Status UnfusedAttention( } else { // no mask if (nullptr != data.output_qk) { int64_t qk_size = (int64_t)batch_size * num_heads * sequence_length * total_sequence_length; - if (std::is_same::value) { - cudaMemcpyAsync(data.output_qk, data.scratch, qk_size * sizeof(QK), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - (CopyQK(stream, static_cast(qk_size), data.scratch, reinterpret_cast(data.output_qk)))); - } + ORT_RETURN_IF_ERROR( + (CopyQK(stream, static_cast(qk_size), data.scratch, reinterpret_cast(data.output_qk)))); } ORT_RETURN_IF_ERROR( ComputeSoftmax( From 72d15aa6be477453a329460db5c63054fd15c5f5 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 22 Mar 2025 03:05:15 +0000 Subject: [PATCH 3/4] Add missing kernel --- onnxruntime/contrib_ops/cuda/bert/attention_qk.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu index 1d0bb8d7a58d1..64d8ecf4a63bf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu @@ -24,6 +24,14 @@ __global__ void ConvertAndCopyQK(const int count, const half* input, float* outp } } +template +__global__ void ConvertAndCopyQK(const int count, const T* input, T* output) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < count) { + output[idx] = input[idx]; + } +} + template Status CopyQK(cudaStream_t stream, const int qk_size, From 586f53493714327907c854c49c9e392517eb8c35 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 22 Mar 2025 03:41:27 +0000 Subject: [PATCH 4/4] Fix ROCm build --- onnxruntime/contrib_ops/cuda/bert/attention_qk.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu index 64d8ecf4a63bf..b81783377936f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu @@ -38,7 +38,7 @@ Status CopyQK(cudaStream_t stream, const T* input, QK* output) { if constexpr (std::is_same::value) { - cudaMemcpyAsync(output, input, qk_size * sizeof(QK), cudaMemcpyDeviceToDevice, stream); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(QK), cudaMemcpyDeviceToDevice, stream)); return Status::OK(); } const bool half2float = std::is_same::value && std::is_same::value;