Skip to content

Commit

Permalink
Support for uints in constant op
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Jan 15, 2025
1 parent d1a5e78 commit 876d169
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
19 changes: 14 additions & 5 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,24 @@ class StableHLOToTTIRConstantOpConversionPattern
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return rebuildValueAttr<int8_t>(valueAttr, 8);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint8_t>(valueAttr, 8)
: rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return rebuildValueAttr<int16_t>(valueAttr, 16);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint16_t>(valueAttr, 16)
: rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return rebuildValueAttr<int32_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint32_t>(valueAttr, 32)
: rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return rebuildValueAttr<int64_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint64_t>(valueAttr, 32)
: rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
Expand Down Expand Up @@ -331,7 +339,8 @@ class StableHLOToTTIRConstantOpConversionPattern

// Extract the values (using the given ElementType) and create new data
// structure. This is used to convert scalars (of type boolean, int8, int16,
// int32, and int64) and tensors (of type boolean and int64).
// int32, int64, uint8, uint16, uint32, uint64) and tensors (of type boolean
// and int64).
template <typename ElementType>
mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr,
size_t bitWidth) const {
Expand Down
22 changes: 16 additions & 6 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ class ConstantOpConversionPattern
Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(valueAttr)
? getFloatFromIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();

::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
Expand Down Expand Up @@ -743,19 +743,29 @@ class ConstantOpConversionPattern
return success();
}

float getIntegerValue(mlir::ElementsAttr valueAttr) const {
float getFloatFromIntegerValue(mlir::ElementsAttr valueAttr) const {
size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth();
Type elementType = valueAttr.getElementType();

switch (bitWidth) {
case 1:
return static_cast<float>(valueAttr.getSplatValue<bool>());
case 8:
return static_cast<float>(valueAttr.getSplatValue<int8_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint8_t>())
: static_cast<float>(valueAttr.getSplatValue<int8_t>());
case 16:
return static_cast<float>(valueAttr.getSplatValue<int16_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint16_t>())
: static_cast<float>(valueAttr.getSplatValue<int16_t>());
case 32:
return static_cast<float>(valueAttr.getSplatValue<int>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint32_t>())
: static_cast<float>(valueAttr.getSplatValue<int32_t>());
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint64_t>())
: static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
assert(false && "Unsupported integer type.");
}
Expand Down

0 comments on commit 876d169

Please sign in to comment.