-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Parser] Add nan
and inf
keywords
#116176
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd two new keywords for parsing Note: The printer still prints Depends on #116172. Full diff: https://github.com/llvm/llvm-project/pull/116176.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
static bool semanticsHasZero(const fltSemantics &);
static bool semanticsHasSignedRepr(const fltSemantics &);
+ static bool semanticsHasInf(const fltSemantics &);
+ static bool semanticsHasNan(const fltSemantics &);
// Returns true if any number described by \p Src can be precisely represented
// by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
return semantics.hasSignedRepr;
}
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+ && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
const fltSemantics &Dst) {
// Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ const Token tok = getToken();
+ consumeToken();
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
return (emitError("floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+ auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ if (failed(result))
+ return Attribute();
+ return FloatAttr::get(floatType, *result);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
+ // Check for inf keyword.
+ if (tok.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(semantics)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(semantics, isNegative);
+ }
+
+ // Check for NaN keyword.
+ if (tok.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(semantics)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(semantics, isNegative);
+ }
+
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
- if (!val)
- return emitError() << "floating point value too large";
+ if (!val) {
+ emitError() << "floating point value too large";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO &&
+ !APFloat::semanticsHasZero(semantics)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
APFloat result(isNegative ? -*val : *val);
bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
+ %nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
+ "test.float_attrs"() {
+ // Note: nan/inf are printed in binary format because there may be multiple
+ // nan/inf representations.
+ // CHECK: float_attr = 0x7FC00000 : f32
+ float_attr = nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7C : f8E4M3
+ float_attr = nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFFC00000 : f32
+ float_attr = -nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFC : f8E4M3
+ float_attr = -nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7F800000 : f32
+ float_attr = inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x78 : f8E4M3
+ float_attr = inf : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFF800000 : f32
+ float_attr = -inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xF8 : f8E4M3
+ float_attr = -inf : f8E4M3
+ } : () -> ()
return
}
+// -----
+
+func.func @float_nan_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support NaN}}
+ float_attr = nan : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support infinity}}
+ float_attr = inf : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @tanh_f32(%nan) : (f32) -> ()
return
@@ -87,15 +87,15 @@ func.func @log() {
call @log_f32(%zero) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @log_f32(%nan) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 0.693147
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -141,11 +141,11 @@ func.func @log2() {
call @log2_f32(%neg_one) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log2_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 1.58496
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -192,11 +192,11 @@ func.func @log1p() {
call @log1p_f32(%neg_two) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log1p_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 9.99995e-06
- %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+ %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -247,7 +247,7 @@ func.func @erf() {
call @erf_f32(%val7) : (f32) -> ()
// CHECK: -1
- %negativeInf = arith.constant 0xff800000 : f32
+ %negativeInf = arith.constant -inf : f32
call @erf_f32(%negativeInf) : (f32) -> ()
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 1
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @erf_f32(%inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @erf_f32(%nan) : (f32) -> ()
return
@@ -306,15 +306,15 @@ func.func @exp() {
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @exp_f32(%inf) : (f32) -> ()
// CHECK: 0
- %negative_inf = arith.constant 0xff800000 : f32
+ %negative_inf = arith.constant -inf : f32
call @exp_f32(%negative_inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @exp_f32(%nan) : (f32) -> ()
return
@@ -358,19 +358,19 @@ func.func @expm1() {
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: -1
- %neg_inf = arith.constant 0xff800000 : f32
+ %neg_inf = arith.constant -inf : f32
call @expm1_f32(%neg_inf) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @expm1_f32(%inf) : (f32) -> ()
// CHECK: -1, inf, 1e-10
- %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @expm1_f32(%nan) : (f32) -> ()
return
|
@llvm/pr-subscribers-llvm-adt Author: Matthias Springer (matthias-springer) ChangesAdd two new keywords for parsing Note: The printer still prints Depends on #116172. Full diff: https://github.com/llvm/llvm-project/pull/116176.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
static bool semanticsHasZero(const fltSemantics &);
static bool semanticsHasSignedRepr(const fltSemantics &);
+ static bool semanticsHasInf(const fltSemantics &);
+ static bool semanticsHasNan(const fltSemantics &);
// Returns true if any number described by \p Src can be precisely represented
// by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
return semantics.hasSignedRepr;
}
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+ && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+ return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
const fltSemantics &Dst) {
// Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
+#include <cmath>
#include <optional>
using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
// Parse floating point and integer attributes.
case Token::floatliteral:
+ case Token::kw_inf:
+ case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
+ if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+ getToken().is(Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);
return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val)
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
+ const Token tok = getToken();
+ consumeToken();
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!isa<FloatType>(type))
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
return (emitError("floating point value not valid for specified type"),
nullptr);
- return FloatAttr::get(type, isNegative ? -*val : *val);
+ auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ if (failed(result))
+ return Attribute();
+ return FloatAttr::get(floatType, *result);
}
/// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}
// Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
+ if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
+ case Token::kw_inf:
+ case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+ Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
+ // Check for inf keyword.
+ if (tok.is(Token::kw_inf)) {
+ if (!APFloat::semanticsHasInf(semantics)) {
+ emitError() << "floating point type does not support infinity";
+ return failure();
+ }
+ return APFloat::getInf(semantics, isNegative);
+ }
+
+ // Check for NaN keyword.
+ if (tok.is(Token::kw_nan)) {
+ if (!APFloat::semanticsHasNan(semantics)) {
+ emitError() << "floating point type does not support NaN";
+ return failure();
+ }
+ return APFloat::getNaN(semantics, isNegative);
+ }
+
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
- if (!val)
- return emitError() << "floating point value too large";
+ if (!val) {
+ emitError() << "floating point value too large";
+ return failure();
+ }
+ if (std::fpclassify(*val) == FP_ZERO &&
+ !APFloat::semanticsHasZero(semantics)) {
+ emitError() << "floating point type does not support zero";
+ return failure();
+ }
APFloat result(isNegative ? -*val : *val);
bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %inf = arith.constant 0x7F800000 : f32
+ %inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
- %-inf = arith.constant 0xFF800000 : f32
+ %-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
+ %nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
+ "test.float_attrs"() {
+ // Note: nan/inf are printed in binary format because there may be multiple
+ // nan/inf representations.
+ // CHECK: float_attr = 0x7FC00000 : f32
+ float_attr = nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7C : f8E4M3
+ float_attr = nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFFC00000 : f32
+ float_attr = -nan : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFC : f8E4M3
+ float_attr = -nan : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x7F800000 : f32
+ float_attr = inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0x78 : f8E4M3
+ float_attr = inf : f8E4M3
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xFF800000 : f32
+ float_attr = -inf : f32
+ } : () -> ()
+ "test.float_attrs"() {
+ // CHECK: float_attr = 0xF8 : f8E4M3
+ float_attr = -inf : f8E4M3
+ } : () -> ()
return
}
+// -----
+
+func.func @float_nan_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support NaN}}
+ float_attr = nan : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+ "test.float_attrs"() {
+ // expected-error @below{{floating point type does not support infinity}}
+ float_attr = inf : f4E2M1FN
+ } : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @tanh_f32(%nan) : (f32) -> ()
return
@@ -87,15 +87,15 @@ func.func @log() {
call @log_f32(%zero) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @log_f32(%nan) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 0.693147
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -141,11 +141,11 @@ func.func @log2() {
call @log2_f32(%neg_one) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log2_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 1.58496
- %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+ %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -192,11 +192,11 @@ func.func @log1p() {
call @log1p_f32(%neg_two) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @log1p_f32(%inf) : (f32) -> ()
// CHECK: -inf, nan, inf, 9.99995e-06
- %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+ %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
return
@@ -247,7 +247,7 @@ func.func @erf() {
call @erf_f32(%val7) : (f32) -> ()
// CHECK: -1
- %negativeInf = arith.constant 0xff800000 : f32
+ %negativeInf = arith.constant -inf : f32
call @erf_f32(%negativeInf) : (f32) -> ()
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 1
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @erf_f32(%inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @erf_f32(%nan) : (f32) -> ()
return
@@ -306,15 +306,15 @@ func.func @exp() {
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @exp_f32(%inf) : (f32) -> ()
// CHECK: 0
- %negative_inf = arith.constant 0xff800000 : f32
+ %negative_inf = arith.constant -inf : f32
call @exp_f32(%negative_inf) : (f32) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @exp_f32(%nan) : (f32) -> ()
return
@@ -358,19 +358,19 @@ func.func @expm1() {
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
// CHECK: -1
- %neg_inf = arith.constant 0xff800000 : f32
+ %neg_inf = arith.constant -inf : f32
call @expm1_f32(%neg_inf) : (f32) -> ()
// CHECK: inf
- %inf = arith.constant 0x7f800000 : f32
+ %inf = arith.constant inf : f32
call @expm1_f32(%inf) : (f32) -> ()
// CHECK: -1, inf, 1e-10
- %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+ %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
// CHECK: nan
- %nan = arith.constant 0x7fc00000 : f32
+ %nan = arith.constant nan : f32
call @expm1_f32(%nan) : (f32) -> ()
return
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
please separate out the APInt changes and include unit tests in a change just to llvm separate from/before the mlir side of this |
e140e08
to
1f81180
Compare
8f237ae
to
a361fd4
Compare
86c9eff
to
80171c9
Compare
80171c9
to
fdde14e
Compare
fdde14e
to
ba4cb5c
Compare
I'm kind of iffy on this mostly because I don't like the "unspecified" nature of things like this. That being said, if others find this useful I have a few comments:
|
Do you know why Edit: I just saw that NaN can carry a payload. Which I guess means that these return a well-specified value. Supporting payloads the syntax would probably go to far... |
Add two new keywords for parsing
nan
/inf
floating-point literals. This is more convenient that writing the integer hexadecimal bit patterns by hand (which differ depending on the floating-point type).Note: The printer still prints
nan
/inf
literals as integer hexadecimals. That's because there can be multiplenan
/inf
bit patterns. When parsing anan
/inf
keyword, the exact bit pattern is unspecified: we use whateverAPFloat::getInf
/APFloat::getNaN
returns.