-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-614] Adding Synchronized Batch Normalization #11502
Conversation
Help Wanted for passing the CI Test!! |
'ndev': num_devices, 'key': self.prefix} | ||
|
||
def _get_num_devices(self): | ||
# Caution: if not using all the GPUs, please mannually set num_devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the warning to docstring rather than showing a comment here
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
# include <condition_variable> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space between # and include?
template<class T> | ||
class SharedND { | ||
private: | ||
int nDev; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convention for variables is xxx_
for private members
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and camel for functions, which is correct right now
std::lock_guard<std::mutex> lock(mutex_); | ||
auto it = registry_.find(key); | ||
if (it != registry_.end()) return it->second; | ||
T *newT = new T(ndev); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memory is not released pointed by these raw pointers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in the class deconstruction function :)
https://github.com/zhanghang1989/incubator-mxnet/blob/cc60d11f44e37e954a627c8db43bd1b6fc45e68d/src/operator/contrib/sync_batch_norm-inl.h#L160-L164
Thanks @RogerChern ! The comments in deconstruction function is really helpful. |
Finally pass the CI Test. Please take a look and let me know if you have further comments. @zhreshold @eric-haibin-lin @zhreshold @piiswrong . Thanks! Docs are deployed here http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-11502/31/api/python/gluon/contrib.html?highlight=syncbatchnorm#mxnet.gluon.contrib.nn.SyncBatchNorm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor suggestions
_assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), | ||
_find_bn(bn2).running_var.data(ctx_list[0])) | ||
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) | ||
#print('input1.grad', input1.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, Will do. Thx
_assert_tensor_close(input1.grad, input2grad) | ||
|
||
def test_sync_batchnorm(): | ||
def get_num_devices(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's test_utils.list_gpus()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is slightly different. list_gpus() doesn’t consider CUDA_VISIBLE_DEVICES
@@ -1909,6 +1909,91 @@ def test_context_num_gpus(): | |||
# Test that num_gpus reports at least one GPU, as the test is run on a GPU host. | |||
assert mx.context.num_gpus() > 0 | |||
|
|||
def _check_batchnorm_result(input, num_devices=1, cuda=False): | |||
from mxnet.gluon.utils import split_and_load | |||
def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will assert_almost_equal
do?
} | ||
|
||
~SharedND() { | ||
mshadow::FreeSpace(&mean_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for data_inited_ before freeing memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I Agree. Will make the changes. Thx
} | ||
} | ||
|
||
T* Retrieve(mshadow::Shape<1> shape, int index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need doc for these member functions
~GlobalShared() { | ||
for (auto it = registry_.begin(); it != registry_.end(); it++) { | ||
T *ptr = it->second; | ||
delete ptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, you have to guarantee deleting valid pointer, since you didn't init them in the constructor, but in a public function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not inited, the map should be empty
} | ||
~GlobalSharedRank() { | ||
for (auto it = registry_.begin(); it != registry_.end(); it++) { | ||
T *ptr = it->second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not inited, the hash map should be empty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, should be fine
mshadow::Shape2(5, mean.shape_[0]), s); | ||
Tensor<xpu, 1> gmean = workspace[0]; | ||
Tensor<xpu, 1> gvar = workspace[1]; | ||
// Tensor<xpu, 1> tmp = workspace[2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused
Comments added. The rest LGTM now. |
@indhub FYI |
SyncBatchNorm class doesn't seem to be available from mxnet-cu91 nightly. Its visible for regular mxnet nightly. Are these changes merged fully? |
@miteshyh |
@miteshyh would you be able to update and use cu92? I heard from @bhavinthaker that nvidia discontinued support for cu91 so we intend to do the same. |
Thanks @szha , I down graded to cu90 as cu92 doesn't have clean support on my hardware yet, and it works. However while I train ADE20K with GluonCV I get "socket.error: [Errno 111] Connection refused" after a few (@551) iterations, I have raised a separate issue for the same. And this happens with/without SyncBatchNorm. |
* sync batch norm * global rank and barrier * lint * cpplint * pylint * doc * add ref * customized barrier * cpplint * get rid of pthread * address comments * warning * pylint * gpu unitest * gpu 0 * mv to cpu test * Revert "mv to cpu test" This reverts commit 24543c9. * ndev = 2 * debuging * sum prod * lint * contrib, ngpu * code style * code style * forward backward * test * cpu test * fix deconstruction * doc indent * doc * doc * address comments * typo * asnumpy
Hello, @RogerChern. I also met a deadlock issue while training PSPNet on |
Please set the |
Hello, @zhanghang1989. Thank you for your reply. I will try it tomorrow morning and update the result with you. Update Hello, @zhanghang1989. I am not quite sure about whether you suggested me to explicitly set |
HI Hang, I used your sync_bn implementation for mxnet symbol. However, it reduced the performance of my network. I wonder whether you have ever tried with symbol API with your sync_bn other than gluon. Thanks |
Asked here #8458 (comment) |
How to use it? |
|
Description
Adding Synchronized Batch Normalization
Thanks @eric-haibin-lin for great help!
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments