-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Parser] Add nan
and inf
keywords
#116176
base: users/matthias-springer/parse_fp_dedup
Are you sure you want to change the base?
[mlir][Parser] Add nan
and inf
keywords
#116176
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd two new keywords for parsing Note: The printer still prints Depends on #116172. Full diff: https://github.com/llvm/llvm-project/pull/116176.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
static bool semanticsHasZero(const fltSemantics &);
static bool semanticsHasSignedRepr(const fltSemantics &);
+ static bool semanticsHasInf(const fltSemantics &);
+ static bool semanticsHasNan(const fltSemantics &);
// Returns true if any number described by \p Src can be precisely represented
// by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
return semantics.hasSignedRepr;
}
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+ && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
const fltSemantics &Dst) {
// Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ const Token tok = getToken();
+ consumeToken();
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
return (emitError("floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+ auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ if (failed(result))
+ return Attribute();
+ return FloatAttr::get(floatType, *result);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
+ // Check for inf keyword.
+ if (tok.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(semantics)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(semantics, isNegative);
+ }
+
+ // Check for NaN keyword.
+ if (tok.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(semantics)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(semantics, isNegative);
+ }
+
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
- if (!val)
- return emitError() << "floating point value too large";
+ if (!val) {
+ emitError() << "floating point value too large";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO &&
+ !APFloat::semanticsHasZero(semantics)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
APFloat result(isNegative ? -*val : *val);
bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
+ %nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
+ "test.float_attrs"() {
+ // Note: nan/inf are printed in binary format because there may be multiple
+ // nan/inf representations.
+ // CHECK: float_attr = 0x7FC00000 : f32
+ float_attr = nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7C : f8E4M3
+ float_attr = nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFFC00000 : f32
+ float_attr = -nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFC : f8E4M3
+ float_attr = -nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7F800000 : f32
+ float_attr = inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x78 : f8E4M3
+ float_attr = inf : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFF800000 : f32
+ float_attr = -inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xF8 : f8E4M3
+ float_attr = -inf : f8E4M3
+ } : () -> ()
return
}
+// -----
+
+func.func @float_nan_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support NaN}}
+ float_attr = nan : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support infinity}}
+ float_attr = inf : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @tanh_f32(%nan) : (f32) -> ()
return
@@ -87,15 +87,15 @@ func.func @log() {
call @log_f32(%zero) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @log_f32(%nan) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 0.693147
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -141,11 +141,11 @@ func.func @log2() {
call @log2_f32(%neg_one) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log2_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 1.58496
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -192,11 +192,11 @@ func.func @log1p() {
call @log1p_f32(%neg_two) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log1p_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 9.99995e-06
- %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+ %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -247,7 +247,7 @@ func.func @erf() {
call @erf_f32(%val7) : (f32) -> ()
// CHECK: -1
- %negativeInf = arith.constant 0xff800000 : f32
+ %negativeInf = arith.constant -inf : f32
call @erf_f32(%negativeInf) : (f32) -> ()
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 1
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @erf_f32(%inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @erf_f32(%nan) : (f32) -> ()
return
@@ -306,15 +306,15 @@ func.func @exp() {
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @exp_f32(%inf) : (f32) -> ()
// CHECK: 0
- %negative_inf = arith.constant 0xff800000 : f32
+ %negative_inf = arith.constant -inf : f32
call @exp_f32(%negative_inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @exp_f32(%nan) : (f32) -> ()
return
@@ -358,19 +358,19 @@ func.func @expm1() {
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: -1
- %neg_inf = arith.constant 0xff800000 : f32
+ %neg_inf = arith.constant -inf : f32
call @expm1_f32(%neg_inf) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @expm1_f32(%inf) : (f32) -> ()
// CHECK: -1, inf, 1e-10
- %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @expm1_f32(%nan) : (f32) -> ()
return
|
@llvm/pr-subscribers-llvm-adt Author: Matthias Springer (matthias-springer) ChangesAdd two new keywords for parsing Note: The printer still prints Depends on #116172. Full diff: https://github.com/llvm/llvm-project/pull/116176.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
static bool semanticsHasZero(const fltSemantics &);
static bool semanticsHasSignedRepr(const fltSemantics &);
+ static bool semanticsHasInf(const fltSemantics &);
+ static bool semanticsHasNan(const fltSemantics &);
// Returns true if any number described by \p Src can be precisely represented
// by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
return semantics.hasSignedRepr;
}
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+ && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
const fltSemantics &Dst) {
// Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ const Token tok = getToken();
+ consumeToken();
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
return (emitError("floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+ auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ if (failed(result))
+ return Attribute();
+ return FloatAttr::get(floatType, *result);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
+ // Check for inf keyword.
+ if (tok.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(semantics)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(semantics, isNegative);
+ }
+
+ // Check for NaN keyword.
+ if (tok.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(semantics)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(semantics, isNegative);
+ }
+
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
- if (!val)
- return emitError() << "floating point value too large";
+ if (!val) {
+ emitError() << "floating point value too large";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO &&
+ !APFloat::semanticsHasZero(semantics)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
APFloat result(isNegative ? -*val : *val);
bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
+ %nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
+ "test.float_attrs"() {
+ // Note: nan/inf are printed in binary format because there may be multiple
+ // nan/inf representations.
+ // CHECK: float_attr = 0x7FC00000 : f32
+ float_attr = nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7C : f8E4M3
+ float_attr = nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFFC00000 : f32
+ float_attr = -nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFC : f8E4M3
+ float_attr = -nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7F800000 : f32
+ float_attr = inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x78 : f8E4M3
+ float_attr = inf : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFF800000 : f32
+ float_attr = -inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xF8 : f8E4M3
+ float_attr = -inf : f8E4M3
+ } : () -> ()
return
}
+// -----
+
+func.func @float_nan_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support NaN}}
+ float_attr = nan : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support infinity}}
+ float_attr = inf : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @tanh_f32(%nan) : (f32) -> ()
return
@@ -87,15 +87,15 @@ func.func @log() {
call @log_f32(%zero) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @log_f32(%nan) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 0.693147
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -141,11 +141,11 @@ func.func @log2() {
call @log2_f32(%neg_one) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log2_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 1.58496
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -192,11 +192,11 @@ func.func @log1p() {
call @log1p_f32(%neg_two) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log1p_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 9.99995e-06
- %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+ %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -247,7 +247,7 @@ func.func @erf() {
call @erf_f32(%val7) : (f32) -> ()
// CHECK: -1
- %negativeInf = arith.constant 0xff800000 : f32
+ %negativeInf = arith.constant -inf : f32
call @erf_f32(%negativeInf) : (f32) -> ()
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 1
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @erf_f32(%inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @erf_f32(%nan) : (f32) -> ()
return
@@ -306,15 +306,15 @@ func.func @exp() {
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @exp_f32(%inf) : (f32) -> ()
// CHECK: 0
- %negative_inf = arith.constant 0xff800000 : f32
+ %negative_inf = arith.constant -inf : f32
call @exp_f32(%negative_inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @exp_f32(%nan) : (f32) -> ()
return
@@ -358,19 +358,19 @@ func.func @expm1() {
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: -1
- %neg_inf = arith.constant 0xff800000 : f32
+ %neg_inf = arith.constant -inf : f32
call @expm1_f32(%neg_inf) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @expm1_f32(%inf) : (f32) -> ()
// CHECK: -1, inf, 1e-10
- %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @expm1_f32(%nan) : (f32) -> ()
return
|
You can test this locally with the following command:git-clang-format --diff 51530aeea8c18804034881c87236d1ab5ceb274f 045afe88e53f873cf027ab92af32c120a1d47d63 --extensions cpp,h -- llvm/include/llvm/ADT/APFloat.h llvm/lib/Support/APFloat.cpp mlir/lib/AsmParser/AttributeParser.cpp mlir/lib/AsmParser/Parser.cpp View the diff from clang-format here.diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 8b9d9af2ca..e82071e5b7 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -376,8 +376,8 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
}
bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
- return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
- && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly &&
+ semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
}
bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
|
please separate out the APInt changes and include unit tests in a change just to llvm separate from/before the mlir side of this |
Add two new keywords for parsing
nan
/inf
floating-point literals. This is more convenient that writing the integer hexadecimal bit patterns by hand (which differ depending on the floating-point type).Note: The printer still prints
nan
/inf
literals as integer hexadecimals. That's because there can be multiplenan
/inf
bit patterns. When parsing anan
/inf
keyword, the exact bit pattern is unspecified: we use whateverAPFloat::getInf
/APFloat::getNaN
returns.TODO: Update more test cases.
Depends on #116172.