From ca31652cdbeb6ea187589dea546ff8019274f8b2 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 2 Nov 2023 03:23:25 -0700 Subject: [PATCH] PR #6601: Add Comparison with False/True cases to algebraic_simplifier HandleCompare Imported from GitHub PR https://github.com/openxla/xla/pull/6601 Added Comparison of PRED param with False/True constants cases to algebraic_simplifier HandleCompare Cases: ``` IF A is PRED: A != false -> A (1) false != A -> A (2) A == true -> A (3) true == A -> A (4) ``` I used example code from `TEST_F(AlgebraicSimplifierTest, AndTrue)` for the tests @majnemer Can you review it? Copybara import of the project: -- 6d8ee9d51ec2d0f14e05e8e1840b8444c179883e by Alexander Pivovarov : Add Comparison with False/True cases to algebraic_simplifier HandleCompare Merging this change closes #6601 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/6601 from apivovarov:compare_false_true 6d8ee9d51ec2d0f14e05e8e1840b8444c179883e PiperOrigin-RevId: 578793421 --- xla/service/algebraic_simplifier.cc | 22 +++++++ xla/service/algebraic_simplifier_test.cc | 84 ++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index ca341f27b0ca5..829115546a4dc 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -4605,6 +4605,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { return ReplaceInstruction(compare, MakeScalarLike(compare, true)); } } + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + if (compare->comparison_direction() == ComparisonDirection::kNe) { + // A != false -> A + if (IsAll(rhs, false)) { + return ReplaceInstruction(compare, lhs); + } + // false != A -> A + if (IsAll(lhs, false)) { + return ReplaceInstruction(compare, rhs); + } + } else if (compare->comparison_direction() == ComparisonDirection::kEq) { + // A == true -> A + if (IsAll(rhs, true)) { + return ReplaceInstruction(compare, lhs); + } + // true == A -> A + if (IsAll(lhs, true)) { + return ReplaceInstruction(compare, rhs); + } + } + } return OkStatus(); } diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 4084e650294be..c29c3331fea33 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -8009,6 +8009,90 @@ TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) { .WithComparisonDirection(ComparisonDirection::kLt))); } +// Test that A != False is simplified to A +TEST_F(AlgebraicSimplifierTest, NeFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateCompare( + r0pred, param0, const_false, ComparisonDirection::kNe)); + + auto computation = m->AddEntryComputationWithLayouts(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCompare); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that False != A is simplified to A +TEST_F(AlgebraicSimplifierTest, NeFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateCompare( + r0pred, const_false, param0, ComparisonDirection::kNe)); + + auto computation = m->AddEntryComputationWithLayouts(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCompare); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A == True is simplified to A +TEST_F(AlgebraicSimplifierTest, EqTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateCompare( + r0pred, param0, const_true, ComparisonDirection::kEq)); + + auto computation = m->AddEntryComputationWithLayouts(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCompare); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that True == A is simplified to A +TEST_F(AlgebraicSimplifierTest, EqTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateCompare( + r0pred, const_true, param0, ComparisonDirection::kEq)); + + auto computation = m->AddEntryComputationWithLayouts(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCompare); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply