Skip to content

Commit

Permalink
Use RunAndCheckHloRewrite in collective_permute_cycle_decomposer_test.cc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685025195
  • Loading branch information
toli-y authored and tensorflower-gardener committed Oct 12, 2024
1 parent 40da335 commit b833d6f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 91 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ xla_cc_test(
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@local_tsl//tsl/platform:statusor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ class CollectivePermuteCycleDecomposer : public HloModulePass {
return "collective-permute-cycle-decomposer";
}

using HloPassInterface::Run;
// Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,27 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

using ::testing::HasSubstr;
using CollectivePermuteCycleDecomposerTest = HloTestBase;
using Decomposer = CollectivePermuteCycleDecomposer;

TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
const absl::string_view kModuleStr = R"(
HloPrintOptions PrintOptions() {
HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_include_layout_in_shapes(false);
return options;
}

TEST_F(CollectivePermuteCycleDecomposerTest, NoCycle_NotTransformed) {
absl::string_view kHlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
Expand All @@ -47,17 +54,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_FALSE(changed);
TF_ASSERT_OK(RunAndCheckHloRewrite(kHlo, Decomposer(0), false));
}

TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
TEST_F(CollectivePermuteCycleDecomposerTest, HonorsThreshold) {
// When `size of data` > `threshold`, then it is decomposed, otherwise it
// stays as it is.
const absl::string_view kModuleStr = R"(
// u32[4,2] = 4*4*2 = 32 bytes
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[4,2] parameter(0)
Expand All @@ -66,16 +70,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
RunHloPass(CollectivePermuteCycleDecomposer(33), module.get()));
EXPECT_FALSE(changed);
TF_ASSERT_OK_AND_ASSIGN(
changed, RunHloPass(CollectivePermuteCycleDecomposer(16), module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(33), false));
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(32), true));
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(16), true));
}

TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
Expand All @@ -84,7 +81,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
// 2. They should split over the value of partition-id.
// 3. The metadata and frontend_attributes are propagated to split
// collectives.
const absl::string_view kModuleStr = R"(
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
Expand All @@ -94,30 +91,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);

TF_CHECK_OK(VerifyHloModule(module.get(), false, true));
HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_include_layout_in_shapes(false);
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
// CHECK-DAG: %[[partition_id:.+]] = u32[] partition-id()
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition_id]], %[[c0]]), direction=EQ
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
// CHECK-SAME{LITERAL}: source_target_pairs={{3,0}}, frontend_attributes={_xla_send_recv_validation={{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}, frontend_attributes={_xla_send_recv_validation={{0,7},{1,8},{2,9}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
// CHECK-DAG: }
)"));
Expand All @@ -127,7 +115,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
// For a forward cycle, this checks:
// 1. Split collectives should not have channel-id
// 2. Split collectives are combined based on replica-id.
const absl::string_view kModuleStr = R"(
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
Expand All @@ -136,17 +124,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_CHECK_OK(VerifyHloModule(module.get(), false, true));

HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_include_layout_in_shapes(false);
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
Expand All @@ -155,17 +135,17 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{3,0}}
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{0,1},{1,2},{2,3}}
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
// CHECK-DAG: }
)"));
}

TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
const absl::string_view kModuleStr = R"(
absl::string_view hlo = R"(
HloModule test
while_cond {
Expand Down Expand Up @@ -198,11 +178,8 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
HloCollectivePermuteInstruction* cp1 =
DynCast<HloCollectivePermuteInstruction>(
FindInstruction(module.get(), "cp.backward"));
Expand All @@ -222,7 +199,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
// 1. Metadata is propagated to split collectives.
// 2. Frontend attributes are accurately split.
// 3. The split collectives have channel IDs.
const absl::string_view kModuleStr = R"(
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
Expand All @@ -232,29 +209,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
RunHloPass(CollectivePermuteCycleDecomposer(0), module.get()));
EXPECT_TRUE(changed);
TF_CHECK_OK(VerifyHloModule(module.get(), true, false));
HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_include_layout_in_shapes(false);
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
// CHECK-DAG: %[[partition:.+]] = u32[] partition-id()
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition]], %[[three]]), direction=EQ
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, source_target_pairs=
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, source_target_pairs=
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
// CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
// CHECK-DAG: }
)"));
Expand All @@ -264,7 +233,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
// For backward cycle, this checks:
// 1. Split collectives do not have a channel-id
// 2. Split collectives are combined based on the value of replica-id.
const absl::string_view kModuleStr = R"(
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
Expand All @@ -273,28 +242,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule((kModuleStr)));
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_include_layout_in_shapes(false);
TF_CHECK_OK(VerifyHloModule(module.get(), false, true));
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[three]]), direction=EQ
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
// CHECK-DAG: }
)"));
Expand Down

0 comments on commit b833d6f

Please sign in to comment.