diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 451923e0a624..4bd450c8c3eb 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -684,15 +684,41 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { constantValue = lhsInfo.getConstantValue(); } } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = op.getOperand(0).getType().template isa(); for (auto d = 0; d < rank; ++d) { - constancy.push_back( - std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), - gcd(rhsInfo.getConstancy(d), condConstancy[d]))); - divisibility.push_back( - std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); - contiguity.push_back( - std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), - gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } } if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value() && diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 8f7fffb10b32..0ea43128a510 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -276,7 +276,7 @@ tt.func @logic() { // ----- // CHECK-LABEL: @select -tt.func @select() { +tt.func @select(%arg0 : i1, %arg1 : tensor<4xi1>) { // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 @@ -293,6 +293,22 @@ tt.func @select() { %5 = arith.select %4, %3, %7 : tensor<128xi1> // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %8 = arith.select %7, %3, %2 : tensor<128xi1>, tensor<128xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 1], constant_value = + %9 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 1], constant_value = + %10 = tt.expand_dims %3 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 1], constant_value = + %11 = arith.select %arg0, %9, %10 : tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [4], constant_value = 4 + %cst = arith.constant dense<4> : tensor<4xi32> + // CHECK-NEXT: contiguity = [4], divisibility = [1073741824], constancy = [1], constant_value = + %12 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %13 = arith.muli %12, %cst : tensor<4xi32> + // CHECK-NEXT: contiguity = [4], divisibility = [16], constancy = [1], constant_value = + %14 = tt.make_range {end = 20 : i32, start = 16 : i32} : tensor<4xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %15 = arith.select %arg1, %12, %13 : tensor<4xi1>, tensor<4xi32> tt.return }