Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DataType {
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
kE2M1Float = 8U,
kFloat4E2M1Fn = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand Down Expand Up @@ -88,7 +88,7 @@ class DataType {
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
if (code == kE2M1Float) {
if (code == kFloat4E2M1Fn) {
ICHECK_EQ(bits, 4);
}
}
Expand Down Expand Up @@ -131,12 +131,10 @@ class DataType {
bits() == 8;
}
/*! \return whether type is a float4 type. */
bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; }
bool is_float4() const { return code() == DataType::kFloat4E2M1Fn && bits() == 4; }
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }

bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }

bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); }
bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4E2M1Fn && bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -262,11 +260,11 @@ class DataType {
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
/*!
* \brief Construct NV float4 e2m1 datatype.
* \brief Construct NV float4_e2m1fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); }
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4E2M1Fn, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
Expand Down Expand Up @@ -313,7 +311,7 @@ inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
return 1;
}
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
Expand Down Expand Up @@ -399,8 +397,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
case DataType::kE2M1Float:
return "e2m1_float";
case DataType::kFloat4E2M1Fn:
return "float4_e2m1fn";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
Expand Down Expand Up @@ -458,6 +456,18 @@ inline DLDataType String2DLDataType(std::string s) {
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt;
scan = s.c_str() + 4;
} else if (s.substr(0, 13) == "float4_e2m1fn") {
// Avoid being treated as "float"
t.code = DataType::kFloat4E2M1Fn;
t.bits = 4;
scan = s.c_str() + 13;
char* endpt = nullptr;
if (*scan == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(scan + 1, &endpt, 10));
scan = endpt;
}
ICHECK(scan == s.c_str() + s.length()) << "unknown type " << s;
return t;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat;
scan = s.c_str() + 5;
Expand All @@ -482,10 +492,6 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kE5M2Float;
t.bits = 8;
scan = s.c_str() + 10;
} else if (s.substr(0, 10) == "e2m1_float") {
t.code = DataType::kE2M1Float;
t.bits = 4;
scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1fn, DataType::NVFloat4E2M1FN);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
Expand Down
23 changes: 15 additions & 8 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DataTypeCode(object):
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7
E2M1Float = 8
FLOAT4E2M1FN = 8


class DataType(ctypes.Structure):
Expand All @@ -83,7 +83,7 @@ class DataType(ctypes.Structure):
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
DataTypeCode.E2M1Float: "e2m1_float",
DataTypeCode.FLOAT4E2M1FN: "float4_e2m1fn",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
Expand Down Expand Up @@ -114,7 +114,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1},
"float4_e2m1fn": {"type_code": DataTypeCode.FLOAT4E2M1FN, "bits": 4, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -155,6 +155,11 @@ def __init__(self, type_str):
elif head.startswith("uint"):
self.type_code = DataTypeCode.UINT
head = head[4:]
elif head.startswith("float4_e2m1fn"):
# Avoid being treated as "float"
self.type_code = DataTypeCode.FLOAT4E2M1FN
bits = 4
head = ""
elif head.startswith("float"):
self.type_code = DataTypeCode.FLOAT
head = head[5:]
Expand All @@ -171,9 +176,6 @@ def __init__(self, type_str):
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
elif head.startswith("e2m1_float"):
self.type_code = DataTypeCode.E2M1Float
head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down Expand Up @@ -201,7 +203,12 @@ def __repr__(self):
import tvm.runtime._ffi_api

type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
if self.type_code in [
DataTypeCode.FLOAT4E2M1FN,
]:
x = type_name
else:
x = "%s%d" % (type_name, self.bits)
lanes_as_int = ctypes.c_int16(self.lanes).value
if lanes_as_int > 1:
x += "x%d" % self.lanes
Expand Down Expand Up @@ -238,7 +245,7 @@ def itemsize(self):
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn"

RPC_SESS_MASK = 128

Expand Down
14 changes: 9 additions & 5 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def copyfrom(self, source_array):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
if dtype.startswith("e2m1_float4"):
if self.dtype.startswith("float4_e2m1fn") and self.dtype != "float4_e2m1fn":
# float4_e2m1fn in numpy is not packed.
# So we need to pack the input data when converting to vectorized float4_e2m1fn type.
data_bits = source_array.view(dtype="uint8")
if data_bits.size % 2:
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
Expand Down Expand Up @@ -261,22 +263,24 @@ def numpy(self):
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
)
if dtype == "e2m1_float4":
if dtype == "float4_e2m1fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float4_e2m1fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy."
"ml_dtypes is not installed, cannot convert float4_e2m1fn array to numpy."
)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
if old_dtype.startswith("e2m1_float4"):
if old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn":
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
else:
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
if old_dtype == "int4" or (
old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn"
):
length = np_arr.size
np_arr = np_arr.view("int8")
np_arr_ret = np.empty((length,), dtype="int8")
Expand Down
31 changes: 16 additions & 15 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import functools
import inspect
from numbers import Integral
import sys
from numbers import Integral
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# isort: off
Expand All @@ -29,8 +29,7 @@

import numpy as np # type: ignore

from tvm import tir
from tvm import ir
from tvm import ir, tir
from tvm.ir import Type
from tvm.ir.base import deprecated
from tvm.runtime import String, convert, ndarray
Expand Down Expand Up @@ -1457,12 +1456,13 @@ def func(
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
e5m2_float8x64 = func_gen(("E5M2Float8x64"))

e2m1_float4 = func_gen(("E2M1Float4"))
e2m1_float4x4 = func_gen(("E2M1Float4x4"))
e2m1_float4x8 = func_gen(("E2M1Float4x8"))
e2m1_float4x16 = func_gen(("E2M1Float4x16"))
e2m1_float4x32 = func_gen(("E2M1Float4x32"))
e2m1_float4x64 = func_gen(("E2M1Float4x64"))
float4_e2m1fn = func_gen(("Float4E2M1fn"))
float4_e2m1fnx2 = func_gen(("Float4E2M1fnx2"))
float4_e2m1fnx4 = func_gen(("Float4E2M1fnx4"))
float4_e2m1fnx8 = func_gen(("Float4E2M1fnx8"))
float4_e2m1fnx16 = func_gen(("Float4E2M1fnx16"))
float4_e2m1fnx32 = func_gen(("Float4E2M1fnx32"))
float4_e2m1fnx64 = func_gen(("Float4E2M1fnx64"))


# pylint: enable=invalid-name
Expand Down Expand Up @@ -2013,37 +2013,38 @@ def wrapped(*args, **kwargs):
"uint64x64",
"e4m3_float8",
"e5m2_float8",
"e2m1_float4",
"float4_e2m1fn",
"float16",
"float32",
"float64",
"float4_e2m1fnx2",
"e4m3_float8x4",
"e5m2_float8x4",
"e2m1_float4x4",
"float4_e2m1fnx4",
"float16x4",
"float32x4",
"float64x4",
"e4m3_float8x8",
"e5m2_float8x8",
"e2m1_float4x8",
"float4_e2m1fnx8",
"float16x8",
"float32x8",
"float64x8",
"e4m3_float8x16",
"e5m2_float8x16",
"e2m1_float4x16",
"float4_e2m1fnx16",
"float16x16",
"float32x16",
"float64x16",
"e4m3_float8x32",
"e5m2_float8x32",
"e2m1_float4x32",
"float4_e2m1fnx32",
"float16x32",
"float32x32",
"float64x32",
"e4m3_float8x64",
"e5m2_float8x64",
"e2m1_float4x64",
"float4_e2m1fnx64",
"float16x64",
"float32x64",
"float64x64",
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ inline void VerifyDataType(DLDataType dtype) {
return;
else if (dtype.bits == 4 && dtype.code == kDLInt)
return;
else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
else if (dtype.bits == 4 && dtype.code == DataType::kFloat4E2M1Fn)
return;
else
ICHECK_EQ(dtype.bits % 8, 0);
Expand Down
4 changes: 2 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1fn").set_body_typed(Float4E2M1fn);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1fn", Float4E2M1fn);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
} else if (dtype.code() == DataType::kE2M1Float) {
} else if (dtype.code() == DataType::kFloat4E2M1Fn) {
etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
Expand Down
8 changes: 7 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,11 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
}
}

if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
Expand Down Expand Up @@ -839,7 +844,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
} else {
arith::PVar<PrimExpr> base;

if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr) &&
!value_dtype.is_float4_e2m1fn()) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
} else {
Expand Down
Loading
Loading