Skip to content

Commit 8edf01f

Browse files
committed
[Contrib] Workspace for cuBLAS backend
This PR adds a 32MB workspace for cuBLAS backend, so that functions like `cublasLtMatmul` can take the workspace as input. The workspace is managed under CuBlasThreadEntry so that it will be allocated only once in each thread.
1 parent 593a4bd commit 8edf01f

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

src/runtime/contrib/cublas/cublas.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }
135135

136136
#if CUDART_VERSION >= 10010
137137

138-
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B,
138+
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
139+
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
139140
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
140-
cublasLtEpilogue_t epilogue) {
141+
void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue) {
141142
ICHECK(TypeEqual(A->dtype, B->dtype));
142143
// Reversed strides indicates an in-place transpose operation.
143144
transa = IsInPlaceTransposed(A) ? !transa : transa;
@@ -265,8 +266,21 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A,
265266
auto B_data = static_cast<char*>(B->data) + B->byte_offset;
266267
auto C_data = static_cast<char*>(C->data) + C->byte_offset;
267268

269+
cublasLtMatmulPreferenceSetAttribute(matmul_pref_desc, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
270+
&workspace_size, sizeof(size_t));
271+
272+
cublasLtMatmulHeuristicResult_t heuristic_result = {};
273+
int returned_result = 0;
274+
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc,
275+
matmul_pref_desc, 1, &heuristic_result,
276+
&returned_result));
277+
if (returned_result == 0) {
278+
CHECK_CUBLAS_ERROR(CUBLAS_STATUS_NOT_SUPPORTED);
279+
}
280+
268281
CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta,
269-
C_data, C_desc, C_data, C_desc, nullptr, nullptr, 0, stream));
282+
C_data, C_desc, C_data, C_desc, &heuristic_result.algo,
283+
workspace_ptr, workspace_size, stream));
270284

271285
cublasLtMatmulDescDestroy(op_desc);
272286
cublasLtMatrixLayoutDestroy(A_desc);

src/runtime/contrib/cublas/cublas_json_runtime.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ class CublasJSONRuntime : public JSONRuntimeBase {
129129

130130
auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);
131131

132-
tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, bias_ptr, out_ptr,
133-
transa, transb, epilogue);
132+
tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr,
133+
b_ptr, bias_ptr, out_ptr, transa, transb,
134+
entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue);
134135
}
135136
}
136137
}

src/runtime/contrib/cublas/cublas_utils.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,25 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
4848
return retval;
4949
}
5050

51-
CuBlasLtThreadEntry::CuBlasLtThreadEntry() { CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); }
51+
CuBlasLtThreadEntry::CuBlasLtThreadEntry() {
52+
CHECK_CUBLAS_ERROR(cublasLtCreate(&handle));
53+
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc));
54+
CUDA_CALL(cudaMalloc(&workspace_ptr, workspace_size));
55+
}
5256

5357
CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
5458
if (handle) {
5559
cublasLtDestroy(handle);
5660
handle = nullptr;
5761
}
62+
if (matmul_pref_desc) {
63+
cublasLtMatmulPreferenceDestroy(matmul_pref_desc);
64+
matmul_pref_desc = nullptr;
65+
}
66+
if (workspace_ptr != nullptr) {
67+
cudaFree(workspace_ptr);
68+
workspace_ptr = nullptr;
69+
}
5870
}
5971

6072
typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;

src/runtime/contrib/cublas/cublas_utils.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,14 @@ struct CuBlasThreadEntry {
8080
struct CuBlasLtThreadEntry {
8181
CuBlasLtThreadEntry();
8282
~CuBlasLtThreadEntry();
83+
8384
cublasLtHandle_t handle{nullptr};
85+
cublasLtMatmulPreference_t matmul_pref_desc{nullptr};
86+
void* workspace_ptr{nullptr};
87+
// 32MB workspace as suggested by NVIDIA
88+
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace.
89+
static constexpr const size_t workspace_size = 33554432;
90+
8491
static CuBlasLtThreadEntry* ThreadLocal();
8592
}; // CuBlasLtThreadEntry
8693

@@ -113,8 +120,10 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
113120
}
114121

115122
/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */
116-
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* A, const DLTensor* B,
123+
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
124+
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
117125
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
126+
void* workspace_ptr, size_t workspace_size,
118127
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT);
119128

120129
} // namespace contrib

0 commit comments

Comments
 (0)