Skip to content

Commit

Permalink
Allow propagations on reduce to occur
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674646492
  • Loading branch information
amitsabne1 authored and Google-ML-Automation committed Sep 14, 2024
1 parent 72eec79 commit dedab4f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
9 changes: 5 additions & 4 deletions xla/service/space_to_batch_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1747,11 +1747,12 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
// Support the trivial case where none of the batch and split spatial dim
// are being reduced.
return !absl::c_linear_search(reduce_dims, batch_dim) &&
!absl::c_linear_search(reduce_dims, space_dim);
if (!absl::c_linear_search(reduce_dims, batch_dim) &&
!absl::c_linear_search(reduce_dims, space_dim)) {
return true;
}

// Support only the trivial case where both batch and split spatial dim are
// being reduced
// If both batch and space dim are being reduced, propagate.
return absl::c_linear_search(reduce_dims, batch_dim) &&
absl::c_linear_search(reduce_dims, space_dim);
}
Expand Down
33 changes: 33 additions & 0 deletions xla/service/space_to_batch_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,5 +385,38 @@ TEST_F(SpaceToBatchConverterTest, ReduceDegenerateDim) {
EXPECT_THAT(root->operand(0), op::Slice());
}

TEST_F(SpaceToBatchConverterTest, PropagateOnReduce) {
std::string hlo_string = R"(
HloModule xla_computation_unknown.14
region_0.134 {
Arg_0.135 = f32[] parameter(0)
Arg_1.136 = f32[] parameter(1)
ROOT add.137 = f32[] add(Arg_0.135, Arg_1.136)
}
ENTRY main.140 {
p0 = bf16[1,512,32,128]{3,2,1,0} parameter(0)
p1 = f32[3,3,128,128]{3,2,1,0} parameter(1)
%convolution.755 = f32[1,512,32,128]{3,2,1,0}
convolution(p0, p1),
window={size=3x3 pad=1_1x1_1 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
%constant.19458 = f32[] constant(0)
ROOT %reduce.1354 = f32[128]{0} reduce(%convolution.755, %constant.19458),
dimensions={0,1,2}, to_apply=%region_0.134
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

auto computation = module->entry_computation();
SpaceToBatchConverter converter(
SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8});
ASSERT_TRUE(converter.Run(module.get()).value());

HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Reduce());
}

} // namespace
} // namespace xla

0 comments on commit dedab4f

Please sign in to comment.