From 0c0842761b5b6b0f620ac9562c50439376fb67d0 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 11 Jul 2024 00:06:22 -0700 Subject: [PATCH] Simplify more constraints and mods. Additional mod simplification was a coincidence due to refactoring the sum-splitting logic. I can factor this out into a separate CL if desired. - enables additional vectorization when there is a constraint on the loop symbol that is redundant. - fixes some inconsistencies in div/mod simplification (should lead to better code in some cases, see loop emitter test). - replaces an ad-hoc simplification in fusion_emitter.cc with a more general one Also remove most of the change detector tests in reduction_test. PiperOrigin-RevId: 651294205 --- .../service/gpu/fusions/concatenate_mlir.cc | 1 + .../xla/service/gpu/fusions/fusion_emitter.cc | 8 +- .../xla/service/gpu/fusions/loop_mlir_test.cc | 4 +- .../xla/xla/service/gpu/fusions/loop_test.cc | 4 +- .../xla/xla/service/gpu/fusions/reduction.cc | 2 + .../xla/service/gpu/fusions/reduction_mlir.cc | 22 +- .../gpu/fusions/reduction_mlir_test.cc | 10 +- .../xla/service/gpu/fusions/reduction_test.cc | 346 +----------------- .../gpu/fusions/transpose_mlir_test.cc | 5 +- .../xla/service/gpu/fusions/transpose_test.cc | 8 +- .../xla/xla/service/gpu/model/indexing_map.cc | 257 ++++++++----- .../xla/xla/service/gpu/model/indexing_map.h | 17 +- .../service/gpu/model/indexing_map_test.cc | 58 ++- 13 files changed, 273 insertions(+), 469 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index 355030ad49fcf8..df717c95152c83 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -114,6 +114,7 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction( auto thread_id_to_output_map = ComposeIndexingMaps( ComposeIndexingMaps(thread_id_to_input_map, input_to_output_map), epilogue_indexing); + thread_id_to_output_map.Simplify(); auto loop_nest_body_builder = [&, operand_index = operand_index]( diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 3f7ac5b3db2d8e..d8cd2f7d321570 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -188,13 +188,7 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( mlir::AffineMap::get(/*dimCount=*/6, /*symbolCount=*/2, output_dims, ctx), dim_vars, range_vars, /*rt_vars=*/{}); - // Remove the unroll_elem_id symbol if unrolling divides num_elements. - if (num_elements % unroll_factor == 0) { - indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}), - Interval{0, num_elements - unroll_factor}); - } else { - indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1}); - } + indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1}); indexing_map.Simplify(); return indexing_map; } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 190b704e82d93c..4b2d838ad20866 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -56,7 +56,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100, ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id ) domain: th_x in [0, 128) @@ -67,7 +67,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { bl_z in [0, 1) chunk_id in [0, 12) unroll_id in [0, 4) - (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997) + th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_test.cc index c28491b18448a7..d4898f2a2e7fcb 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_test.cc @@ -90,7 +90,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100, ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - (th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id + ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id ) domain: th_x in [0, 128) @@ -101,7 +101,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { bl_z in [0, 1) chunk_id in [0, 12) unroll_id in [0, 4) - (th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997) + th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/reduction.cc index 8c72ff84668396..df50bf092fef64 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction.cc @@ -1295,6 +1295,7 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( }(); AddGroupIdConstraint(map, root_index, groups_); + map.Simplify(); return map; } @@ -1321,6 +1322,7 @@ std::optional ReductionInfo::ComputeThreadIdToInputIndexing( GetBitcastMap(tiling_.GetXlaShape(), hero.operand(hero_operand_index)->shape(), ctx)); AddGroupIdConstraint(map, root_index, groups_); + map.Simplify(); return map; } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index 3493aaa414d924..23830afe48ed65 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -564,10 +564,13 @@ std::optional MlirReductionFusion::ComputeThreadIdToInputIndexing( .indexing_maps[hero_operand_index] .begin()); } - auto map = ComputeReductionInputIndexing(ctx); - AddGroupIdConstraint(map, root_index, groups_); - return map * GetBitcastMap(input_shape_, - hero.operand(hero_operand_index)->shape(), ctx); + auto projected_map = ComputeReductionInputIndexing(ctx); + AddGroupIdConstraint(projected_map, root_index, groups_); + auto map = projected_map * + GetBitcastMap(input_shape_, + hero.operand(hero_operand_index)->shape(), ctx); + map.Simplify(); + return map; } std::optional MlirReductionFusion::ComputeThreadIdToOutputIndexing( @@ -578,6 +581,7 @@ std::optional MlirReductionFusion::ComputeThreadIdToOutputIndexing( GetBitcastMap(input_shape_, analysis_.fusion_root(root_index).shape(), ctx)); AddGroupIdConstraint(map, root_index, groups_); + map.Simplify(); return map; } @@ -594,10 +598,12 @@ std::optional MlirReductionFusion::ComputeThreadIdToOutputIndexing( const auto& hero = analysis_.fusion_hero(root_index).instruction(); auto physical_shape = ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape()); - return projected_indexing * - GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType::U8, output_shape), - physical_shape, ctx); + auto map = projected_indexing * + GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout( + PrimitiveType::U8, output_shape), + physical_shape, ctx); + map.Simplify(); + return map; } SmallVector MlirReductionFusion::EvaluateEpilogue( diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 491cb3df5b71cc..8a1f880dc64a68 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -670,7 +670,7 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d3 floordiv 11, d0 floordiv 32 + s0 * 32, - (d3 mod 11) * 32 + d0 mod 32 + s1 + (d3 mod 11) * 32 + d0 mod 32 ) domain: d0 in [0, 1024) @@ -681,24 +681,24 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { d5 in [0, 1) s0 in [0, 33) s1 in [0, 1) - (d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 321) + (d3 mod 11) * 32 + d0 mod 32 in [0, 321) d0 floordiv 32 + s0 * 32 in [0, 1051) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0] -> ( - d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 + s0 + d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 ) domain: - d0 in [0, 1024) + d0 in [0, 993) d1 in [0, 1) d2 in [0, 1) d3 in [0, 143) d4 in [0, 1) d5 in [0, 1) s0 in [0, 1) - (d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 321) + (d3 mod 11) * 32 + d0 floordiv 32 in [0, 321) d0 mod 32 in [0, 1) )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc index 491a80c84ad215..4ce5dacd6bdb79 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_test.cc @@ -76,9 +76,9 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - (d3 * 8 + d0 floordiv 32) floordiv 64, - (d3 * 8 + d0 floordiv 32) mod 64, - (d0 mod 32 + s2 * 32) * 2 + s3 + d3 floordiv 8, + (d3 mod 8) * 8 + d0 floordiv 32, + (d0 mod 32) * 2 + s2 * 64 + s3 ) domain: d0 in [0, 256) @@ -91,356 +91,22 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { s1 in [0, 1) s2 in [0, 8) s3 in [0, 2) - d0 mod 32 + s2 * 32 in [0, 256) - d3 * 8 + d0 floordiv 32 in [0, 6400) )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5) -> ( - (d3 * 8 + d0 floordiv 32) floordiv 64, - (d3 * 8 + d0 floordiv 32) mod 64 + d3 floordiv 8, + (d3 mod 8) * 8 + d0 floordiv 32 ) domain: - d0 in [0, 256) + d0 in [0, 225) d1 in [0, 1) d2 in [0, 1) d3 in [0, 800) d4 in [0, 1) d5 in [0, 1) d0 mod 32 in [0, 1) - d3 * 8 + d0 floordiv 32 in [0, 6400) - )")); -} - -TEST_F(ReductionTest, ThreadIndexingMultiRowReduction) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[100,64,4] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add - } - - ENTRY entry { - %input = f32[100,64,4] parameter(0) - ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 + (d0 floordiv 4) floordiv 64, - (d0 floordiv 4) mod 64, - d0 mod 4 - ) - domain: - d0 in [0, 256) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 100) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 1) - d0 mod 4 in [0, 4) - d3 * 64 + d0 floordiv 4 in [0, 6400) - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - d3 + (d0 floordiv 4) floordiv 64, - (d0 floordiv 4) mod 64 - ) - domain: - d0 in [0, 256) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 100) - d4 in [0, 1) - d5 in [0, 1) - d0 mod 4 in [0, 1) - d3 * 64 + d0 floordiv 4 in [0, 6400) - )")); -} - -TEST_F(ReductionTest, ThreadIndexingColumnReduction) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[100,64,32] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[100,32] reduce(%input, %c0), dimensions={1}, to_apply=add - } - - ENTRY entry { - %input = f32[100,64,32] parameter(0) - ROOT %fusion = f32[100,32] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3, - d0 floordiv 32 + s1 * 32, - d0 mod 32 - ) - domain: - d0 in [0, 1024) d1 in [0, 1) d2 in [0, 1) - d3 in [0, 100) d4 in [0, 1) d5 in [0, 1) - s0 in [0, 1) s1 in [0, 128) s2 in [0, 1) - d0 floordiv 32 + s1 * 32 in [0, 64) - d0 mod 32 in [0, 32) - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - d3, - d0 floordiv 32 - ) - domain: - d0 in [0, 1024) d1 in [0, 1) d2 in [0, 1) - d3 in [0, 100) d4 in [0, 1) d5 in [0, 1) - d0 mod 32 in [0, 1) - )")); -} - -TEST_F(ReductionTest, ThreadIndexingOutputLayout) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[100,64,512] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[100,64]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add - } - - ENTRY entry { - %input = f32[100,64,512] parameter(0) - ROOT %fusion = f32[100,64]{0,1} fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - (d3 * 8 + d0 floordiv 32) floordiv 64, - (d3 * 8 + d0 floordiv 32) mod 64 - ) - domain: - d0 in [0, 256) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 800) - d4 in [0, 1) - d5 in [0, 1) - d0 mod 32 in [0, 1) - d3 * 8 + d0 floordiv 32 in [0, 6400) - )")); -} - -TEST_F(ReductionTest, ThreadIndexingSideOutput) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[100,64,512] parameter(0) - %c0 = f32[] constant(0) - %log = f32[100,64,512] log(%input) - %reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add - ROOT tuple = (f32[100,64], f32[100,64,512]) tuple(%reduce, %log) - } - - ENTRY entry { - %input = f32[100,64,512] parameter(0) - ROOT %fusion = (f32[100,64], f32[100,64,512]) fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - - constexpr char kExpectedIndexing[] = R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32, - (d0 mod 32) * 2 + s2 * 64 + s3 - ) - domain: - d0 in [0, 256) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 800) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 8) - s3 in [0, 2) - )"; - auto input_indexing = - fusion.ComputeThreadIdToInputIndexing(1, 0, &mlir_context_); - input_indexing->Simplify(); - EXPECT_THAT(input_indexing->ToString(), - MatchIndexingString(kExpectedIndexing)); - auto output_indexing = - fusion.ComputeThreadIdToOutputIndexing(1, &mlir_context_); - output_indexing->Simplify(); - EXPECT_THAT(output_indexing->ToString(), - MatchIndexingString(kExpectedIndexing)); -} - -TEST_F(ReductionTest, ThreadIndexingVectorized) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - fusion { - %input = f32[1024, 8192] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[1024]{0} reduce(f32[1024, 8192] %input, f32[] %c0), - dimensions={1}, to_apply=add - } - ENTRY entry { - %input = f32[1024, 8192] parameter(0) - ROOT %fusion = f32[1024] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - d3, - (d0 + s2 * 512) * 2 + s3 - ) - domain: - d0 in [0, 512) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 1024) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 8) - s3 in [0, 2) - d0 + s2 * 512 in [0, 4096) - )")); -} - -TEST_F(ReductionTest, ThreadIndexingBroadcastSideOutput) { - auto module = ParseAndReturnVerifiedModule(R"( - %add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - %fusion { - %p0 = f32[6,6] parameter(0) - %c0 = f32[] constant(0) - %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add - %broadcast = f32[6,6] broadcast(%reduce), dimensions={} - ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce) - } - ENTRY main { - %p0 = f32[6,6] parameter(0) - ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - ReductionFusion fusion(analysis); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - (d0 + s2 * 32) floordiv 6, - (d0 + s2 * 32) mod 6 - ) - domain: - d0 in [0, 32) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 1) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 16) - d0 + s2 * 32 in [0, 36) - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> () - domain: - d0 in [0, 32) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 1) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 16) - (d0 + s2 * 32) mod 6 in [0, 6) - d0 + s2 * 32 in [0, 36) )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index 33d18cd2461ce6..dd06b695fdfdc8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -109,7 +109,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( d3 floordiv 2, - (d3 * 32 + s0 * 4) mod 64 + d0 floordiv 32, + (d3 mod 2) * 32 + s0 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -317,11 +317,12 @@ TEST_F(MlirTransposeFusionTest, FusedTranspose210) { // // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x1x33xf32> // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for - // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: %[[C0]] to %[[C5]] step %[[C1]] // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) // CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fused_computation_exp // CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc index 6255b09cca63d7..41b009b7af44f7 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_test.cc @@ -138,7 +138,7 @@ TEST_F(TransposeTest, ThreadIndexing201) { MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d3 floordiv 2, - (d3 * 32 + s1 * 4) mod 64 + d0 floordiv 32, + (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -213,10 +213,9 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d3 in [0, 2) d4 in [0, 1) d5 in [0, 1) - s0 in [0, 8) + s0 in [0, 6) s1 in [0, 1) s2 in [0, 1) - d0 floordiv 32 + s0 * 4 in [0, 24) d0 mod 32 in [0, 24) )")); EXPECT_THAT( @@ -235,10 +234,9 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d3 in [0, 2) d4 in [0, 1) d5 in [0, 1) - s0 in [0, 8) + s0 in [0, 6) s1 in [0, 1) s2 in [0, 1) - d0 floordiv 32 + s0 * 4 in [0, 24) d0 mod 32 in [0, 24) )")); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 2c64b7ba6740fb..11504ae9b2cd97 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -125,6 +125,14 @@ class AffineExprSimplifier { mlir::AffineExpr Simplify(mlir::AffineExpr expr); + // Performs AffineExpr simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintExprs(IndexingMap& map); + + // Performs range simplification for all constraints. + // Returns true if simplification was performed. + bool SimplifyConstraintRanges(IndexingMap& map); + private: std::optional GetConstantRhs(mlir::AffineExpr expr, AffineExprKind kind); @@ -162,6 +170,14 @@ class AffineExprSimplifier { mlir::AffineExpr SimplifyWithMlir(mlir::AffineExpr expr, int num_dims, int num_symbols); + bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range); + bool SimplifyConstraintRange(AffineExpr* expr, Interval* range); + bool SimplifyAddConstraint(AffineExpr* add, Interval* range); + + // Splits a nested sum into a * gcd + b. + std::tuple SplitSumByGcd( + AffineExpr sum); + mlir::AffineMap SimplifyWithMlir(mlir::AffineMap map) { llvm::SmallVector exprs; for (auto e : map.getResults()) { @@ -221,35 +237,14 @@ AffineExpr AffineExprSimplifier::RewriteMod(AffineBinaryOpExpr mod) { }); new_lhs = new_lhs + (extracted_constant % m); - Interval no_multiplier_range{0, 0}; - std::optional multiplier_gcd = std::nullopt; - VisitSummands(new_lhs, [&](AffineExpr expr) { - if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { - if (multiplier_gcd.has_value()) { - multiplier_gcd = std::gcd(*multiplier_gcd, *multiplier); - } else { - multiplier_gcd = *multiplier; - } - } else { - auto range = range_evaluator_->ComputeExpressionRange(expr); - no_multiplier_range.lower += range.lower; - no_multiplier_range.upper += range.upper; - } - }); - + auto [multiplied, multiplier_gcd, not_multiplied] = SplitSumByGcd(new_lhs); mlir::AffineExpr extracted = getAffineConstantExpr(0, mod.getContext()); - if (multiplier_gcd.has_value()) { - if (m % *multiplier_gcd == 0 && no_multiplier_range.lower >= 0 && - no_multiplier_range.upper < *multiplier_gcd) { - // Remove everything that doesn't have a multiplier. - new_lhs = MapSummands(new_lhs, [&](AffineExpr expr) { - if (GetConstantRhs(expr, AffineExprKind::Mul)) { - return expr; - } - extracted = extracted + expr; - return zero; - }); - } + if (multiplier_gcd != 1 && m % multiplier_gcd == 0 && + Interval{0, multiplier_gcd - 1}.Contains( + range_evaluator_->ComputeExpressionRange(not_multiplied))) { + // Remove everything that doesn't have a multiplier. + new_lhs = multiplied * multiplier_gcd; + extracted = not_multiplied; } return new_lhs % mod.getRHS() + extracted; @@ -291,38 +286,27 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, return expr; }); - // The gcd of all multipliers and the divisor. - int64_t multiplier_divisor_gcd = divisor; - Interval no_multiplier_range{0, 0}; std::optional inner_divisor = std::nullopt; int num_inner_divisors = 0; VisitSummands(new_dividend, [&](AffineExpr summand) { - if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) { - multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier); - } else { - no_multiplier_range = no_multiplier_range + - range_evaluator_->ComputeExpressionRange(summand); - } - if (auto divisor = GetConstantRhs(summand, AffineExprKind::FloorDiv)) { inner_divisor = divisor; ++num_inner_divisors; } }); + // Split `new_dividend` into `multiplied * multiplier_gcd + not_multiplied`. + auto [multiplied, multiplier_gcd, not_multiplied] = + SplitSumByGcd(new_dividend); + int64_t multiplier_divisor_gcd = std::gcd(divisor, multiplier_gcd); + // Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at // most `[0; 3)`, we can rewrite it to `(x * 2) / 3`, since `y` can't affect // the result. - if (no_multiplier_range.lower >= 0 && - no_multiplier_range.upper < multiplier_divisor_gcd) { - new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { - if (auto mult = GetConstantRhs(summand, AffineExprKind::Mul)) { - return GetLhs(summand) * (*mult / multiplier_divisor_gcd); - } - // This has no multiplier and we previously determined it can't affect - // the result of the division. - return zero; - }); + if (multiplier_divisor_gcd != 1 && + Interval{0, multiplier_divisor_gcd - 1}.Contains( + range_evaluator_->ComputeExpressionRange(not_multiplied))) { + new_dividend = multiplied * (multiplier_gcd / multiplier_divisor_gcd); divisor /= multiplier_divisor_gcd; } @@ -642,9 +626,51 @@ AffineMap AffineExprSimplifier::Simplify(AffineMap affine_map) { affine_map.getContext())); } -// Simplifies a constraint range, i.e. a constraint d0 + x in [lb, ub] will +bool AffineExprSimplifier::SimplifyAddConstraint(AffineExpr* add, + Interval* range) { + if (add->getKind() != AffineExprKind::Add) { + return false; + } + + auto rhs_range = range_evaluator_->ComputeExpressionRange(GetRhs(*add)); + if (rhs_range.IsPoint()) { + *add = GetLhs(*add); + range->lower -= rhs_range.lower; + range->upper -= rhs_range.lower; + return true; + } + + if (range->lower != 0) { + return false; + } + + // Split the sum into `multiplied * multiplier_gcd + not_multiplied`. + // 0 <= a * gcd + b <= ub + // 0 <= a * gcd <= ub - b + // 0 <= a <= (ub - b) floordiv gcd + // If `(ub - b) floordiv gcd` is a constant, that means the value of b is + // irrelevant to this constraint. + auto [multiplied, multiplier_gcd, not_multiplied] = SplitSumByGcd(*add); + if (multiplier_gcd == 1) { + // If we didn't split anything, there's nothing to do. + return false; + } + + Interval difference_range = + Interval{range->upper, range->upper} - + range_evaluator_->ComputeExpressionRange(not_multiplied); + if (!difference_range.FloorDiv(multiplier_gcd).IsPoint()) { + return false; + } + + *add = multiplied * multiplier_gcd; + return true; +} + +// Simplifies a constraint range, e.g. a constraint d0 + x in [lb, ub] will // become d0 in [lb - x, ub - x]. Also supports *, floorDiv. -bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { +bool AffineExprSimplifier::SimplifyConstraintRangeOnce(AffineExpr* expr, + Interval* range) { switch (expr->getKind()) { case AffineExprKind::DimId: case AffineExprKind::SymbolId: @@ -652,25 +678,20 @@ bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { case AffineExprKind::Constant: { return false; } + case AffineExprKind::Add: + return SimplifyAddConstraint(expr, range); default: { auto binary_op = mlir::cast(*expr); CHECK(binary_op); auto lhs = binary_op.getLHS(); - auto rhs = binary_op.getRHS(); - auto constant = mlir::dyn_cast(rhs); - if (!constant) { + auto rhs_range = range_evaluator_->ComputeExpressionRange(GetRhs(*expr)); + if (!rhs_range.IsPoint()) { return false; } + int64_t rhs_cst = rhs_range.lower; switch (expr->getKind()) { - case AffineExprKind::Add: { - int64_t shift = constant.getValue(); - range->lower -= shift; - range->upper -= shift; - *expr = lhs; - return true; - } case AffineExprKind::Mul: { - int64_t factor = constant.getValue(); + int64_t factor = rhs_cst; if (factor < 0) { factor *= -1; range->lower *= -1; @@ -683,7 +704,7 @@ bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { return true; } case AffineExprKind::FloorDiv: { - int64_t divisor = constant.getValue(); + int64_t divisor = rhs_cst; if (divisor < 0) { divisor *= -1; range->lower *= -1; @@ -704,7 +725,8 @@ bool SimplifyConstraintRangeOnce(AffineExpr* expr, Interval* range) { } // Repeatedly simplifies the range of the constraint. -bool SimplifyConstraintRange(AffineExpr* expr, Interval* range) { +bool AffineExprSimplifier::SimplifyConstraintRange(AffineExpr* expr, + Interval* range) { bool is_simplified = false; while (SimplifyConstraintRangeOnce(expr, range)) { is_simplified = true; @@ -894,6 +916,34 @@ Interval Interval::operator*(const Interval& rhs) const { return mul(rhs.lower).Union(mul(rhs.upper)); } +Interval Interval::operator-() const { + int64_t ub = lower == std::numeric_limits::min() + ? std::numeric_limits::max() + : -lower; + int64_t lb = upper == std::numeric_limits::max() + ? std::numeric_limits::min() + : -upper; + return Interval{lb, ub}; +} + +Interval Interval::FloorDiv(int64_t rhs) const { + auto saturate_div = [](int64_t lhs, int64_t rhs) { + constexpr int64_t kMin = std::numeric_limits::min(); + constexpr int64_t kMax = std::numeric_limits::max(); + if (lhs == kMin) { + return rhs > 0 ? kMin : kMax; + } + if (lhs == kMax) { + return rhs > 0 ? kMax : kMin; + } + return xla::gpu::FloorDiv(lhs, rhs); + }; + + int64_t a = saturate_div(lower, rhs); + int64_t b = saturate_div(upper, rhs); + return {std::min(a, b), std::max(a, b)}; +} + std::ostream& operator<<(std::ostream& out, const Interval& range) { range.Print(out); return out; @@ -1052,10 +1102,6 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { } return; } - if (SimplifyConstraintRange(&expr, &range)) { - AddConstraint(expr, range); - return; - } auto [it, inserted] = constraints_.insert({expr, range}); if (!inserted) { it->second = it->second.Intersect(range); @@ -1065,6 +1111,10 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { } } +void IndexingMap::EraseConstraint(mlir::AffineExpr expr) { + constraints_.erase(expr); +} + bool IndexingMap::ConstraintsSatisfied( ArrayRef dim_const_exprs, ArrayRef symbol_const_exprs) const { @@ -1265,10 +1315,16 @@ bool IndexingMap::Simplify() { // Simplify constraints to shrink the lower/upper bounds of dims and symbols. bool constraints_were_simplified = false; + + // Simplify affine_map using the optimized ranges. + // Potentially, we can be smarter about recreating the range_evaluator. + RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), + GetMLIRContext()); + AffineExprSimplifier simplifier(&range_evaluator); while (true) { bool did_simplify = false; - did_simplify |= SimplifyConstraintExprs(); - did_simplify |= SimplifyConstraintRanges(); + did_simplify |= simplifier.SimplifyConstraintExprs(*this); + did_simplify |= simplifier.SimplifyConstraintRanges(*this); if (!did_simplify) { break; } @@ -1276,12 +1332,7 @@ bool IndexingMap::Simplify() { } // Simplify dependent constraints. constraints_were_simplified |= MergeModConstraints(); - // Simplify affine_map using the optimized ranges. - // Potentially, we can be smarter about recreating the range_evaluator. - RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), - GetMLIRContext()); - AffineMap simplified_affine_map = - AffineExprSimplifier(&range_evaluator).Simplify(affine_map_); + AffineMap simplified_affine_map = simplifier.Simplify(affine_map_); bool affine_map_was_simplified = simplified_affine_map != affine_map_; if (affine_map_was_simplified) { affine_map_ = simplified_affine_map; @@ -1290,19 +1341,16 @@ bool IndexingMap::Simplify() { rtvars_were_eliminated; } -bool IndexingMap::SimplifyConstraintExprs() { +bool AffineExprSimplifier::SimplifyConstraintExprs(IndexingMap& map) { // Simplify affine expression in the constraints_. - RangeEvaluator range_evaluator(GetDimensionBounds(), GetSymbolBounds(), - GetMLIRContext()); - AffineExprSimplifier simplifier(&range_evaluator); std::vector to_remove; std::vector> to_add; - for (const auto& [expr, range] : constraints_) { - AffineExpr simplified = simplifier.Simplify(expr); + for (const auto& [expr, range] : map.GetConstraints()) { + AffineExpr simplified = Simplify(expr); // Skip constraints that are always satisfied. Interval evaluated_range = - range_evaluator.ComputeExpressionRange(simplified); + range_evaluator_->ComputeExpressionRange(simplified); if (evaluated_range.upper <= range.upper && evaluated_range.lower >= range.lower) { to_remove.push_back(expr); @@ -1313,18 +1361,18 @@ bool IndexingMap::SimplifyConstraintExprs() { to_remove.push_back(expr); } for (const auto& expr : to_remove) { - constraints_.erase(expr); + map.EraseConstraint(expr); } for (const auto& [expr, range] : to_add) { - AddConstraint(expr, range); + map.AddConstraint(expr, range); } return !to_add.empty(); } -bool IndexingMap::SimplifyConstraintRanges() { +bool AffineExprSimplifier::SimplifyConstraintRanges(IndexingMap& map) { std::vector to_remove; std::vector> to_add; - for (const auto& [expr, range] : constraints_) { + for (const auto& [expr, range] : map.GetConstraints()) { AffineExpr simplified_expr = expr; Interval simplified_range = range; if (SimplifyConstraintRange(&simplified_expr, &simplified_range)) { @@ -1333,14 +1381,47 @@ bool IndexingMap::SimplifyConstraintRanges() { } } for (const auto& expr : to_remove) { - constraints_.erase(expr); + map.EraseConstraint(expr); } for (const auto& [expr, range] : to_add) { - AddConstraint(expr, range); + map.AddConstraint(expr, range); } return !to_add.empty(); } +std::tuple AffineExprSimplifier::SplitSumByGcd( + AffineExpr sum) { + std::optional multiplier_gcd = std::nullopt; + AffineExpr zero = getAffineConstantExpr(0, sum.getContext()); + AffineExpr no_multiplier = zero; + VisitSummands(sum, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { + if (multiplier_gcd.has_value()) { + multiplier_gcd = std::gcd(*multiplier_gcd, *multiplier); + } else { + multiplier_gcd = *multiplier; + } + } + }); + + // If nothing had a multiplier, or the GCD was 1, there's nothing to split. + if (multiplier_gcd.value_or(1) == 1) { + return {zero, 1, sum}; + } + + auto scaled = MapSummands(sum, [&](AffineExpr expr) { + if (auto multiplier = GetConstantRhs(expr, AffineExprKind::Mul)) { + // Rescale the multiplier. + return GetLhs(expr) * (*multiplier / *multiplier_gcd); + } + // Extract the summand. + no_multiplier = no_multiplier + expr; + return zero; + }); + + return {scaled, *multiplier_gcd, no_multiplier}; +} + namespace { struct UsedParameters { diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index e35bff974af890..81a2e6d7edb720 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -57,6 +57,9 @@ struct Interval { return value >= lower && value <= upper; } + // Returns true if this interval contains the entire other interval. + bool Contains(Interval other) const { return Intersect(other) == other; } + // The result of a range comparison. We wrap std::optional in a struct to // avoid accidental implicit conversion to bool: // if (range < 42) { @@ -114,6 +117,11 @@ struct Interval { // Computes the range of the product of the two intervals. Implements // saturating semantics. Interval operator*(const Interval& rhs) const; + // Computes the range of the difference of the two intervals. Implements + // saturating semantics. + Interval operator-(const Interval& rhs) const { return *this + (-rhs); } + Interval operator-() const; + Interval FloorDiv(int64_t rhs) const; Interval min(const Interval& rhs) const { return {std::min(lower, rhs.lower), std::min(upper, rhs.upper)}; @@ -331,6 +339,7 @@ class IndexingMap { // ranges. void AddConstraint(mlir::AffineExpr expr, Interval range); void ClearConstraints() { constraints_.clear(); } + void EraseConstraint(mlir::AffineExpr expr); // Evaluates the constraints at a given point and returns `true` if all // constraints are satisfied. @@ -385,14 +394,6 @@ class IndexingMap { private: IndexingMap() = default; - // Performs AffineExpr simplification for all constraints. - // Returns true if simplification was performed. - bool SimplifyConstraintExprs(); - - // Performs range simplification for all constraints. - // Returns true if simplification was performed. - bool SimplifyConstraintRanges(); - // Merges "mod" constraints for the same AffineExpr. // Returns true if simplification was performed. bool MergeModConstraints(); diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 403d34bba7b5fb..d252974afa1ebd 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -368,6 +368,7 @@ TEST_F(IndexingMapTest, {50, 60}, {70, 20}); indexing_map.AddConstraint(ParseAffineExpr("s1 floordiv 20", &mlir_context_), Interval{2, 2}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } @@ -422,7 +423,7 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), Interval{50, 54}); - + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: @@ -431,6 +432,53 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { )")); } +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_Sum_IndependentOfSymbol) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), + {2000}, {2, 3}); + + indexing_map.AddConstraint( + ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), + Interval{0, 599}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1) + domain: + d0 in [0, 100) + s0 in [0, 2) + s1 in [0, 3) + )")); +} + +TEST_F(IndexingMapTest, + ConstraintIntervalSimplification_Sum_NotIndependentOfSymbol) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), + {2000}, {2, 3}); + + indexing_map.AddConstraint( + ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), + Interval{0, 598}); + EXPECT_FALSE(indexing_map.Simplify()); +} + +TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0)[s0] -> (d0 * 6 + s0 * 3)", &mlir_context_), {2000}, + {2}); + + indexing_map.AddConstraint(ParseAffineExpr("d0 * 6 + s0 * 3", &mlir_context_), + Interval{0, 599}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + (d0)[s0] -> (d0 * 6 + s0 * 3) + domain: + d0 in [0, 100) + s0 in [0, 2) + )")); +} + TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( @@ -438,6 +486,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), Interval{5, 11}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: @@ -453,6 +502,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), Interval{-11, -5}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: @@ -469,6 +519,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), Interval{-11, -5}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: @@ -484,6 +535,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), Interval{14, 33}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: @@ -499,6 +551,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), Interval{-11, -5}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: @@ -515,6 +568,7 @@ TEST_F(IndexingMapTest, indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), Interval{-11, -5}); + EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: @@ -783,7 +837,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( - (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 + s1 + ((s0 * 114688 + s3 * 128 + s2) mod 5000) * 4 + s1 ) domain: s0 in [0, 872)