Skip to content

Commit 8d7c83c

Browse files
change NV naming; prevent CUDA codegen on unsupported dtypes
Co-authored-by: DerrickYLJ <[email protected]>
1 parent 1c36432 commit 8d7c83c

File tree

6 files changed

+45
-50
lines changed

6 files changed

+45
-50
lines changed

include/tvm/runtime/data_type.h

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -295,83 +295,83 @@ class DataType {
295295
*/
296296
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
297297
/*!
298-
* \brief Construct NV float8 e3m4 datatype.
298+
* \brief Construct float8 e3m4 datatype.
299299
* \param lanes The number of lanes
300300
* \return The constructed data type.
301301
*/
302-
static DataType NVFloat8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); }
302+
static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); }
303303

304304
/*!
305-
* \brief Construct NV float8 e4m3 datatype.
305+
* \brief Construct float8 e4m3 datatype.
306306
* \param lanes The number of lanes
307307
* \return The constructed data type.
308308
*/
309-
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); }
309+
static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); }
310310

311311
/*!
312-
* \brief Construct NV float8 e4m3b11fnuz datatype.
312+
* \brief Construct float8 e4m3b11fnuz datatype.
313313
* \param lanes The number of lanes
314314
* \return The constructed data type.
315315
*/
316-
static DataType NVFloat8E4M3B11FNUZ(int lanes = 1) {
316+
static DataType Float8E4M3B11FNUZ(int lanes = 1) {
317317
return DataType(kFloat8_e4m3b11fnuz, 8, lanes);
318318
}
319319

320320
/*!
321-
* \brief Construct NV float8 e4m3fn datatype.
321+
* \brief Construct float8 e4m3fn datatype.
322322
* \param lanes The number of lanes
323323
* \return The constructed data type.
324324
*/
325-
static DataType NVFloat8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
325+
static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
326326

327327
/*!
328-
* \brief Construct NV float8 e4m3fnuz datatype.
328+
* \brief Construct float8 e4m3fnuz datatype.
329329
* \param lanes The number of lanes
330330
* \return The constructed data type.
331331
*/
332-
static DataType NVFloat8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); }
332+
static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); }
333333

334334
/*!
335-
* \brief Construct NV float8 e5m2 datatype.
335+
* \brief Construct float8 e5m2 datatype.
336336
* \param lanes The number of lanes
337337
* \return The constructed data type.
338338
*/
339-
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
339+
static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
340340

341341
/*!
342-
* \brief Construct NV float8 e5m2fnuz datatype.
342+
* \brief Construct float8 e5m2fnuz datatype.
343343
* \param lanes The number of lanes
344344
* \return The constructed data type.
345345
*/
346-
static DataType NVFloat8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); }
346+
static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); }
347347

348348
/*!
349-
* \brief Construct NV float8 e8m0fnu datatype.
349+
* \brief Construct float8 e8m0fnu datatype.
350350
* \param lanes The number of lanes
351351
* \return The constructed data type.
352352
*/
353-
static DataType NVFloat8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); }
353+
static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); }
354354

355355
/*!
356-
* \brief Construct NV float6 e2m3fn datatype.
356+
* \brief Construct float6 e2m3fn datatype.
357357
* \param lanes The number of lanes
358358
* \return The constructed data type.
359359
*/
360-
static DataType NVFloat6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); }
360+
static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); }
361361

362362
/*!
363-
* \brief Construct NV float6 e3m2fn datatype.
363+
* \brief Construct float6 e3m2fn datatype.
364364
* \param lanes The number of lanes
365365
* \return The constructed data type.
366366
*/
367-
static DataType NVFloat6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); }
367+
static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); }
368368

369369
/*!
370-
* \brief Construct NV float4 e2m1fn datatype.
370+
* \brief Construct float4 e2m1fn datatype.
371371
* \param lanes The number of lanes
372372
* \return The constructed data type.
373373
*/
374-
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
374+
static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
375375
/*!
376376
* \brief Construct a bool type.
377377
* \param lanes The number of lanes.
@@ -418,8 +418,8 @@ inline int GetVectorBytes(DataType dtype) {
418418
int data_bits = dtype.bits() * dtype.lanes();
419419
// allow bool to exist
420420
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
421-
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN() ||
422-
dtype == DataType::NVFloat6E2M3FN() || dtype == DataType::NVFloat6E3M2FN()) {
421+
dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() ||
422+
dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) {
423423
return 1;
424424
}
425425
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -504,20 +504,19 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
504504
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
505505
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));
506506

507-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::NVFloat8E3M4);
508-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::NVFloat8E4M3);
509-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ,
510-
DataType::NVFloat8E4M3B11FNUZ);
511-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3FN);
512-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::NVFloat8E4M3FNUZ);
513-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2);
514-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::NVFloat8E5M2FNUZ);
515-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::NVFloat8E8M0FNU);
516-
517-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::NVFloat6E2M3FN);
518-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::NVFloat6E3M2FN);
519-
520-
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN);
507+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::Float8E3M4);
508+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::Float8E4M3);
509+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, DataType::Float8E4M3B11FNUZ);
510+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::Float8E4M3FN);
511+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::Float8E4M3FNUZ);
512+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::Float8E5M2);
513+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::Float8E5M2FNUZ);
514+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::Float8E8M0FNU);
515+
516+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::Float6E2M3FN);
517+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float6E3M2FN);
518+
519+
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN);
521520

522521
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
523522
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

src/relax/op/tensor/qdq.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
5050
const auto* attrs = call->attrs.as<QuantizeAttrs>();
5151
if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) &&
5252
attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) &&
53-
attrs->out_dtype != DataType::NVFloat8E4M3() &&
54-
attrs->out_dtype != DataType::NVFloat8E5M2()) {
53+
attrs->out_dtype != DataType::Float8E4M3FN() && attrs->out_dtype != DataType::Float8E5M2()) {
5554
ctx->ReportFatal(Diagnostic::Error(call)
5655
<< "Unsupported output datatype attribute for operation: '"
5756
<< attrs->out_dtype);
@@ -145,8 +144,8 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx)
145144
// Check input datatype:
146145
if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) &&
147146
input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) &&
148-
input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::NVFloat8E4M3() &&
149-
input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype != DataType::Float(16) &&
147+
input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::Float8E4M3FN() &&
148+
input_sinfo->dtype != DataType::Float8E5M2() && input_sinfo->dtype != DataType::Float(16) &&
150149
input_sinfo->dtype != DataType::Float(32)) {
151150
ctx->ReportFatal(Diagnostic::Error(call)
152151
<< "Unsupported input datatype for operation: " << attrs->out_dtype);

src/target/source/codegen_cuda.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ std::string GetFP8Type(DataType type) {
6060
}
6161
stream << "__nv_fp8";
6262
std::string suffix;
63-
if (type.code() == DataType::kFloat8_e4m3fn || type.code() == DataType::kFloat8_e4m3fnuz ||
64-
type.code() == DataType::kFloat8_e4m3 || type.code() == DataType::kFloat8_e4m3b11fnuz) {
63+
if (type.code() == DataType::kFloat8_e4m3fn) {
6564
suffix = "_e4m3";
66-
} else if (type.code() == DataType::kFloat8_e5m2 || type.code() == DataType::kFloat8_e5m2fnuz) {
65+
} else if (type.code() == DataType::kFloat8_e5m2) {
6766
suffix = "_e5m2";
6867
} else if (type.code() == DataType::kFloat8_e8m0fnu) {
6968
suffix = "_e8m0";

src/tir/transforms/vectorize_loop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
503503
if (value.dtype().is_scalable_vector()) {
504504
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value});
505505
} else {
506-
int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() &&
507-
op->args[0].dtype() != DataType::NVFloat4E2M1FN())
506+
int new_lanes = (op->dtype != DataType::Float4E2M1FN() &&
507+
op->args[0].dtype() != DataType::Float4E2M1FN())
508508
? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits()
509509
: value.dtype().lanes();
510510
return Call(op->dtype.with_lanes(new_lanes), op->op, {value});

tests/python/codegen/test_target_codegen_cuda_fp8.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@
4040
"input",
4141
[
4242
("float8_e4m3fn", "__nv_fp8_e4m3"),
43-
("float8_e4m3fnuz", "__nv_fp8_e4m3"),
4443
("float8_e5m2", "__nv_fp8_e5m2"),
45-
("float8_e5m2fnuz", "__nv_fp8_e5m2"),
4644
],
4745
)
4846
@tvm.testing.requires_cuda_compute_version(10)
@@ -90,7 +88,7 @@ def add(
9088

9189
@pytest.mark.parametrize(
9290
"dtype",
93-
["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu"],
91+
["float8_e4m3fn", "float8_e5m2", "float8_e8m0fnu"],
9492
)
9593
@tvm.testing.requires_cuda_compute_version(10)
9694
def test_fp8_packing(dtype):

0 commit comments

Comments
 (0)