Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding unit test for MKLDNN FullyConnected operator (#12985)
Browse files Browse the repository at this point in the history
* adding unit test for MKLDNN FullyConnected operator

* removing mkldnn filter

* removing mkldnn filter
  • Loading branch information
mseth10 authored and anirudh2290 committed Nov 15, 2018
1 parent 7541021 commit cf991ff
Showing 1 changed file with 155 additions and 0 deletions.
155 changes: 155 additions & 0 deletions tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,36 @@ OpAttrs GetLRNBackwardsOp() {
return attrs;
}

OpAttrs GetFullyConnectedOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("FullyConnected");
attrs.attrs.dict.insert({"num_hidden" , "20"});
attrs.num_inputs = 3;
attrs.num_outputs = 1;
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.requests.insert(OpReqType::kWriteTo);
attrs.input_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN |
ArrayTypes::NormalReshaped |
ArrayTypes::MKLDNNReshaped;
attrs.output_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN |
ArrayTypes::NormalReshaped |
ArrayTypes::MKLDNNReshaped;
return attrs;
}

OpAttrs GetFullyConnectedBackwardsOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("_backward_FullyConnected");
attrs.attrs.dict.insert({"num_hidden" , "20"});
attrs.num_inputs = 3;
attrs.num_outputs = 3;
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.requests.insert(OpReqType::kWriteTo);
return attrs;
}

void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
Expand Down Expand Up @@ -557,6 +587,125 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}

// Computes second dimension of FC weight matrix based on input shape
uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
uint32_t dim = 1;
for (int i = 1; i < arr.ndim(); i++) {
dim *= arr[i];
}
return dim;
}

void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);

std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);

std::vector<OpReqType> req(forward_attrs.num_outputs);
std::vector<OpReqType> back_req(backwards_attrs.num_outputs);

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

std::string str_hid = const_cast<OpAttrs&>(forward_attrs).attrs.dict["num_hidden"];
int num_hid = std::stoi(str_hid);

if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
for (int i1 = 0; i1 < in_arrs.size(); i1++) {
auto in_arr = in_arrs[i1];
auto in_shape = in_arr.arr.shape();
if (in_shape.ndim() < 2)
continue;

nnvm::TShape wt_shape(2);
wt_shape[0] = num_hid;
wt_shape[1] = GetFCWeightDim2(in_shape);
NDArray weights(wt_shape, Context());
InitDefaultArray(&weights, false);

nnvm::TShape bias_shape(1);
bias_shape[0] = num_hid;
NDArray bias(bias_shape, Context());
InitDefaultArray(&bias, false);

inputs[0] = &in_arr.arr;
inputs[1] = &weights;
inputs[2] = &bias;

nnvm::TShape out_shape(2);
out_shape[0] = in_shape[0];
out_shape[1] = num_hid;

for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
}

for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
for (int i = 0; i < forward_attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
}
Imperative::Get()->set_is_training(true);

PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs, outputs, req,
DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs, ex_outputs, req,
DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(outputs, ex_outputs);

// backwards test performed same time since output needed
backwards_input[0] = outputs[0]; // output grad
backwards_input[1] = inputs[0]; // input
backwards_input[2] = inputs[1]; // weights

auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
NDArray back_weights(wt_shape, Context());
NDArray back_bias(bias_shape, Context());
backwards_outputs[0] = &tmp_output.arr;
backwards_outputs[1] = &back_weights;
backwards_outputs[2] = &back_bias;

auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
NDArray back_ex_weights(wt_shape, Context());
NDArray back_ex_bias(bias_shape, Context());
backwards_ex_outputs[0] = &tmp_output2.arr;
backwards_ex_outputs[1] = &back_ex_weights;
backwards_ex_outputs[2] = &back_ex_bias;

for (int i = 0; i < backwards_attrs.num_outputs; i++)
back_req[i] = kWriteTo;

std::cout << "Backwards: ";
PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
}
}
}
}

void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
Expand Down Expand Up @@ -717,6 +866,12 @@ TEST(IMPERATIVE, LRNOp) {
TestOpEx(forward_attrs, backwards_attrs);
}

TEST(IMPERATIVE, FullyConnectedOp) {
OpAttrs forward_attrs = GetFullyConnectedOp();
OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();
TestFullyConnectedOp(forward_attrs, backwards_attrs);
}

TEST(IMPERATIVE, PoolingOp) {
for (int dim = 2; dim < 4; dim++) {
for (int kernel = 1; kernel < 4; kernel++) {
Expand Down

0 comments on commit cf991ff

Please sign in to comment.