diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index 8fc3538421..6a6e7e9d52 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -58,12 +58,19 @@ class StablehloTypeConverter : public TypeConverter { changed = true; } else if (bitWidth == 64) { // Convert 64 bit integer element type to 32 bit integer. - if (isa(type.getElementType())) { - elementType = IntegerType::get(context, 32); + // If element is unsigned, we explicitly assign it Unsigned semantics to + // it (like `ui32`). Otherwise, we don't explicitly use Signed semantics + // (like `si32`), but rather Signless (like `i32`) which is the default. + if (isa(elementType)) { + elementType = IntegerType::get( + context, 32, + elementType.isUnsignedInteger() + ? IntegerType::SignednessSemantics::Unsigned + : IntegerType::SignednessSemantics::Signless); changed = true; } // Convert 64 bit float element type to 32 bit float. - else if (isa(type.getElementType())) { + else if (isa(elementType)) { elementType = FloatType::getF32(context); changed = true; } diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 40574741da..642039e53d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -291,16 +291,24 @@ class StableHLOToTTIRConstantOpConversionPattern return rebuildValueAttr(valueAttr, 1); } case 8: { - return rebuildValueAttr(valueAttr, 8); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 8) + : rebuildValueAttr(valueAttr, 8); } case 16: { - return rebuildValueAttr(valueAttr, 16); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 16) + : rebuildValueAttr(valueAttr, 16); } case 32: { - return rebuildValueAttr(valueAttr, 32); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 32) + : rebuildValueAttr(valueAttr, 32); } case 64: { - return rebuildValueAttr(valueAttr, 32); + return elementType.isUnsignedInteger() + ? rebuildValueAttr(valueAttr, 32) + : rebuildValueAttr(valueAttr, 32); } default: { assert(false && "Unsupported integer type."); @@ -331,7 +339,8 @@ class StableHLOToTTIRConstantOpConversionPattern // Extract the values (using the given ElementType) and create new data // structure. This is used to convert scalars (of type boolean, int8, int16, - // int32, and int64) and tensors (of type boolean and int64). + // int32, int64, uint8, uint16, uint32, uint64) and tensors (of type boolean + // and int64). template mlir::ElementsAttr rebuildValueAttr(mlir::ElementsAttr valueAttr, size_t bitWidth) const { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb3471..dbd1e17e5c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -713,7 +713,7 @@ class ConstantOpConversionPattern Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); float fillValue = valueAttr.getElementType().isInteger() - ? getIntegerValue(valueAttr) + ? getFloatFromIntegerValue(valueAttr) : valueAttr.getSplatValue().convertToFloat(); ::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue); @@ -743,19 +743,29 @@ class ConstantOpConversionPattern return success(); } - float getIntegerValue(mlir::ElementsAttr valueAttr) const { + float getFloatFromIntegerValue(mlir::ElementsAttr valueAttr) const { size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth(); + Type elementType = valueAttr.getElementType(); + switch (bitWidth) { case 1: return static_cast(valueAttr.getSplatValue()); case 8: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 16: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 32: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); case 64: - return static_cast(valueAttr.getSplatValue()); + return elementType.isUnsignedInteger() + ? static_cast(valueAttr.getSplatValue()) + : static_cast(valueAttr.getSplatValue()); } assert(false && "Unsupported integer type."); } diff --git a/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir b/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir index 0cbe0385d0..b388daeb15 100644 --- a/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir +++ b/test/ttmlir/Conversion/ArithToStableHLO/constant_op.mlir @@ -1,15 +1,66 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s module @jit_constant attributes {} { - func.func public @test_splat() -> tensor<64xf32> { - %0 = arith.constant dense<0.3> : tensor<64xf32> - // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + func.func public @test_scalar_float() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = arith.constant dense<3.0> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xf32> + return %0 : tensor + } + + func.func public @test_splat_float() -> tensor<64xf32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32> + %0 = arith.constant dense<3.0> : tensor<64xf32> + // CHECK: return %{{[0-9]+}} : tensor<64xf32> return %0 : tensor<64xf32> } - func.func public @test_multiple() -> tensor<2x2xf32> { + func.func public @test_multiple_float() -> tensor<2x2xf32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> %0 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + // CHECK: return %{{[0-9]+}} : tensor<2x2xf32> return %0 : tensor<2x2xf32> } + + func.func public @test_scalar_int() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32> + %0 = arith.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi32> + return %0 : tensor + } + + func.func public @test_splat_int() -> tensor<64xi32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32> + %0 = arith.constant dense<3> : tensor<64xi32> + // CHECK: return %{{[0-9]+}} : tensor<64xi32> + return %0 : tensor<64xi32> + } + + func.func public @test_multiple_int() -> tensor<2x2xi32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> + %0 = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + + func.func public @test_scalar_uint() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32> + %0 = arith.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui32> + return %0 : tensor + } + + func.func public @test_splat_uint() -> tensor<64xui32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32> + %0 = arith.constant dense<3> : tensor<64xui32> + // CHECK: return %{{[0-9]+}} : tensor<64xui32> + return %0 : tensor<64xui32> + } + + func.func public @test_multiple_uint() -> tensor<2x2xui32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32> + %0 = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui32> + return %0 : tensor<2x2xui32> + } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir index eb0aa5b951..878d9f1652 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir @@ -214,4 +214,96 @@ module @jit_constant attributes {} { // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> return %0 : tensor<2x2xi64> } + + func.func public @test_uint8_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui8>}> : () -> tensor<1xui8> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui8> + return %0 : tensor + } + + func.func public @test_uint8_splat() -> tensor<64xui8> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui8>}> : () -> tensor<64xui8> + %0 = stablehlo.constant dense<3> : tensor<64xui8> + // CHECK: return %{{[0-9]+}} : tensor<64xui8> + return %0 : tensor<64xui8> + } + + func.func public @test_uint8_multiple() -> tensor<2x2xui8> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui8>}> : () -> tensor<2x2xui8> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui8> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui8> + return %0 : tensor<2x2xui8> + } + + func.func public @test_uint16_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui16>}> : () -> tensor<1xui16> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui16> + return %0 : tensor + } + + func.func public @test_uint16_splat() -> tensor<64xui16> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui16>}> : () -> tensor<64xui16> + %0 = stablehlo.constant dense<3> : tensor<64xui16> + // CHECK: return %{{[0-9]+}} : tensor<64xui16> + return %0 : tensor<64xui16> + } + + func.func public @test_uint16_multiple() -> tensor<2x2xui16> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui16>}> : () -> tensor<2x2xui16> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui16> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui16> + return %0 : tensor<2x2xui16> + } + + func.func public @test_uint32_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui32> + return %0 : tensor + } + + func.func public @test_uint32_splat() -> tensor<64xui32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32> + %0 = stablehlo.constant dense<3> : tensor<64xui32> + // CHECK: return %{{[0-9]+}} : tensor<64xui32> + return %0 : tensor<64xui32> + } + + func.func public @test_uint32_multiple() -> tensor<2x2xui32> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui32> + return %0 : tensor<2x2xui32> + } + + func.func public @test_uint64_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui32> + return %0 : tensor + } + + func.func public @test_uint64_splat() -> tensor<64xui64> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32> + %0 = stablehlo.constant dense<3> : tensor<64xui64> + // CHECK: return %{{[0-9]+}} : tensor<64xui32> + return %0 : tensor<64xui64> + } + + func.func public @test_uint64_multiple() -> tensor<2x2xui64> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui64> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui32> + return %0 : tensor<2x2xui64> + } } diff --git a/test/ttmlir/Dialect/TTNN/simple_constant.mlir b/test/ttmlir/Dialect/TTNN/simple_constant.mlir index 53de9a5ee1..b212164905 100644 --- a/test/ttmlir/Dialect/TTNN/simple_constant.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_constant.mlir @@ -18,6 +18,12 @@ module attributes {} { return %0 : tensor<64x128xi32> } + func.func @test_empty_uint() -> tensor<64x128xui32> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xui32>}> : () -> tensor<64x128xui32> + // CHECK: %{{[0-9]+}} = "ttnn.full" + return %0 : tensor<64x128xui32> + } + func.func @test_empty_bfloat16() -> tensor<64x128xbf16> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> // CHECK: %{{[0-9]+}} = "ttnn.full" @@ -54,6 +60,14 @@ module attributes {} { return %0 : tensor<64x128xi32> } + func.func @test_full_uint() -> tensor<64x128xui32> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xui32 + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xui32>}> : () -> tensor<64x128xui32> + return %0 : tensor<64x128xui32> + } + func.func @test_full_bfloat16() -> tensor<64x128xbf16> { // CHECK: %{{[0-9]+}} = "ttnn.full" // CHECK-SAME: fillValue = 1.000000e+00 : f32 diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_ui16.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui16.mlir new file mode 100644 index 0000000000..0144da3084 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui16.mlir @@ -0,0 +1,43 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_constant attributes {} { + func.func public @test_uint16_scalar() -> tensor { + // CHECK-LABEL: func.func public @test_uint16_scalar + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<1xui16 + %0 = stablehlo.constant dense<3> : tensor + return %0 : tensor + } + + func.func public @test_uint16_scalar_empty() -> tensor { + // CHECK-LABEL: func.func public @test_uint16_scalar_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<1xui16 + %0 = stablehlo.constant dense<0> : tensor + return %0 : tensor + } + + func.func public @test_uint16_empty() -> tensor<64x128xui16> { + // CHECK-LABEL: func.func public @test_uint16_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<64x128xui16 + %0 = stablehlo.constant dense<0> : tensor<64x128xui16> + return %0 : tensor<64x128xui16> + } + + func.func public @test_uint16_splat() -> tensor<64x128xui16> { + // CHECK-LABEL: func.func public @test_uint16_splat + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<64x128xui16 + %0 = stablehlo.constant dense<3> : tensor<64x128xui16> + return %0 : tensor<64x128xui16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_ui32.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui32.mlir new file mode 100644 index 0000000000..029d45f9ce --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui32.mlir @@ -0,0 +1,43 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_constant attributes {} { + func.func public @test_uint32_scalar() -> tensor { + // CHECK-LABEL: func.func public @test_uint32_scalar + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<1xui32 + %0 = stablehlo.constant dense<3> : tensor + return %0 : tensor + } + + func.func public @test_uint32_scalar_empty() -> tensor { + // CHECK-LABEL: func.func public @test_uint32_scalar_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<1xui32 + %0 = stablehlo.constant dense<0> : tensor + return %0 : tensor + } + + func.func public @test_uint32_empty() -> tensor<64x128xui32> { + // CHECK-LABEL: func.func public @test_uint32_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<64x128xui32 + %0 = stablehlo.constant dense<0> : tensor<64x128xui32> + return %0 : tensor<64x128xui32> + } + + func.func public @test_uint32_splat() -> tensor<64x128xui32> { + // CHECK-LABEL: func.func public @test_uint32_splat + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<64x128xui32 + %0 = stablehlo.constant dense<3> : tensor<64x128xui32> + return %0 : tensor<64x128xui32> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_ui64.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui64.mlir new file mode 100644 index 0000000000..8f24ebdfe8 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_ui64.mlir @@ -0,0 +1,43 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_constant attributes {} { + func.func public @test_uint64_scalar() -> tensor { + // CHECK-LABEL: func.func public @test_uint64_scalar + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<1xui32 + %0 = stablehlo.constant dense<3> : tensor + return %0 : tensor + } + + func.func public @test_uint64_scalar_empty() -> tensor { + // CHECK-LABEL: func.func public @test_uint64_scalar_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<1xui32 + %0 = stablehlo.constant dense<0> : tensor + return %0 : tensor + } + + func.func public @test_uint64_empty() -> tensor<64x128xui32> { + // CHECK-LABEL: func.func public @test_uint64_empty + // CHECK: ttnn.full + // CHECK-SAME: -> tensor<64x128xui32 + %0 = stablehlo.constant dense<0> : tensor<64x128xui32> + return %0 : tensor<64x128xui32> + } + + func.func public @test_uint64_splat() -> tensor<64x128xui32> { + // CHECK-LABEL: func.func public @test_uint64_splat + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<64x128xui32 + %0 = stablehlo.constant dense<3> : tensor<64x128xui32> + return %0 : tensor<64x128xui32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_constant.mlir b/test/ttmlir/Silicon/TTNN/simple_constant.mlir index 35728f0a93..5e135e4521 100644 --- a/test/ttmlir/Silicon/TTNN/simple_constant.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_constant.mlir @@ -8,6 +8,12 @@ module @sysmem_creation attributes {} { return %0 : tensor<64x128xi32> } + func.func @test_empty_uint() -> tensor<64x128xui32> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xui32>}> : () -> tensor<64x128xui32> + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + return %0 : tensor<64x128xui32> + } + func.func @test_empty_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] @@ -20,13 +26,18 @@ module @sysmem_creation attributes {} { return %0 : tensor<1x1xf32> } - func.func @test_full_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xi32> } + func.func @test_full_uint() -> tensor<64x128xui32> { + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xui32>}> : () -> tensor<64x128xui32> + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + return %0 : tensor<64x128xui32> + } + func.func @test_full_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]