Skip to content

Support XPU for dygraph auto-parallel #70997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ void SubMeshToGlobalReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call SubMeshToGlobalReshardFunction Eval";
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(::common::errors::Unimplemented(
"Not supported PSendKernel/PRecv on xpu yet."));
#else
const TensorDistAttr& in_dist_attr = in.dist_attr();
const ProcessMesh& in_process_mesh = in_dist_attr.process_mesh();
const ProcessMesh& out_process_mesh = out_dist_attr.process_mesh();
Expand Down Expand Up @@ -132,6 +136,7 @@ void SubMeshToGlobalReshardFunction::Eval(phi::DeviceContext* dev_ctx,
GetMutableTensor(out));
}
SetDistProps(out, in.dims(), out_dist_attr);
#endif
}

} // namespace phi::distributed
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,18 @@ void ReshardPToSWithPadding(DeviceContext* dev_ctx,
}

DenseTensor out_reduce_scatter;
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(::common::errors::Unimplemented(
"Not supported Reducescatter on xpu yet."));
#else
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
ReduceScatter,
dtype,
process_ids,
in_reduce_scatter,
static_cast<int64_t>(process_ids.size()),
&out_reduce_scatter);

#endif
DenseTensor out_result;
if (split_axis != 0) {
RESHARD_FUNCTOR(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ void RToXExpandReshardFunction::Eval(phi::DeviceContext* dev_ctx,
int64_t cur_global_rank = GetCurGlobalRank();
int64_t root_rank = in_process_ids[0];
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids);
bool dynamic_shape = true;
auto dtype = in.dtype();
const auto& out_partial_status = out_dist_attr.partial_status();
bool cur_rank_in_out_mesh =
Expand All @@ -72,27 +71,37 @@ void RToXExpandReshardFunction::Eval(phi::DeviceContext* dev_ctx,
if (root_rank == cur_global_rank) {
for (const auto& out_process_id : out_process_ids) {
if (out_process_id != root_rank) {
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(::common::errors::Unimplemented(
"Not supported PSendKernel on xpu yet."));
#else
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PSendKernel,
dtype,
all_process_ids,
in.value(),
out_process_id,
dynamic_shape);
/*dynamic_shape=*/true);
#endif
}
}
if (cur_rank_in_out_mesh) {
result_value = in.value();
}
} else {
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(
::common::errors::Unimplemented("Not supported PRecv on xpu yet."));
#else
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PRecv,
dtype,
all_process_ids,
root_rank,
{} /*out_shape*/,
dynamic_shape,
/*dynamic_shape=*/true,
&result_value);
#endif
}

if (cur_rank_in_out_mesh) {
Expand Down
44 changes: 30 additions & 14 deletions paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,41 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
PADDLE_THROW(common::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else if (phi::CustomContext::classof(&dev_ctx)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CommContextManager::CreateXCCLCommContext(
store, unique_comm_key, dev_ctx.GetPlace(), rank, world_size);
#endif
} else {
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
else if (phi::GPUContext::classof(&dev_ctx)) { // NOLINT
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (phi::GPUContext::classof(&dev_ctx)) {
CommContextManager::CreateNCCLCommContext(store,
unique_comm_key,
static_cast<int>(rank),
static_cast<int>(world_size));
}
CommContextManager::CreateNCCLCommContext(store,
unique_comm_key,
static_cast<int>(rank),
static_cast<int>(world_size));
#else
PADDLE_THROW(common::errors::Unimplemented(
"CommContext is only supported on CPU and GPU for now, other devices "
"will be supported later."));
"Cannot use nccl on GPU, please turn WITH_NCCL flag on."));
#endif
}
#elif defined(PADDLE_WITH_XPU)
else if (phi::XPUContext::classof(&dev_ctx)) { // NOLINT
#if defined(PADDLE_WITH_XPU_BKCL)
CommContextManager::CreateBKCLCommContext(store,
unique_comm_key,
static_cast<int>(rank),
static_cast<int>(world_size));
#else
PADDLE_THROW(common::errors::Unimplemented(
"Cannot use xpu on GPU, please turn WITH_XPU_BKCL flag on."));
#endif
}
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
else if (phi::CustomContext::classof(&dev_ctx)) { // NOLINT
CommContextManager::CreateXCCLCommContext(
store, unique_comm_key, dev_ctx.GetPlace(), rank, world_size);
}
#endif
else { // NOLINT
PADDLE_THROW(common::errors::Unimplemented(
"CommContext is only supported CPU, GPU, XPU, and CustomDevice."));
}
}

auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);
Expand Down
120 changes: 79 additions & 41 deletions paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,63 @@ phi::DDim InferShapeForReshardFromReplicate(
const TensorDistAttr& dist_attr);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_GPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(common::errors::Unimplemented( \
"The %s in reshard only supported on CPU and GPU for now.", \
#fn_name)); \
} \
#define DEVICE_CONTEXT GPUContext
#elif defined(PADDLE_WITH_XPU)
#define DEVICE_CONTEXT XPUContext
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
#define DEVICE_CONTEXT CustomContext
#endif

// Some reshard function supports fewer data types on xpu than on gpu. For
// example, `Transpose`, `Split`, and `Divide` do not support double type.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define PD_VISIT_RESHARD_TYPES PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES
#else
#define PD_VISIT_RESHARD_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT16, paddle::float16, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE_BFLOAT16(NAME, __VA_ARGS__) \
default: \
PD_THROW("Reshard function " #NAME \
" is not implemented" \
" for data type `", \
__dtype__, \
"`"); \
} \
}()
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else if (DEVICE_CONTEXT::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on device."; \
PD_VISIT_RESHARD_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(common::errors::Unimplemented( \
"The %s in reshard only supported on CPU, GPU, and XPU for now.", \
#fn_name)); \
} \
} while (0)
#else
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
Expand Down Expand Up @@ -130,35 +166,37 @@ phi::DDim InferShapeForReshardFromReplicate(
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
} while (0)

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name \
<< "`without DType in Resharding on CPU."; \
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name \
<< "`without DType in Resharding on GPU."; \
fn_name(static_cast<const GPUContext&>(*dev_ctx), __VA_ARGS__); \
} else { \
PADDLE_THROW(common::errors::Unimplemented( \
"The %s in reshard only supported on CPU and GPU for now.", \
#fn_name)); \
} \
} while (0)
#else
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU)
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name \
<< "`without DType in Resharding on CPU."; \
<< "`without DType in Resharding on CPU."; \
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
} else if (DEVICE_CONTEXT::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name \
<< "`without DType in Resharding on device."; \
fn_name(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), __VA_ARGS__); \
} else { \
PADDLE_THROW(common::errors::Unimplemented( \
"The %s in reshard only supported on CPU for now.", #fn_name)); \
"The %s in reshard only supported CPU, GPU, and XPU Device", \
#fn_name)); \
} \
} while (0)
#else
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name \
<< "`without DType in Resharding on CPU."; \
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
} else { \
PADDLE_THROW(common::errors::Unimplemented( \
"The %s in reshard only supported CPU, GPU, and XPU Device.", \
#fn_name)); \
} \
} while (0)
#endif

#define RESHARD_SHORTCUT_IF_FALSE(expr) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx,
// For balanced split to replicate, we need to do all gather first.
// If the input value doesn't split on axis 0, we need to split
// and concat on specific axis.
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(
::common::errors::Unimplemented("Not supported AllGather on xpu yet."));
#else
RESHARD_FUNCTOR_WITH_COMM(
dev_ctx, AllGather, dtype, process_ids, in, num_of_process, out);
#endif

if (split_axis != 0 || padding_nums != 0) {
IntArray sections(std::vector<int64_t>(num_of_process, in.dims()[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
}

// 2. use all to all to switch data to other ranks
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(
::common::errors::Unimplemented("Not supported AllToAll on xpu yet."));
#else
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
AllToAll,
dtype,
in_process_ids,
in_all_to_all,
GetMutableTensor(out));
#endif

// 3. postprocess, reshape and transpose the output tensor
if (in_split_axis != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,6 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const auto& out_process_mesh = out_dist_attr.process_mesh();
const auto& out_process_ids = out_process_mesh.process_ids();
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids);
auto dtype = in.dtype();
// TODO(liyurui): Use dynamic shape will lead to poor performance, but we
// don't have any other good idea now. For the following reasons:
// 1. We can not ensure the meta being right deduce by the infermeta.
// 2. The meta of some kernels can't decide in compile time.
// 3. DenseTensor with empty value only need infermeta and skip the real
// kernel execution.
bool dynamic_shape = true;

// TODO(GhostScreaming): After cross-mesh reshard, current device may
// needs to execute next layer. When it construct next layer's backward
Expand All @@ -86,28 +78,44 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
int64_t src = iter.first;
int64_t dst = iter.second;
if (src == cur_global_rank) {
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(::common::errors::Unimplemented(
"Not supported PSendKernel on xpu yet."));
#else
VLOG(3) << "Send from src " << src << " to dst " << dst;
int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst);
// Since send kernel only has input, so we don't need to infermeta
// actually. According to this reason, just use the kernel directly.
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PSendKernel,
dtype,
in.dtype(),
all_process_ids,
in.value(),
dst_local_rank,
dynamic_shape);
/*dynamic_shape=*/true);
// TODO(liyurui): Use dynamic shape will lead to poor performance, but we
// don't have any other good idea now. For the following reasons:
// 1. We can not ensure the meta being right deduce by the infermeta.
// 2. The meta of some kernels can't decide in compile time.
// 3. DenseTensor with empty value only need infermeta and skip the real
// kernel execution.
#endif
} else if (dst == cur_global_rank) {
#if defined(PADDLE_WITH_XPU)
PADDLE_THROW(::common::errors::Unimplemented(
"Not supported PRecvKernel on xpu yet."));
#else
VLOG(3) << "Recv from src " << src << " to dst " << dst;
int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src);
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PRecv,
dtype,
in.dtype(),
all_process_ids,
src_local_rank,
{} /*out_shape*/,
dynamic_shape,
/*dynamic_shape=*/true,
GetMutableTensor(out));
#endif
}
}
SetDistProps(out, in.dims(), out_dist_attr);
Expand Down
Loading