diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 40574741da..642039e53d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -291,16 +291,24 @@ class StableHLOToTTIRConstantOpConversionPattern return rebuildValueAttr(valueAttr, 1); } case 8: { - return rebuildValueAttr(valueAttr, 8); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 8) + : rebuildValueAttr(valueAttr, 8); } case 16: { - return rebuildValueAttr(valueAttr, 16); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 16) + : rebuildValueAttr(valueAttr, 16); } case 32: { - return rebuildValueAttr(valueAttr, 32); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 32) + : rebuildValueAttr(valueAttr, 32); } case 64: { - return rebuildValueAttr(valueAttr, 32); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 32) + : rebuildValueAttr(valueAttr, 32); } default: { assert(false && "Unsupported integer type."); @@ -331,7 +339,8 @@ class StableHLOToTTIRConstantOpConversionPattern // Extract the values (using the given ElementType) and create new data // structure. This is used to convert scalars (of type boolean, int8, int16, - // int32, and int64) and tensors (of type boolean and int64). + // int32, int64, uint8, uint16, uint32, uint64) and tensors (of type boolean + // and int64). template mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr, size_t bitWidth) const { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb3471..dbd1e17e5c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -713,7 +713,7 @@ class ConstantOpConversionPattern Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); float fillValue = valueAttr.getElementType().isInteger() - ? getIntegerValue(valueAttr) + ? getFloatFromIntegerValue(valueAttr) : valueAttr.getSplatValue().convertToFloat(); ::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue); @@ -743,19 +743,29 @@ class ConstantOpConversionPattern return success(); } - float getIntegerValue(mlir::ElementsAttr valueAttr) const { + float getFloatFromIntegerValue(mlir::ElementsAttr valueAttr) const { size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth(); + Type elementType = valueAttr.getElementType(); + switch (bitWidth) { case 1: return static_cast(valueAttr.getSplatValue()); case 8: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 16: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 32: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 64: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); } assert(false && "Unsupported integer type."); }