Skip to content

Commit fe345cd

Browse files
tongyuantongyuchzblych
authored andcommitted
[None][fix] Migrate to new cuda binding package name (NVIDIA#6700)
Signed-off-by: Yuan Tong <[email protected]>
1 parent 9d902e5 commit fe345cd

File tree

16 files changed

+75
-2153
lines changed

16 files changed

+75
-2153
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import subprocess
22

33
import pytest
4-
from cuda import cuda, nvrtc
4+
5+
try:
6+
from cuda.bindings import driver as cuda
7+
from cuda.bindings import nvrtc
8+
except ImportError:
9+
from cuda import cuda, nvrtc
510

611

712
def ASSERT_DRV(err):

tensorrt_llm/_ipc_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,20 @@
1717
import sys
1818
from typing import List, Tuple
1919

20-
from cuda import cuda, cudart
21-
from cuda.cudart import cudaError_t
20+
try:
21+
from cuda.bindings import driver as cuda
22+
from cuda.bindings import runtime as cudart
23+
except ImportError:
24+
from cuda import cuda, cudart
2225

2326
from ._utils import mpi_comm
2427
from .logger import logger
2528
from .mapping import Mapping
2629

2730

28-
def _raise_if_error(error: cudaError_t | cuda.CUresult):
29-
if isinstance(error, cudaError_t):
30-
if error != cudaError_t.cudaSuccess:
31+
def _raise_if_error(error: cudart.cudaError_t | cuda.CUresult):
32+
if isinstance(error, cudart.cudaError_t):
33+
if error != cudart.cudaError_t.cudaSuccess:
3134
raise RuntimeError(f"CUDA Runtime API error: {repr(error)}")
3235
if isinstance(error, cuda.CUresult):
3336
if error != cuda.CUresult.CUDA_SUCCESS:

tensorrt_llm/_mnnvl_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
import pynvml
2020
import torch
21-
from cuda import cuda
21+
22+
try:
23+
from cuda.bindings import driver as cuda
24+
except ImportError:
25+
from cuda import cuda
2226

2327
from ._dlpack_utils import pack_strided_memory
2428
from ._utils import mpi_comm

0 commit comments

Comments
 (0)