Skip to content

Commit 5ce9719

Browse files
authored
[https://nvbugs/5503138] [fix] Remove compile warnings (#8167)
Signed-off-by: Xiwen Yu <[email protected]>
1 parent 72fcff1 commit 5ce9719

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ __inline__ __device__ T warpReduceSum(T val)
396396
return val;
397397
}
398398

399+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
399400
inline __device__ float block_reduce_sum(float val)
400401
{
401402
__shared__ float smem[WARP_SIZE];
@@ -426,6 +427,7 @@ __device__ float4 loadfloat4(void const* ptr)
426427

427428
return return_value;
428429
}
430+
#endif
429431
} // namespace
430432

431433
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,8 @@ PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled()
8989
// Get pointer to cuTensorMapEncodeTiled
9090
cudaDriverEntryPointQueryResult driver_status;
9191
void* cuTensorMapEncodeTiled_ptr = nullptr;
92-
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
9392
cudaGetDriverEntryPointByVersion(
9493
"cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, cudaEnableDefault, &driver_status);
95-
#else
96-
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, &driver_status);
97-
#endif
9894

9995
if (driver_status != cudaDriverEntryPointSuccess)
10096
{

cpp/tensorrt_llm/kernels/recoverFromRingAtten.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ __global__ void reduce4ring_attention(
5353
float* softmax_sum = softmax_stats + 1;
5454
float* max = softmax_stats;
5555

56+
#ifdef __NVCC_DIAG_PRAGMA_SUPPORT__
57+
#pragma nv_diag_suppress static_var_with_dynamic_init
58+
// https://nvidia.github.io/cccl/libcudacxx/extended_api/synchronization_primitives/barrier.html
59+
#endif
5660
__shared__ cuda::barrier<cuda::thread_scope::thread_scope_block> barrier;
5761
if (block.thread_rank() == 0)
5862
{
@@ -113,11 +117,6 @@ template <typename Tout>
113117
void invokeRecoverFromRA(Tout* accu_output, float* accu_softmax_stats, Tout* output, float* softmax_stats, int b, int s,
114118
int h, int d, int* cu_seqlens, cudaStream_t stream)
115119
{
116-
float* accu_softmax_sum = accu_softmax_stats;
117-
float* accu_softmax_max = accu_softmax_stats + b * s * h;
118-
float* softmax_sum = softmax_stats;
119-
float* softmax_max = softmax_stats + b * s * h;
120-
121120
int threads_per_block = 128;
122121
int saturated_s_block_dim = 3000 / b + 1;
123122
s = s * h;

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616

1717
#include <vector>
1818

19+
// clang-format off
20+
#include "trtllmGen_gemm_export/GemmInterface.h"
21+
#include "trtllmGen_gemm_export/GemmOptions.h"
22+
#include "trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h"
23+
// clang-format on
24+
1925
#include "KernelRunner.h"
2026
#include "tensorrt_llm/common/assert.h"
2127
#include "tensorrt_llm/common/cudaUtils.h"
2228
#include "tensorrt_llm/common/envUtils.h"
23-
#include "trtllmGen_gemm_export/GemmInterface.h"
24-
#include "trtllmGen_gemm_export/GemmOptions.h"
25-
#include "trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h"
2629

2730
namespace tensorrt_llm
2831
{

0 commit comments

Comments
 (0)