Skip to content

Commit

Permalink
Simplify more constraints and mods.
Browse files Browse the repository at this point in the history
Additional mod simplification was a coincidence due to refactoring the
sum-splitting logic. I can factor this out into a separate CL if desired.

- enables additional vectorization when there is a
  constraint on the loop symbol that is redundant.
- fixes some inconsistencies in div/mod simplification (should
  lead to better code in some cases, see loop emitter test).
- replaces an ad-hoc simplification in fusion_emitter.cc with
  a more general one

Also remove most of the change detector tests in reduction_test.

PiperOrigin-RevId: 651294205
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Jul 11, 2024
1 parent 19ad21f commit 0c08427
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 469 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction(
auto thread_id_to_output_map = ComposeIndexingMaps(
ComposeIndexingMaps(thread_id_to_input_map, input_to_output_map),
epilogue_indexing);
thread_id_to_output_map.Simplify();

auto loop_nest_body_builder =
[&, operand_index = operand_index](
Expand Down
8 changes: 1 addition & 7 deletions third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,7 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap(
mlir::AffineMap::get(/*dimCount=*/6,
/*symbolCount=*/2, output_dims, ctx),
dim_vars, range_vars, /*rt_vars=*/{});
// Remove the unroll_elem_id symbol if unrolling divides num_elements.
if (num_elements % unroll_factor == 0) {
indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}),
Interval{0, num_elements - unroll_factor});
} else {
indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1});
}
indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1});
indexing_map.Simplify();
return indexing_map;
}
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100,
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
)
domain:
th_x in [0, 128)
Expand All @@ -67,7 +67,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
bl_z in [0, 1)
chunk_id in [0, 12)
unroll_id in [0, 4)
(th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997)
th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000)
)"));
}

Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100,
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
)
domain:
th_x in [0, 128)
Expand All @@ -101,7 +101,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
bl_z in [0, 1)
chunk_id in [0, 12)
unroll_id in [0, 4)
(th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997)
th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000)
)"));
}

Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
}();

AddGroupIdConstraint(map, root_index, groups_);
map.Simplify();
return map;
}

Expand All @@ -1321,6 +1322,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
GetBitcastMap(tiling_.GetXlaShape(),
hero.operand(hero_operand_index)->shape(), ctx));
AddGroupIdConstraint(map, root_index, groups_);
map.Simplify();
return map;
}

Expand Down
22 changes: 14 additions & 8 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,13 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToInputIndexing(
.indexing_maps[hero_operand_index]
.begin());
}
auto map = ComputeReductionInputIndexing(ctx);
AddGroupIdConstraint(map, root_index, groups_);
return map * GetBitcastMap(input_shape_,
hero.operand(hero_operand_index)->shape(), ctx);
auto projected_map = ComputeReductionInputIndexing(ctx);
AddGroupIdConstraint(projected_map, root_index, groups_);
auto map = projected_map *
GetBitcastMap(input_shape_,
hero.operand(hero_operand_index)->shape(), ctx);
map.Simplify();
return map;
}

std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
Expand All @@ -578,6 +581,7 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
GetBitcastMap(input_shape_, analysis_.fusion_root(root_index).shape(),
ctx));
AddGroupIdConstraint(map, root_index, groups_);
map.Simplify();
return map;
}

Expand All @@ -594,10 +598,12 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
const auto& hero = analysis_.fusion_hero(root_index).instruction();
auto physical_shape =
ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
return projected_indexing *
GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout(
PrimitiveType::U8, output_shape),
physical_shape, ctx);
auto map = projected_indexing *
GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout(
PrimitiveType::U8, output_shape),
physical_shape, ctx);
map.Simplify();
return map;
}

SmallVector<Value> MlirReductionFusion::EvaluateEpilogue(
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) {
(d0, d1, d2, d3, d4, d5)[s0, s1] -> (
d3 floordiv 11,
d0 floordiv 32 + s0 * 32,
(d3 mod 11) * 32 + d0 mod 32 + s1
(d3 mod 11) * 32 + d0 mod 32
)
domain:
d0 in [0, 1024)
Expand All @@ -681,24 +681,24 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) {
d5 in [0, 1)
s0 in [0, 33)
s1 in [0, 1)
(d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 321)
(d3 mod 11) * 32 + d0 mod 32 in [0, 321)
d0 floordiv 32 + s0 * 32 in [0, 1051)
)"));
EXPECT_THAT(
fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0] -> (
d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 + s0
d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32
)
domain:
d0 in [0, 1024)
d0 in [0, 993)
d1 in [0, 1)
d2 in [0, 1)
d3 in [0, 143)
d4 in [0, 1)
d5 in [0, 1)
s0 in [0, 1)
(d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 321)
(d3 mod 11) * 32 + d0 floordiv 32 in [0, 321)
d0 mod 32 in [0, 1)
)"));
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
Expand Down
Loading

0 comments on commit 0c08427

Please sign in to comment.