Skip to content

Commit

Permalink
[mlir][IR] Remove builder API + caching for low-precision FP types (#…
Browse files Browse the repository at this point in the history
…123321)

Remove builder API (e.g., `b.getFloat4E2M1FNType()`) and caching in
`MLIRContext` for low-precision FP types. Types are still cached in the
type uniquer.

For details, see:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Note for LLVM integration: Use `b.getType<Float4E2M1FNType>()` or
`Float4E2M1FNType::get(b.getContext())` instead of
`b.getFloat4E2M1FNType()`.
  • Loading branch information
matthias-springer authored Jan 18, 2025
1 parent 67c3f2b commit f494346
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 158 deletions.
11 changes: 0 additions & 11 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@ class Builder {
Attribute metadata = Attribute());

// Types.
FloatType getFloat4E2M1FNType();
FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
Expand Down
20 changes: 13 additions & 7 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
}

// Float types that are cached in MLIRContext.
class Builtin_CachedFloatType<string name, string mnemonic,
list<string> declaredInterfaceMethods = []>
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
Expand Down Expand Up @@ -326,52 +332,52 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type

def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}

//===----------------------------------------------------------------------===//
// Float16Type

def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// FloatTF32Type

def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
let summary = "TF32 floating-point type";
}

//===----------------------------------------------------------------------===//
// Float32Type

def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float64Type

def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
let summary = "64-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float80Type

def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
let summary = "80-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float128Type

def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
let summary = "128-bit floating-point type";
}

Expand Down
26 changes: 13 additions & 13 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,31 +330,31 @@ def F80 : F<80>;
def F128 : F<128>;

def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
BuildableType<"$_builder.getTF32Type()">;
BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
BuildableType<"$_builder.getFloat8E4M3Type()">;
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
BuildableType<"$_builder.getFloat4E2M1FNType()">;
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
BuildableType<"$_builder.getFloat6E2M3FNType()">;
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
BuildableType<"$_builder.getFloat8E8M0FNUType()">;
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;

def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;
Expand Down
36 changes: 18 additions & 18 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
return builder.getFloat4E2M1FNType();
return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
return builder.getFloat6E2M3FNType();
return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getFloat6E3M2FNType();
return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
return builder.getFloat8E4M3Type();
return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getFloat8E5M2FNUZType();
return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
return builder.getFloat8E4M3B11FNUZType();
return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
return builder.getFloat8E3M4Type();
return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
return builder.getFloat8E8M0FNUType();
return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
return builder.getTF32Type();
return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getF32Type();
return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getF80Type();
return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getF128Type();
return builder.getType<Float128Type>();

// index-type
case Token::kw_index:
Expand Down
32 changes: 16 additions & 16 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
.Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
.Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
.Case("f8E5M2", b.getType<Float8E5M2Type>())
.Case("f8E4M3", b.getType<Float8E4M3Type>())
.Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
.Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
.Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
.Case("f8E3M4", b.getType<Float8E3M4Type>())
.Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
.Case("bf16", b.getType<BFloat16Type>())
.Case("f16", b.getType<Float16Type>())
.Case("f32", b.getType<Float32Type>())
.Case("f64", b.getType<Float64Type>())
.Case("f80", b.getType<Float80Type>())
.Case("f128", b.getType<Float128Type>())
.Default(std::nullopt);
}

Expand Down
38 changes: 0 additions & 38 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//

FloatType Builder::getFloat4E2M1FNType() {
return Float4E2M1FNType::get(context);
}

FloatType Builder::getFloat6E2M3FNType() {
return Float6E2M3FNType::get(context);
}

FloatType Builder::getFloat6E3M2FNType() {
return Float6E3M2FNType::get(context);
}

FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }

FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }

FloatType Builder::getFloat8E4M3FNType() {
return Float8E4M3FNType::get(context);
}

FloatType Builder::getFloat8E5M2FNUZType() {
return Float8E5M2FNUZType::get(context);
}

FloatType Builder::getFloat8E4M3FNUZType() {
return Float8E4M3FNUZType::get(context);
}

FloatType Builder::getFloat8E4M3B11FNUZType() {
return Float8E4M3B11FNUZType::get(context);
}

FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }

FloatType Builder::getFloat8E8M0FNUType() {
return Float8E8M0FNUType::get(context);
}

FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }

FloatType Builder::getF16Type() { return Float16Type::get(context); }
Expand Down
55 changes: 0 additions & 55 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,6 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;

/// Cached Type Instances.
Float4E2M1FNType f4E2M1FNTy;
Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
Float8E3M4Type f8E3M4Ty;
Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
Expand Down Expand Up @@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)

//// Types.
/// Floating-point Types.
impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
Expand Down Expand Up @@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }

Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
return context->getImpl().f4E2M1FNTy;
}
Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
return context->getImpl().f6E2M3FNTy;
}
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
return context->getImpl().f6E3M2FNTy;
}
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
return context->getImpl().f8E4M3Ty;
}
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E5M2FNUZTy;
}
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3B11FNUZTy;
}
Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
return context->getImpl().f8E3M4Ty;
}
Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
return context->getImpl().f8E8M0FNUTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
Expand Down

0 comments on commit f494346

Please sign in to comment.