Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.5.x] [MKLDNN] Independent gradients requests check with respect to weights… #15805

Merged
merged 2 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,9 @@ class MKLDNNConvBackward {
mkldnn::primitive::at(*this->weight), *this->in_grad));
}

void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight) {
void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
Expand Down Expand Up @@ -649,7 +649,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
if (req[conv::kWeight] || req[conv::kBias]) {
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data,
weight, bias, out_grad, fwd_pd);
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
Expand All @@ -662,17 +662,16 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
in_grad[conv::kWeight],
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
req[conv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second);
*in_grad_weight.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
} else {
in_grad_bias = CreateMKLDNNMem(
auto in_grad_bias = CreateMKLDNNMem(
in_grad[conv::kBias],
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second, *in_grad_bias.second);
*in_grad_weight.second, *in_grad_bias.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
CommitOutput(in_grad[conv::kBias], in_grad_bias);
}
Expand Down
84 changes: 84 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,90 @@ def test_depthwise_convolution():
for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3)


@with_seed()
def test_convolution_independent_gradients():
Copy link
Contributor

@zixuanweeei zixuanweeei Aug 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have reported flakiness on this UT and added some notes. Seems the the cherry-picked commit is behind the merged one on master branch. And it looks good on master. Would you mind checking it again? Thanks. https://github.com/apache/incubator-mxnet/blob/57927a932610581a351e4aed4fb7c9209d0e57c2/tests/python/unittest/test_operator.py#L2003-L2006

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for a mistake. The flakiness is fixed by #15631. I think these two should be taken into v1.5.x together.

# NOTE(zixuanweeei): Flaky test tracked by https://github.com/apache/incubator-mxnet/issues/15603.
# GPU context will be enabled after figuring out the possible issue tracked at
# https://github.com/apache/incubator-mxnet/issues/15638.
ctx = mx.cpu()
atol = 1.0e-3
rtol = 1.0e-3
reqs = ["null", "write", "add"]
var_names = ["x", "w", "b"]
dims = [1, 2]
num_bases = [1, 16, 64]
kernel_xs = [3, 5]
stride_xs = [1, 2]
pad_xs = [0, 1]
in_sizes = [7, 32]
no_biases = [True, False]
for dim, num_base, kernel_x, stride_x, pad_x , in_size, no_bias in \
itertools.product(dims, num_bases, kernel_xs, stride_xs, pad_xs, in_sizes, no_biases):
# Prepare params shape
kernel = (kernel_x,) * dim
stride = (stride_x,) * dim
pad = (pad_x,) * dim
num_filter = num_base
x_shape = (2, num_base) + (in_size,) * dim
w_shape = (num_filter, num_base) + kernel

# Symbols definition
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b') if not no_bias else None
conv = mx.sym.Convolution(x, w, b, num_filter=num_filter,
kernel=kernel, stride=stride, pad=pad, no_bias=no_bias)

for req_kind in reqs:
# Binding args for conv with possible dependent gradients
base_args = {
'x': mx.nd.random.normal(shape=x_shape, ctx=ctx),
'w': mx.nd.random.normal(shape=w_shape, ctx=ctx),
'b': mx.nd.random.normal(shape=(num_filter, ), ctx=ctx) if not no_bias else None}
args1 = copy.deepcopy(base_args)
grad1 = {
'x': mx.nd.zeros(shape=x_shape, ctx=ctx),
'w': mx.nd.zeros(shape=w_shape, ctx=ctx),
'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None}

grad_req1 = [req_kind] * 3
grad_req1 = dict(zip(var_names, grad_req1))

exe1 = conv.bind(ctx, args1, args_grad=grad1, grad_req=grad_req1)
exe1.forward(is_train=True)
exe1.backward(exe1.outputs[0])

for x_req, w_req, b_req in itertools.product(reqs, repeat=3):
# Binding args for conv with independent gradients
args2 = copy.deepcopy(base_args) # Deepcopy the same params of `exe1`
grad2 = {
'x': mx.nd.zeros(shape=x_shape, ctx=ctx),
'w': mx.nd.zeros(shape=w_shape, ctx=ctx),
'b': mx.nd.zeros(shape=(num_filter, ), ctx=ctx) if not no_bias else None}
grad_req2 = {"x": x_req, "w": w_req, "b": b_req}
exe2 = conv.bind(ctx, args2, args_grad=grad2, grad_req=grad_req2)

exe2.forward(is_train=True)
np.testing.assert_allclose(exe1.outputs[0].asnumpy(),
exe2.outputs[0].asnumpy(), rtol=rtol, atol=atol)

exe2.backward(exe2.outputs[0])
for var_name in var_names:
if var_name == "b" and no_bias:
continue
if grad_req2[var_name] == "null":
exe2_var_grad = grad2[var_name].asnumpy()
np.testing.assert_allclose(exe2_var_grad,
np.zeros_like(exe2_var_grad), rtol=rtol, atol=atol)
if grad_req2[var_name] != grad_req1[var_name]:
continue
np.testing.assert_allclose(args1[var_name].asnumpy(),
args2[var_name].asnumpy(), rtol=rtol, atol=atol)
np.testing.assert_allclose(grad1[var_name].asnumpy(),
grad2[var_name].asnumpy(), rtol=rtol, atol=atol)


def gen_broadcast_data(idx):
# Manually set test cases
binary_op_data_shape = np.array(
Expand Down