-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5d8c602
commit 1dd7293
Showing
23 changed files
with
431 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,6 +168,7 @@ enum ReductionOpType: uint32 { | |
Sum, | ||
Mean, | ||
Max, | ||
Prod, | ||
} | ||
|
||
table ReductionOp { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
82 changes: 82 additions & 0 deletions
82
test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_prod_op.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// REQUIRES: stablehlo | ||
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s | ||
module @jit_reduce_prod attributes {} { | ||
func.func public @test_reduce_prod_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32x4xf32> { | ||
// CHECK-LABEL: func.func public @test_reduce_prod_4to3dim | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [1 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10x32x4xf32> | ||
// CHECK-SAME: -> tensor<128x32x4xf32> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<128x32x4xf32> | ||
return %0 : tensor<128x32x4xf32> | ||
} | ||
|
||
func.func public @test_reduce_prod_4to0dim(%arg0: tensor<128x10x32x4xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> { | ||
// CHECK-LABEL: func.func public @test_reduce_prod_4to0dim | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10x32x4xbf16> | ||
// CHECK-SAME: -> tensor<1xbf16> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2, 3] : (tensor<128x10x32x4xbf16>, tensor<bf16>) -> tensor<bf16> | ||
return %0 : tensor<bf16> | ||
} | ||
|
||
func.func public @test_reduce_prod_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> { | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [1 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10x4xf32> | ||
// CHECK-SAME: -> tensor<128x4xf32> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x4xf32>, tensor<f32>) -> tensor<128x4xf32> | ||
return %0 : tensor<128x4xf32> | ||
} | ||
|
||
func.func public @test_reduce_prod_3to0dim(%arg0: tensor<128x10x4xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> { | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10x4xbf16> | ||
// CHECK-SAME: -> tensor<1xbf16> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2] : (tensor<128x10x4xbf16>, tensor<bf16>) -> tensor<bf16> | ||
return %0 : tensor<bf16> | ||
} | ||
|
||
func.func public @test_reduce_prod_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> { | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [1 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10xf32> | ||
// CHECK-SAME: -> tensor<128xf32> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32> | ||
return %0 : tensor<128xf32> | ||
} | ||
|
||
func.func public @test_reduce_prod_2to0dim(%arg0: tensor<128x10xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> { | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128x10xbf16> | ||
// CHECK-SAME: -> tensor<1xbf16> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1] : (tensor<128x10xbf16>, tensor<bf16>) -> tensor<bf16> | ||
return %0 : tensor<bf16> | ||
} | ||
|
||
func.func public @test_reduce_prod_1to0dim(%arg0: tensor<128xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> { | ||
// CHECK: tensor.empty | ||
// CHECK: "ttir.prod" | ||
// CHECK-SAME: dim_arg = [0 : i32] | ||
// CHECK-SAME: keep_dim = false | ||
// CHECK-SAME: tensor<128xbf16> | ||
// CHECK-SAME: -> tensor<1xbf16> | ||
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0] : (tensor<128xbf16>, tensor<bf16>) -> tensor<bf16> | ||
return %0 : tensor<bf16> | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s | ||
// Negative tests for reduce(prod) op | ||
|
||
// CHECK: error: 'ttir.prod' op TTNN only supports reduce(prod) along one dimension or all dimensions. | ||
func.func public @test_reduce_prod_multiple_dims(%arg0: tensor<128x10x32x4xf32>) -> tensor<128x32xf32> { | ||
%0 = tensor.empty() : tensor<128x32xf32> | ||
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [1 : i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128x32xf32>) -> tensor<128x32xf32> | ||
return %1 : tensor<128x32xf32> | ||
} | ||
|
||
// ----- | ||
// CHECK: error: 'ttir.prod' op TTNN only supports Reduce(prod) along all dimensions for bfloat16 datatype. | ||
func.func public @test_reduce_prod_all_dims_f32(%arg0: tensor<128x10x32x4xf32>) -> tensor<1xf32> { | ||
%0 = tensor.empty() : tensor<1xf32> | ||
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0: i32, 1 : i32, 2: i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<1xf32>) -> tensor<1xf32> | ||
return %1 : tensor<1xf32> | ||
} | ||
|
||
// ----- | ||
// error: 'ttir.prod' op Input tensor rank is greater than 4 for reduce(product). | ||
func.func public @test_reduce_prod_higher_rank(%arg0: tensor<128x10x32x4x1xf32>) -> tensor<10x32x4x1xf32> { | ||
%0 = tensor.empty() : tensor<10x32x4x1xf32> | ||
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10x32x4x1xf32>, tensor<10x32x4x1xf32>) -> tensor<10x32x4x1xf32> | ||
return %1 : tensor<10x32x4x1xf32> | ||
} |
File renamed without changes.
File renamed without changes.
Oops, something went wrong.