Skip to content

Commit

Permalink
Remove dead code from HloRematerialization and fix comment.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653503548
  • Loading branch information
SandSnip3r authored and tensorflower-gardener committed Jul 18, 2024
1 parent f391d84 commit 22556d3
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions third_party/xla/xla/service/hlo_rematerialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,10 @@ class HloRematerialization : public HloModulePass {
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
// and inserted into 'order'.
absl::StatusOr<bool> RematerializeComputation(HloComputation* computation,
HloSchedule* schedule,
int64_t memory_limit_bytes,
int64_t min_remat_size) {
return RematerializeComputation(computation, schedule, memory_limit_bytes,
min_remat_size, /*execution_threads=*/{});
}

// Rematerializes instructions within the given computation. 'schedule'
// constains the order in which the computation's instructions will be emitted
// in the backend. Rematerialized instructions will be added to the HLO
// computation and inserted into 'schedule'.
virtual absl::StatusOr<bool> RematerializeComputation(
HloComputation* computation, HloSchedule* schedule,
int64_t memory_limit_bytes, int64_t min_remat_size,
Expand Down

0 comments on commit 22556d3

Please sign in to comment.