Skip to content

Commit

Permalink
Add test for broadcast of constant.
Browse files Browse the repository at this point in the history
Add another test case for the MLIR loop emitter.

PiperOrigin-RevId: 644654520
  • Loading branch information
akuegel authored and tensorflower-gardener committed Jun 19, 2024
1 parent b9359c0 commit e68ffab
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,42 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) {
)"));
}

TEST_F(MlirLoopFusionTest, Constant_Broadcast) {
auto kHloString = R"(
HloModule module
bcast {
zero = bf16[] constant(0)
ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={}
}
ENTRY entry {
ROOT %fusion = bf16[2,16,48]{2,1,0} fusion(), kind=kLoop, calls=bcast
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)>
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (((d1 * 4 + d0 floordiv 256) floordiv 3) mod 2)>
// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 64 + d0 floordiv 16) floordiv 3) mod 16)>
// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)>
// CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16>
// CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index
// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id
// CHECK: %[[BLOCK_ID:.*]] = gpu.block_id
// CHECK: %[[LINEAR:.*]] = xla_gpu.apply_indexing #[[MAP0]]
// CHECL: %[[IN_BOUNDS:.*]] = arith.cmpi sle, %[[LINEAR]], %[[UPPER_BOUND]] : index
// scf.if %[[IN_BOUNDS]]
// CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP1]]
// CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP2]]
// CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP3]]
// CHECK: %[[BCAST:.*]] = xla_gpu.pure_call @bcast_broadcast
// CHECK: %[[INSERTED:.*]] = tensor.insert %[[BCAST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]]
// CHECK: func.func private @bcast_broadcast
// CHECK: arith.constant 0.000000e+00
)"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0}));
}

TEST_F(MlirLoopFusionTest, NoCodeDuplication) {
// This test HLO is copied from
// xla/service/fusion_node_indexing_evaluation_test.cc.
Expand Down

0 comments on commit e68ffab

Please sign in to comment.