diff --git a/xla/hlo/analysis/while_loop_analysis.cc b/xla/hlo/analysis/while_loop_analysis.cc index 6e69f2f277ad9..ccccda556676f 100644 --- a/xla/hlo/analysis/while_loop_analysis.cc +++ b/xla/hlo/analysis/while_loop_analysis.cc @@ -177,7 +177,7 @@ static std::optional GetUniqueGTEDependenceIndex( [](const HloInstruction* inst) -> ReplaceType { return ReplaceType::kReplaceParam; }, - /*cross_computation=*/false, /*inline_calls_and_fusions=*/false, + /*cross_computation=*/false, /*inline_calls_and_fusions=*/true, /*run_verifier=*/false); HloComputation* entry = extracted->entry_computation(); diff --git a/xla/hlo/analysis/while_loop_analysis_test.cc b/xla/hlo/analysis/while_loop_analysis_test.cc index ab69ff36512a6..7539acd531395 100644 --- a/xla/hlo/analysis/while_loop_analysis_test.cc +++ b/xla/hlo/analysis/while_loop_analysis_test.cc @@ -953,5 +953,46 @@ TEST_F(WhileLoopAnalysisTest, EXPECT_EQ(trip_count, std::nullopt); } +TEST_F(WhileLoopAnalysisTest, GetIndvarIndexShouldWorkWhenParamIsCopied) { + const char* hlo = R"( + HloModule test + + fused_copy { + param.1 = (s32[],s32[]) parameter(0) + ROOT copy = (s32[], s32[]) copy(param.1) + } + + body { + param.1 = (s32[], s32[]) parameter(0) + copy_fusion = (s32[], s32[]) fusion(param.1), kind=kInput, calls=fused_copy + iter.1 = s32[] get-tuple-element(copy_fusion), index=0 + c.1 = s32[] constant(1) + add.1 = s32[] add(iter.1, c.1) + data.1 = s32[] get-tuple-element(copy_fusion), index=1 + ROOT tuple = (s32[], s32[]) tuple(add.1, data.1) + } + + condition { + param = (s32[], s32[]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + c.10 = s32[] constant(10) + ROOT compare = pred[] compare(iter, c.10), direction=LT + } + + ENTRY main { + c0 = s32[] constant(0) + data = s32[] parameter(0) + tuple = (s32[], s32[]) tuple(c0, data) + ROOT while = (s32[], s32[]) while(tuple), body=body, condition=condition + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = m->entry_computation()->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + EXPECT_EQ(GetLoopInductionVarTupleIdx(while_op), 0); +} + } // namespace } // namespace xla diff --git a/xla/tools/hlo_extractor.cc b/xla/tools/hlo_extractor.cc index d0a8a34f14494..bbed5c97fa89d 100644 --- a/xla/tools/hlo_extractor.cc +++ b/xla/tools/hlo_extractor.cc @@ -329,10 +329,20 @@ absl::Status Inline(HloModule* module) { for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( - instruction, HloInstruction::CreateCall( - instruction->shape(), instruction->operands(), - instruction->fused_instructions_computation()))); + HloInstruction* new_instruction = + computation->AddInstruction(HloInstruction::CreateCall( + /*shape=*/instruction->shape(), + /*operands=*/instruction->operands(), + /*computation=*/ + instruction->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation + ->ReplaceInstruction( + /*old_instruction=*/instruction, + /*new_instruction=*/new_instruction, + /*preserve_sharding=*/false, + /*relay_control_dependency=*/true, + /*remove_unused_operands=*/true) + .status()); } } }