diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 81f84cff90ca4..bcab5e9e6fa1b 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1192,7 +1192,7 @@ function(onnxruntime_configure_target target_name) # Keep BinSkim happy if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM") - target_link_options(${target_name} PRIVATE "/CETCOMPAT") + target_link_options(${target_name} PRIVATE "$<$:/CETCOMPAT>" "$<$:-Xlinker=/CETCOMPAT>") endif() endfunction() @@ -1421,7 +1421,6 @@ configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_c get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) if (onnxruntime_USE_CUDA) - set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) if(onnxruntime_CUDA_HOME) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) @@ -1441,6 +1440,14 @@ if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch") + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + if (UNIX) + # Suppress deprecation errors (e.g., long4 in CUDA 13) + add_compile_options(-Wno-deprecated-declarations) + endif() + endif() + if (NOT WIN32) list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC) endif() diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake index ef94ec25132e3..be6a5febf3e14 100644 --- a/cmake/external/cuda_configuration.cmake +++ b/cmake/external/cuda_configuration.cmake @@ -58,6 +58,19 @@ macro(setup_cuda_compiler) if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION) message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}") endif() + + # For CUDA 13+, explicitly set the compiler front-end to Clang to handle + # MSVC-specific pragmas correctly in device code. + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0 AND NOT DEFINED CMAKE_CUDA_COMPILER_FRONTEND_VARIANT) + message(STATUS "Setting CUDA compiler front-end to Clang by default for CUDA 13+.") + set(CMAKE_CUDA_COMPILER_FRONTEND_VARIANT "CLANG") + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CUDA_RUNTIME_LIBRARY "Hybrid") + else() + set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared") + endif() endmacro() macro(setup_cuda_architectures) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 91707c485d3c5..68a3e9014b7b0 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -191,6 +191,17 @@ target_compile_options(${target} PRIVATE "$<$:--diag-suppress=221>") endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + if (UNIX) + # Suppress -Wattributes warning from protobuf headers with nvcc on Linux + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-attributes>") + endif() + + if (MSVC) + target_compile_options(${target} PRIVATE "$<$:--diag-suppress=20199>") + endif() + endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 88a1c300f5721..41983d63f6afe 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1523,7 +1523,7 @@ endif() list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) endif() if (onnxruntime_USE_TENSORRT) @@ -1751,6 +1751,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (HAS_QSPECTRE) list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() + set(custom_op_lib_link ${custom_op_lib_link} CUDA::cudart) endif() file(GLOB custom_op_src ${custom_op_src_patterns}) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 09f5e66286807..6c08d7fbd9b3f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -22,6 +22,12 @@ namespace cuda { constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; +// longlong4 is deprecated in cuda 13. +// LongLong4 is similar to longlong4_32a, except this is also visible in Host compiler (longlong4_32a is only visible to nvcc); +typedef struct __align__(32) { + long long int x, y, z, w; +} LongLong4; + // A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that. struct CumulatedSequenceLengthCache { onnxruntime::IAllocatorUniquePtr buffer; @@ -144,14 +150,14 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i template Status LaunchStridedCopy( cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) int max_threads_per_block); template Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 66e56e701c558..838249a255899 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -11,8 +11,8 @@ namespace contrib { namespace cuda { template -__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) +__global__ void StridedCopy(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) { const int h = threadIdx.x; const int n = threadIdx.y; @@ -30,8 +30,8 @@ __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // } template -__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) +__global__ void StridedCopyLarge(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) const int* in_seqlens_offset, const int* out_seqlens_offset) { // Use when (H*)*num_heads > 1024 int h = threadIdx.x; @@ -77,7 +77,7 @@ struct ToByteType<16> { template <> struct ToByteType<32> { - using T = ulonglong4; + using T = LongLong4; }; template @@ -86,8 +86,8 @@ using ToBytes = typename ToByteType::T; template Status LaunchStridedCopy( cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; @@ -157,8 +157,8 @@ Status LaunchStridedCopy( template Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) int max_threads_per_block) { const int* in_seqlens_offset = nullptr; const int* out_seqlens_offset = nullptr; @@ -170,14 +170,14 @@ Status LaunchStridedCopy(cudaStream_t stream, template Status LaunchStridedCopy( cudaStream_t stream, - const float* in, int4 in_shape, longlong4 in_strides, - float* out, longlong4 out_strides, + const float* in, int4 in_shape, LongLong4 in_strides, + float* out, LongLong4 out_strides, int max_threads_per_block); template Status LaunchStridedCopy( cudaStream_t stream, - const half* in, int4 in_shape, longlong4 in_strides, - half* out, longlong4 out_strides, + const half* in, int4 in_shape, LongLong4 in_strides, + half* out, LongLong4 out_strides, int max_threads_per_block); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index 41691d823f528..c20b2981f7aca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -31,12 +31,13 @@ #pragma once +#include "core/providers/cuda/curand_wrapper.h" + #ifdef HAS_PYTORCH #include #include #endif -#include #include #include #include diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp index 499504439aa46..dd2d7f56d85f0 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -49,7 +49,6 @@ #include "cute/atom/copy_traits_sm90_tma.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/tensor_predicate.hpp" #include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 15bb4f469ba3d..b10dcde122063 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -4,8 +4,9 @@ #pragma once #include +#include "core/providers/cuda/curand_wrapper.h" #include -#include + #include #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3053521ad946e..07d875e90fa4b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -14,6 +14,10 @@ namespace onnxruntime { namespace contrib { namespace rocm { +typedef struct __align__(32) { + long long int x, y, z, w; +} LongLong4; + size_t GetAttentionScratchSize( size_t element_size, int batch_size, @@ -162,14 +166,14 @@ Status ClassifyAttentionMode(AttentionType type, template Status LaunchStridedCopy( hipStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) int max_threads_per_block); template Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) + const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index e190a6938dc6b..226b89cfb2b86 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -143,7 +143,7 @@ struct Strides { int seqlen_dim, int head_size_dim) { ORT_UNUSED_PARAMETER(batch_dim); - return Strides{longlong4{ + return Strides{LongLong4{ static_cast(num_head_dim) * seqlen_dim * head_size_dim, static_cast(seqlen_dim) * head_size_dim, static_cast(head_size_dim), @@ -157,7 +157,7 @@ struct Strides { int num_head_dim, int head_size_dim) { ORT_UNUSED_PARAMETER(batch_dim); - return Strides{longlong4{ + return Strides{LongLong4{ static_cast(seqlen_dim) * num_head_dim * head_size_dim, static_cast(head_size_dim), static_cast(num_head_dim) * head_size_dim, @@ -165,7 +165,7 @@ struct Strides { }}; } - template + template T ForBNSHCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), @@ -174,7 +174,7 @@ struct Strides { static_cast(strides_for_bnsh_coord.w)}; } - template + template T ForBSNHCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), @@ -183,7 +183,7 @@ struct Strides { static_cast(strides_for_bnsh_coord.w)}; } - template + template T ForBNHSCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), @@ -198,7 +198,7 @@ struct Strides { } // store intermediate strides in the canonical (b,n,s,h) coordinate order - longlong4 strides_for_bnsh_coord; + LongLong4 strides_for_bnsh_coord; }; template diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index 1dd7a14332b3e..ebdac882bff09 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -178,7 +178,7 @@ struct TreeNodeElement { inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0x1F); } inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); } - inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; } + inline bool is_missing_track_true() const { return static_cast(flags) & static_cast(MissingTrack::kTrue); } #if defined(_TREE_DEBUG) std::string str() const { diff --git a/onnxruntime/core/providers/cuda/curand_wrapper.h b/onnxruntime/core/providers/cuda/curand_wrapper.h new file mode 100644 index 0000000000000..e67fce1e9ff15 --- /dev/null +++ b/onnxruntime/core/providers/cuda/curand_wrapper.h @@ -0,0 +1,12 @@ +// +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#pragma once + +#if defined(CUDA_VERSION) && CUDA_VERSION == 13000 +#define __NV_NO_VECTOR_DEPRECATION_DIAG 1 +#endif + +#include diff --git a/onnxruntime/core/providers/cuda/fpgeneric.cu b/onnxruntime/core/providers/cuda/fpgeneric.cu index ef9eda8a0553c..78426c1294883 100644 --- a/onnxruntime/core/providers/cuda/fpgeneric.cu +++ b/onnxruntime/core/providers/cuda/fpgeneric.cu @@ -10,9 +10,8 @@ */ // NV_TODO: optimize speed -- pass things needed in, optimize kernel speed, add half2 // NV_TODO: investigate cub support for half - +#include "core/providers/cuda/curand_wrapper.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include #define TRANS_TILE_DIM 32 #define BLOCK_ROWS 8 diff --git a/onnxruntime/core/providers/cuda/generator/random_impl.cu b/onnxruntime/core/providers/cuda/generator/random_impl.cu index 7b256f3def25e..2507177d7c21f 100644 --- a/onnxruntime/core/providers/cuda/generator/random_impl.cu +++ b/onnxruntime/core/providers/cuda/generator/random_impl.cu @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cuda/curand_wrapper.h" #include "core/providers/cuda/generator/random_impl.h" -#include #include #include "core/providers/cuda/cu_inc/common.cuh" diff --git a/onnxruntime/core/providers/cuda/nn/dropout_impl.cu b/onnxruntime/core/providers/cuda/nn/dropout_impl.cu index 03af88bbc07cd..b412cc6a7e49e 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/dropout_impl.cu @@ -15,10 +15,10 @@ */ /* Modifications Copyright (c) Microsoft. */ +#include "core/providers/cuda/curand_wrapper.h" #include "core/providers/cuda/nn/dropout_impl.h" -#include #include #include "core/providers/cuda/cu_inc/bitmask.cuh" diff --git a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu index 62753ef4d5867..72fbbf53bfb21 100644 --- a/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.cu @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include "core/providers/cuda/curand_wrapper.h" #include "orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.h" -#include #include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/math/softmax_warpwise_impl.cuh"