diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 52949d973c60..1de2c0f23439 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -143,6 +143,14 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); + #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ /* Open the shared library */ \ @@ -168,6 +176,9 @@ typedef CUresult (*cuOccupancyMaxActiveClusters_t)( defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); + static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, maxActiveClusters = -1; @@ -206,6 +217,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; if (cuOccupancyMaxActiveClusters == NULL) { cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle(); + if (cuOccupancyMaxActiveClusters == NULL) { + return NULL; + } } Py_BEGIN_ALLOW_THREADS; @@ -288,6 +302,13 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { } assert((elementSize * tensorDim) >= 32 && "block size too small."); int rank = 1; + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + if (cuTensorMapEncodeTiled == NULL) { + cuTensorMapEncodeTiled = getCuTensorMapEncodeTiledHandle(); + if (cuTensorMapEncodeTiled == NULL) { + return NULL; + } + } CUresult result = cuTensorMapEncodeTiled( (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,