Skip to content

Commit

Permalink
PR #6601: Add Comparison with False/True cases to algebraic_simplifie…
Browse files Browse the repository at this point in the history
…r HandleCompare

Imported from GitHub PR #6601

Added Comparison of PRED param with False/True constants cases to algebraic_simplifier HandleCompare

Cases:
```
IF A is PRED:

  A != false -> A        (1)
  false != A -> A        (2)
  A == true -> A         (3)
  true == A -> A         (4)
```

I used example code from `TEST_F(AlgebraicSimplifierTest, AndTrue)` for the tests

@majnemer Can you review it?
Copybara import of the project:

--
6d8ee9d by Alexander Pivovarov <[email protected]>:

Add Comparison with False/True cases to algebraic_simplifier HandleCompare

Merging this change closes #6601

COPYBARA_INTEGRATE_REVIEW=#6601 from apivovarov:compare_false_true 6d8ee9d
PiperOrigin-RevId: 578793421
  • Loading branch information
apivovarov authored and copybara-github committed Nov 2, 2023
1 parent d6ec2c8 commit ca31652
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
22 changes: 22 additions & 0 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4605,6 +4605,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
}
if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
if (compare->comparison_direction() == ComparisonDirection::kNe) {
// A != false -> A
if (IsAll(rhs, false)) {
return ReplaceInstruction(compare, lhs);
}
// false != A -> A
if (IsAll(lhs, false)) {
return ReplaceInstruction(compare, rhs);
}
} else if (compare->comparison_direction() == ComparisonDirection::kEq) {
// A == true -> A
if (IsAll(rhs, true)) {
return ReplaceInstruction(compare, lhs);
}
// true == A -> A
if (IsAll(lhs, true)) {
return ReplaceInstruction(compare, rhs);
}
}
}
return OkStatus();
}

Expand Down
84 changes: 84 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8009,6 +8009,90 @@ TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) {
.WithComparisonDirection(ComparisonDirection::kLt)));
}

// Test that A != False is simplified to A
TEST_F(AlgebraicSimplifierTest, NeFalse) {
auto m = CreateNewVerifiedModule();
Shape r0pred = ShapeUtil::MakeShape(PRED, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0pred, "param0"));
HloInstruction* const_false = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateCompare(
r0pred, param0, const_false, ComparisonDirection::kNe));

auto computation = m->AddEntryComputationWithLayouts(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).value());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}

// Test that False != A is simplified to A
TEST_F(AlgebraicSimplifierTest, NeFalse2) {
auto m = CreateNewVerifiedModule();
Shape r0pred = ShapeUtil::MakeShape(PRED, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0pred, "param0"));
HloInstruction* const_false = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateCompare(
r0pred, const_false, param0, ComparisonDirection::kNe));

auto computation = m->AddEntryComputationWithLayouts(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).value());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}

// Test that A == True is simplified to A
TEST_F(AlgebraicSimplifierTest, EqTrue) {
auto m = CreateNewVerifiedModule();
Shape r0pred = ShapeUtil::MakeShape(PRED, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0pred, "param0"));
HloInstruction* const_true = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
builder.AddInstruction(HloInstruction::CreateCompare(
r0pred, param0, const_true, ComparisonDirection::kEq));

auto computation = m->AddEntryComputationWithLayouts(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).value());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}

// Test that True == A is simplified to A
TEST_F(AlgebraicSimplifierTest, EqTrue2) {
auto m = CreateNewVerifiedModule();
Shape r0pred = ShapeUtil::MakeShape(PRED, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0pred, "param0"));
HloInstruction* const_true = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
builder.AddInstruction(HloInstruction::CreateCompare(
r0pred, const_true, param0, ComparisonDirection::kEq));

auto computation = m->AddEntryComputationWithLayouts(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCompare);
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).value());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}

TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
// Some backends may have better performance by treating an outer product as a
// Dot, rather than a broadcast Multiply
Expand Down

0 comments on commit ca31652

Please sign in to comment.