Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class DataType {
code() == DataType::kE5M2Float) &&
bits() == 8;
}
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }

bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@

import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None

import tvm
from tvm._ffi.libinfo import find_lib_path

Expand Down Expand Up @@ -295,6 +300,20 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):
arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
if dtype == "e4m3_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
)
if dtype == "e5m2_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
)
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
arr.copyfrom(_convert_bf16_to_f32(data))
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relax/backend/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
from ..utils import has_leaking_intermediate_variables


def _is_supported_dtype(lhs_dtype, rhs_dtype):
def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
"""Check if dtypes in the given workload are supported by cuBLAS BYOC."""
if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
return out_dtype != "e5m2_float8"
return (
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
Expand All @@ -42,10 +45,12 @@ def _check_matmul(context: PatternCheckContext) -> bool:
return False
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
matmul_call = context.annotated_expr["root"]

lhs_dtype = lhs.struct_info.dtype
rhs_dtype = rhs.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype):
out_dtype = matmul_call.struct_info.dtype
if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
return False

lhs_shape = lhs.struct_info.shape.values
Expand All @@ -62,6 +67,13 @@ def _check_matmul(context: PatternCheckContext) -> bool:
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0:
# Rows number must be multiples of 4 for IGEMM
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
# Matrix dimensions must be multiples of 16. This requirement is missing from the cuBLAS
# docs, but it was observed during testing.
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 16 != 0:
return False
if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or rhs_shape[-2] % 16 != 0:
return False

lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
Expand Down
27 changes: 16 additions & 11 deletions python/tvm/relax/transform/legalize_ops/qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def te_quantize(
def quantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
round_val = te.round(data[indices] / scale_value) + zp_value
scaled = data[indices] / scale_value
round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value
return clip_cast(round_val, out_dtype)

output_shape = data.shape
Expand All @@ -75,15 +76,18 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
Compute datatype: float32

Example of lowering:
qnn.dequantize(data, scale, zp, "float32") -->
sub = subtract(cast(data, "int32"), zp)
out = multiply(cast(sub, "float32"), scale)

qnn.dequantize(data, scale, zp, "float16") -->
sub = subtract(cast(data, "int32"), zp)
mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
out = cast(clipped_out, "float16")

dtype = ["int32"|"float32"]

qnn.dequantize(data, scale, zp, "float32") -->
sub = subtract(cast(data, dtype), zp)
out = multiply(cast(sub, "float32"), scale)

qnn.dequantize(data, scale, zp, "float16") -->
sub = subtract(cast(data, dtype), zp)
mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
out = cast(clipped_out, "float16")
"""
axis = call.attrs.axis
out_dtype = call.attrs.out_dtype
Expand All @@ -96,7 +100,8 @@ def te_dequantize(
def dequantize_compute(*indices):
scale_value = scale if is_const_scalar(scale) else scale[indices[axis]]
zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
sub = te.subtract(data[indices].astype("int32"), zp_value)
dtype = "float32" if "float" in data.dtype else "int32"
sub = te.subtract(data[indices].astype(dtype), zp_value)
out = te.multiply(sub, scale_value.astype("float32"))
if out_dtype == "float32":
return out
Expand Down
4 changes: 4 additions & 0 deletions src/relax/backend/contrib/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ inline std::string DType2String(const tvm::DataType dtype) {
std::ostringstream os;
if (dtype.is_float()) {
os << "float";
} else if (dtype.is_e4m3_float8()) {
os << "e4m3_float";
} else if (dtype.is_e5m2_float8()) {
os << "e5m2_float";
} else if (dtype.is_int()) {
os << "int";
} else if (dtype.is_uint()) {
Expand Down
18 changes: 12 additions & 6 deletions src/relax/op/tensor/qdq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize);
StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<QuantizeAttrs>();
if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) &&
attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16)) {
attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) &&
attrs->out_dtype != DataType::NVFloat8E4M3() &&
attrs->out_dtype != DataType::NVFloat8E5M2()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Unsupported output datatype attribute for operation: '"
<< attrs->out_dtype);
Expand All @@ -73,9 +75,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
}

// Check datatype of zero_point param:
if (zp_sinfo->dtype != DataType::Int(8)) {
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "zero_point param datatype should be int8, but got " << zp_sinfo->dtype);
<< "zero_point param datatype should be 'int8' or 'float16', but got "
<< zp_sinfo->dtype);
}

// Check that "axis" attribute is not out of range:
Expand Down Expand Up @@ -142,7 +145,9 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
// Check input datatype:
if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) &&
input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) &&
input_sinfo->dtype != DataType::Int(32)) {
input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::NVFloat8E4M3() &&
input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype != DataType::Float(16) &&
input_sinfo->dtype != DataType::Float(32)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Unsupported input datatype for operation: " << attrs->out_dtype);
}
Expand All @@ -155,9 +160,10 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
}

// Check datatype of zero_point param:
if (zp_sinfo->dtype != DataType::Int(8)) {
if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "zero_point param datatype should be int8, but got " << zp_sinfo->dtype);
<< "zero_point param datatype should be 'int8' or 'float16', but got "
<< zp_sinfo->dtype);
}

// Check that "axis" attribute is not out of range:
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
ab_type = CUDA_R_16F;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
} else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
ab_type = CUDA_R_8F_E4M3;
}

if (TypeMatch(C->dtype, kDLFloat, 16)) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
} else if (dtype.is_float8()) {
// according to https://arxiv.org/pdf/2209.05433.pdf
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
return FloatImm(dtype, 57344.0, span);
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
Expand Down Expand Up @@ -303,6 +304,7 @@ PrimExpr min_value(const DataType& dtype, Span span) {
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
} else if (dtype.is_float8()) {
// according to https://arxiv.org/pdf/2209.05433.pdf
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
return FloatImm(dtype, -57344.0, span);
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
Expand Down
59 changes: 59 additions & 0 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import relax as R

try:
import ml_dtypes
except ImportError:
ml_dtypes = None


@pytest.fixture(autouse=True)
def reset_seed():
Expand Down Expand Up @@ -226,6 +231,60 @@ def test_matmul_igemm_offload(
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, out_dtype",
[
((10, 32), (64, 32), True, "float32"),
((32, 16), (32, 16), True, "float16"),
((2, 10, 32), (2, 64, 32), True, "float32"),
],
)
def test_matmul_fp8_offload(
x_shape,
y_shape,
transpose_y,
out_dtype,
):
in_dtype = "e4m3_float8"
mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=None,
transposed_y=transpose_y,
activation=None,
)
numpytype = "float8_e4m3fn"
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
z = np.swapaxes(y, -2, -1) if transpose_y else y
args = (x, y)

out = get_result_with_relax_cublas_offload(mod, args)
ref_out = np.matmul(x, z).astype(out_dtype)

tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize(
"M, N, K, out_dtype, partition_done",
[
(15, 64, 32, "float32", True),
(15, 64, 32, "e4m3_float8", True),
(15, 64, 32, "e5m2_float8", False),
(16, 32, 60, "float32", False),
(16, 30, 64, "float32", False),
],
)
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done):
mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True)
mod = partition_for_cublas(mod)
func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
assert func_name in mod["main"].script()


def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))
Expand Down
37 changes: 37 additions & 0 deletions tests/python/relax/test_op_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,42 @@ def test_qdq_op_infer_struct_info_symbolic():
)


def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "float16"))
_check_inference(
bb,
relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
relax.TensorStructInfo((n, 3), "e4m3_float8"),
)
_check_inference(
bb,
relax.op.dequantize(dx, s, zp, 1, "float32"),
relax.TensorStructInfo((n, 3), "float32"),
)


def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
dtype = "e5m2_float8"
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
x = relax.Var("x", R.Tensor((n, 3), "float32"))
dx = relax.Var("dx", R.Tensor((n, 3), dtype))
s = relax.Var("s", R.Tensor([3], "float32"))
zp = relax.Var("zp", R.Tensor([3], "float16"))
_check_inference(
bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 3), dtype)
)
_check_inference(
bb,
relax.op.dequantize(dx, s, zp, 1, "float32"),
relax.TensorStructInfo((n, 3), "float32"),
)


if __name__ == "__main__":
tvm.testing.main()