Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 80 additions & 93 deletions xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ constexpr double kMaxCostEpsilon = 1.0001;
// same amount.
constexpr double kMemoryMultiplier = 1e-6;

// Any memory terms below this threshold will be dropped (to reduce MIP size).
constexpr double kTinyTermThreshold = 1e-6;

// Any memory segments differing by this amount are skipped (reduces MIP size).
constexpr double kSimilarityThreshold = 1e-2;

bool AutoShardingSolverResult::operator==(
const AutoShardingSolverResult& other) const {
return status == other.status &&
Expand Down Expand Up @@ -216,24 +210,6 @@ AutoShardingSolverRequest ScaleRequest(
return scaled_request;
}

double MemoryDifference(
const AutoShardingSolverRequest& request,
const tsl::protobuf::RepeatedPtrField<AutoShardingSolverRequest_Costs>& c,
const absl::flat_hash_set<int64_t>& live_prev,
const absl::flat_hash_set<int64_t>& live_curr) {
double memory_diff = 0.0; // How much this segment differs from the last.
absl::flat_hash_set<int64_t> live_union;
live_union.insert(live_prev.begin(), live_prev.end());
live_union.insert(live_curr.begin(), live_curr.end());
for (int64_t idx : live_union) {
if (!live_prev.contains(idx) || !live_curr.contains(idx)) {
memory_diff +=
*std::max_element(c.at(idx).costs().begin(), c.at(idx).costs().end());
}
}
return memory_diff;
}

// Taking an auto-sharding problem (`request`) as an input, calls the OR tools
// CP-SAT solver and outputs a solution to the input problem.
//
Expand Down Expand Up @@ -486,75 +462,6 @@ AutoShardingSolverResult CallORToolsSolver(
}
// c.
if (request.memory_budget() > 0) {
int tiny_term_count = 0;
int segment_similarity_skips = 0;
absl::flat_hash_set<int64_t> live_nodes_prev, live_edges_prev;
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
// Decide whether this segment is similar enough to be skipped.
absl::flat_hash_set<int64_t> live_nodes_curr, live_edges_curr;
const auto& live_nodes = request.live(time_idx).nodes();
live_nodes_curr.insert(live_nodes.begin(), live_nodes.end());
double memory_diff = MemoryDifference(request, request.memory_costs(),
live_nodes_prev, live_nodes_curr);
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
const auto& live_edges = request.live_edges(time_idx).edges();
live_edges_curr.insert(live_edges.begin(), live_edges.end());
memory_diff += MemoryDifference(request, request.memory_edge_costs(),
live_edges_prev, live_edges_curr);
}
if (memory_diff < kSimilarityThreshold * request.memory_budget()) {
++segment_similarity_skips;
continue;
}
live_nodes_prev = live_nodes_curr;
live_edges_prev = live_edges_curr;
MPConstraint* constraint =
solver->MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(),
absl::StrCat("mem[", time_idx, "]"));
if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0);
double tiny_term_total = 0.0; // Used to trim the memory budget downward.
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
double tiny_term_max = 0.0;
for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) {
double memory_cost = request.memory_costs(node_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(s[node_idx][j]);
constraint->SetCoefficient(s[node_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
double tiny_term_max = 0.0;
for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) {
double memory_cost = request.memory_edge_costs(edge_idx).costs(j);
if (memory_cost < kTinyTermThreshold * request.memory_budget()) {
tiny_term_max = std::max(tiny_term_max, memory_cost);
if (memory_cost > 0.0) ++tiny_term_count;
continue;
}
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(e[edge_idx][j]);
constraint->SetCoefficient(e[edge_idx][j],
accumulated_coefficient + memory_cost);
}
tiny_term_total += tiny_term_max;
}
}
constraint->SetUB(kMemoryMultiplier *
(request.memory_budget() - tiny_term_total));
}
LOG(INFO) << "Number of tiny terms: " << tiny_term_count;
LOG(INFO) << "Skipped " << segment_similarity_skips << " segments out of "
<< request.live().size() << " due to similarity";
if (overbudget_var) {
solver->MutableObjective()->SetCoefficient(
overbudget_var,
Expand Down Expand Up @@ -789,6 +696,73 @@ std::vector<EdgeStrategyIdx> GetChosenEdgeStrategy(
return chosen_edge_strategy;
}

// Finds the timestep with the largest memory overbudget (-1 if no such value).
LivenessIdx FindPeakLiveness(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e) {
const std::vector<NodeStrategyIdx> chosen_node_strategy =
GetChosenNodeStrategy(request, s);
const std::vector<EdgeStrategyIdx> chosen_edge_strategy =
GetChosenEdgeStrategy(request, e);
LivenessIdx peak_time_idx = -1;
double peak_overbudget = 0.0;
for (LivenessIdx time_idx = 0; time_idx < request.live_size(); ++time_idx) {
double memory_usage = 0.0;
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
const NodeStrategyIdx j = chosen_node_strategy[node_idx];
memory_usage += request.memory_costs(node_idx).costs(j);
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx];
memory_usage += request.memory_edge_costs(edge_idx).costs(j);
}
}
const double overbudget = memory_usage - request.memory_budget();
if (peak_overbudget < overbudget) {
peak_overbudget = overbudget;
peak_time_idx = time_idx;
}
}
return peak_time_idx;
}

// Imposes a new memory constraint at the given location.
void ImposeMemoryConstraint(const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
const std::vector<std::vector<MPVariable*>>& e,
const MPVariable* overbudget_var,
const MPVariable* makespan_var, MPSolver& solver,
LivenessIdx time_idx) {
MPConstraint* constraint =
solver.MakeRowConstraint(-MPSolver::infinity(), MPSolver::infinity(),
absl::StrCat("mem[", time_idx, "]"));
if (overbudget_var) constraint->SetCoefficient(overbudget_var, -1.0);
for (NodeIdx node_idx : request.live(time_idx).nodes()) {
for (NodeStrategyIdx j = 0; j < s[node_idx].size(); ++j) {
double memory_cost = request.memory_costs(node_idx).costs(j);
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(s[node_idx][j]);
constraint->SetCoefficient(s[node_idx][j],
accumulated_coefficient + memory_cost);
}
}
if (!request.live_edges().empty() && request.enable_memory_edge_costs()) {
for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) {
for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) {
double memory_cost = request.memory_edge_costs(edge_idx).costs(j);
memory_cost *= kMemoryMultiplier;
const double accumulated_coefficient =
constraint->GetCoefficient(e[edge_idx][j]);
constraint->SetCoefficient(e[edge_idx][j],
accumulated_coefficient + memory_cost);
}
}
}
constraint->SetUB(kMemoryMultiplier * request.memory_budget());
}

AutoShardingSolverResult SolveAndExtractSolution(
const AutoShardingSolverRequest& request,
const std::vector<std::vector<MPVariable*>>& s,
Expand All @@ -797,6 +771,19 @@ AutoShardingSolverResult SolveAndExtractSolution(
MPSolver& solver) {
absl::Time start_time = absl::Now();
auto status = solver.Solve();
if (request.memory_budget() > 0) {
absl::flat_hash_set<LivenessIdx> peak_times;
while (status == operations_research::MPSolver::OPTIMAL) {
const LivenessIdx peak_time_idx = FindPeakLiveness(request, s, e);
if (peak_time_idx == -1 || peak_times.contains(peak_time_idx)) break;
peak_times.insert(peak_time_idx);
ImposeMemoryConstraint(request, s, e, overbudget_var, makespan_var,
solver, peak_time_idx);
status = solver.Solve();
}
LOG(INFO) << "Imposed " << peak_times.size()
<< " memory constraints out of " << request.live_size();
}
absl::Time end_time = absl::Now();
auto duration = end_time - start_time;
LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms";
Expand Down