Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Parser] Deduplicate floating-point parsing functionality #116172

Open
wants to merge 1 commit into
base: users/matthias-springer/parse_fp_int_lit
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 14, 2024

The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function.

NFC apart from the slightly changed error messages. (We now print the exact same error messages regardless of whether the float is parsed standalone or inside of a tensor literal, etc.)

Depends on #116171.

@matthias-springer matthias-springer changed the title [mlir][Parser] Deduplicate fp parsing functionality [mlir][Parser] Deduplicate floating-point parsing functionality Nov 14, 2024
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 14, 2024
@llvmbot
Copy link

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function.

NFC apart from the slightly changed error messages.

Depends on #116171.


Full diff: https://github.com/llvm/llvm-project/pull/116172.diff

5 Files Affected:

  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+7-26)
  • (modified) mlir/lib/AsmParser/AttributeParser.cpp (+14-57)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+23)
  • (modified) mlir/lib/AsmParser/Parser.h (+6)
  • (modified) mlir/test/IR/invalid-builtin-attributes.mlir (+6-4)
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1e6cbc0ec51beb..bbd70d5980f8fe 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
     auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
-
-    // Check for a floating point value.
-    if (curTok.is(Token::floatliteral)) {
-      auto val = curTok.getFloatingPointValue();
-      if (!val)
-        return emitErrorAtTok() << "floating point value too large";
-      parser.consumeToken(Token::floatliteral);
-      result = APFloat(isNegative ? -*val : *val);
-      bool losesInfo;
-      result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
-      return success();
-    }
-
-    // Check for a hexadecimal float value.
-    if (curTok.is(Token::integer)) {
-      FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, curTok, isNegative, semantics);
-      if (failed(apResult))
-        return failure();
-
-      result = *apResult;
-      parser.consumeToken(Token::integer);
-      return success();
-    }
-
-    return emitErrorAtTok() << "expected floating point literal";
+    FailureOr<APFloat> apResult =
+        parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+    if (failed(apResult))
+      return failure();
+    parser.consumeToken();
+    result = *apResult;
+    return success();
   }
 
   /// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ba9be3b030453a..9ebada076cd042 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
-
-    // Handle hexadecimal float literals.
-    if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
-      auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-      FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
-      if (failed(result))
-        return failure();
-
-      floatValues.push_back(*result);
-      continue;
-    }
-
-    // Check to see if any decimal integers or booleans were parsed.
-    if (!token.is(Token::floatliteral))
-      return p.emitError()
-             << "expected floating-point elements, but parsed integer";
-
-    // Build the float values from tokens.
-    auto val = token.getFloatingPointValue();
-    if (!val)
-      return p.emitError("floating point value too large for attribute");
-
-    APFloat apVal(isNegative ? -*val : *val);
-    if (!eltTy.isF64()) {
-      bool unused;
-      apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
-                    &unused);
-    }
-    floatValues.push_back(apVal);
+    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+    FailureOr<APFloat> result = parseFloatFromLiteral(
+        emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+    if (failed(result))
+      return failure();
+    floatValues.push_back(*result);
   }
   return success();
 }
@@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
 
 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   bool isNegative = p.consumeIf(Token::minus);
-
   Token token = p.getToken();
-  std::optional<APFloat> result;
-  auto floatType = cast<FloatType>(type);
-  if (p.consumeIf(Token::integer)) {
-    // Parse an integer literal as a float.
-    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-    FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
-        emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
-    if (failed(fromIntLit))
-      return failure();
-    result = *fromIntLit;
-  } else if (p.consumeIf(Token::floatliteral)) {
-    // Parse a floating point literal.
-    std::optional<double> val = token.getFloatingPointValue();
-    if (!val)
-      return failure();
-    result = APFloat(isNegative ? -*val : *val);
-    if (!type.isF64()) {
-      bool unused;
-      result->convert(floatType.getFloatSemantics(),
-                      APFloat::rmNearestTiesToEven, &unused);
-    }
-  } else {
-    return p.emitError("expected integer or floating point literal");
-  }
-
-  append(result->bitcastToAPInt());
+  auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+  FailureOr<APFloat> fromIntLit =
+      parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
+                            cast<FloatType>(type).getFloatSemantics());
+  if (failed(fromIntLit))
+    return failure();
+  p.consumeToken();
+  append(fromIntLit->bitcastToAPInt());
   return success();
 }
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ac7eec931b1250..15f3dd7a66c358 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
   return APFloat(semantics, truncatedValue);
 }
 
+FailureOr<APFloat>
+detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                              const Token &tok, bool isNegative,
+                              const llvm::fltSemantics &semantics) {
+  // Check for a floating point value.
+  if (tok.is(Token::floatliteral)) {
+    auto val = tok.getFloatingPointValue();
+    if (!val)
+      return emitError() << "floating point value too large";
+
+    APFloat result(isNegative ? -*val : *val);
+    bool unused;
+    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+    return result;
+  }
+
+  // Check for a hexadecimal float value.
+  if (tok.is(Token::integer))
+    return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
+
+  return emitError() << "expected floating point literal";
+}
+
 //===----------------------------------------------------------------------===//
 // CodeComplete
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index fa29264ffe506a..ab445476a91923 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
                              const Token &tok, bool isNegative,
                              const llvm::fltSemantics &semantics);
 
+/// Parse a floating point value from a literal.
+FailureOr<APFloat>
+parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                      const Token &tok, bool isNegative,
+                      const llvm::fltSemantics &semantics);
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 431c7b12b8f5fe..5098fe751fd01f 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
 // -----
 
 func.func @elementsattr_floattype2() -> () {
-  // expected-error@+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error@below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note@below {{add a trailing dot to make the literal a float}}
   "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
 }
 
@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
 // -----
 
 func.func @float_in_bool_tensor() {
-  // expected-error @+1 {{expected integer elements, but parsed floating-point}}
+  // expected-error@below {{expected integer elements, but parsed floating-point}}
   "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
 }
 
 // -----
 
 func.func @decimal_int_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error@below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note@below {{add a trailing dot to make the literal a float}}
   "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
 }
 
 // -----
 
 func.func @bool_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error @+1 {{expected floating point literal}}
   "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
 }
 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants