Skip to content

Commit 10aaa40

Browse files
ibsidorenkoelvin-n
authored andcommitted
[CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#63)
Co-authored-by: Andrey Malyshev <[email protected]>
1 parent a64d1f1 commit 10aaa40

File tree

10 files changed

+175
-19
lines changed

10 files changed

+175
-19
lines changed

include/tvm/runtime/data_type.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ class DataType {
126126
code() == DataType::kE5M2Float) &&
127127
bits() == 8;
128128
}
129+
bool is_e4m3_float8() const {
130+
return (code() == DataType::kE4M3Float &&
131+
bits() == 8);
132+
}
133+
134+
bool is_e5m2_float8() const {
135+
return (code() == DataType::kE5M2Float &&
136+
bits() == 8);
137+
}
129138
/*! \return whether type is a float16 type. */
130139
bool is_float16() const { return is_float() && bits() == 16; }
131140
/*! \return whether type is a bfloat16 type. */

python/tvm/contrib/tvmjs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828

2929
import numpy as np
3030

31+
try:
32+
import ml_dtypes
33+
except ImportError:
34+
ml_dtypes = None
35+
3136
import tvm
3237
from tvm._ffi.libinfo import find_lib_path
3338

@@ -295,6 +300,20 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):
295300
arr = tvm.nd.empty(shape, dtype, device=device)
296301
assert offset + nbytes <= len(raw_data)
297302
buffer_source = raw_data[offset : offset + nbytes]
303+
if dtype == "e4m3_float8":
304+
if ml_dtypes is not None:
305+
dtype = ml_dtypes.float8_e4m3fn
306+
else:
307+
raise RuntimeError(
308+
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
309+
)
310+
if dtype == "e5m2_float8":
311+
if ml_dtypes is not None:
312+
dtype = ml_dtypes.float8_e5m2
313+
else:
314+
raise RuntimeError(
315+
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
316+
)
298317
if encode_format == "f32-to-bf16" and dtype == "float32":
299318
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
300319
arr.copyfrom(_convert_bf16_to_f32(data))

python/tvm/relax/backend/contrib/cublas.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828
from ..utils import has_leaking_intermediate_variables
2929

3030

31-
def _is_supported_dtype(lhs_dtype, rhs_dtype):
31+
def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
3232
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
33+
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
34+
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
35+
return out_dtype != "e5m2_float8"
3336
return (
3437
(lhs_dtype == "float16" and rhs_dtype == "float16")
3538
or (lhs_dtype == "float32" and rhs_dtype == "float32")
@@ -42,10 +45,12 @@ def _check_matmul(context: PatternCheckContext) -> bool:
4245
return False
4346
lhs = context.annotated_expr["lhs"]
4447
rhs = context.annotated_expr["rhs"]
48+
matmul_call = context.annotated_expr["root"]
4549

4650
lhs_dtype = lhs.struct_info.dtype
4751
rhs_dtype = rhs.struct_info.dtype
48-
if not _is_supported_dtype(lhs_dtype, rhs_dtype):
52+
out_dtype = matmul_call.struct_info.dtype
53+
if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
4954
return False
5055

5156
lhs_shape = lhs.struct_info.shape.values
@@ -62,6 +67,13 @@ def _check_matmul(context: PatternCheckContext) -> bool:
6267
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0:
6368
# Rows number must be multiples of 4 for IGEMM
6469
return False
70+
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
71+
# Matrix dimensions must be multiples of 16. This requirement is missing from the cuBLAS
72+
# docs, but it was observed during testing.
73+
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 16 != 0:
74+
return False
75+
if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or rhs_shape[-2] % 16 != 0:
76+
return False
6577

6678
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
6779
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)

python/tvm/relax/transform/legalize_ops/qdq.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def te_quantize(
5252
def quantize_compute(*indices):
5353
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
5454
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
55-
round_val = te.round(data[indices] / scale_value) + zp_value
55+
scaled = data[indices] / scale_value
56+
round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value
5657
return clip_cast(round_val, out_dtype)
5758

5859
output_shape = data.shape
@@ -75,15 +76,18 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
7576
Compute datatype: float32
7677
7778
Example of lowering:
78-
qnn.dequantize(data, scale, zp, "float32") -->
79-
sub = subtract(cast(data, "int32"), zp)
80-
out = multiply(cast(sub, "float32"), scale)
81-
82-
qnn.dequantize(data, scale, zp, "float16") -->
83-
sub = subtract(cast(data, "int32"), zp)
84-
mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
85-
clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
86-
out = cast(clipped_out, "float16")
79+
80+
dtype = ["int32"|"float32"]
81+
82+
qnn.dequantize(data, scale, zp, "float32") -->
83+
sub = subtract(cast(data, dtype), zp)
84+
out = multiply(cast(sub, "float32"), scale)
85+
86+
qnn.dequantize(data, scale, zp, "float16") -->
87+
sub = subtract(cast(data, dtype), zp)
88+
mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
89+
clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
90+
out = cast(clipped_out, "float16")
8791
"""
8892
axis = call.attrs.axis
8993
out_dtype = call.attrs.out_dtype
@@ -96,7 +100,8 @@ def te_dequantize(
96100
def dequantize_compute(*indices):
97101
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
98102
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
99-
sub = te.subtract(data[indices].astype("int32"), zp_value)
103+
dtype = "float32" if "float" in data.dtype else "int32"
104+
sub = te.subtract(data[indices].astype(dtype), zp_value)
100105
out = te.multiply(sub, scale_value.astype("float32"))
101106
if out_dtype == "float32":
102107
return out

src/relax/backend/contrib/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ inline std::string DType2String(const tvm::DataType dtype) {
7272
std::ostringstream os;
7373
if (dtype.is_float()) {
7474
os << "float";
75+
} else if (dtype.is_e4m3_float8()) {
76+
os << "e4m3_float";
77+
} else if (dtype.is_e5m2_float8()) {
78+
os << "e5m2_float";
7579
} else if (dtype.is_int()) {
7680
os << "int";
7781
} else if (dtype.is_uint()) {

src/relax/op/tensor/qdq.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize);
4949
StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
5050
const auto* attrs = call->attrs.as<QuantizeAttrs>();
5151
if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) &&
52-
attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16)) {
52+
attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) &&
53+
attrs->out_dtype != DataType::NVFloat8E4M3() && attrs->out_dtype != DataType::NVFloat8E5M2()) {
5354
ctx->ReportFatal(Diagnostic::Error(call)
5455
<< "Unsupported output datatype attribute for operation: '"
5556
<< attrs->out_dtype);
@@ -73,9 +74,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
7374
}
7475

7576
// Check datatype of zero_point param:
76-
if (zp_sinfo->dtype != DataType::Int(8)) {
77+
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
7778
ctx->ReportFatal(Diagnostic::Error(call)
78-
<< "zero_point param datatype should be int8, but got " << zp_sinfo->dtype);
79+
<< "zero_point param datatype should be 'int8' or 'float16', but got "
80+
<< zp_sinfo->dtype);
7981
}
8082

8183
// Check that "axis" attribute is not out of range:
@@ -142,7 +144,10 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
142144
// Check input datatype:
143145
if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) &&
144146
input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) &&
145-
input_sinfo->dtype != DataType::Int(32)) {
147+
input_sinfo->dtype != DataType::Int(32) &&
148+
input_sinfo->dtype != DataType::NVFloat8E4M3() &&
149+
input_sinfo->dtype != DataType::NVFloat8E5M2() &&
150+
input_sinfo->dtype != DataType::Float(16) && input_sinfo->dtype != DataType::Float(32)) {
146151
ctx->ReportFatal(Diagnostic::Error(call)
147152
<< "Unsupported input datatype for operation: " << attrs->out_dtype);
148153
}
@@ -155,9 +160,10 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
155160
}
156161

157162
// Check datatype of zero_point param:
158-
if (zp_sinfo->dtype != DataType::Int(8)) {
163+
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
159164
ctx->ReportFatal(Diagnostic::Error(call)
160-
<< "zero_point param datatype should be int8, but got " << zp_sinfo->dtype);
165+
<< "zero_point param datatype should be 'int8' or 'float16', but got "
166+
<< zp_sinfo->dtype);
161167
}
162168

163169
// Check that "axis" attribute is not out of range:

src/runtime/contrib/cublas/cublas.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
161161
ab_type = CUDA_R_16F;
162162
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
163163
ab_type = CUDA_R_8I;
164+
} else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
165+
ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
166+
ab_type = CUDA_R_8F_E4M3;
164167
}
165168

166169
if (TypeMatch(C->dtype, kDLFloat, 16)) {

src/tir/op/op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
263263
} else if (dtype.is_bfloat16()) {
264264
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
265265
} else if (dtype.is_float8()) {
266+
// according to https://arxiv.org/pdf/2209.05433.pdf
266267
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
267268
return FloatImm(dtype, 57344.0, span);
268269
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
@@ -303,6 +304,7 @@ PrimExpr min_value(const DataType& dtype, Span span) {
303304
} else if (dtype.is_bfloat16()) {
304305
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
305306
} else if (dtype.is_float8()) {
307+
// according to https://arxiv.org/pdf/2209.05433.pdf
306308
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
307309
return FloatImm(dtype, -57344.0, span);
308310
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {

tests/python/relax/test_codegen_cublas.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from tvm.relax.testing import get_relax_matmul_module
2626
from tvm.script import relax as R
2727

28+
try:
29+
import ml_dtypes
30+
except ImportError:
31+
ml_dtypes = None
32+
2833

2934
@pytest.fixture(autouse=True)
3035
def reset_seed():
@@ -226,6 +231,60 @@ def test_matmul_igemm_offload(
226231
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
227232

228233

234+
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
235+
@pytest.mark.parametrize(
236+
"x_shape, y_shape, transpose_y, out_dtype",
237+
[
238+
((10, 32), (64, 32), True, "float32"),
239+
((32, 16), (32, 16), True, "float16"),
240+
((2, 10, 32), (2, 64, 32), True, "float32"),
241+
],
242+
)
243+
def test_matmul_fp8_offload(
244+
x_shape,
245+
y_shape,
246+
transpose_y,
247+
out_dtype,
248+
):
249+
in_dtype = "e4m3_float8"
250+
mod = get_relax_matmul_module(
251+
x_shape,
252+
y_shape,
253+
in_dtype,
254+
out_dtype,
255+
bias_shape=None,
256+
transposed_y=transpose_y,
257+
activation=None,
258+
)
259+
numpytype = "float8_e4m3fn"
260+
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
261+
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
262+
z = np.swapaxes(y, -2, -1) if transpose_y else y
263+
args = (x, y)
264+
265+
out = get_result_with_relax_cublas_offload(mod, args)
266+
ref_out = np.matmul(x, z).astype(out_dtype)
267+
268+
tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3)
269+
270+
271+
@pytest.mark.parametrize(
272+
"M, N, K, out_dtype, partition_done",
273+
[
274+
(15, 64, 32, "float32", True),
275+
(15, 64, 32, "e4m3_float8", True),
276+
(15, 64, 32, "e5m2_float8", False),
277+
(16, 32, 60, "float32", False),
278+
(16, 30, 64, "float32", False),
279+
],
280+
)
281+
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done):
282+
mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True)
283+
mod = partition_for_cublas(mod)
284+
func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
285+
assert func_name in mod["main"].script()
286+
287+
229288
def test_cublas_partition_matmul_without_bias():
230289
# cuBLAS does not handle 2D bias (residual input)
231290
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))

tests/python/relax/test_op_qdq.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,42 @@ def test_qdq_op_infer_struct_info_symbolic():
6868
)
6969

7070

71+
def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
72+
bb = relax.BlockBuilder()
73+
n = tir.Var("n", "int64")
74+
x = relax.Var("x", R.Tensor((n, 3), "float32"))
75+
dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
76+
s = relax.Var("s", R.Tensor([3], "float32"))
77+
zp = relax.Var("zp", R.Tensor([3], "float16"))
78+
_check_inference(
79+
bb,
80+
relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
81+
relax.TensorStructInfo((n, 3), "e4m3_float8"),
82+
)
83+
_check_inference(
84+
bb,
85+
relax.op.dequantize(dx, s, zp, 1, "float32"),
86+
relax.TensorStructInfo((n, 3), "float32"),
87+
)
88+
89+
90+
def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
91+
dtype = "e5m2_float8"
92+
bb = relax.BlockBuilder()
93+
n = tir.Var("n", "int64")
94+
x = relax.Var("x", R.Tensor((n, 3), "float32"))
95+
dx = relax.Var("dx", R.Tensor((n, 3), dtype))
96+
s = relax.Var("s", R.Tensor([3], "float32"))
97+
zp = relax.Var("zp", R.Tensor([3], "float16"))
98+
_check_inference(
99+
bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 3), dtype)
100+
)
101+
_check_inference(
102+
bb,
103+
relax.op.dequantize(dx, s, zp, 1, "float32"),
104+
relax.TensorStructInfo((n, 3), "float32"),
105+
)
106+
107+
71108
if __name__ == "__main__":
72109
tvm.testing.main()

0 commit comments

Comments
 (0)