-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Parser] Deduplicate floating-point parsing functionality #116172
base: users/matthias-springer/parse_fp_int_lit
Are you sure you want to change the base?
[mlir][Parser] Deduplicate floating-point parsing functionality #116172
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function. NFC apart from the slightly changed error messages. Depends on #116171. Full diff: https://github.com/llvm/llvm-project/pull/116172.diff 5 Files Affected:
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1e6cbc0ec51beb..bbd70d5980f8fe 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
-
- // Check for a floating point value.
- if (curTok.is(Token::floatliteral)) {
- auto val = curTok.getFloatingPointValue();
- if (!val)
- return emitErrorAtTok() << "floating point value too large";
- parser.consumeToken(Token::floatliteral);
- result = APFloat(isNegative ? -*val : *val);
- bool losesInfo;
- result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
- return success();
- }
-
- // Check for a hexadecimal float value.
- if (curTok.is(Token::integer)) {
- FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
- emitErrorAtTok, curTok, isNegative, semantics);
- if (failed(apResult))
- return failure();
-
- result = *apResult;
- parser.consumeToken(Token::integer);
- return success();
- }
-
- return emitErrorAtTok() << "expected floating point literal";
+ FailureOr<APFloat> apResult =
+ parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+ if (failed(apResult))
+ return failure();
+ parser.consumeToken();
+ result = *apResult;
+ return success();
}
/// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ba9be3b030453a..9ebada076cd042 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
-
- // Handle hexadecimal float literals.
- if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
- emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
- if (failed(result))
- return failure();
-
- floatValues.push_back(*result);
- continue;
- }
-
- // Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral))
- return p.emitError()
- << "expected floating-point elements, but parsed integer";
-
- // Build the float values from tokens.
- auto val = token.getFloatingPointValue();
- if (!val)
- return p.emitError("floating point value too large for attribute");
-
- APFloat apVal(isNegative ? -*val : *val);
- if (!eltTy.isF64()) {
- bool unused;
- apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &unused);
- }
- floatValues.push_back(apVal);
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+ if (failed(result))
+ return failure();
+ floatValues.push_back(*result);
}
return success();
}
@@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
-
Token token = p.getToken();
- std::optional<APFloat> result;
- auto floatType = cast<FloatType>(type);
- if (p.consumeIf(Token::integer)) {
- // Parse an integer literal as a float.
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
- emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
- if (failed(fromIntLit))
- return failure();
- result = *fromIntLit;
- } else if (p.consumeIf(Token::floatliteral)) {
- // Parse a floating point literal.
- std::optional<double> val = token.getFloatingPointValue();
- if (!val)
- return failure();
- result = APFloat(isNegative ? -*val : *val);
- if (!type.isF64()) {
- bool unused;
- result->convert(floatType.getFloatSemantics(),
- APFloat::rmNearestTiesToEven, &unused);
- }
- } else {
- return p.emitError("expected integer or floating point literal");
- }
-
- append(result->bitcastToAPInt());
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> fromIntLit =
+ parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
+ cast<FloatType>(type).getFloatSemantics());
+ if (failed(fromIntLit))
+ return failure();
+ p.consumeToken();
+ append(fromIntLit->bitcastToAPInt());
return success();
}
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ac7eec931b1250..15f3dd7a66c358 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
return APFloat(semantics, truncatedValue);
}
+FailureOr<APFloat>
+detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics) {
+ // Check for a floating point value.
+ if (tok.is(Token::floatliteral)) {
+ auto val = tok.getFloatingPointValue();
+ if (!val)
+ return emitError() << "floating point value too large";
+
+ APFloat result(isNegative ? -*val : *val);
+ bool unused;
+ result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+ return result;
+ }
+
+ // Check for a hexadecimal float value.
+ if (tok.is(Token::integer))
+ return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
+
+ return emitError() << "expected floating point literal";
+}
+
//===----------------------------------------------------------------------===//
// CodeComplete
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index fa29264ffe506a..ab445476a91923 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics);
+/// Parse a floating point value from a literal.
+FailureOr<APFloat>
+parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
+
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 431c7b12b8f5fe..5098fe751fd01f 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
// -----
func.func @elementsattr_floattype2() -> () {
- // expected-error@+1 {{expected floating-point elements, but parsed integer}}
+ // expected-error@below {{unexpected decimal integer literal for a floating point value}}
+ // expected-note@below {{add a trailing dot to make the literal a float}}
"foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
}
@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
// -----
func.func @float_in_bool_tensor() {
- // expected-error @+1 {{expected integer elements, but parsed floating-point}}
+ // expected-error@below {{expected integer elements, but parsed floating-point}}
"foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
}
// -----
func.func @decimal_int_in_float_tensor() {
- // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+ // expected-error@below {{unexpected decimal integer literal for a floating point value}}
+ // expected-note@below {{add a trailing dot to make the literal a float}}
"foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
}
// -----
func.func @bool_in_float_tensor() {
- // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+ // expected-error @+1 {{expected floating point literal}}
"foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
}
|
The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function.
NFC apart from the slightly changed error messages. (We now print the exact same error messages regardless of whether the float is parsed standalone or inside of a tensor literal, etc.)
Depends on #116171.