Skip to content

Commit

Permalink
PR tensorflow#14862: Add SPMD config option to specify zero cost meth…
Browse files Browse the repository at this point in the history
…od for gather/scatter.

Imported from GitHub PR openxla/xla#14862

Issue tensorflow#13304

In SPMD handling of gather/scatter the partition strategy is hardcoded to IndexParallel strategy. This is not optimal for all topology. This PR makes this option an SPMD config, but defaults to IndexParallel to maintain existing behavior.

Clang-format also fixed some formatting. Tests were added and all tests pass.
Copybara import of the project:

--
7f83c21573f24cd4e314b13ce2e349dd6194b451 by ptoulme-aws <[email protected]>:

Add SPMD config option to specify zero cost method for gather/scatter.

Merging this change closes tensorflow#14862

PiperOrigin-RevId: 652736743
  • Loading branch information
ptoulme-aws authored and tensorflower-gardener committed Jul 16, 2024
1 parent 15c227b commit 5c3785b
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 79 deletions.
194 changes: 117 additions & 77 deletions third_party/xla/xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ limitations under the License.

namespace xla {
namespace spmd {

namespace {

using hlo_sharding_util::GroupedSharding;
PartitioningMethod gather_partition_method = PartitioningMethod::kIndexParallel;
PartitioningMethod scatter_partition_method =
PartitioningMethod::kIndexParallel;

// Generates per-group partitioned hlo based on given grouped sharding.
PartitionedHlo PerGroupPartitionedHlo(
Expand Down Expand Up @@ -723,6 +724,22 @@ GatherPartitionMethods() {
"PartitionGatherIndexPassthroughDimensions"}};
}

// Helper function to get the gather partitioning method.
decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) {
switch (method) {
case PartitioningMethod::kIndexParallel:
return PartitionGatherIndexParallelDimensions;
case PartitioningMethod::kOperandPassthrough:
return PartitionGatherOperandPassthroughDimensions;
case PartitioningMethod::kTrivialSlicedOperand:
return PartitionGatherTrivialSlicedOperandDimensions;
case PartitioningMethod::kIndexPassthrough:
return PartitionGatherIndexPassthroughDimensions;
default:
return PartitionGatherIndexParallelDimensions;
}
}

// Estimates the memory and communication cost for each partitioning methods for
// gather.
std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
Expand All @@ -731,9 +748,12 @@ std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
const PartitionedHlo& indices, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
if (partition_method == PartitionGatherIndexParallelDimensions) {
// Always prioritize index parallel partitioning, and assume it has zero
decltype(PartitionGather)* zero_cost_method =
GetGatherPartitionMethod(gather_partition_method);
if (partition_method == zero_cost_method) {
// Always prioritize the user's chosen partitioning, and assume it has zero
// cost.
// This defaults to IndexParallel.
return {0, 0};
}
return EvaluatePartitionCost(gather, partition_method, gather, operand,
Expand Down Expand Up @@ -838,6 +858,7 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
batch_dims.push_back(i);
}
}
gather_partition_method = options().gather_partition_method;
TF_ASSIGN_OR_RETURN(
HloInstruction * pgather,
PartitionGather(gather, operand, indices, gather->shape(),
Expand Down Expand Up @@ -1292,82 +1313,80 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexPassthroughDimensions(
// results.
return nullptr;
}
HloInstruction* identity;
switch (*reduction_opcode) {
case HloOpcode::kAdd:
case HloOpcode::kOr:
HloInstruction* identity;
switch (*reduction_opcode) {
case HloOpcode::kAdd:
case HloOpcode::kOr:
identity = CreateZero(per_group_operand.hlo()->shape(), b);
break;
case HloOpcode::kMultiply:
case HloOpcode::kAnd:
case HloOpcode::kMultiply:
case HloOpcode::kAnd:
identity = CreateOne(per_group_operand.hlo()->shape(), b);
break;
case HloOpcode::kMinimum:
case HloOpcode::kMinimum:
identity = CreateConstant(
per_group_operand.hlo()->shape(),
LiteralUtil::MaxValue(scatter->shape().element_type()), b);
break;
case HloOpcode::kMaximum:
case HloOpcode::kMaximum:
identity = CreateConstant(
per_group_operand.hlo()->shape(),
LiteralUtil::MinValue(scatter->shape().element_type()), b);
break;
default:
return nullptr;
}
// Update partition_id for partial replicate.
auto partition_id = indices.state().partition_id;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, b);
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
// all-reduce, we use identity operands for all non-zero partitions.
auto not_partition_zero = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(PRED), partition_id));
not_partition_zero = b->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(identity->shape(), PRED),
not_partition_zero, {}));
auto select_operand =
b->AddInstruction(HloInstruction::HloInstruction::CreateTernary(
identity->shape(), HloOpcode::kSelect, not_partition_zero, identity,
per_group_operand.hlo()));
PartitionedHlo new_operand =
per_group_operand.CloneWithNewHlo(select_operand);
std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
std::vector<PartitionedHlo> per_group_updates = {
PerGroupPartitionedHlo(updates[0], update_grouped, b, clean_ups)};
PartitionedHlo per_group_indices =
PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups);
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
TF_ASSIGN_OR_RETURN(
HloInstruction * pscatter,
PartitionScatter(
scatter, per_group_new_operands, per_group_indices,
per_group_updates, pshape,
HloSharding::Single(scatter->shape(), output_grouped.sharding),
slice_sizes, visitor, allow_recursive));
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is not sharded on index_vector_dim.
std::vector<int64_t> all_dims(indices.rank());
absl::c_iota(all_dims, 0);
auto all_reduce =
operands[0].state().partitioner->AllReduceAlongShardingDims(
b, pscatter, original_indices_sharding,
indices.state().next_channel_id, all_dims,
operands[0].state().collective_ops_creator, scatter->to_apply());
all_reduce->set_sharding(
hlo_sharding_util::UngroupSharding(output_grouped));
if (allow_recursive) {
VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough";
}
return PartitionedHlo(all_reduce, output_shape, operands[0].state())
.Reshard(output_sharding)
.hlo();
default:
return nullptr;
}
// Update partition_id for partial replicate.
auto partition_id = indices.state().partition_id;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, b);
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
// all-reduce, we use identity operands for all non-zero partitions.
auto not_partition_zero = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(PRED), partition_id));
not_partition_zero = b->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(identity->shape(), PRED), not_partition_zero,
{}));
auto select_operand =
b->AddInstruction(HloInstruction::HloInstruction::CreateTernary(
identity->shape(), HloOpcode::kSelect, not_partition_zero, identity,
per_group_operand.hlo()));
PartitionedHlo new_operand =
per_group_operand.CloneWithNewHlo(select_operand);
std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
std::vector<PartitionedHlo> per_group_updates = {
PerGroupPartitionedHlo(updates[0], update_grouped, b, clean_ups)};
PartitionedHlo per_group_indices =
PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups);
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
TF_ASSIGN_OR_RETURN(
HloInstruction * pscatter,
PartitionScatter(
scatter, per_group_new_operands, per_group_indices, per_group_updates,
pshape,
HloSharding::Single(scatter->shape(), output_grouped.sharding),
slice_sizes, visitor, allow_recursive));
// All-reduce along all dims in operand sharding -- this is OK because the
// operand is not sharded on index_vector_dim.
std::vector<int64_t> all_dims(indices.rank());
absl::c_iota(all_dims, 0);
auto all_reduce = operands[0].state().partitioner->AllReduceAlongShardingDims(
b, pscatter, original_indices_sharding, indices.state().next_channel_id,
all_dims, operands[0].state().collective_ops_creator,
scatter->to_apply());
all_reduce->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped));
if (allow_recursive) {
VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough";
}
return PartitionedHlo(all_reduce, output_shape, operands[0].state())
.Reshard(output_sharding)
.hlo();
}

// Partition a Scatter when its sliced in a dimension in the operand that is
Expand Down Expand Up @@ -1487,14 +1506,31 @@ absl::StatusOr<HloInstruction*> PartitionScatterTrivialSlicedOperandDimensions(
// Returns a full list of partitioning methods used for scatter.
std::vector<std::pair<decltype(PartitionScatter)*, absl::string_view>>
ScatterPartitionMethods() {
return {{PartitionScatterIndexParallelDimensions,
"PartitionScatterIndexParallelDimensions"},
{PartitionScatterOperandPassthroughDimensions,
"PartitionScatterOperandPassthroughDimensions"},
{PartitionScatterTrivialSlicedOperandDimensions,
"PartitionScatterTrivialSlicedOperandDimensions"},
{PartitionScatterIndexPassthroughDimensions,
"PartitionScatterIndexPassthroughDimensions"}};
return {{PartitionScatterIndexParallelDimensions,
"PartitionScatterIndexParallelDimensions"},
{PartitionScatterOperandPassthroughDimensions,
"PartitionScatterOperandPassthroughDimensions"},
{PartitionScatterTrivialSlicedOperandDimensions,
"PartitionScatterTrivialSlicedOperandDimensions"},
{PartitionScatterIndexPassthroughDimensions,
"PartitionScatterIndexPassthroughDimensions"}};
}

// Helper function to get the actual scatter partitioning method
decltype(PartitionScatter)* GetScatterPartitionMethod(
PartitioningMethod method) {
switch (method) {
case PartitioningMethod::kIndexParallel:
return PartitionScatterIndexParallelDimensions;
case PartitioningMethod::kOperandPassthrough:
return PartitionScatterOperandPassthroughDimensions;
case PartitioningMethod::kTrivialSlicedOperand:
return PartitionScatterTrivialSlicedOperandDimensions;
case PartitioningMethod::kIndexPassthrough:
return PartitionScatterIndexPassthroughDimensions;
default:
return PartitionScatterIndexParallelDimensions;
}
}

// Estimates the memory and communication for each partitioning methods for
Expand All @@ -1506,7 +1542,10 @@ std::pair<int64_t, int64_t> ScatterPartitionMethodCostModel(
const std::vector<PartitionedHlo>& updates, const Shape& output_shape,
const HloSharding& output_sharding, absl::Span<const int64_t> slice_sizes,
SpmdPartitioningVisitor* visitor) {
if (partition_method == PartitionScatterIndexParallelDimensions) {
decltype(PartitionScatter)* zero_cost_method =
GetScatterPartitionMethod(scatter_partition_method);

if (partition_method == zero_cost_method) {
// Always prioritize index parallel partitioning, and assume it has zero
// cost.
return {0, 0};
Expand Down Expand Up @@ -1679,6 +1718,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
break;
}
}
scatter_partition_method = options().scatter_partition_method;
std::vector<int64_t> slice_sizes = hlo_sharding_util::GetScatterSliceSize(
operands[0].base_shape(), updates[0].base_shape(), dnums);

Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ limitations under the License.
namespace xla {
namespace spmd {

// Enum representing the partitioning methods for gather and scatter.
enum class PartitioningMethod {
kIndexParallel,
kOperandPassthrough,
kTrivialSlicedOperand,
kIndexPassthrough,
};

struct SpmdPartitionerOptions {
// Always exchange halo on LHS for all convolutions. If false, backprop filter
// convolution exchanges halo on RHS.
Expand Down Expand Up @@ -100,6 +108,14 @@ struct SpmdPartitionerOptions {
// Whether disable rewrite for dots that share the same
// operand as an already rewritten windowed einsum loop.
bool disable_ag_rewrite_for_multiple_consumers = false;

// Partitioning method to prioritize for gather operations.
PartitioningMethod gather_partition_method =
PartitioningMethod::kIndexParallel;

// Partitioning method to prioritize for scatter operations.
PartitioningMethod scatter_partition_method =
PartitioningMethod::kIndexParallel;
};

// Class to wrap the computation builder to capture information during SPMD
Expand Down
Loading

0 comments on commit 5c3785b

Please sign in to comment.