diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 2a09737f66c62..27868ee689510 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -70,12 +70,6 @@ constexpr double kMaxCostEpsilon = 1.0001; // same amount. constexpr double kMemoryMultiplier = 1e-6; -// Any memory terms below this threshold will be dropped (to reduce MIP size). -constexpr double kTinyTermThreshold = 1e-6; - -// Any memory segments differing by this amount are skipped (reduces MIP size). -constexpr double kSimilarityThreshold = 1e-2; - bool AutoShardingSolverResult::operator==( const AutoShardingSolverResult& other) const { return status == other.status && @@ -216,24 +210,6 @@ AutoShardingSolverRequest ScaleRequest( return scaled_request; } -double MemoryDifference( - const AutoShardingSolverRequest& request, - const tsl::protobuf::RepeatedPtrField& c, - const absl::flat_hash_set& live_prev, - const absl::flat_hash_set& live_curr) { - double memory_diff = 0.0; // How much this segment differs from the last. - absl::flat_hash_set live_union; - live_union.insert(live_prev.begin(), live_prev.end()); - live_union.insert(live_curr.begin(), live_curr.end()); - for (int64_t idx : live_union) { - if (!live_prev.contains(idx) || !live_curr.contains(idx)) { - memory_diff += - *std::max_element(c.at(idx).costs().begin(), c.at(idx).costs().end()); - } - } - return memory_diff; -} - // Taking an auto-sharding problem (`request`) as an input, calls the OR tools // CP-SAT solver and outputs a solution to the input problem. // @@ -486,75 +462,6 @@ AutoShardingSolverResult CallORToolsSolver( } // c. if (request.memory_budget() > 0) { - int tiny_term_count = 0; - int segment_similarity_skips = 0; - absl::flat_hash_set live_nodes_prev, live_edges_prev; - for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { - // Decide whether this segment is similar enough to be skipped. - absl::flat_hash_set live_nodes_curr, live_edges_curr; - const auto& live_nodes = request.live(time_idx).nodes(); - live_nodes_curr.insert(live_nodes.begin(), live_nodes.end()); - double memory_diff = MemoryDifference(request, request.memory_costs(), - live_nodes_prev, live_nodes_curr); - if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { - const auto& live_edges = request.live_edges(time_idx).edges(); - live_edges_curr.insert(live_edges.begin(), live_edges.end()); - memory_diff += MemoryDifference(request, request.memory_edge_costs(), - live_edges_prev, live_edges_curr); - } - if (memory_diff < kSimilarityThreshold * request.memory_budget()) { - ++segment_similarity_skips; - continue; - } - live_nodes_prev = live_nodes_curr; - live_edges_prev = live_edges_curr; - MPConstraint* constraint = - solver->MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(), - absl::StrCat("mem[", time_idx, "]")); - if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); - double tiny_term_total = 0.0; // Used to trim the memory budget downward. - for (NodeIdx node_idx : request.live(time_idx).nodes()) { - double tiny_term_max = 0.0; - for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { - double memory_cost = request.memory_costs(node_idx).costs(j); - if (memory_cost < kTinyTermThreshold * request.memory_budget()) { - tiny_term_max = std::max(tiny_term_max, memory_cost); - if (memory_cost > 0.0) ++tiny_term_count; - continue; - } - memory_cost *= kMemoryMultiplier; - const double accumulated_coefficient = - constraint->GetCoefficient(s[node_idx][j]); - constraint->SetCoefficient(s[node_idx][j], - accumulated_coefficient + memory_cost); - } - tiny_term_total += tiny_term_max; - } - if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { - for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { - double tiny_term_max = 0.0; - for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - double memory_cost = request.memory_edge_costs(edge_idx).costs(j); - if (memory_cost < kTinyTermThreshold * request.memory_budget()) { - tiny_term_max = std::max(tiny_term_max, memory_cost); - if (memory_cost > 0.0) ++tiny_term_count; - continue; - } - memory_cost *= kMemoryMultiplier; - const double accumulated_coefficient = - constraint->GetCoefficient(e[edge_idx][j]); - constraint->SetCoefficient(e[edge_idx][j], - accumulated_coefficient + memory_cost); - } - tiny_term_total += tiny_term_max; - } - } - constraint->SetUB(kMemoryMultiplier * - (request.memory_budget() - tiny_term_total)); - } - LOG(INFO) << "Number of tiny terms: " << tiny_term_count; - LOG(INFO) << "Skipped " << segment_similarity_skips << " segments out of " - << request.live().size() << " due to similarity"; if (overbudget_var) { solver->MutableObjective()->SetCoefficient( overbudget_var, @@ -789,6 +696,73 @@ std::vector GetChosenEdgeStrategy( return chosen_edge_strategy; } +// Finds the timestep with the largest memory overbudget (-1 if no such value). +LivenessIdx FindPeakLiveness(const AutoShardingSolverRequest& request, + const std::vector>& s, + const std::vector>& e) { + const std::vector chosen_node_strategy = + GetChosenNodeStrategy(request, s); + const std::vector chosen_edge_strategy = + GetChosenEdgeStrategy(request, e); + LivenessIdx peak_time_idx = -1; + double peak_overbudget = 0.0; + for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) { + double memory_usage = 0.0; + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + const NodeStrategyIdx j = chosen_node_strategy[node_idx]; + memory_usage += request.memory_costs(node_idx).costs(j); + } + if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { + for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { + const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx]; + memory_usage += request.memory_edge_costs(edge_idx).costs(j); + } + } + const double overbudget = memory_usage - request.memory_budget(); + if (peak_overbudget < overbudget) { + peak_overbudget = overbudget; + peak_time_idx = time_idx; + } + } + return peak_time_idx; +} + +// Imposes a new memory constraint at the given location. +void ImposeMemoryConstraint(const AutoShardingSolverRequest& request, + const std::vector>& s, + const std::vector>& e, + const MPVariable* overbudget_var, + const MPVariable* makespan_var, MPSolver& solver, + LivenessIdx time_idx) { + MPConstraint* constraint = + solver.MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(), + absl::StrCat("mem[", time_idx, "]")); + if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0); + for (NodeIdx node_idx : request.live(time_idx).nodes()) { + for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) { + double memory_cost = request.memory_costs(node_idx).costs(j); + memory_cost *= kMemoryMultiplier; + const double accumulated_coefficient = + constraint->GetCoefficient(s[node_idx][j]); + constraint->SetCoefficient(s[node_idx][j], + accumulated_coefficient + memory_cost); + } + } + if (!request.live_edges().empty() && request.enable_memory_edge_costs()) { + for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { + for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { + double memory_cost = request.memory_edge_costs(edge_idx).costs(j); + memory_cost *= kMemoryMultiplier; + const double accumulated_coefficient = + constraint->GetCoefficient(e[edge_idx][j]); + constraint->SetCoefficient(e[edge_idx][j], + accumulated_coefficient + memory_cost); + } + } + } + constraint->SetUB(kMemoryMultiplier * request.memory_budget()); +} + AutoShardingSolverResult SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, @@ -797,6 +771,19 @@ AutoShardingSolverResult SolveAndExtractSolution( MPSolver& solver) { absl::Time start_time = absl::Now(); auto status = solver.Solve(); + if (request.memory_budget() > 0) { + absl::flat_hash_set peak_times; + while (status == operations_research::MPSolver::OPTIMAL) { + const LivenessIdx peak_time_idx = FindPeakLiveness(request, s, e); + if (peak_time_idx == -1 || peak_times.contains(peak_time_idx)) break; + peak_times.insert(peak_time_idx); + ImposeMemoryConstraint(request, s, e, overbudget_var, makespan_var, + solver, peak_time_idx); + status = solver.Solve(); + } + LOG(INFO) << "Imposed " << peak_times.size() + << " memory constraints out of " << request.live_size(); + } absl::Time end_time = absl::Now(); auto duration = end_time - start_time; LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms";