|
1 | | -/************************************************************************* |
2 | | - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
3 | 3 | * |
4 | | - * See LICENSE.txt for license information |
5 | | - ************************************************************************/ |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
6 | 16 |
|
7 | 17 | #include "config.h" |
8 | 18 | #include "nccl.h" |
9 | 19 | #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 28, 0) |
10 | | -#include "kernels.h" |
| 20 | +#include "kernels.cuh" |
11 | 21 | #endif |
12 | 22 | #include "tensorrt_llm/common/cudaUtils.h" |
13 | 23 | #include "tensorrt_llm/common/envUtils.h" |
@@ -176,19 +186,9 @@ bool TypedLaunchConfig<T>::isValidConfig(int threadsPerBlock, int unrollFactor) |
176 | 186 | { |
177 | 187 | // Get CUDA device properties |
178 | 188 | int dev = -1; |
179 | | - cudaError_t cudaStatus = cudaGetDevice(&dev); |
180 | | - if (cudaStatus != cudaSuccess) |
181 | | - { |
182 | | - TLLM_LOG_ERROR("Failed to get CUDA device: " + std::string(cudaGetErrorString(cudaStatus))); |
183 | | - return false; |
184 | | - } |
| 189 | + TLLM_CUDA_CHECK(cudaGetDevice(&dev)); |
185 | 190 | cudaDeviceProp deviceProp; |
186 | | - cudaStatus = cudaGetDeviceProperties(&deviceProp, dev); |
187 | | - if (cudaStatus != cudaSuccess) |
188 | | - { |
189 | | - TLLM_LOG_ERROR("Failed to get CUDA device properties: " + std::string(cudaGetErrorString(cudaStatus))); |
190 | | - return false; |
191 | | - } |
| 191 | + TLLM_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, dev)); |
192 | 192 |
|
193 | 193 | // Check threads per block limits |
194 | 194 | if (threadsPerBlock <= 0 || threadsPerBlock > deviceProp.maxThreadsPerBlock) |
@@ -217,13 +217,7 @@ bool TypedLaunchConfig<T>::isValidConfig(int threadsPerBlock, int unrollFactor) |
217 | 217 |
|
218 | 218 | // Get actual register and shared memory usage from the kernel |
219 | 219 | cudaFuncAttributes funcAttrib; |
220 | | - cudaError_t attrStatus = cudaFuncGetAttributes(&funcAttrib, reinterpret_cast<void const*>(kernelPtr)); |
221 | | - if (attrStatus != cudaSuccess) |
222 | | - { |
223 | | - TLLM_LOG_WARNING( |
224 | | - "Failed to get kernel attributes for validation: " + std::string(cudaGetErrorString(attrStatus))); |
225 | | - return false; |
226 | | - } |
| 220 | + TLLM_CUDA_CHECK(cudaFuncGetAttributes(&funcAttrib, reinterpret_cast<void const*>(kernelPtr))); |
227 | 221 |
|
228 | 222 | // Check register usage |
229 | 223 | int const totalRegistersPerBlock = funcAttrib.numRegs * threadsPerBlock; |
|
0 commit comments