Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/jenkins/docker-images.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# This data file is read during when Jenkins runs job to determine docker images.
[jenkins]
ci_arm: tlcpack/ci-arm:20250226-223225-63bc315f
ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f
ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f
ci_cpu: tlcpack/ci_cpu:20250226-223225-63bc315f_patch
ci_gpu: tlcpack/ci-gpu:20250226-223225-63bc315f_patch
ci_hexagon: tlcpack/ci-hexagon:20250226-223225-63bc315f
ci_i386: tlcpack/ci-i386:20250226-223225-63bc315f
ci_lint: tlcpack/ci-lint:20250226-223225-63bc315f
Expand Down
4 changes: 2 additions & 2 deletions ci/jenkins/unity_jenkinsfile.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.jenkinsci.plugins.pipeline.modeldefinition.Utils

// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. -->
ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f'
ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f'
ci_gpu = 'tlcpack/ci-gpu:20250226-223225-63bc315f_patch'
ci_cpu = 'tlcpack/ci-cpu:20250226-223225-63bc315f_patch'
// <--- End of regex-scanned config.

// Parameters to allow overriding (in Jenkins UI), the images
Expand Down
22 changes: 21 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class DataType {
kBFloat = kDLBfloat,
kE4M3Float = 6U,
kE5M2Float = 7U,
kE2M1Float = 8U,
kCustomBegin = 129
};
/*! \brief default constructor */
Expand Down Expand Up @@ -87,6 +88,9 @@ class DataType {
if (code == kE4M3Float || code == kE5M2Float) {
ICHECK_EQ(bits, 8);
}
if (code == kE2M1Float) {
ICHECK_EQ(bits, 4);
}
}
/*! \return The type code. */
int code() const { return static_cast<int>(data_.code); }
Expand Down Expand Up @@ -126,9 +130,13 @@ class DataType {
code() == DataType::kE5M2Float) &&
bits() == 8;
}
/*! \return whether type is a float4 type. */
bool is_float4() const { return code() == DataType::kE2M1Float && bits() == 4; }
bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); }

bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); }

bool is_e2m1_float4() const { return (code() == DataType::kE2M1Float && bits() == 4); }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -253,6 +261,12 @@ class DataType {
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); }
/*!
* \brief Construct NV float4 e2m1 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat4E2M1(int lanes = 1) { return DataType(kE2M1Float, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
Expand Down Expand Up @@ -299,7 +313,7 @@ inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1()) {
return 1;
}
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
Expand Down Expand Up @@ -385,6 +399,8 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
return "e4m3_float";
case DataType::kE5M2Float:
return "e5m2_float";
case DataType::kE2M1Float:
return "e2m1_float";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
Expand Down Expand Up @@ -466,6 +482,10 @@ inline DLDataType String2DLDataType(std::string s) {
t.code = DataType::kE5M2Float;
t.bits = 8;
scan = s.c_str() + 10;
} else if (s.substr(0, 10) == "e2m1_float") {
t.code = DataType::kE2M1Float;
t.bits = 4;
scan = s.c_str() + 10;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E4M3Float8, DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E5M2Float8, DataType::NVFloat8E5M2);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(E2M1Float4, DataType::NVFloat4E2M1);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16() || t.is_float8())
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4())
return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class DataTypeCode(object):
BFLOAT = 4
E4M3Float = 6
E5M2Float = 7
E2M1Float = 8


class DataType(ctypes.Structure):
Expand All @@ -82,6 +83,7 @@ class DataType(ctypes.Structure):
DataTypeCode.BFLOAT: "bfloat",
DataTypeCode.E4M3Float: "e4m3_float",
DataTypeCode.E5M2Float: "e5m2_float",
DataTypeCode.E2M1Float: "e2m1_float",
}
NUMPY2STR = {
np.dtype(np.bool_): "bool",
Expand Down Expand Up @@ -112,6 +114,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"e2m1_float4": {"type_code": DataTypeCode.E2M1Float, "bits": 4, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -168,6 +171,9 @@ def __init__(self, type_str):
elif head.startswith("e5m2_float"):
self.type_code = DataTypeCode.E5M2Float
head = head[10:]
elif head.startswith("e2m1_float"):
self.type_code = DataTypeCode.E2M1Float
head = head[10:]
elif head.startswith("custom"):
# pylint: disable=import-outside-toplevel
import tvm.runtime._ffi_api
Expand Down Expand Up @@ -232,6 +238,7 @@ def itemsize(self):
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "e2m1_float4"

RPC_SESS_MASK = 128

Expand Down
16 changes: 16 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,19 @@ def have_fp8(compute_version):
if major >= 9:
return True
return False


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4")
def have_fp4(compute_version):
"""Whether fp4 support is provided in the specified compute capability or not
Parameters
----------
compute_version : str
GPU capability
"""
major, minor = parse_compute_version(compute_version)
# fp4 is suppored in Blackwell (10.0) or later architectures.
if major == 10 and minor == 0:
return True
return False
25 changes: 22 additions & 3 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ def copyfrom(self, source_array):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
if dtype.startswith("e2m1_float4"):
data_bits = source_array.view(dtype="uint8")
if data_bits.size % 2:
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
data_bits = data_bits.reshape(-1, 2)
packed = ((data_bits[:, 0] & 0x0F) << 4) | (data_bits[:, 1] & 0x0F)
source_array = packed.astype(np.int8)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down Expand Up @@ -254,20 +261,32 @@ def numpy(self):
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
)
if dtype == "e2m1_float4":
if ml_dtypes is not None:
dtype = ml_dtypes.float4_e2m1fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e2m1_float4 array to numpy."
)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
if old_dtype.startswith("e2m1_float4"):
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize // 2)
else:
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
if old_dtype == "int4":
if old_dtype == "int4" or old_dtype.startswith("e2m1_float4"):
length = np_arr.size
np_arr = np_arr.view("int8")
np_arr_ret = np.empty((length,), dtype="int8")
np_arr = np_arr.reshape((length,))
old_index = np.bitwise_and(np_arr, 0x0F)
even_index = np.bitwise_and(np_arr >> 4, 0x0F)
np_arr_ret[1::2] = old_index[0 : length // 2]
np_arr_ret[0::2] = even_index[0 : length // 2]
return np_arr_ret.reshape(shape)
return np_arr_ret.reshape(shape).view(dtype)

return np_arr

def copyto(self, target, mem_scope=None):
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,14 @@ def func(
e5m2_float8x32 = func_gen(("E5M2Float8x32"))
e5m2_float8x64 = func_gen(("E5M2Float8x64"))

e2m1_float4 = func_gen(("E2M1Float4"))
e2m1_float4x4 = func_gen(("E2M1Float4x4"))
e2m1_float4x8 = func_gen(("E2M1Float4x8"))
e2m1_float4x16 = func_gen(("E2M1Float4x16"))
e2m1_float4x32 = func_gen(("E2M1Float4x32"))
e2m1_float4x64 = func_gen(("E2M1Float4x64"))


# pylint: enable=invalid-name


Expand Down Expand Up @@ -2005,31 +2013,37 @@ def wrapped(*args, **kwargs):
"uint64x64",
"e4m3_float8",
"e5m2_float8",
"e2m1_float4",
"float16",
"float32",
"float64",
"e4m3_float8x4",
"e5m2_float8x4",
"e2m1_float4x4",
"float16x4",
"float32x4",
"float64x4",
"e4m3_float8x8",
"e5m2_float8x8",
"e2m1_float4x8",
"float16x8",
"float32x8",
"float64x8",
"e4m3_float8x16",
"e5m2_float8x16",
"e2m1_float4x16",
"float16x16",
"float32x16",
"float64x16",
"e4m3_float8x32",
"e5m2_float8x32",
"e2m1_float4x32",
"float16x32",
"float32x32",
"float64x32",
"e4m3_float8x64",
"e5m2_float8x64",
"e2m1_float4x64",
"float16x64",
"float32x64",
"float64x64",
Expand Down
7 changes: 6 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() ||
ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4() ||
dtype.code() >= DataType::kCustomBegin)
<< "ValueError: FloatImm supports only float, but " << dtype << " was supplied.";

Expand All @@ -137,6 +137,11 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) {
<< dtype;
ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of "
<< dtype;
} else if (dtype.is_float4()) {
ICHECK_GE(value, -support::kMaxE2M1)
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxE2M1)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>

#include "runtime_base.h"
#include "tvm/runtime/data_type.h"

extern "C" {
// C-mangled dlpack deleter.
Expand All @@ -53,6 +54,8 @@ inline void VerifyDataType(DLDataType dtype) {
return;
else if (dtype.bits == 4 && dtype.code == kDLInt)
return;
else if (dtype.bits == 4 && dtype.code == DataType::kE2M1Float)
return;
else
ICHECK_EQ(dtype.bits % 8, 0);
}
Expand Down
3 changes: 3 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.E2M1Float4").set_body_typed(E2M1Float4);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E2M1Float4", E2M1Float4);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
Expand Down
3 changes: 3 additions & 0 deletions src/support/scalars.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ constexpr double kMaxE4M3 = 448;
// See https://arxiv.org/pdf/2209.05433.pdf
constexpr double kMaxE5M2 = 57344;

// 2^2 * (1 + 1/2)
constexpr double kMaxE2M1 = 6.0;

} // namespace support
} // namespace tvm

Expand Down
2 changes: 2 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
} else if (dtype.code() == DataType::kE2M1Float) {
etype = llvm::Type::getIntNTy(*ctx, 4);
}
if (!dtype.is_scalar()) {
#if TVM_LLVM_VERSION >= 110
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp
}

std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
Expand Down
Loading
Loading