Skip to content

Commit

Permalink
Preserve the frontend attributes associated with an HLO when partitio…
Browse files Browse the repository at this point in the history
…ning it into a partitioned HLO through the spmd_partitioner pass.

As part of this change, we broke down the `SpmdBuilder::AddInstruction` into multiple smaller functions.

PiperOrigin-RevId: 684223463
  • Loading branch information
tensorflower-gardener committed Oct 10, 2024
1 parent ffda817 commit e0e9dca
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 112 deletions.
268 changes: 156 additions & 112 deletions third_party/xla/xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,135 +246,179 @@ HloInstruction* SpmdBuilder::AddInstruction(
HloInstruction* hlo =
HloComputation::Builder::AddInstruction(std::move(instruction));
if (visiting_hlo_) {
hlo->set_metadata(visiting_hlo_->metadata());
std::shared_ptr<const HloSharding> prev_sharding = hlo->sharding_ptr();
visiting_hlo_->SetupDerivedInstruction(hlo);
if (prev_sharding != nullptr) {
hlo->set_sharding(*prev_sharding);
} else {
hlo->clear_sharding();
}
instructions_[visiting_hlo_].push_back(hlo);
}
if (hlo->opcode() == HloOpcode::kBroadcast) {
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_linear_search(hlo->dimensions(), i)) {
broadcast_dims_[hlo].insert(i);
SetBroadcastDimsForAddedHlo(*hlo);
return hlo;
}

void SpmdBuilder::SetBroadcastDimsForAddedHlo(const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kBroadcast) {
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_linear_search(hlo.dimensions(), i)) {
broadcast_dims_[&hlo].insert(i);
}
}
}
if (hlo->IsElementwise() && hlo->operand_count() > 0 &&
if (hlo.IsElementwise() && hlo.operand_count() > 0 &&
// Copy can have a tuple result.
hlo->shape().IsArray()) {
absl::flat_hash_set<int64_t> broadcast_dims;
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
broadcast_dims.insert(i);
hlo.shape().IsArray()) {
SetBroadcastDimsForElementwise(hlo);
}
if (hlo.opcode() == HloOpcode::kTranspose) {
SetBroadcastDimsForTranspose(hlo);
}
if (hlo.opcode() == HloOpcode::kReshape &&
Product(hlo.shape().dimensions()) > 0) {
SetBroadcastDimsForReshape(hlo);
}
if (hlo.opcode() == HloOpcode::kSlice ||
hlo.opcode() == HloOpcode::kDynamicSlice) {
SetBroadcastDimsForSlice(hlo);
}
if (hlo.opcode() == HloOpcode::kPad) {
SetBroadcastDimsForPad(hlo);
}
}

void SpmdBuilder::SetBroadcastDimsForReshape(const HloInstruction& hlo) {
CHECK(hlo.opcode() == HloOpcode::kReshape);

auto it = broadcast_dims_.find(hlo.operand(0));
if (it == broadcast_dims_.end()) {
return;
}
std::vector<int64_t> iota_dims(hlo.shape().rank());
absl::c_iota(iota_dims, 0);
absl::flat_hash_set<int64_t> reshape_broadcast_dims(iota_dims.begin(),
iota_dims.end());

absl::Span<const int64_t> operand_dims = hlo.operand(0)->shape().dimensions();
absl::Span<const int64_t> hlo_dims = hlo.shape().dimensions();
std::vector<int64_t> before_dim_size_stack(operand_dims.rbegin(),
operand_dims.rend());
std::vector<int64_t> after_dim_size_stack(hlo_dims.rbegin(), hlo_dims.rend());

auto erase_reshape_broadcast_dims = [&reshape_broadcast_dims](int64_t from,
int64_t to) {
for (int64_t i = from; i < to; ++i) {
reshape_broadcast_dims.erase(i);
}
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
auto it = broadcast_dims_.find(hlo->operand(i));
if (it == broadcast_dims_.end()) {
broadcast_dims.clear();
break;
}
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
if (!it->second.contains(i)) {
broadcast_dims.erase(i);
}
}
};

while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
int64_t before_size = before_dim_size_stack.back();
int64_t after_size = after_dim_size_stack.back();
int64_t current_before_dim =
hlo.operand(0)->shape().rank() - before_dim_size_stack.size();
int64_t current_after_dim =
hlo.shape().rank() - after_dim_size_stack.size();
before_dim_size_stack.pop_back();
after_dim_size_stack.pop_back();
if (!it->second.contains(current_before_dim)) {
reshape_broadcast_dims.erase(current_after_dim);
}
if (before_size == after_size) {
continue;
}
if (!broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(broadcast_dims);
if (before_size % after_size == 0) {
// Split dim.
before_dim_size_stack.push_back(before_size / after_size);
} else if (after_size % before_size == 0) {
// Merge dim.
after_dim_size_stack.push_back(after_size / before_size);
} else {
// Other cases, mark all remaining dims as non-broadcast.
erase_reshape_broadcast_dims(current_after_dim, hlo.shape().rank());
break;
}
}
if (hlo->opcode() == HloOpcode::kTranspose) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64_t> xpose_broadcast_dims;
std::vector<int64_t> reverse_map(hlo->shape().rank());
for (int64_t i = 0; i < reverse_map.size(); ++i) {
reverse_map[hlo->dimensions(i)] = i;
}
for (int64_t dim : it->second) {
xpose_broadcast_dims.insert(reverse_map[dim]);
}
broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
}

bool has_broadcast_dims = !reshape_broadcast_dims.empty() &&
before_dim_size_stack.empty() &&
after_dim_size_stack.empty();
if (has_broadcast_dims) {
broadcast_dims_[&hlo] = std::move(reshape_broadcast_dims);
}
if (hlo->opcode() == HloOpcode::kReshape &&
Product(hlo->shape().dimensions()) > 0) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64_t> reshape_broadcast_dims;
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
reshape_broadcast_dims.insert(i);
}
std::vector<int64_t> before_dim_size_stack;
std::vector<int64_t> after_dim_size_stack;
const int64_t operand0_rank = hlo->operand(0)->shape().rank();
const int64_t hlo_shape_rank = hlo->shape().rank();
before_dim_size_stack.reserve(operand0_rank);
after_dim_size_stack.reserve(hlo_shape_rank);
for (int64_t i = operand0_rank - 1; i >= 0; --i) {
before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
}
for (int64_t i = hlo_shape_rank - 1; i >= 0; --i) {
after_dim_size_stack.push_back(hlo->shape().dimensions(i));
}
while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
int64_t before_size = before_dim_size_stack.back();
int64_t after_size = after_dim_size_stack.back();
int64_t current_before_dim =
hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
int64_t current_after_dim =
hlo->shape().rank() - after_dim_size_stack.size();
before_dim_size_stack.pop_back();
after_dim_size_stack.pop_back();
if (!it->second.contains(current_before_dim)) {
reshape_broadcast_dims.erase(current_after_dim);
}
if (before_size == after_size) {
continue;
}
if (before_size % after_size == 0) {
// Split dim.
before_dim_size_stack.push_back(before_size / after_size);
} else if (after_size % before_size == 0) {
// Merge dim.
after_dim_size_stack.push_back(after_size / before_size);
} else {
// Other cases, mark all remaining dims as non-broadcast.
for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
reshape_broadcast_dims.erase(i);
}
break;
}
}
if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
reshape_broadcast_dims.clear();
}
if (!reshape_broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
}
}
}

void SpmdBuilder::SetBroadcastDimsForTranspose(const HloInstruction& hlo) {
CHECK(hlo.opcode() == HloOpcode::kTranspose);
auto it = broadcast_dims_.find(hlo.operand(0));
if (it == broadcast_dims_.end()) {
return;
}
absl::flat_hash_set<int64_t> xpose_broadcast_dims;
std::vector<int64_t> reverse_map(hlo.shape().rank());
for (int64_t i = 0; i < reverse_map.size(); ++i) {
reverse_map[hlo.dimensions(i)] = i;
}
for (int64_t dim : it->second) {
xpose_broadcast_dims.insert(reverse_map[dim]);
}
broadcast_dims_[&hlo] = std::move(xpose_broadcast_dims);
}

void SpmdBuilder::SetBroadcastDimsForPad(const HloInstruction& hlo) {
CHECK(hlo.opcode() == HloOpcode::kPad);
auto it = broadcast_dims_.find(hlo.operand(0));
if (it == broadcast_dims_.end()) {
return;
}
if (hlo->opcode() == HloOpcode::kSlice ||
hlo->opcode() == HloOpcode::kDynamicSlice) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
auto dims = it->second;
broadcast_dims_[hlo] = std::move(dims);
absl::flat_hash_set<int64_t> pad_broadcast_dims;
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
const auto& dim = hlo.padding_config().dimensions(i);
if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
dim.interior_padding() == 0 && it->second.contains(i)) {
pad_broadcast_dims.insert(i);
}
}
if (hlo->opcode() == HloOpcode::kPad) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64_t> pad_broadcast_dims;
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
const auto& dim = hlo->padding_config().dimensions(i);
if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
dim.interior_padding() == 0 && it->second.contains(i)) {
pad_broadcast_dims.insert(i);
}
}
if (!pad_broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
if (!pad_broadcast_dims.empty()) {
broadcast_dims_[&hlo] = std::move(pad_broadcast_dims);
}
}

void SpmdBuilder::SetBroadcastDimsForSlice(const HloInstruction& hlo) {
CHECK(hlo.opcode() == HloOpcode::kSlice ||
hlo.opcode() == HloOpcode::kDynamicSlice);
auto it = broadcast_dims_.find(hlo.operand(0));
if (it != broadcast_dims_.end()) {
auto dims = it->second;
broadcast_dims_[&hlo] = std::move(dims);
}
}

void SpmdBuilder::SetBroadcastDimsForElementwise(const HloInstruction& hlo) {
CHECK(hlo.IsElementwise());
if (hlo.operand_count() == 0 || hlo.shape().IsTuple()) {
return;
}
absl::flat_hash_set<int64_t> broadcast_dims;
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
broadcast_dims.insert(i);
}
for (int64_t i = 0; i < hlo.operand_count(); ++i) {
auto it = broadcast_dims_.find(hlo.operand(i));
if (it == broadcast_dims_.end()) {
broadcast_dims.clear();
break;
}
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
if (!it->second.contains(i)) {
broadcast_dims.erase(i);
}
}
}
return hlo;
if (!broadcast_dims.empty()) {
broadcast_dims_[&hlo] = std::move(broadcast_dims);
}
}

PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target,
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class SpmdBuilder : public HloComputation::Builder {
}

private:
// Sets the broadcast dims for the newly added/created hlo.
void SetBroadcastDimsForAddedHlo(const HloInstruction& hlo);

void SetBroadcastDimsForReshape(const HloInstruction& hlo);

void SetBroadcastDimsForTranspose(const HloInstruction& hlo);

void SetBroadcastDimsForPad(const HloInstruction& hlo);

void SetBroadcastDimsForSlice(const HloInstruction& hlo);

void SetBroadcastDimsForElementwise(const HloInstruction& hlo);
// Currently visiting instruction.
HloInstruction* visiting_hlo_;

Expand Down

0 comments on commit e0e9dca

Please sign in to comment.