Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Parser] Add nan and inf keywords #116176

Open
wants to merge 1 commit into
base: users/matthias-springer/parse_fp_dedup
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 14, 2024

Add two new keywords for parsing nan / inf floating-point literals. This is more convenient that writing the integer hexadecimal bit patterns by hand (which differ depending on the floating-point type).

Note: The printer still prints nan / inf literals as integer hexadecimals. That's because there can be multiple nan / inf bit patterns. When parsing a nan / inf keyword, the exact bit pattern is unspecified: we use whatever APFloat::getInf/APFloat::getNaN returns.

TODO: Update more test cases.

Depends on #116172.

@llvmbot
Copy link

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add two new keywords for parsing nan / inf floating-point literals. This is more convenient that writing the integer hexadecimal bit patterns by hand (which differ depending on the floating-point type).

Note: The printer still prints nan / inf literals as integer hexadecimals. That's because there can be multiple nan / inf bit patterns. When parsing a nan / inf keyword, the exact bit pattern is unspecified: we use whatever APFloat::getInf/APFloat::getNaN returns.

Depends on #116172.


Full diff: https://github.com/llvm/llvm-project/pull/116176.diff

8 Files Affected:

  • (modified) llvm/include/llvm/ADT/APFloat.h (+2)
  • (modified) llvm/lib/Support/APFloat.cpp (+9)
  • (modified) mlir/lib/AsmParser/AttributeParser.cpp (+21-9)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+27-2)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+2)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+5-5)
  • (modified) mlir/test/IR/attribute.mlir (+54)
  • (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+18-18)
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
   static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
   static bool semanticsHasZero(const fltSemantics &);
   static bool semanticsHasSignedRepr(const fltSemantics &);
+  static bool semanticsHasInf(const fltSemantics &);
+  static bool semanticsHasNan(const fltSemantics &);
 
   // Returns true if any number described by \p Src can be precisely represented
   // by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
   return semantics.hasSignedRepr;
 }
 
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+      && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
 bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
                                             const fltSemantics &Dst) {
   // Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Endian.h"
+#include <cmath>
 #include <optional>
 
 using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse floating point and integer attributes.
   case Token::floatliteral:
+  case Token::kw_inf:
+  case Token::kw_nan:
     return parseFloatAttr(type, /*isNegative=*/false);
   case Token::integer:
     return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
     consumeToken(Token::minus);
     if (getToken().is(Token::integer))
       return parseDecOrHexAttr(type, /*isNegative=*/true);
-    if (getToken().is(Token::floatliteral))
+    if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+        getToken().is(Token::kw_nan))
       return parseFloatAttr(type, /*isNegative=*/true);
 
     return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
 
 /// Parse a float attribute.
 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
-  auto val = getToken().getFloatingPointValue();
-  if (!val)
-    return (emitError("floating point value too large for attribute"), nullptr);
-  consumeToken(Token::floatliteral);
+  const Token tok = getToken();
+  consumeToken();
   if (!type) {
     // Default to F64 when no type is specified.
     if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
     else if (!(type = parseType()))
       return nullptr;
   }
-  if (!isa<FloatType>(type))
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
     return (emitError("floating point value not valid for specified type"),
             nullptr);
-  return FloatAttr::get(type, isNegative ? -*val : *val);
+  auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+  FailureOr<APFloat> result = parseFloatFromLiteral(
+      emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+  if (failed(result))
+    return Attribute();
+  return FloatAttr::get(floatType, *result);
 }
 
 /// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
     }
 
     // Check to see if floating point values were parsed.
-    if (token.is(Token::floatliteral)) {
+    if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
       return p.emitError(tokenLoc)
              << "expected integer elements, but parsed floating-point";
     }
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a boolean element.
   case Token::kw_true:
   case Token::kw_false:
+  case Token::kw_inf:
+  case Token::kw_nan:
   case Token::floatliteral:
   case Token::integer:
     storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a signed integer or a negative floating-point element.
   case Token::minus:
     p.consumeToken(Token::minus);
-    if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+    if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+                            Token::integer))
       return p.emitError("expected integer or floating point literal");
     storage.emplace_back(/*isNegative=*/true, p.getToken());
     p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
 detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
                               const Token &tok, bool isNegative,
                               const llvm::fltSemantics &semantics) {
+  // Check for inf keyword.
+  if (tok.is(Token::kw_inf)) {
+    if (!APFloat::semanticsHasInf(semantics)) {
+      emitError() << "floating point type does not support infinity";
+      return failure();
+    }
+    return APFloat::getInf(semantics, isNegative);
+  }
+
+  // Check for NaN keyword.
+  if (tok.is(Token::kw_nan)) {
+    if (!APFloat::semanticsHasNan(semantics)) {
+      emitError() << "floating point type does not support NaN";
+      return failure();
+    }
+    return APFloat::getNaN(semantics, isNegative);
+  }
+
   // Check for a floating point value.
   if (tok.is(Token::floatliteral)) {
     auto val = tok.getFloatingPointValue();
-    if (!val)
-      return emitError() << "floating point value too large";
+    if (!val) {
+      emitError() << "floating point value too large";
+      return failure();
+    }
+    if (std::fpclassify(*val) == FP_ZERO &&
+        !APFloat::semanticsHasZero(semantics)) {
+      emitError() << "floating point type does not support zero";
+      return failure();
+    }
 
     APFloat result(isNegative ? -*val : *val);
     bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
 TOK_KEYWORD(for)
 TOK_KEYWORD(func)
 TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
 TOK_KEYWORD(loc)
 TOK_KEYWORD(max)
 TOK_KEYWORD(memref)
 TOK_KEYWORD(min)
 TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
 TOK_KEYWORD(none)
 TOK_KEYWORD(offset)
 TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minimumf %c0, %arg0 : f32
   %1 = arith.minimumf %arg0, %arg0 : f32
   %2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maximumf %c0, %arg0 : f32
   %1 = arith.maximumf %arg0, %arg0 : f32
   %2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true
 //   CHECK-DAG:   %[[F:.*]] = arith.constant false
 //       CHECK:   return %[[F]], %[[F]], %[[T]], %[[T]]
-  %nan = arith.constant 0x7fffffff : f32
+  %nan = arith.constant nan : f32
   %0 = arith.cmpf olt, %nan, %arg0 : f32
   %1 = arith.cmpf olt, %arg0, %nan : f32
   %2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f128
     float_attr = 2. : f128
   } : () -> ()
+  "test.float_attrs"() {
+    // Note: nan/inf are printed in binary format because there may be multiple
+    // nan/inf representations.
+    // CHECK: float_attr = 0x7FC00000 : f32
+    float_attr = nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7C : f8E4M3
+    float_attr = nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFFC00000 : f32
+    float_attr = -nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFC : f8E4M3
+    float_attr = -nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7F800000 : f32
+    float_attr = inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x78 : f8E4M3
+    float_attr = inf : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFF800000 : f32
+    float_attr = -inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xF8 : f8E4M3
+    float_attr = -inf : f8E4M3
+  } : () -> ()
   return
 }
 
+// -----
+
+func.func @float_nan_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support NaN}}
+    float_attr = nan : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support infinity}}
+    float_attr = inf : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
   call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @tanh_f32(%nan) : (f32) -> ()
 
  return
@@ -87,15 +87,15 @@ func.func @log() {
   call @log_f32(%zero) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @log_f32(%nan) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 0.693147
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
   call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -141,11 +141,11 @@ func.func @log2() {
   call @log2_f32(%neg_one) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log2_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 1.58496
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
   call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -192,11 +192,11 @@ func.func @log1p() {
   call @log1p_f32(%neg_two) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log1p_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 9.99995e-06
-  %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+  %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
   call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -247,7 +247,7 @@ func.func @erf() {
   call @erf_f32(%val7) : (f32) -> ()
 
   // CHECK: -1
-  %negativeInf = arith.constant 0xff800000 : f32
+  %negativeInf = arith.constant -inf : f32
   call @erf_f32(%negativeInf) : (f32) -> ()
 
   // CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
   call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
 
   // CHECK: 1
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @erf_f32(%inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @erf_f32(%nan) : (f32) -> ()
 
   return
@@ -306,15 +306,15 @@ func.func @exp() {
   call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @exp_f32(%inf) : (f32) -> ()
 
   // CHECK: 0
-  %negative_inf = arith.constant 0xff800000 : f32
+  %negative_inf = arith.constant -inf : f32
   call @exp_f32(%negative_inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @exp_f32(%nan) : (f32) -> ()
 
   return
@@ -358,19 +358,19 @@ func.func @expm1() {
   call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: -1
-  %neg_inf = arith.constant 0xff800000 : f32
+  %neg_inf = arith.constant -inf : f32
   call @expm1_f32(%neg_inf) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @expm1_f32(%inf) : (f32) -> ()
 
   // CHECK: -1, inf, 1e-10
-  %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+  %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
   call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @expm1_f32(%nan) : (f32) -> ()
 
   return

@llvmbot
Copy link

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-llvm-adt

Author: Matthias Springer (matthias-springer)

Changes

Add two new keywords for parsing nan / inf floating-point literals. This is more convenient that writing the integer hexadecimal bit patterns by hand (which differ depending on the floating-point type).

Note: The printer still prints nan / inf literals as integer hexadecimals. That's because there can be multiple nan / inf bit patterns. When parsing a nan / inf keyword, the exact bit pattern is unspecified: we use whatever APFloat::getInf/APFloat::getNaN returns.

Depends on #116172.


Full diff: https://github.com/llvm/llvm-project/pull/116176.diff

8 Files Affected:

  • (modified) llvm/include/llvm/ADT/APFloat.h (+2)
  • (modified) llvm/lib/Support/APFloat.cpp (+9)
  • (modified) mlir/lib/AsmParser/AttributeParser.cpp (+21-9)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+27-2)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+2)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+5-5)
  • (modified) mlir/test/IR/attribute.mlir (+54)
  • (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+18-18)
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 97547fb577e0ec..40ad7ba92552ed 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -311,6 +311,8 @@ struct APFloatBase {
   static unsigned int semanticsIntSizeInBits(const fltSemantics&, bool);
   static bool semanticsHasZero(const fltSemantics &);
   static bool semanticsHasSignedRepr(const fltSemantics &);
+  static bool semanticsHasInf(const fltSemantics &);
+  static bool semanticsHasNan(const fltSemantics &);
 
   // Returns true if any number described by \p Src can be precisely represented
   // by a normal (not subnormal) value in \p Dst.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index c566d489d11b03..8b9d9af2ca65b3 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -375,6 +375,15 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
   return semantics.hasSignedRepr;
 }
 
+bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
+  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
+      && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
+bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {
+  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+}
+
 bool APFloatBase::isRepresentableAsNormalIn(const fltSemantics &Src,
                                             const fltSemantics &Dst) {
   // Exponent range must be larger.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..68da950f09e568 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Endian.h"
+#include <cmath>
 #include <optional>
 
 using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse floating point and integer attributes.
   case Token::floatliteral:
+  case Token::kw_inf:
+  case Token::kw_nan:
     return parseFloatAttr(type, /*isNegative=*/false);
   case Token::integer:
     return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
     consumeToken(Token::minus);
     if (getToken().is(Token::integer))
       return parseDecOrHexAttr(type, /*isNegative=*/true);
-    if (getToken().is(Token::floatliteral))
+    if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+        getToken().is(Token::kw_nan))
       return parseFloatAttr(type, /*isNegative=*/true);
 
     return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
 
 /// Parse a float attribute.
 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
-  auto val = getToken().getFloatingPointValue();
-  if (!val)
-    return (emitError("floating point value too large for attribute"), nullptr);
-  consumeToken(Token::floatliteral);
+  const Token tok = getToken();
+  consumeToken();
   if (!type) {
     // Default to F64 when no type is specified.
     if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
     else if (!(type = parseType()))
       return nullptr;
   }
-  if (!isa<FloatType>(type))
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
     return (emitError("floating point value not valid for specified type"),
             nullptr);
-  return FloatAttr::get(type, isNegative ? -*val : *val);
+  auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+  FailureOr<APFloat> result = parseFloatFromLiteral(
+      emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+  if (failed(result))
+    return Attribute();
+  return FloatAttr::get(floatType, *result);
 }
 
 /// Construct an APint from a parsed value, a known attribute type and
@@ -623,7 +632,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
     }
 
     // Check to see if floating point values were parsed.
-    if (token.is(Token::floatliteral)) {
+    if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
       return p.emitError(tokenLoc)
              << "expected integer elements, but parsed floating-point";
     }
@@ -731,6 +740,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a boolean element.
   case Token::kw_true:
   case Token::kw_false:
+  case Token::kw_inf:
+  case Token::kw_nan:
   case Token::floatliteral:
   case Token::integer:
     storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -740,7 +751,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a signed integer or a negative floating-point element.
   case Token::minus:
     p.consumeToken(Token::minus);
-    if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+    if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+                            Token::integer))
       return p.emitError("expected integer or floating point literal");
     storage.emplace_back(/*isNegative=*/true, p.getToken());
     p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..6ff43b1749b64e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -103,11 +103,36 @@ FailureOr<APFloat>
 detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
                               const Token &tok, bool isNegative,
                               const llvm::fltSemantics &semantics) {
+  // Check for inf keyword.
+  if (tok.is(Token::kw_inf)) {
+    if (!APFloat::semanticsHasInf(semantics)) {
+      emitError() << "floating point type does not support infinity";
+      return failure();
+    }
+    return APFloat::getInf(semantics, isNegative);
+  }
+
+  // Check for NaN keyword.
+  if (tok.is(Token::kw_nan)) {
+    if (!APFloat::semanticsHasNan(semantics)) {
+      emitError() << "floating point type does not support NaN";
+      return failure();
+    }
+    return APFloat::getNaN(semantics, isNegative);
+  }
+
   // Check for a floating point value.
   if (tok.is(Token::floatliteral)) {
     auto val = tok.getFloatingPointValue();
-    if (!val)
-      return emitError() << "floating point value too large";
+    if (!val) {
+      emitError() << "floating point value too large";
+      return failure();
+    }
+    if (std::fpclassify(*val) == FP_ZERO &&
+        !APFloat::semanticsHasZero(semantics)) {
+      emitError() << "floating point type does not support zero";
+      return failure();
+    }
 
     APFloat result(isNegative ? -*val : *val);
     bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
 TOK_KEYWORD(for)
 TOK_KEYWORD(func)
 TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
 TOK_KEYWORD(loc)
 TOK_KEYWORD(max)
 TOK_KEYWORD(memref)
 TOK_KEYWORD(min)
 TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
 TOK_KEYWORD(none)
 TOK_KEYWORD(offset)
 TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minimumf %c0, %arg0 : f32
   %1 = arith.minimumf %arg0, %arg0 : f32
   %2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maximumf %c0, %arg0 : f32
   %1 = arith.maximumf %arg0, %arg0 : f32
   %2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true
 //   CHECK-DAG:   %[[F:.*]] = arith.constant false
 //       CHECK:   return %[[F]], %[[F]], %[[T]], %[[T]]
-  %nan = arith.constant 0x7fffffff : f32
+  %nan = arith.constant nan : f32
   %0 = arith.cmpf olt, %nan, %arg0 : f32
   %1 = arith.cmpf olt, %arg0, %nan : f32
   %2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f128
     float_attr = 2. : f128
   } : () -> ()
+  "test.float_attrs"() {
+    // Note: nan/inf are printed in binary format because there may be multiple
+    // nan/inf representations.
+    // CHECK: float_attr = 0x7FC00000 : f32
+    float_attr = nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7C : f8E4M3
+    float_attr = nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFFC00000 : f32
+    float_attr = -nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFC : f8E4M3
+    float_attr = -nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7F800000 : f32
+    float_attr = inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x78 : f8E4M3
+    float_attr = inf : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFF800000 : f32
+    float_attr = -inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xF8 : f8E4M3
+    float_attr = -inf : f8E4M3
+  } : () -> ()
   return
 }
 
+// -----
+
+func.func @float_nan_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support NaN}}
+    float_attr = nan : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support infinity}}
+    float_attr = inf : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
   call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @tanh_f32(%nan) : (f32) -> ()
 
  return
@@ -87,15 +87,15 @@ func.func @log() {
   call @log_f32(%zero) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @log_f32(%nan) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 0.693147
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
   call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -141,11 +141,11 @@ func.func @log2() {
   call @log2_f32(%neg_one) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log2_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 1.58496
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
   call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -192,11 +192,11 @@ func.func @log1p() {
   call @log1p_f32(%neg_two) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log1p_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 9.99995e-06
-  %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+  %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
   call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -247,7 +247,7 @@ func.func @erf() {
   call @erf_f32(%val7) : (f32) -> ()
 
   // CHECK: -1
-  %negativeInf = arith.constant 0xff800000 : f32
+  %negativeInf = arith.constant -inf : f32
   call @erf_f32(%negativeInf) : (f32) -> ()
 
   // CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
   call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
 
   // CHECK: 1
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @erf_f32(%inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @erf_f32(%nan) : (f32) -> ()
 
   return
@@ -306,15 +306,15 @@ func.func @exp() {
   call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @exp_f32(%inf) : (f32) -> ()
 
   // CHECK: 0
-  %negative_inf = arith.constant 0xff800000 : f32
+  %negative_inf = arith.constant -inf : f32
   call @exp_f32(%negative_inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @exp_f32(%nan) : (f32) -> ()
 
   return
@@ -358,19 +358,19 @@ func.func @expm1() {
   call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: -1
-  %neg_inf = arith.constant 0xff800000 : f32
+  %neg_inf = arith.constant -inf : f32
   call @expm1_f32(%neg_inf) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @expm1_f32(%inf) : (f32) -> ()
 
   // CHECK: -1, inf, 1e-10
-  %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+  %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
   call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @expm1_f32(%nan) : (f32) -> ()
 
   return

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 51530aeea8c18804034881c87236d1ab5ceb274f 045afe88e53f873cf027ab92af32c120a1d47d63 --extensions cpp,h -- llvm/include/llvm/ADT/APFloat.h llvm/lib/Support/APFloat.cpp mlir/lib/AsmParser/AttributeParser.cpp mlir/lib/AsmParser/Parser.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 8b9d9af2ca..e82071e5b7 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -376,8 +376,8 @@ bool APFloatBase::semanticsHasSignedRepr(const fltSemantics &semantics) {
 }
 
 bool APFloatBase::semanticsHasInf(const fltSemantics &semantics) {
-  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly
-      && semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
+  return semantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly &&
+         semantics.nonFiniteBehavior != fltNonfiniteBehavior::FiniteOnly;
 }
 
 bool APFloatBase::semanticsHasNan(const fltSemantics &semantics) {

@dwblaikie
Copy link
Collaborator

please separate out the APInt changes and include unit tests in a change just to llvm separate from/before the mlir side of this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants