diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index c8a40b56f6..ebc3ba015b 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" #include @@ -202,6 +203,51 @@ getBroadcastDimensions(llvm::ArrayRef inputShape, return broadcastShape; } +// For a given llvm::APInt value, returns it as a C++ integer type T. +template +inline T integerAs(const llvm::APInt &value) { + if constexpr (std::is_signed_v) { + return static_cast(value.getSExtValue()); + } else { + static_assert(std::is_unsigned_v, + "T must be signed or unsigned integer type"); + return static_cast(value.getZExtValue()); + } +} + +// For a given mlir::Attribute attr, returns a pair of integers of type +// ReturnTy. If attr is an IntegerAttr, it's interpreted as a (value(attr), +// value(attr)) pair of values, where value(attr) is of type ScalarTy. If attr +// is a DenseArrayAttr of size 2, it's interpreted as a +// (attr[0], attr[1]) pair of values. Otherwise, returns an error message. +template +inline llvm::Expected> +getPairOfInteger(mlir::Attribute attr) { + ReturnTy x{}; + ReturnTy y{}; + // If attr is IntgerAttr, it's interpreted as a (attr, attr) pair of values. + if (auto value = mlir::dyn_cast(attr)) { + x = y = integerAs(value.getValue()); + // If attr is DenseArrayAttr, it's interpreted as a (attr[0], attr[1]) pair + // of values if it has size 2. + } else if (auto tuple = mlir::dyn_cast< + ::mlir::detail::DenseArrayAttrImpl>(attr); + tuple.size() == 2) { + x = tuple[0]; + y = tuple[1]; + // Otherwise, it's an error. + } else if (tuple) { + return llvm::createStringError( + "Expected integer or pair of integers, got tuple of size %lu", + tuple.size()); + } else { + return llvm::createStringError("Unexpected attribute type"); + } + + return std::make_pair(x, y); +} + } // namespace ttmlir::utils #endif