Skip to content

Commit

Permalink
Support for uints in constant op
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Jan 15, 2025
1 parent d1a5e78 commit 9ee8657
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
16 changes: 12 additions & 4 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,24 @@ class StableHLOToTTIRConstantOpConversionPattern
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return rebuildValueAttr<int8_t>(valueAttr, 8);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint8_t>(valueAttr, 8)
: rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return rebuildValueAttr<int16_t>(valueAttr, 16);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint16_t>(valueAttr, 16)
: rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return rebuildValueAttr<int32_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint32_t>(valueAttr, 32)
: rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return rebuildValueAttr<int64_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint64_t>(valueAttr, 32)
: rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
Expand Down
18 changes: 14 additions & 4 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,17 +745,27 @@ class ConstantOpConversionPattern

float getIntegerValue(mlir::ElementsAttr valueAttr) const {
size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth();
Type elementType = valueAttr.getElementType();

switch (bitWidth) {
case 1:
return static_cast<float>(valueAttr.getSplatValue<bool>());
case 8:
return static_cast<float>(valueAttr.getSplatValue<int8_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint8_t>())
: static_cast<float>(valueAttr.getSplatValue<int8_t>());
case 16:
return static_cast<float>(valueAttr.getSplatValue<int16_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint16_t>())
: static_cast<float>(valueAttr.getSplatValue<int16_t>());
case 32:
return static_cast<float>(valueAttr.getSplatValue<int>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint32_t>())
: static_cast<float>(valueAttr.getSplatValue<int32_t>());
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint64_t>())
: static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
assert(false && "Unsupported integer type.");
}
Expand Down

0 comments on commit 9ee8657

Please sign in to comment.