-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MKLDNN] Independent gradients requests check with respect to weights and bias of convolution #15497
[MKLDNN] Independent gradients requests check with respect to weights and bias of convolution #15497
Changes from 4 commits
80fde0a
9041993
a21f065
cde8ab8
8b2cee4
9ca0428
3fda51f
8eedac6
1f98778
b2301ba
952a2ba
9767772
4f41250
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -507,9 +507,9 @@ class MKLDNNConvBackward { | |
mkldnn::primitive::at(*this->weight), *this->in_grad)); | ||
} | ||
|
||
void SetWeightNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &out_grad, | ||
const mkldnn::memory &in_grad_weight) { | ||
void SetWeightNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &out_grad, | ||
const mkldnn::memory &in_grad_weight) { | ||
if (this->data == nullptr) | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); | ||
|
@@ -649,7 +649,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct | |
MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); | ||
CommitOutput(in_grad[conv::kData], in_grad_mem); | ||
} | ||
if (req[conv::kWeight]) { | ||
if (req[conv::kWeight] || req[conv::kBias]) { | ||
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data, | ||
weight, bias, out_grad, fwd_pd); | ||
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != | ||
|
@@ -662,21 +662,21 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct | |
in_grad[conv::kWeight], | ||
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), | ||
req[conv::kWeight]); | ||
mkldnn_output_t in_grad_bias; | ||
if (param.no_bias) { | ||
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, | ||
*in_grad_weight.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); | ||
} else { | ||
in_grad_bias = CreateMKLDNNMem( | ||
|
||
if (!param.no_bias && req[conv::kBias]) { | ||
auto in_grad_bias = CreateMKLDNNMem( | ||
in_grad[conv::kBias], | ||
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); | ||
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, | ||
*in_grad_weight.second, *in_grad_bias.second); | ||
*in_grad_weight.second, *in_grad_bias.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); | ||
CommitOutput(in_grad[conv::kBias], in_grad_bias); | ||
} else { | ||
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, | ||
*in_grad_weight.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); | ||
} | ||
CommitOutput(in_grad[conv::kWeight], in_grad_weight); | ||
if (req[conv::kWeight]) CommitOutput(in_grad[conv::kWeight], in_grad_weight); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the behavior of req[conv::bias]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has the same behavior as req[kWeight]. Both of them return the operation request type ( |
||
} | ||
MKLDNNStream::Get()->Submit(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1907,6 +1907,85 @@ def test_depthwise_convolution(): | |
for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): | ||
np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3) | ||
|
||
|
||
@with_seed() | ||
def test_convolution_independent_gradients(): | ||
reqs = ["null", "write", "add"] | ||
var_names = ["x", "w", "b"] | ||
dims = [1, 2] | ||
num_bases = [1, 16, 64] | ||
kernel_xs = [3, 5] | ||
stride_xs = [1, 2] | ||
pad_xs = [0, 1] | ||
in_sizes = [7, 32] | ||
no_biases = [True, False] | ||
for dim, num_base, kernel_x, stride_x, pad_x , in_size, no_bias in \ | ||
itertools.product(dims, num_bases, kernel_xs, stride_xs, pad_xs, in_sizes, no_biases): | ||
# Prepare params shape | ||
kernel = (kernel_x,) * dim | ||
stride = (stride_x,) * dim | ||
pad = (pad_x,) * dim | ||
num_filter = num_base | ||
x_shape = (2, num_base) + (in_size,) * dim | ||
w_shape = (num_filter, num_base) + kernel | ||
|
||
# Symbols definition | ||
x = mx.sym.Variable('x') | ||
w = mx.sym.Variable('w') | ||
b = mx.sym.Variable('b') if not no_bias else None | ||
conv = mx.sym.Convolution(x, w, b, num_filter=num_filter, | ||
kernel=kernel, stride=stride, pad=pad, no_bias=no_bias) | ||
|
||
for req_kind in reqs: | ||
# Binding args for conv with possible dependent gradients | ||
base_args = { | ||
'x': mx.nd.random.normal(shape=x_shape), | ||
'w': mx.nd.random.normal(shape=w_shape), | ||
'b': mx.nd.random.normal(shape=(num_filter, )) if not no_bias else None} | ||
args1 = copy.deepcopy(base_args) | ||
grad1 = { | ||
'x': mx.nd.zeros(shape=x_shape), | ||
'w': mx.nd.zeros(shape=w_shape), | ||
'b': mx.nd.zeros(shape=(num_filter, )) if not no_bias else None} | ||
|
||
grad_req1 = [req_kind] * 3 | ||
grad_req1 = dict(zip(var_names, grad_req1)) | ||
|
||
ctx = default_context() | ||
exe1 = conv.bind(ctx, args1, args_grad=grad1, grad_req=grad_req1) | ||
exe1.forward(is_train=True) | ||
exe1.backward(exe1.outputs[0]) | ||
|
||
for x_req, w_req, b_req in itertools.product(reqs, repeat=3): | ||
# Binding args for conv with independent gradients | ||
args2 = copy.deepcopy(base_args) # Deepcopy the same params of `exe1` | ||
grad2 = { | ||
'x': mx.nd.zeros(shape=x_shape), | ||
'w': mx.nd.zeros(shape=w_shape), | ||
'b': mx.nd.zeros(shape=(num_filter, )) if not no_bias else None} | ||
grad_req2 = {"x": x_req, "w": w_req, "b": b_req} | ||
exe2 = conv.bind(ctx, args2, args_grad=grad2, grad_req=grad_req2) | ||
|
||
exe2.forward(is_train=True) | ||
np.testing.assert_allclose(exe1.outputs[0].asnumpy(), | ||
exe2.outputs[0].asnumpy(), rtol=1e-3, atol=1e-3) | ||
|
||
exe2.backward(exe2.outputs[0]) | ||
for var_name in var_names: | ||
if var_name == "b" and no_bias: | ||
continue | ||
if grad_req2[var_name] == "null": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have such case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. It is a very corner use of only requesting the gradient with respect to bias. |
||
exe2_var_grad = grad2[var_name].asnumpy() | ||
np.testing.assert_allclose(exe2_var_grad, | ||
np.zeros_like(exe2_var_grad), rtol=1e-3, atol=1e-3) | ||
if grad_req2[var_name] != grad_req1[var_name]: | ||
continue | ||
np.testing.assert_allclose(args1[var_name].asnumpy(), | ||
args2[var_name].asnumpy(), rtol=1e-3, atol=1e-3) | ||
np.testing.assert_allclose(grad1[var_name].asnumpy(), | ||
grad2[var_name].asnumpy(), rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def gen_broadcast_data(idx): | ||
# Manually set test cases | ||
binary_op_data_shape = np.array( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to check req[conv::kWeight] here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I see. There is unnecessary primitive registration without the check enabled. Thanks.