Skip to content

Commit

Permalink
Pair of integers util function
Browse files Browse the repository at this point in the history
  • Loading branch information
azecevicTT committed Jan 13, 2025
1 parent 5b4122a commit 940ed69
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Error.h"

#include <cstdint>

Expand Down Expand Up @@ -202,6 +203,51 @@ getBroadcastDimensions(llvm::ArrayRef<int64_t> inputShape,
return broadcastShape;
}

// For a given llvm::APInt value, returns it as a C++ integer type T.
template <typename T>
inline T integerAs(const llvm::APInt &value) {
if constexpr (std::is_signed_v<T>) {
return static_cast<T>(value.getSExtValue());
} else {
static_assert(std::is_unsigned_v<T>,
"T must be signed or unsigned integer type");
return static_cast<T>(value.getZExtValue());
}
}

// For a given mlir::Attribute attr, returns a pair of integers of type
// ReturnTy. If attr is an IntegerAttr, it's interpreted as a (value(attr),
// value(attr)) pair of values, where value(attr) is of type ScalarTy. If attr
// is a DenseArrayAttr<VectorElementTy> of size 2, it's interpreted as a
// (attr[0], attr[1]) pair of values. Otherwise, returns an error message.
template <typename ScalarTy, typename VectorElementTy = ScalarTy,
typename ReturnTy = ScalarTy>
inline llvm::Expected<std::pair<ReturnTy, ReturnTy>>
getPairOfInteger(mlir::Attribute attr) {
ReturnTy x{};
ReturnTy y{};
// If attr is IntgerAttr, it's interpreted as a (attr, attr) pair of values.
if (auto value = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
x = y = integerAs<ScalarTy>(value.getValue());
// If attr is DenseArrayAttr, it's interpreted as a (attr[0], attr[1]) pair
// of values if it has size 2.
} else if (auto tuple = mlir::dyn_cast<
::mlir::detail::DenseArrayAttrImpl<VectorElementTy>>(attr);
tuple.size() == 2) {
x = tuple[0];
y = tuple[1];
// Otherwise, it's an error.
} else if (tuple) {
return llvm::createStringError(
"Expected integer or pair of integers, got tuple of size %lu",
tuple.size());
} else {
return llvm::createStringError("Unexpected attribute type");
}

return std::make_pair(x, y);
}

} // namespace ttmlir::utils

#endif

0 comments on commit 940ed69

Please sign in to comment.