Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN fallback when not recording gradients and calling backwards #12411

Closed
wants to merge 17 commits into from
Closed
2 changes: 1 addition & 1 deletion src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNN(inputs[0]) && ctx.need_grad) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
TShape shape = inputs[0].shape();
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param)
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())
&& ctx.need_grad) {
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
if (SupportMKLDNNConv(params, inputs[0])) {
if (SupportMKLDNNConv(params, inputs[0]) && ctx.need_grad) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
if (SupportMKLDNNDeconv(param, inputs[0])) {
if (SupportMKLDNNDeconv(param, inputs[0]) && ctx.need_grad) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNN(inputs[0]) && ctx.need_grad) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNFCBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const NDArray &in_data = inputs[1];
const NDArray &in_grad = outputs[0];

if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNN(inputs[0]) && ctx.need_grad) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNLRNBackward(ctx, param, out_grad, in_data, req[0], in_grad);
MKLDNN_OPCHECK_RUN(LRNGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,


if (SupportMKLDNN(inputs[0])
&& SupportMKLDNNPooling(param, inputs[0].shape())) {
&& SupportMKLDNNPooling(param, inputs[0].shape())
&& ctx.need_grad) {
const NDArray &out_grad = inputs[0];
const NDArray *workspace = nullptr;
const NDArray *in_data = nullptr;
Expand Down
26 changes: 14 additions & 12 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)

def check_hybrid_static_memory(**kwargs):
def check_hybrid_static_memory(train_modes, **kwargs):
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
x.attach_grad()

Expand All @@ -1221,27 +1221,29 @@ def check_hybrid_static_memory(**kwargs):
net1(x)
net2(x)

def test(net, x):
with mx.autograd.record():
def test(net, x, train_mode=True):
with mx.autograd.record(train_mode=train_mode):
y = net(x) + net(x)
y.backward()
y.backward(train_mode=train_mode)

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

y1, grads1 = test(net1, x)
y2, grads2 = test(net2, x)
for train_mode in train_modes:
y1, grads1 = test(net1, x, train_mode)
y2, grads2 = test(net2, x, train_mode)

assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

@with_seed()
def test_hybrid_static_memory():
check_hybrid_static_memory()
check_hybrid_static_memory(static_alloc=True)
check_hybrid_static_memory(static_alloc=True, static_shape=True)
check_hybrid_static_memory(train_modes=[True, False])
check_hybrid_static_memory(train_modes=[True, False], static_alloc=True)
# TODO: MKLDNN (issue #13445) does not work with static_shape backwards
check_hybrid_static_memory(train_modes=[True], static_alloc=True, static_shape=True)

def check_hybrid_static_memory_switching(**kwargs):
net = gluon.model_zoo.vision.get_resnet(
Expand Down