Skip to content

Commit 172d7c1

Browse files
committed
[Cublas] Added support for bfloat16 while dispatching to cublas kernels
1 parent 43adad7 commit 172d7c1

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

python/tvm/relax/backend/cuda/cublas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
4343
(lhs_dtype == "float16" and rhs_dtype == "float16")
4444
or (lhs_dtype == "float32" and rhs_dtype == "float32")
4545
or (lhs_dtype == "int8" and rhs_dtype == "int8")
46+
or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16")
4647
)
4748

4849

src/runtime/contrib/cublas/cublas.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
162162

163163
if (TypeMatch(A->dtype, kDLFloat, 16)) {
164164
ab_type = CUDA_R_16F;
165+
} else if(TypeMatch(A->dtype, kDLBfloat, 16)){
166+
ab_type = CUDA_R_16BF;
165167
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
166168
ab_type = CUDA_R_8I;
167169
} else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
@@ -171,6 +173,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
171173

172174
if (TypeMatch(C->dtype, kDLFloat, 16)) {
173175
c_type = CUDA_R_16F;
176+
} else if(TypeMatch(C->dtype, kDLBfloat, 16)){
177+
c_type = CUDA_R_16BF;
174178
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
175179
c_type = CUDA_R_32I;
176180
compute_type = CUBLAS_COMPUTE_32I;

src/runtime/contrib/cublas/cublas_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
116116
case 64:
117117
return CUDA_R_64F;
118118
}
119+
} else if (type.code == kDLBfloat){
120+
switch (type.bits) {
121+
case 16:
122+
return CUDA_R_16BF;
123+
}
119124
}
120125
LOG(FATAL) << "Unsupported cuda type";
121126
}

0 commit comments

Comments
 (0)