Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,16 +714,27 @@ class AsmParser {
return *parseResult;
}

/// Parse a decimal integer value from the stream.
template <typename IntT>
ParseResult parseDecimalInteger(IntT &result) {
auto loc = getCurrentLocation();
OptionalParseResult parseResult = parseOptionalDecimalInteger(result);
if (!parseResult.has_value())
return emitError(loc, "expected decimal integer value");
return *parseResult;
}

/// Parse an optional integer value from the stream.
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;

template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
private:
template <typename IntT, typename ParseFn>
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
ParseFn &&parseFn) {
auto loc = getCurrentLocation();

// Parse the unsigned variant.
APInt uintResult;
OptionalParseResult parseResult = parseOptionalInteger(uintResult);
OptionalParseResult parseResult = parseFn(uintResult);
if (!parseResult.has_value() || failed(*parseResult))
return parseResult;

Expand All @@ -737,6 +748,20 @@ class AsmParser {
return success();
}

public:
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(
result, [&](APInt &result) { return parseOptionalInteger(result); });
}

template <typename IntT>
OptionalParseResult parseOptionalDecimalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(result, [&](APInt &result) {
return parseOptionalDecimalInteger(result);
});
}

/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList.
enum class Delimiter {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ class AsmParserImpl : public BaseT {
return parser.parseOptionalInteger(result);
}

/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalDecimalInteger(APInt &result) override {
return parser.parseOptionalDecimalInteger(result);
}

/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
Expand Down
40 changes: 40 additions & 0 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Alignment.h"
Expand Down Expand Up @@ -307,6 +308,45 @@ OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
return success();
}

/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
Token curToken = getToken();
if (curToken.isNot(Token::integer, Token::minus)) {
return std::nullopt;
}

bool negative = consumeIf(Token::minus);
Token curTok = getToken();
if (parseToken(Token::integer, "expected integer value")) {
return failure();
}

StringRef spelling = curTok.getSpelling();
// If the integer is in hexadecimal return only the 0. The lexer has already
// moved past the entire hexidecimal encoded integer so we reset the lex
// pointer to just past the 0 we actualy want to consume.
if (spelling[0] == '0' && spelling.size() > 1 &&
llvm::toLower(spelling[1]) == 'x') {
result = 0;
state.lex.resetPointer(spelling.data() + 1);
consumeToken();
return success();
}

if (spelling.getAsInteger(10, result))
return emitError(curTok.getLoc(), "integer value too large");

// Make sure we have a zero at the top so we return the right signedness.
if (result.isNegative())
result = result.zext(result.getBitWidth() + 1);

// Process the negative sign if present.
if (negative)
result.negate();

return success();
}

/// Parse a floating point value from an integer literal token.
ParseResult Parser::parseFloatFromIntegerLiteral(
std::optional<APFloat> &result, const Token &tok, bool isNegative,
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class Parser {
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result);

/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult parseOptionalDecimalInteger(APInt &result);

/// Parse a floating point value from an integer literal token.
ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
let mnemonic = "attr_with_trait";
}

// An attribute of a list of decimal formatted integers in similar format to shapes.
def TestDecimalShapeAttr : Test_Attr<"TestDecimalShape"> {
let mnemonic = "decimal_shape";

let parameters = (ins ArrayRefParameter<"int64_t">:$shape);

let hasCustomAssemblyFormat = 1;
}

// Test support for ElementsAttrInterface.
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
let mnemonic = "i64_elements";
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

#include "TestAttributes.h"
#include "TestDialect.h"
#include "TestTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Hashing.h"
Expand Down Expand Up @@ -63,6 +66,39 @@ void CompoundAAttr::print(AsmPrinter &printer) const {
// CompoundAAttr
//===----------------------------------------------------------------------===//

Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess()){
return Attribute();
}
SmallVector<int64_t> shape;
if (parser.parseOptionalGreater()) {
auto parseDecimal = [&]() {
shape.emplace_back();
auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
if (!parseResult.has_value() || failed(*parseResult)) {
parser.emitError(parser.getCurrentLocation()) << "expected an integer";
return failure();
}
return success();
};
if (failed(parseDecimal())) {
return Attribute();
}
while (failed(parser.parseOptionalGreater())) {
if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
return Attribute();
}
}
}
return get(parser.getContext(), shape);
}

void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
printer << "<";
llvm::interleave(getShape(), printer, "x");
printer << ">";
}

Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
SmallVector<uint64_t> elements;
if (parser.parseLess() || parser.parseLSquare())
Expand Down
27 changes: 26 additions & 1 deletion mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func private @compoundA()
// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]>
Expand All @@ -19,3 +19,28 @@ func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qu
func.func private @overriddenAttr() attributes {
foo = #test.override_builder<5>
}

// CHECK-LABEL: @decimalIntegerShapeEmpty
// CHECK-SAME: foo = #test.decimal_shape<>
func.func private @decimalIntegerShapeEmpty() attributes {
foo = #test.decimal_shape<>
}

// CHECK-LABEL: @decimalIntegerShape
// CHECK-SAME: foo = #test.decimal_shape<5>
func.func private @decimalIntegerShape() attributes {
foo = #test.decimal_shape<5>
}

// CHECK-LABEL: @decimalIntegerShapeMultiple
// CHECK-SAME: foo = #test.decimal_shape<0x3x7>
func.func private @decimalIntegerShapeMultiple() attributes {
foo = #test.decimal_shape<0x3x7>
}

// -----

func.func private @hexdecimalInteger() attributes {
// expected-error @below {{expected an integer}}
sdg = #test.decimal_shape<1x0xb>
}