-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1107] Fix CPUPinned unexpected behaviour #12031
Changes from 7 commits
9e23ac9
42dc88f
ca1a0f9
92ecd95
1738b78
4ca556a
5c02c35
af4590d
3cbf4eb
1b7611a
ec51ada
156b368
e7a340f
e316cb8
01b0d6f
0ea50cc
fea651f
b361966
cfdcb29
3914ffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -720,6 +720,9 @@ class CommDevice : public Comm { | |
int n = static_cast<int>(gpus.size()); | ||
int enabled = 0; | ||
std::vector<int> p2p(n*n); | ||
|
||
// Restores active device to what it was before EnableP2P | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (int i = 0; i < n; ++i) { | ||
cudaSetDevice(gpus[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @eric-haibin-lin suggested, replace this with a class method by your RAII object SetDevice (before renaming) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace this with a class method by the SetDevice (renaming) object as @eric-haibin-lin suggested. |
||
for (int j = 0; j < n; j++) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,6 +339,7 @@ class CommDeviceTree : public CommDevice { | |
int n = static_cast<int>(gpus.size()); | ||
int enabled = 0; | ||
std::vector<int> p2p(n*n); | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (int i = 0; i < n; ++i) { | ||
cudaSetDevice(gpus[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
for (int j = 0; j < n; j++) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -428,6 +428,7 @@ class KVStoreNCCL : public KVStoreLocal { | |
mutate_vars.push_back(ptr(dst[i])->var()); | ||
} | ||
Engine::Get()->PushSync([this](RunContext rctx) { | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (auto cur : nccl_data_) { | ||
CUDA_CALL(cudaSetDevice(cur.second.dev_id)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
CUDA_CALL(cudaStreamSynchronize(cur.second.stream)); | ||
|
@@ -479,6 +480,7 @@ class KVStoreNCCL : public KVStoreLocal { | |
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU)); | ||
std::vector<ncclComm_t> comms(devs.size()); | ||
ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0])); | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (size_t i = 0; i < devs.size(); ++i) { | ||
NCCLEntry e; | ||
e.dev_id = device_ids_[i]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to DeviceScope, add a member function SetDevice, and remove all occurrences of
cudaGetDevice
?