diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu index b81783377936f..3f02a441da73e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu @@ -37,26 +37,17 @@ Status CopyQK(cudaStream_t stream, const int qk_size, const T* input, QK* output) { - if constexpr (std::is_same::value) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(QK), cudaMemcpyDeviceToDevice, stream)); - return Status::OK(); - } - const bool half2float = std::is_same::value && std::is_same::value; - const bool float2half = std::is_same::value && std::is_same::value; - ORT_ENFORCE(half2float || float2half); + constexpr const bool half2float = std::is_same::value && std::is_same::value; + constexpr const bool float2half = std::is_same::value && std::is_same::value; + static_assert(half2float || float2half, "This function supports either or "); - int block_size = 256; + constexpr const int block_size = 256; int num_blocks = (qk_size + block_size - 1) / block_size; ConvertAndCopyQK<<>>(qk_size, input, output); return CUDA_CALL(cudaGetLastError()); } -template Status CopyQK(cudaStream_t stream, - const int qk_size, - const float* input, - float* output); - template Status CopyQK(cudaStream_t stream, const int qk_size, const float* input, @@ -67,10 +58,23 @@ template Status CopyQK(cudaStream_t stream, const half* input, float* output); -template Status CopyQK(cudaStream_t stream, - const int qk_size, - const half* input, - half* output); +template <> +Status CopyQK(cudaStream_t stream, + const int qk_size, + const float* input, + float* output) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + return Status::OK(); +} + +template <> +Status CopyQK(cudaStream_t stream, + const int qk_size, + const half* input, + half* output) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(half), cudaMemcpyDeviceToDevice, stream)); + return Status::OK(); +} } // namespace cuda } // namespace contrib