From 491cd54c36889796ba6c62a9e372df65043c316b Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Wed, 15 Jan 2025 17:42:26 +0000 Subject: [PATCH] Code review --- .../StableHLOToTTIR/constant_op.mlir | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir index eb0aa5b95..83281dd2d 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir @@ -214,4 +214,99 @@ module @jit_constant attributes {} { // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> return %0 : tensor<2x2xi64> } + + func.func public @test_uint8_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui8>}> : () -> tensor<1xui8> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui8> + return %0 : tensor + } + + func.func public @test_uint8_splat() -> tensor<64xui8> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui8>}> : () -> tensor<64xui8> + %0 = stablehlo.constant dense<3> : tensor<64xui8> + // CHECK: return %{{[0-9]+}} : tensor<64xui8> + return %0 : tensor<64xui8> + } + + func.func public @test_uint8_multiple() -> tensor<2x2xui8> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui8>}> : () -> tensor<2x2xui8> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui8> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui8> + return %0 : tensor<2x2xui8> + } + + func.func public @test_uint16_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui16>}> : () -> tensor<1xui16> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui16> + return %0 : tensor + } + + func.func public @test_uint16_splat() -> tensor<64xui16> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui16>}> : () -> tensor<64xui16> + %0 = stablehlo.constant dense<3> : tensor<64xui16> + // CHECK: return %{{[0-9]+}} : tensor<64xui16> + return %0 : tensor<64xui16> + } + + func.func public @test_uint16_multiple() -> tensor<2x2xui16> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui16>}> : () -> tensor<2x2xui16> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui16> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui16> + return %0 : tensor<2x2xui16> + } + + func.func public @test_uint32_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xui32>}> : () -> tensor<1xui32> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xui32> + return %0 : tensor + } + + func.func public @test_uint32_splat() -> tensor<64xui32> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xui32>}> : () -> tensor<64xui32> + %0 = stablehlo.constant dense<3> : tensor<64xui32> + // CHECK: return %{{[0-9]+}} : tensor<64xui32> + return %0 : tensor<64xui32> + } + + func.func public @test_uint32_multiple() -> tensor<2x2xui32> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xui32>}> : () -> tensor<2x2xui32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui32> + // CHECK: return %{{[0-9]+}} : tensor<2x2xui32> + return %0 : tensor<2x2xui32> + } + + // TODO (kmitrovic) these should cast to ui32 not i32 + + func.func public @test_uint64_scalar() -> tensor { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32> + %0 = stablehlo.constant dense<3> : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi32> + return %0 : tensor + } + + func.func public @test_uint64_splat() -> tensor<64xui64> { + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<64xi32>}> : () -> tensor<64xi32> + %0 = stablehlo.constant dense<3> : tensor<64xui64> + // CHECK: return %{{[0-9]+}} : tensor<64xi32> + return %0 : tensor<64xui64> + } + + func.func public @test_uint64_multiple() -> tensor<2x2xui64> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0, 1], [2, 3]]> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> + %0 = stablehlo.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xui64> + // CHECK: return %{{[0-9]+}} : tensor<2x2xi32> + return %0 : tensor<2x2xui64> + } + }