Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

heter for collective #37613

Merged
merged 6 commits into from
Dec 6, 2021
Merged

heter for collective #37613

merged 6 commits into from
Dec 6, 2021

Conversation

kuizhiqing
Copy link
Member

PR types

New features

PR changes

Others

Describe

Heterogenous mix training represents the model training with heterogenous hardwares. Dygraph mode is only supported now. GPU/NPU/XPU are targeting devices for this prototype work.

The basic idea is very similar as the use of hierarchical communication topology. The low layer reduce the data within each node, while the upper layer reduce across all global nodes.

image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着都是动态图的,不能支持静态图吧?

@@ -176,6 +176,11 @@ void GLOOParallelContext::AllReduce(const framework::SelectedRows &src,
}
}

void GLOOParallelContext::BroadCast(framework::Variable *src, int ring_id) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcast?Broadcast是一个单词。另外,这个接口没有实现,为什么还要添加这个接口呢?

@@ -47,6 +47,8 @@ class GLOOParallelContext : public ParallelContext {
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;

void BroadCast(framework::Variable* src, int ring_id) override;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 同上。
  2. gloo接口为什么需要传入ring_id?

@@ -158,6 +158,29 @@ void HCCLParallelContext::AllReduceByStream(const framework::Variable &src,
}
}

void HCCLParallelContext::BroadCast(framework::Variable *src, int ring_id) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadCast -> Broadcast?

@@ -127,6 +135,20 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
}

void NCCLParallelContext::BroadCast(framework::Variable *src, int ring_id) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadCast -> Broadcast?

@@ -60,6 +60,8 @@ class NCCLParallelContext : public ParallelContext {
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;

void BroadCast(framework::Variable* src, int ring_id) override;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -56,6 +56,8 @@ class ParallelContext {
framework::Variable* dst, int ring_id,
bool use_calc_stream) = 0;

virtual void BroadCast(framework::Variable* src, int ring_id) = 0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -41,6 +42,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DivNRanks(tensor, nranks, context);
#endif
} else if (platform::is_npu_place(tensor->place())) {
// TODO(kuizhiqing)
VLOG(4) << "divnrank for npu not support yet";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abort?

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for const_cast

@sandyhouse sandyhouse merged commit 1bdb857 into PaddlePaddle:develop Dec 6, 2021
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants