Skip to content

Commit 0d2eab2

Browse files
authored
[Cublas] Added support for bfloat16 while dispatching to cublas kernels (#17796)
In this PR I have made changes so that we can support CUBLAS dispatch operations for bfloat16 data type.
1 parent 88d9aa6 commit 0d2eab2

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

python/tvm/relax/backend/cuda/cublas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
4343
(lhs_dtype == "float16" and rhs_dtype == "float16")
4444
or (lhs_dtype == "float32" and rhs_dtype == "float32")
4545
or (lhs_dtype == "int8" and rhs_dtype == "int8")
46+
or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16")
4647
)
4748

4849

src/runtime/contrib/cublas/cublas.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s
125125
if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
126126
return TypeMatch(in_dtype, kDLInt, 8);
127127
} else if (TypeMatch(out_dtype, kDLFloat, 32)) {
128-
return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16);
128+
return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) ||
129+
TypeMatch(in_dtype, kDLBfloat, 16);
129130
} else {
130131
return false;
131132
}
@@ -162,6 +163,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
162163

163164
if (TypeMatch(A->dtype, kDLFloat, 16)) {
164165
ab_type = CUDA_R_16F;
166+
} else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
167+
ab_type = CUDA_R_16BF;
165168
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
166169
ab_type = CUDA_R_8I;
167170
} else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
@@ -171,6 +174,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
171174

172175
if (TypeMatch(C->dtype, kDLFloat, 16)) {
173176
c_type = CUDA_R_16F;
177+
} else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
178+
c_type = CUDA_R_16BF;
174179
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
175180
c_type = CUDA_R_32I;
176181
compute_type = CUBLAS_COMPUTE_32I;

src/runtime/contrib/cublas/cublas_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
116116
case 64:
117117
return CUDA_R_64F;
118118
}
119+
} else if (type.code == kDLBfloat) {
120+
switch (type.bits) {
121+
case 16:
122+
return CUDA_R_16BF;
123+
}
119124
}
120125
LOG(FATAL) << "Unsupported cuda type";
121126
}

tests/python/relax/test_codegen_cublas.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload():
393393
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
394394

395395

396+
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
397+
@pytest.mark.parametrize(
398+
"x_shape, y_shape, transpose_y, out_dtype",
399+
[
400+
((10, 32), (64, 32), True, "float32"),
401+
((32, 16), (32, 16), True, "float32"),
402+
((2, 10, 32), (2, 64, 32), True, "float32"),
403+
],
404+
)
405+
def test_matmul_bfloat16_offload(
406+
x_shape,
407+
y_shape,
408+
transpose_y,
409+
out_dtype,
410+
):
411+
in_dtype = "bfloat16"
412+
mod = get_relax_matmul_module(
413+
x_shape,
414+
y_shape,
415+
in_dtype,
416+
out_dtype,
417+
bias_shape=None,
418+
transposed_y=transpose_y,
419+
activation=None,
420+
)
421+
# Generate input data in float32 and then convert to bfloat16 using ml_dtypes.
422+
x_float32 = np.random.uniform(low=0, high=5, size=x_shape).astype("float32")
423+
y_float32 = np.random.uniform(low=0, high=5, size=y_shape).astype("float32")
424+
x_bf16 = ml_dtypes.bfloat16(x_float32)
425+
y_bf16 = ml_dtypes.bfloat16(y_float32)
426+
427+
# For the reference result, adjust y (if needed) in float32.
428+
z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32
429+
args = (x_bf16, y_bf16)
430+
431+
out = get_result_with_relax_cublas_offload(mod, args)
432+
ref_out = np.matmul(x_float32, z).astype(out_dtype)
433+
434+
tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2)
435+
436+
396437
@pytest.mark.parametrize(
397438
"M, N, K, out_dtype, transposed_y, partition_done",
398439
[

0 commit comments

Comments
 (0)