diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index 21b257e40c37..9e30cd8fa628 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -213,6 +213,36 @@ OpAttrs GetLRNBackwardsOp() { return attrs; } +OpAttrs GetFullyConnectedOp() { + OpAttrs attrs; + attrs.attrs.op = Op::Get("FullyConnected"); + attrs.attrs.dict.insert({"num_hidden" , "20"}); + attrs.num_inputs = 3; + attrs.num_outputs = 1; + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.requests.insert(OpReqType::kWriteTo); + attrs.input_types = ArrayTypes::Normal | + ArrayTypes::MKLDNN | + ArrayTypes::NormalReshaped | + ArrayTypes::MKLDNNReshaped; + attrs.output_types = ArrayTypes::Normal | + ArrayTypes::MKLDNN | + ArrayTypes::NormalReshaped | + ArrayTypes::MKLDNNReshaped; + return attrs; +} + +OpAttrs GetFullyConnectedBackwardsOp() { + OpAttrs attrs; + attrs.attrs.op = Op::Get("_backward_FullyConnected"); + attrs.attrs.dict.insert({"num_hidden" , "20"}); + attrs.num_inputs = 3; + attrs.num_outputs = 3; + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.requests.insert(OpReqType::kWriteTo); + return attrs; +} + void AssertEqual(const std::vector &in_arrs, const std::vector &out_arrs, float rtol = 1e-5, float atol = 1e-8) { @@ -557,6 +587,125 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { } } +// Computes second dimension of FC weight matrix based on input shape +uint32_t GetFCWeightDim2(const nnvm::TShape arr) { + uint32_t dim = 1; + for (int i = 1; i < arr.ndim(); i++) { + dim *= arr[i]; + } + return dim; +} + +void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { + std::vector inputs(forward_attrs.num_inputs); + std::vector outputs(forward_attrs.num_outputs); + std::vector ex_outputs(forward_attrs.num_outputs); + + std::vector backwards_input(backwards_attrs.num_inputs); + std::vector backwards_outputs(backwards_attrs.num_outputs); + std::vector backwards_ex_outputs(backwards_attrs.num_outputs); + + std::vector req(forward_attrs.num_outputs); + std::vector back_req(backwards_attrs.num_outputs); + + TestArrayShapes tas = GetTestArrayShapes(); + std::vector pds = tas.pds; + + std::vector in_arrs = GetTestInputArrays(forward_attrs.input_types, true); + std::vector> out_arrs(forward_attrs.num_outputs); + std::vector> ex_out_arrs(forward_attrs.num_outputs); + + std::string str_hid = const_cast(forward_attrs).attrs.dict["num_hidden"]; + int num_hid = std::stoi(str_hid); + + if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) { + for (int i1 = 0; i1 < in_arrs.size(); i1++) { + auto in_arr = in_arrs[i1]; + auto in_shape = in_arr.arr.shape(); + if (in_shape.ndim() < 2) + continue; + + nnvm::TShape wt_shape(2); + wt_shape[0] = num_hid; + wt_shape[1] = GetFCWeightDim2(in_shape); + NDArray weights(wt_shape, Context()); + InitDefaultArray(&weights, false); + + nnvm::TShape bias_shape(1); + bias_shape[0] = num_hid; + NDArray bias(bias_shape, Context()); + InitDefaultArray(&bias, false); + + inputs[0] = &in_arr.arr; + inputs[1] = &weights; + inputs[2] = &bias; + + nnvm::TShape out_shape(2); + out_shape[0] = in_shape[0]; + out_shape[1] = num_hid; + + for (int i = 0; i < forward_attrs.num_outputs; i++) { + out_arrs[i] = + GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + ex_out_arrs[i] = + GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types); + } + + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + for (int i = 0; i < forward_attrs.num_outputs; i++) { + req[i] = kWriteTo; + outputs[i] = &out_arrs[i][output_i].arr; + ex_outputs[i] = &ex_out_arrs[i][output_i].arr; + } + Imperative::Get()->set_is_training(true); + + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); + Imperative::Get()->InvokeOp( + Context(), forward_attrs.attrs, inputs, outputs, req, + DispatchMode::kFCompute, mxnet::OpStatePtr()); + Imperative::Get()->InvokeOp( + Context(), forward_attrs.attrs, inputs, ex_outputs, req, + DispatchMode::kFComputeEx, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + AssertEqual(outputs, ex_outputs); + + // backwards test performed same time since output needed + backwards_input[0] = outputs[0]; // output grad + backwards_input[1] = inputs[0]; // input + backwards_input[2] = inputs[1]; // weights + + auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1]; + NDArray back_weights(wt_shape, Context()); + NDArray back_bias(bias_shape, Context()); + backwards_outputs[0] = &tmp_output.arr; + backwards_outputs[1] = &back_weights; + backwards_outputs[2] = &back_bias; + + auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1]; + NDArray back_ex_weights(wt_shape, Context()); + NDArray back_ex_bias(bias_shape, Context()); + backwards_ex_outputs[0] = &tmp_output2.arr; + backwards_ex_outputs[1] = &back_ex_weights; + backwards_ex_outputs[2] = &back_ex_bias; + + for (int i = 0; i < backwards_attrs.num_outputs; i++) + back_req[i] = kWriteTo; + + std::cout << "Backwards: "; + PrintVerifyMsg(out_arrs[0][output_i], tmp_output); + Imperative::Get()->InvokeOp( + Context(), backwards_attrs.attrs, backwards_input, backwards_outputs, + back_req, DispatchMode::kFCompute, mxnet::OpStatePtr()); + Imperative::Get()->InvokeOp( + Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs, + back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + AssertEqual(backwards_outputs, backwards_ex_outputs); + } + } + } +} + void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { std::vector inputs(forward_attrs.num_inputs); std::vector outputs(forward_attrs.num_outputs); @@ -717,6 +866,12 @@ TEST(IMPERATIVE, LRNOp) { TestOpEx(forward_attrs, backwards_attrs); } +TEST(IMPERATIVE, FullyConnectedOp) { + OpAttrs forward_attrs = GetFullyConnectedOp(); + OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp(); + TestFullyConnectedOp(forward_attrs, backwards_attrs); +} + TEST(IMPERATIVE, PoolingOp) { for (int dim = 2; dim < 4; dim++) { for (int kernel = 1; kernel < 4; kernel++) {