Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
CR
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Sep 6, 2019
1 parent b0ca25f commit 1f3afcc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
25 changes: 6 additions & 19 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,10 @@ namespace fullc {
enum FullyConnectedOpInputs {kData, kWeight, kBias};
enum FullyConnectedOpResource {kTempSpace};
enum FullyConnectedOpOutputs {kOut};
enum FullyConnectedGradGradOutputs {
kOyGrad,
kXGradGrad,
kWGradGrad,
kBGradGrad
};
enum Inputs {
kOxGrad,
kOwGrad,
};
enum InputsBias {
kObGrad = 2,
kOyBias,
};
enum InputsNoBias {
kOy = 2,
};
enum FullyConnectedGradGradOutputs { kOyGrad, kXGradGrad, kWGradGrad, kBGradGrad };
enum GradGradInputs { kOxGrad, kOwGrad, };
enum GradGradInputsBias { kObGrad = 2, kOyBias, };
enum GradGradInputsNoBias { kOy = 2, };
} // namespace fullc

namespace quantized_fullc {
Expand Down Expand Up @@ -363,8 +350,8 @@ void FullyConnectedGradGradCompute(const nnvm::NodeAttrs& attrs,
x_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kXGradGrad], ctx);
w_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kWGradGrad], ctx);
}
linalg_gemm(o_y, o_w_grad, x_grad_grad, false, false, stream);
linalg_gemm(o_y, o_x_grad, w_grad_grad, true, false, stream);
linalg_gemm(o_y, o_w_grad, x_grad_grad, false, false, stream, req[kXGradGrad]);
linalg_gemm(o_y, o_x_grad, w_grad_grad, true, false, stream, req[kWGradGrad]);
// 3rd order not supported
Fill(stream, o_y_grad_blob, kWriteTo, static_cast<DType>(0));
/* TODO(larroy) bias is not supported yet as there's no bias input to backward. Bias grad grad is
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal, same
from mxnet.test_utils import assert_almost_equal
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from common import (setup_module, with_seed, assertRaises, teardown,
assert_raises_cudnn_not_satisfied)
Expand Down

0 comments on commit 1f3afcc

Please sign in to comment.