diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index fa435cb3155ed..ae412c7227f8e 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -714,16 +714,27 @@ class AsmParser { return *parseResult; } + /// Parse a decimal integer value from the stream. + template + ParseResult parseDecimalInteger(IntT &result) { + auto loc = getCurrentLocation(); + OptionalParseResult parseResult = parseOptionalDecimalInteger(result); + if (!parseResult.has_value()) + return emitError(loc, "expected decimal integer value"); + return *parseResult; + } + /// Parse an optional integer value from the stream. virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; + virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0; - template - OptionalParseResult parseOptionalInteger(IntT &result) { + private: + template + OptionalParseResult parseOptionalIntegerAndCheck(IntT &result, + ParseFn &&parseFn) { auto loc = getCurrentLocation(); - - // Parse the unsigned variant. APInt uintResult; - OptionalParseResult parseResult = parseOptionalInteger(uintResult); + OptionalParseResult parseResult = parseFn(uintResult); if (!parseResult.has_value() || failed(*parseResult)) return parseResult; @@ -737,6 +748,20 @@ class AsmParser { return success(); } + public: + template + OptionalParseResult parseOptionalInteger(IntT &result) { + return parseOptionalIntegerAndCheck( + result, [&](APInt &result) { return parseOptionalInteger(result); }); + } + + template + OptionalParseResult parseOptionalDecimalInteger(IntT &result) { + return parseOptionalIntegerAndCheck(result, [&](APInt &result) { + return parseOptionalDecimalInteger(result); + }); + } + /// These are the supported delimiters around operand lists and region /// argument lists, used by parseOperandList. enum class Delimiter { diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 8f22be80865bf..b12687833e3fd 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -322,6 +322,11 @@ class AsmParserImpl : public BaseT { return parser.parseOptionalInteger(result); } + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalDecimalInteger(APInt &result) override { + return parser.parseOptionalDecimalInteger(result); + } + /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 7181f13d3c8bb..2e4c4a36d46b9 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -41,6 +41,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Alignment.h" @@ -307,6 +308,45 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) { return success(); } +/// Parse an optional integer value only in decimal format from the stream. +OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { + Token curToken = getToken(); + if (curToken.isNot(Token::integer, Token::minus)) { + return std::nullopt; + } + + bool negative = consumeIf(Token::minus); + Token curTok = getToken(); + if (parseToken(Token::integer, "expected integer value")) { + return failure(); + } + + StringRef spelling = curTok.getSpelling(); + // If the integer is in hexadecimal return only the 0. The lexer has already + // moved past the entire hexidecimal encoded integer so we reset the lex + // pointer to just past the 0 we actualy want to consume. + if (spelling[0] == '0' && spelling.size() > 1 && + llvm::toLower(spelling[1]) == 'x') { + result = 0; + state.lex.resetPointer(spelling.data() + 1); + consumeToken(); + return success(); + } + + if (spelling.getAsInteger(10, result)) + return emitError(curTok.getLoc(), "integer value too large"); + + // Make sure we have a zero at the top so we return the right signedness. + if (result.isNegative()) + result = result.zext(result.getBitWidth() + 1); + + // Process the negative sign if present. + if (negative) + result.negate(); + + return success(); +} + /// Parse a floating point value from an integer literal token. ParseResult Parser::parseFloatFromIntegerLiteral( std::optional &result, const Token &tok, bool isNegative, diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index b959e67b8e258..4caab499e1a0e 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -144,6 +144,9 @@ class Parser { /// Parse an optional integer value from the stream. OptionalParseResult parseOptionalInteger(APInt &result); + /// Parse an optional integer value only in decimal format from the stream. + OptionalParseResult parseOptionalDecimalInteger(APInt &result); + /// Parse a floating point value from an integer literal token. ParseResult parseFloatFromIntegerLiteral(std::optional &result, const Token &tok, bool isNegative, diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index a0a1cd30ed8ae..8f109f8ce5e6d 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -81,6 +81,15 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> { let mnemonic = "attr_with_trait"; } +// An attribute of a list of decimal formatted integers in similar format to shapes. +def TestDecimalShapeAttr : Test_Attr<"TestDecimalShape"> { + let mnemonic = "decimal_shape"; + + let parameters = (ins ArrayRefParameter<"int64_t">:$shape); + + let hasCustomAssemblyFormat = 1; +} + // Test support for ElementsAttrInterface. def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> { let mnemonic = "i64_elements"; diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b66dfbfcf0895..e09ea10906164 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -13,9 +13,12 @@ #include "TestAttributes.h" #include "TestDialect.h" +#include "TestTypes.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Hashing.h" @@ -63,6 +66,39 @@ void CompoundAAttr::print(AsmPrinter &printer) const { // CompoundAAttr //===----------------------------------------------------------------------===// +Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess()){ + return Attribute(); + } + SmallVector shape; + if (parser.parseOptionalGreater()) { + auto parseDecimal = [&]() { + shape.emplace_back(); + auto parseResult = parser.parseOptionalDecimalInteger(shape.back()); + if (!parseResult.has_value() || failed(*parseResult)) { + parser.emitError(parser.getCurrentLocation()) << "expected an integer"; + return failure(); + } + return success(); + }; + if (failed(parseDecimal())) { + return Attribute(); + } + while (failed(parser.parseOptionalGreater())) { + if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) { + return Attribute(); + } + } + } + return get(parser.getContext(), shape); +} + +void TestDecimalShapeAttr::print(AsmPrinter &printer) const { + printer << "<"; + llvm::interleave(getShape(), printer, "x"); + printer << ">"; +} + Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { SmallVector elements; if (parser.parseLess() || parser.parseLSquare()) diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir index ee92ea06a208c..89ad3594eebd8 100644 --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func private @compoundA() // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> @@ -19,3 +19,28 @@ func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qu func.func private @overriddenAttr() attributes { foo = #test.override_builder<5> } + +// CHECK-LABEL: @decimalIntegerShapeEmpty +// CHECK-SAME: foo = #test.decimal_shape<> +func.func private @decimalIntegerShapeEmpty() attributes { + foo = #test.decimal_shape<> +} + +// CHECK-LABEL: @decimalIntegerShape +// CHECK-SAME: foo = #test.decimal_shape<5> +func.func private @decimalIntegerShape() attributes { + foo = #test.decimal_shape<5> +} + +// CHECK-LABEL: @decimalIntegerShapeMultiple +// CHECK-SAME: foo = #test.decimal_shape<0x3x7> +func.func private @decimalIntegerShapeMultiple() attributes { + foo = #test.decimal_shape<0x3x7> +} + +// ----- + +func.func private @hexdecimalInteger() attributes { +// expected-error @below {{expected an integer}} + sdg = #test.decimal_shape<1x0xb> +}