Skip to content

Commit

Permalink
[MKLDNN] Independent gradients requests check with respect to weights…
Browse files Browse the repository at this point in the history
… and bias of convolution (apache#15497)

* Independent req[kBias] and req[kWeight] check

* Add UT for independent conv gradient requests

* Update conv independent grad UT with no_bias enabled

* Check req[kWeight] for avoiding unnecessary prim registration

* Check `OpReqTpye` in CommitOutput automatically

* Lock cudnn autotune for accurate conv output

* Ignore independent gradients test on GPU

* Trigger CI

* Sets a low bar for autotuned cudnn convolution
  • Loading branch information
zixuanweeei authored and test committed Aug 8, 2019
1 parent 73a0400 commit 18081ee
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,9 @@ class MKLDNNConvBackward {
mkldnn::primitive::at(*this->weight), *this->in_grad));
}

void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight) {
void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
Expand Down Expand Up @@ -649,7 +649,7 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
if (req[conv::kWeight] || req[conv::kBias]) {
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data,
weight, bias, out_grad, fwd_pd);
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
Expand All @@ -662,17 +662,16 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
in_grad[conv::kWeight],
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
req[conv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second);
*in_grad_weight.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
} else {
in_grad_bias = CreateMKLDNNMem(
auto in_grad_bias = CreateMKLDNNMem(
in_grad[conv::kBias],
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second, *in_grad_bias.second);
*in_grad_weight.second, *in_grad_bias.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
CommitOutput(in_grad[conv::kBias], in_grad_bias);
}
Expand Down
82 changes: 82 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,88 @@ def test_depthwise_convolution():
for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3)


@with_seed()
def test_convolution_independent_gradients():
ctx = default_context()
# set a low bar for autotuned cudnn conv
atol = 1.0e-1 if ctx.device_type == "gpu" else 1.0e-3
rtol = 1.0e-2 if ctx.device_type == "gpu" else 1.0e-3
reqs = ["null", "write", "add"]
var_names = ["x", "w", "b"]
dims = [1, 2]
num_bases = [1, 16, 64]
kernel_xs = [3, 5]
stride_xs = [1, 2]
pad_xs = [0, 1]
in_sizes = [7, 32]
no_biases = [True, False]
for dim, num_base, kernel_x, stride_x, pad_x , in_size, no_bias in \
itertools.product(dims, num_bases, kernel_xs, stride_xs, pad_xs, in_sizes, no_biases):
# Prepare params shape
kernel = (kernel_x,) * dim
stride = (stride_x,) * dim
pad = (pad_x,) * dim
num_filter = num_base
x_shape = (2, num_base) + (in_size,) * dim
w_shape = (num_filter, num_base) + kernel

# Symbols definition
x = mx.sym.Variable('x')
w = mx.sym.Variable('w')
b = mx.sym.Variable('b') if not no_bias else None
conv = mx.sym.Convolution(x, w, b, num_filter=num_filter,
kernel=kernel, stride=stride, pad=pad, no_bias=no_bias)

for req_kind in reqs:
# Binding args for conv with possible dependent gradients
base_args = {
'x': mx.nd.random.normal(shape=x_shape),
'w': mx.nd.random.normal(shape=w_shape),
'b': mx.nd.random.normal(shape=(num_filter, )) if not no_bias else None}
args1 = copy.deepcopy(base_args)
grad1 = {
'x': mx.nd.zeros(shape=x_shape),
'w': mx.nd.zeros(shape=w_shape),
'b': mx.nd.zeros(shape=(num_filter, )) if not no_bias else None}

grad_req1 = [req_kind] * 3
grad_req1 = dict(zip(var_names, grad_req1))

exe1 = conv.bind(ctx, args1, args_grad=grad1, grad_req=grad_req1)
exe1.forward(is_train=True)
exe1.backward(exe1.outputs[0])

for x_req, w_req, b_req in itertools.product(reqs, repeat=3):
# Binding args for conv with independent gradients
args2 = copy.deepcopy(base_args) # Deepcopy the same params of `exe1`
grad2 = {
'x': mx.nd.zeros(shape=x_shape),
'w': mx.nd.zeros(shape=w_shape),
'b': mx.nd.zeros(shape=(num_filter, )) if not no_bias else None}
grad_req2 = {"x": x_req, "w": w_req, "b": b_req}
exe2 = conv.bind(ctx, args2, args_grad=grad2, grad_req=grad_req2)

exe2.forward(is_train=True)
np.testing.assert_allclose(exe1.outputs[0].asnumpy(),
exe2.outputs[0].asnumpy(), rtol=rtol, atol=atol)

exe2.backward(exe2.outputs[0])
for var_name in var_names:
if var_name == "b" and no_bias:
continue
if grad_req2[var_name] == "null":
exe2_var_grad = grad2[var_name].asnumpy()
np.testing.assert_allclose(exe2_var_grad,
np.zeros_like(exe2_var_grad), rtol=rtol, atol=atol)
if grad_req2[var_name] != grad_req1[var_name]:
continue
np.testing.assert_allclose(args1[var_name].asnumpy(),
args2[var_name].asnumpy(), rtol=rtol, atol=atol)
np.testing.assert_allclose(grad1[var_name].asnumpy(),
grad2[var_name].asnumpy(), rtol=rtol, atol=atol)


def gen_broadcast_data(idx):
# Manually set test cases
binary_op_data_shape = np.array(
Expand Down

0 comments on commit 18081ee

Please sign in to comment.