Skip to content

Commit

Permalink
Support for uints in constant op
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Jan 16, 2025
1 parent d129a9e commit 4e199b7
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 20 deletions.
13 changes: 10 additions & 3 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,19 @@ class StablehloTypeConverter : public TypeConverter {
changed = true;
} else if (bitWidth == 64) {
// Convert 64 bit integer element type to 32 bit integer.
if (isa<IntegerType>(type.getElementType())) {
elementType = IntegerType::get(context, 32);
// If element is unsigned, we explicitly assign it Unsigned semantics to
// it (like `ui32`). Otherwise, we don't explicitly use Signed semantics
// (like `si32`), but rather Signless (like `i32`) which is the default.
if (isa<IntegerType>(elementType)) {
elementType = IntegerType::get(
context, 32,
elementType.isUnsignedInteger()
? IntegerType::SignednessSemantics::Unsigned
: IntegerType::SignednessSemantics::Signless);
changed = true;
}
// Convert 64 bit float element type to 32 bit float.
else if (isa<FloatType>(type.getElementType())) {
else if (isa<FloatType>(elementType)) {
elementType = FloatType::getF32(context);
changed = true;
}
Expand Down
19 changes: 14 additions & 5 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,24 @@ class StableHLOToTTIRConstantOpConversionPattern
return rebuildValueAttr<bool>(valueAttr, 1);
}
case 8: {
return rebuildValueAttr<int8_t>(valueAttr, 8);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint8_t>(valueAttr, 8)
: rebuildValueAttr<int8_t>(valueAttr, 8);
}
case 16: {
return rebuildValueAttr<int16_t>(valueAttr, 16);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint16_t>(valueAttr, 16)
: rebuildValueAttr<int16_t>(valueAttr, 16);
}
case 32: {
return rebuildValueAttr<int32_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint32_t>(valueAttr, 32)
: rebuildValueAttr<int32_t>(valueAttr, 32);
}
case 64: {
return rebuildValueAttr<int64_t>(valueAttr, 32);
return elementType.isUnsignedInteger()
? rebuildValueAttr<uint64_t>(valueAttr, 32)
: rebuildValueAttr<int64_t>(valueAttr, 32);
}
default: {
assert(false && "Unsupported integer type.");
Expand Down Expand Up @@ -331,7 +339,8 @@ class StableHLOToTTIRConstantOpConversionPattern

// Extract the values (using the given ElementType) and create new data
// structure. This is used to convert scalars (of type boolean, int8, int16,
// int32, and int64) and tensors (of type boolean and int64).
// int32, int64, uint8, uint16, uint32, uint64) and tensors (of type boolean
// and int64).
template <typename ElementType>
mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr,
size_t bitWidth) const {
Expand Down
22 changes: 16 additions & 6 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ class ConstantOpConversionPattern
Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(valueAttr)
? getFloatFromIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();

::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
Expand Down Expand Up @@ -743,19 +743,29 @@ class ConstantOpConversionPattern
return success();
}

float getIntegerValue(mlir::ElementsAttr valueAttr) const {
float getFloatFromIntegerValue(mlir::ElementsAttr valueAttr) const {
size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth();
Type elementType = valueAttr.getElementType();

switch (bitWidth) {
case 1:
return static_cast<float>(valueAttr.getSplatValue<bool>());
case 8:
return static_cast<float>(valueAttr.getSplatValue<int8_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint8_t>())
: static_cast<float>(valueAttr.getSplatValue<int8_t>());
case 16:
return static_cast<float>(valueAttr.getSplatValue<int16_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint16_t>())
: static_cast<float>(valueAttr.getSplatValue<int16_t>());
case 32:
return static_cast<float>(valueAttr.getSplatValue<int>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint32_t>())
: static_cast<float>(valueAttr.getSplatValue<int32_t>());
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
return elementType.isUnsignedInteger()
? static_cast<float>(valueAttr.getSplatValue<uint64_t>())
: static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
assert(false && "Unsupported integer type.");
}
Expand Down
61 changes: 56 additions & 5 deletions test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,66 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_constant attributes {} {
func.func public @test_splat() -> tensor<64xf32> {
%0 = arith.constant dense<0.3> : tensor<64xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
func.func public @test_scalar_float() -> tensor<f32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = arith.constant dense<3.0> : tensor<f32>
// CHECK: return %{{[0-9]+}} : tensor<1xf32>
return %0 : tensor<f32>
}

func.func public @test_splat_float() -> tensor<64xf32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32>
%0 = arith.constant dense<3.0> : tensor<64xf32>
// CHECK: return %{{[0-9]+}} : tensor<64xf32>
return %0 : tensor<64xf32>
}

func.func public @test_multiple() -> tensor<2x2xf32> {
func.func public @test_multiple_float() -> tensor<2x2xf32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%0 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
// CHECK: return %{{[0-9]+}} : tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

func.func public @test_scalar_int() -> tensor<i32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
%0 = arith.constant dense<3> : tensor<i32>
// CHECK: return %{{[0-9]+}} : tensor<1xi32>
return %0 : tensor<i32>
}

func.func public @test_splat_int() -> tensor<64xi32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32>
%0 = arith.constant dense<3> : tensor<64xi32>
// CHECK: return %{{[0-9]+}} : tensor<64xi32>
return %0 : tensor<64xi32>
}

func.func public @test_multiple_int() -> tensor<2x2xi32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
%0 = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
// CHECK: return %{{[0-9]+}} : tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

func.func public @test_scalar_uint() -> tensor<ui32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32>
%0 = arith.constant dense<3> : tensor<ui32>
// CHECK: return %{{[0-9]+}} : tensor<1xui32>
return %0 : tensor<ui32>
}

func.func public @test_splat_uint() -> tensor<64xui32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32>
%0 = arith.constant dense<3> : tensor<64xui32>
// CHECK: return %{{[0-9]+}} : tensor<64xui32>
return %0 : tensor<64xui32>
}

func.func public @test_multiple_uint() -> tensor<2x2xui32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32>
%0 = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui32>
// CHECK: return %{{[0-9]+}} : tensor<2x2xui32>
return %0 : tensor<2x2xui32>
}
}
92 changes: 92 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,96 @@ module @jit_constant attributes {} {
// CHECK: return %{{[0-9]+}} : tensor<2x2xi32>
return %0 : tensor<2x2xi64>
}

func.func public @test_uint8_scalar() -> tensor<ui8> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui8>}> : () -> tensor<1xui8>
%0 = stablehlo.constant dense<3> : tensor<ui8>
// CHECK: return %{{[0-9]+}} : tensor<1xui8>
return %0 : tensor<ui8>
}

func.func public @test_uint8_splat() -> tensor<64xui8> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui8>}> : () -> tensor<64xui8>
%0 = stablehlo.constant dense<3> : tensor<64xui8>
// CHECK: return %{{[0-9]+}} : tensor<64xui8>
return %0 : tensor<64xui8>
}

func.func public @test_uint8_multiple() -> tensor<2x2xui8> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui8>}> : () -> tensor<2x2xui8>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>
// CHECK: return %{{[0-9]+}} : tensor<2x2xui8>
return %0 : tensor<2x2xui8>
}

func.func public @test_uint16_scalar() -> tensor<ui16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui16>}> : () -> tensor<1xui16>
%0 = stablehlo.constant dense<3> : tensor<ui16>
// CHECK: return %{{[0-9]+}} : tensor<1xui16>
return %0 : tensor<ui16>
}

func.func public @test_uint16_splat() -> tensor<64xui16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui16>}> : () -> tensor<64xui16>
%0 = stablehlo.constant dense<3> : tensor<64xui16>
// CHECK: return %{{[0-9]+}} : tensor<64xui16>
return %0 : tensor<64xui16>
}

func.func public @test_uint16_multiple() -> tensor<2x2xui16> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui16>}> : () -> tensor<2x2xui16>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui16>
// CHECK: return %{{[0-9]+}} : tensor<2x2xui16>
return %0 : tensor<2x2xui16>
}

func.func public @test_uint32_scalar() -> tensor<ui32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32>
%0 = stablehlo.constant dense<3> : tensor<ui32>
// CHECK: return %{{[0-9]+}} : tensor<1xui32>
return %0 : tensor<ui32>
}

func.func public @test_uint32_splat() -> tensor<64xui32> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32>
%0 = stablehlo.constant dense<3> : tensor<64xui32>
// CHECK: return %{{[0-9]+}} : tensor<64xui32>
return %0 : tensor<64xui32>
}

func.func public @test_uint32_multiple() -> tensor<2x2xui32> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui32>
// CHECK: return %{{[0-9]+}} : tensor<2x2xui32>
return %0 : tensor<2x2xui32>
}

func.func public @test_uint64_scalar() -> tensor<ui64> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32>
%0 = stablehlo.constant dense<3> : tensor<ui64>
// CHECK: return %{{[0-9]+}} : tensor<1xui32>
return %0 : tensor<ui64>
}

func.func public @test_uint64_splat() -> tensor<64xui64> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32>
%0 = stablehlo.constant dense<3> : tensor<64xui64>
// CHECK: return %{{[0-9]+}} : tensor<64xui32>
return %0 : tensor<64xui64>
}

func.func public @test_uint64_multiple() -> tensor<2x2xui64> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32>
%0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui64>
// CHECK: return %{{[0-9]+}} : tensor<2x2xui32>
return %0 : tensor<2x2xui64>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ module attributes {} {
return %0 : tensor<64x128xi32>
}

func.func @test_empty_uint() -> tensor<64x128xui32> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xui32>}> : () -> tensor<64x128xui32>
// CHECK: %{{[0-9]+}} = "ttnn.full"
return %0 : tensor<64x128xui32>
}

func.func @test_empty_bfloat16() -> tensor<64x128xbf16> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16>
// CHECK: %{{[0-9]+}} = "ttnn.full"
Expand Down Expand Up @@ -54,6 +60,14 @@ module attributes {} {
return %0 : tensor<64x128xi32>
}

func.func @test_full_uint() -> tensor<64x128xui32> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xui32
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xui32>}> : () -> tensor<64x128xui32>
return %0 : tensor<64x128xui32>
}

func.func @test_full_bfloat16() -> tensor<64x128xbf16> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
Expand Down
43 changes: 43 additions & 0 deletions test/ttmlir/Silicon/StableHLO/Constant/constant_ui16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_constant attributes {} {
func.func public @test_uint16_scalar() -> tensor<ui16> {
// CHECK-LABEL: func.func public @test_uint16_scalar
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<1xui16
%0 = stablehlo.constant dense<3> : tensor<ui16>
return %0 : tensor<ui16>
}

func.func public @test_uint16_scalar_empty() -> tensor<ui16> {
// CHECK-LABEL: func.func public @test_uint16_scalar_empty
// CHECK: ttnn.full
// CHECK-SAME: -> tensor<1xui16
%0 = stablehlo.constant dense<0> : tensor<ui16>
return %0 : tensor<ui16>
}

func.func public @test_uint16_empty() -> tensor<64x128xui16> {
// CHECK-LABEL: func.func public @test_uint16_empty
// CHECK: ttnn.full
// CHECK-SAME: -> tensor<64x128xui16
%0 = stablehlo.constant dense<0> : tensor<64x128xui16>
return %0 : tensor<64x128xui16>
}

func.func public @test_uint16_splat() -> tensor<64x128xui16> {
// CHECK-LABEL: func.func public @test_uint16_splat
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<64x128xui16
%0 = stablehlo.constant dense<3> : tensor<64x128xui16>
return %0 : tensor<64x128xui16>
}
}
43 changes: 43 additions & 0 deletions test/ttmlir/Silicon/StableHLO/Constant/constant_ui32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_constant attributes {} {
func.func public @test_uint32_scalar() -> tensor<ui32> {
// CHECK-LABEL: func.func public @test_uint32_scalar
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<1xui32
%0 = stablehlo.constant dense<3> : tensor<ui32>
return %0 : tensor<ui32>
}

func.func public @test_uint32_scalar_empty() -> tensor<ui32> {
// CHECK-LABEL: func.func public @test_uint32_scalar_empty
// CHECK: ttnn.full
// CHECK-SAME: -> tensor<1xui32
%0 = stablehlo.constant dense<0> : tensor<ui32>
return %0 : tensor<ui32>
}

func.func public @test_uint32_empty() -> tensor<64x128xui32> {
// CHECK-LABEL: func.func public @test_uint32_empty
// CHECK: ttnn.full
// CHECK-SAME: -> tensor<64x128xui32
%0 = stablehlo.constant dense<0> : tensor<64x128xui32>
return %0 : tensor<64x128xui32>
}

func.func public @test_uint32_splat() -> tensor<64x128xui32> {
// CHECK-LABEL: func.func public @test_uint32_splat
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<64x128xui32
%0 = stablehlo.constant dense<3> : tensor<64x128xui32>
return %0 : tensor<64x128xui32>
}
}
Loading

0 comments on commit 4e199b7

Please sign in to comment.