diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index daea2a23d6fbe..cd8d3ee0af72b 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -61,17 +61,6 @@ class Builder { Attribute metadata = Attribute()); // Types. - FloatType getFloat4E2M1FNType(); - FloatType getFloat6E2M3FNType(); - FloatType getFloat6E3M2FNType(); - FloatType getFloat8E5M2Type(); - FloatType getFloat8E4M3Type(); - FloatType getFloat8E4M3FNType(); - FloatType getFloat8E5M2FNUZType(); - FloatType getFloat8E4M3FNUZType(); - FloatType getFloat8E4M3B11FNUZType(); - FloatType getFloat8E3M4Type(); - FloatType getFloat8E8M0FNUType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getTF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index fc50b28c09e41..4f09d2e41e7ce 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -85,6 +85,12 @@ class Builtin_FloatType]> { +} + +// Float types that are cached in MLIRContext. +class Builtin_CachedFloatType declaredInterfaceMethods = []> + : Builtin_FloatType { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -326,7 +332,7 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> { //===----------------------------------------------------------------------===// // BFloat16Type -def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16", +def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16", /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "bfloat16 floating-point type"; } @@ -334,7 +340,7 @@ def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16", //===----------------------------------------------------------------------===// // Float16Type -def Builtin_Float16 : Builtin_FloatType<"Float16", "f16", +def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16", /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "16-bit floating-point type"; } @@ -342,14 +348,14 @@ def Builtin_Float16 : Builtin_FloatType<"Float16", "f16", //===----------------------------------------------------------------------===// // FloatTF32Type -def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> { +def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> { let summary = "TF32 floating-point type"; } //===----------------------------------------------------------------------===// // Float32Type -def Builtin_Float32 : Builtin_FloatType<"Float32", "f32", +def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32", /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "32-bit floating-point type"; } @@ -357,21 +363,21 @@ def Builtin_Float32 : Builtin_FloatType<"Float32", "f32", //===----------------------------------------------------------------------===// // Float64Type -def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> { +def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> { let summary = "64-bit floating-point type"; } //===----------------------------------------------------------------------===// // Float80Type -def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> { +def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> { let summary = "80-bit floating-point type"; } //===----------------------------------------------------------------------===// // Float128Type -def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> { +def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> { let summary = "128-bit floating-point type"; } diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index b9f8c1ed19470..6f52195c1d7c9 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -330,31 +330,31 @@ def F80 : F<80>; def F128 : F<128>; def BF16 : Type, "bfloat16 type">, - BuildableType<"$_builder.getBF16Type()">; + BuildableType<"$_builder.getType()">; def TF32 : Type, "tf32 type">, - BuildableType<"$_builder.getTF32Type()">; + BuildableType<"$_builder.getType()">; def F8E4M3FN : Type, "f8E4M3FN type">, - BuildableType<"$_builder.getFloat8E4M3FNType()">; + BuildableType<"$_builder.getType()">; def F8E5M2 : Type, "f8E5M2 type">, - BuildableType<"$_builder.getFloat8E5M2Type()">; + BuildableType<"$_builder.getType()">; def F8E4M3 : Type, "f8E4M3 type">, - BuildableType<"$_builder.getFloat8E4M3Type()">; + BuildableType<"$_builder.getType()">; def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, - BuildableType<"$_builder.getFloat8E4M3FNUZType()">; + BuildableType<"$_builder.getType()">; def F8E4M3B11FNUZ : Type, "f8E4M3B11FNUZ type">, - BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">; + BuildableType<"$_builder.getType()">; def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, - BuildableType<"$_builder.getFloat8E5M2FNUZType()">; + BuildableType<"$_builder.getType()">; def F8E3M4 : Type, "f8E3M4 type">, - BuildableType<"$_builder.getFloat8E3M4Type()">; + BuildableType<"$_builder.getType()">; def F4E2M1FN : Type, "f4E2M1FN type">, - BuildableType<"$_builder.getFloat4E2M1FNType()">; + BuildableType<"$_builder.getType()">; def F6E2M3FN : Type, "f6E2M3FN type">, - BuildableType<"$_builder.getFloat6E2M3FNType()">; + BuildableType<"$_builder.getType()">; def F6E3M2FN : Type, "f6E3M2FN type">, - BuildableType<"$_builder.getFloat6E3M2FNType()">; + BuildableType<"$_builder.getType()">; def F8E8M0FNU : Type, "f8E8M0FNU type">, - BuildableType<"$_builder.getFloat8E8M0FNUType()">; + BuildableType<"$_builder.getType()">; def AnyComplex : Type($_self)">, "complex-type", "::mlir::ComplexType">; diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index c614eb39b364b..21bb0ec3d0d51 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() { // float-type case Token::kw_f4E2M1FN: consumeToken(Token::kw_f4E2M1FN); - return builder.getFloat4E2M1FNType(); + return builder.getType(); case Token::kw_f6E2M3FN: consumeToken(Token::kw_f6E2M3FN); - return builder.getFloat6E2M3FNType(); + return builder.getType(); case Token::kw_f6E3M2FN: consumeToken(Token::kw_f6E3M2FN); - return builder.getFloat6E3M2FNType(); + return builder.getType(); case Token::kw_f8E5M2: consumeToken(Token::kw_f8E5M2); - return builder.getFloat8E5M2Type(); + return builder.getType(); case Token::kw_f8E4M3: consumeToken(Token::kw_f8E4M3); - return builder.getFloat8E4M3Type(); + return builder.getType(); case Token::kw_f8E4M3FN: consumeToken(Token::kw_f8E4M3FN); - return builder.getFloat8E4M3FNType(); + return builder.getType(); case Token::kw_f8E5M2FNUZ: consumeToken(Token::kw_f8E5M2FNUZ); - return builder.getFloat8E5M2FNUZType(); + return builder.getType(); case Token::kw_f8E4M3FNUZ: consumeToken(Token::kw_f8E4M3FNUZ); - return builder.getFloat8E4M3FNUZType(); + return builder.getType(); case Token::kw_f8E4M3B11FNUZ: consumeToken(Token::kw_f8E4M3B11FNUZ); - return builder.getFloat8E4M3B11FNUZType(); + return builder.getType(); case Token::kw_f8E3M4: consumeToken(Token::kw_f8E3M4); - return builder.getFloat8E3M4Type(); + return builder.getType(); case Token::kw_f8E8M0FNU: consumeToken(Token::kw_f8E8M0FNU); - return builder.getFloat8E8M0FNUType(); + return builder.getType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); - return builder.getBF16Type(); + return builder.getType(); case Token::kw_f16: consumeToken(Token::kw_f16); - return builder.getF16Type(); + return builder.getType(); case Token::kw_tf32: consumeToken(Token::kw_tf32); - return builder.getTF32Type(); + return builder.getType(); case Token::kw_f32: consumeToken(Token::kw_f32); - return builder.getF32Type(); + return builder.getType(); case Token::kw_f64: consumeToken(Token::kw_f64); - return builder.getF64Type(); + return builder.getType(); case Token::kw_f80: consumeToken(Token::kw_f80); - return builder.getF80Type(); + return builder.getType(); case Token::kw_f128: consumeToken(Token::kw_f128); - return builder.getF128Type(); + return builder.getType(); // index-type case Token::kw_index: diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 0fa7d32184411..39c9005e449e3 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, std::optional parseFloatType(MLIRContext *ctx, StringRef name) { Builder b(ctx); return llvm::StringSwitch>(name) - .Case("f4E2M1FN", b.getFloat4E2M1FNType()) - .Case("f6E2M3FN", b.getFloat6E2M3FNType()) - .Case("f6E3M2FN", b.getFloat6E3M2FNType()) - .Case("f8E5M2", b.getFloat8E5M2Type()) - .Case("f8E4M3", b.getFloat8E4M3Type()) - .Case("f8E4M3FN", b.getFloat8E4M3FNType()) - .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType()) - .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType()) - .Case("f8E3M4", b.getFloat8E3M4Type()) - .Case("f8E8M0FNU", b.getFloat8E8M0FNUType()) - .Case("bf16", b.getBF16Type()) - .Case("f16", b.getF16Type()) - .Case("f32", b.getF32Type()) - .Case("f64", b.getF64Type()) - .Case("f80", b.getF80Type()) - .Case("f128", b.getF128Type()) + .Case("f4E2M1FN", b.getType()) + .Case("f6E2M3FN", b.getType()) + .Case("f6E3M2FN", b.getType()) + .Case("f8E5M2", b.getType()) + .Case("f8E4M3", b.getType()) + .Case("f8E4M3FN", b.getType()) + .Case("f8E5M2FNUZ", b.getType()) + .Case("f8E4M3FNUZ", b.getType()) + .Case("f8E3M4", b.getType()) + .Case("f8E8M0FNU", b.getType()) + .Case("bf16", b.getType()) + .Case("f16", b.getType()) + .Case("f32", b.getType()) + .Case("f64", b.getType()) + .Case("f80", b.getType()) + .Case("f128", b.getType()) .Default(std::nullopt); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8439b063f2634..d57a7ca07ede5 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef locs, Attribute metadata) { // Types. //===----------------------------------------------------------------------===// -FloatType Builder::getFloat4E2M1FNType() { - return Float4E2M1FNType::get(context); -} - -FloatType Builder::getFloat6E2M3FNType() { - return Float6E2M3FNType::get(context); -} - -FloatType Builder::getFloat6E3M2FNType() { - return Float6E3M2FNType::get(context); -} - -FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); } - -FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); } - -FloatType Builder::getFloat8E4M3FNType() { - return Float8E4M3FNType::get(context); -} - -FloatType Builder::getFloat8E5M2FNUZType() { - return Float8E5M2FNUZType::get(context); -} - -FloatType Builder::getFloat8E4M3FNUZType() { - return Float8E4M3FNUZType::get(context); -} - -FloatType Builder::getFloat8E4M3B11FNUZType() { - return Float8E4M3B11FNUZType::get(context); -} - -FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); } - -FloatType Builder::getFloat8E8M0FNUType() { - return Float8E8M0FNUType::get(context); -} - FloatType Builder::getBF16Type() { return BFloat16Type::get(context); } FloatType Builder::getF16Type() { return Float16Type::get(context); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index b9e745fdf4a13..87782e84dd6e4 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -221,17 +221,6 @@ class MLIRContextImpl { llvm::DenseMap nameToType; /// Cached Type Instances. - Float4E2M1FNType f4E2M1FNTy; - Float6E2M3FNType f6E2M3FNTy; - Float6E3M2FNType f6E3M2FNTy; - Float8E5M2Type f8E5M2Ty; - Float8E4M3Type f8E4M3Ty; - Float8E4M3FNType f8E4M3FNTy; - Float8E5M2FNUZType f8E5M2FNUZTy; - Float8E4M3FNUZType f8E4M3FNUZTy; - Float8E4M3B11FNUZType f8E4M3B11FNUZTy; - Float8E3M4Type f8E3M4Ty; - Float8E8M0FNUType f8E8M0FNUTy; BFloat16Type bf16Ty; Float16Type f16Ty; FloatTF32Type tf32Ty; @@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) //// Types. /// Floating-point Types. - impl->f4E2M1FNTy = TypeUniquer::get(this); - impl->f6E2M3FNTy = TypeUniquer::get(this); - impl->f6E3M2FNTy = TypeUniquer::get(this); - impl->f8E5M2Ty = TypeUniquer::get(this); - impl->f8E4M3Ty = TypeUniquer::get(this); - impl->f8E4M3FNTy = TypeUniquer::get(this); - impl->f8E5M2FNUZTy = TypeUniquer::get(this); - impl->f8E4M3FNUZTy = TypeUniquer::get(this); - impl->f8E4M3B11FNUZTy = TypeUniquer::get(this); - impl->f8E3M4Ty = TypeUniquer::get(this); - impl->f8E8M0FNUTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->tf32Ty = TypeUniquer::get(this); @@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) { /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } -Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) { - return context->getImpl().f4E2M1FNTy; -} -Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) { - return context->getImpl().f6E2M3FNTy; -} -Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) { - return context->getImpl().f6E3M2FNTy; -} -Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { - return context->getImpl().f8E5M2Ty; -} -Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) { - return context->getImpl().f8E4M3Ty; -} -Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { - return context->getImpl().f8E4M3FNTy; -} -Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) { - return context->getImpl().f8E5M2FNUZTy; -} -Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) { - return context->getImpl().f8E4M3FNUZTy; -} -Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) { - return context->getImpl().f8E4M3B11FNUZTy; -} -Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) { - return context->getImpl().f8E3M4Ty; -} -Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) { - return context->getImpl().f8E8M0FNUTy; -} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; }