-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 #18504
Changes from 20 commits
fac3b42
34c6a3b
11c8e26
830ba0d
a9c91d2
8594642
f9995b0
d380443
eaeae21
25924c9
e3bb53c
9c00567
4cab1ed
4bf0000
b78b4cb
e411890
aad431b
939750f
25e4666
97c6746
38da95f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -262,15 +262,26 @@ class CuDNNBatchNormOp { | |||||||||||
|
||||||||||||
private: | ||||||||||||
void Init(const TBlob &in_data) { | ||||||||||||
if (in_data.ndim() == 4) { | ||||||||||||
for (int i = 0; i < 4; ++i) | ||||||||||||
shape_[i] = in_data.shape_[i]; | ||||||||||||
CHECK_GE(param_.axis, 0); | ||||||||||||
CHECK_LT(param_.axis, in_data.ndim()); | ||||||||||||
if (param_.axis == 1) { | ||||||||||||
if (in_data.ndim() == 4) { | ||||||||||||
for (int i = 0; i < 4; ++i) | ||||||||||||
shape_[i] = in_data.shape_[i]; | ||||||||||||
} else { | ||||||||||||
// when in_data.ndim() != 4 | ||||||||||||
shape_[0] = in_data.shape_[0]; | ||||||||||||
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; | ||||||||||||
shape_[2] = 1; | ||||||||||||
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); | ||||||||||||
} | ||||||||||||
} else { | ||||||||||||
// when in_data.ndim() != 4 | ||||||||||||
shape_[0] = in_data.shape_[0]; | ||||||||||||
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; | ||||||||||||
// reshape to (N, C, 1, D), C is the `param_.axis` dimension | ||||||||||||
shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis)); | ||||||||||||
shape_[1] = in_data.shape_[param_.axis]; | ||||||||||||
shape_[2] = 1; | ||||||||||||
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); | ||||||||||||
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1, | ||||||||||||
static_cast<int>(in_data.ndim()))); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need static_cast here? why cant we do it like line 276? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return dtype of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add static_cast to line 276 as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in_data.ndims() shouldn’t need a static cast though, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the dtype of
The signature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @mseth10 , do you have any suggestion about whether to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo you can ignore the compiler warnings (if any) for int32_t to int, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I have updated it : ) |
||||||||||||
} | ||||||||||||
|
||||||||||||
CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are removing the check for
ndim == 4
here, and another lighter check forndim == 1 || ndim ==2 || ndim ==4
present inSupportMKLDNN
.Does that mean
ndim
can be anything >0 ? What are the allowed values forndim
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
ndim
coulld be anything > 0.If the shape of input A is
shape
, the input will be shaped into(prod(shape[0:axis]), shape[axis], 1, prod(shape[axis+1:len(shape)]) )
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! Does the table continue further for ndims > 5? Or should we place a check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it supports ndim > 5 too.