This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
fully_connected-inl.h
254 lines (235 loc) · 9.91 KB
/
fully_connected-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file fully_connect_op-inl.h
* \brief fully connect operator and symbol
*/
#ifndef MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../linalg.h"
#include "../../common/utils.h"
namespace mxnet {
namespace op {
// Declare enumeration of input order to make code more intuitive.
// These enums are only visible within this header
namespace fullc {
enum FullyConnectedOpInputs {kData, kWeight, kBias};
enum FullyConnectedOpResource {kTempSpace};
enum FullyConnectedOpOutputs {kOut};
} // fullc
namespace quantized_fullc {
enum QuantizedFCInputMinMax {kDataMin, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax};
enum QuantizedFCOutputs {kOut, kOutMin, kOutMax};
} // quantized_fullc
struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
int num_hidden;
bool no_bias;
bool flatten;
DMLC_DECLARE_PARAMETER(FullyConnectedParam) {
// TODO(bing) add support for boolean
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
.describe("Number of hidden nodes of the output.");
DMLC_DECLARE_FIELD(no_bias).set_default(false)
.describe("Whether to disable bias parameter.");
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");
}
bool operator==(const FullyConnectedParam& other) const {
return this->num_hidden == other.num_hidden &&
this->no_bias == other.no_bias &&
this->flatten == other.flatten;
}
};
template<typename xpu, typename DType>
void FCForward(const OpContext &ctx, const FullyConnectedParam ¶m,
const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) {
using namespace mshadow;
using namespace mshadow::expr;
if (req[fullc::kOut] == kNullOp) return;
CHECK_EQ(req[fullc::kOut], kWriteTo);
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
// TODO(bing): judge shape to remove flatten op
Stream<xpu> *s = ctx.get_stream<xpu>();
#if defined(__CUDACC__)
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif // __CUDACC__
const mxnet::TShape& ishape = in_data[fullc::kData].shape_;
const mxnet::TShape& oshape = out_data[fullc::kOut].shape_;
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> data, out;
if (!param.flatten) {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
} else {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
}
CHECK_EQ(data.shape_[1], wmat.shape_[1])
<< "Incomplete weight tensor detected: weight.data().shape[1] != prod(data.data().shape[1:])."
" This is not supported by FCForward. If weight is in row_sparse format,"
" please make sure all row ids are present.";
// Legacy approach shown here for comparison:
// out = dot(data, wmat.T());
linalg_gemm(data, wmat, out, false, true, s);
if (!param.no_bias) {
Tensor<xpu, 1, DType> bias = in_data[fullc::kBias].get_with_shape<xpu, 1, DType>(
Shape1(wmat.shape_[0]), s);
CHECK_EQ(bias.shape_[0], wmat.shape_[0])
<< "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
" This is not supported by FCForward. If bias is in row_sparse format, please"
" make sure all row ids are present.";
out += repmat(bias, data.size(0));
}
}
template<typename xpu, typename DType>
void FCBackward(const OpContext &ctx, const FullyConnectedParam ¶m,
const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req, const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& ishape = in_data[fullc::kData].shape_;
const mxnet::TShape& oshape = out_grad[fullc::kOut].shape_;
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> data, grad, gdata;
if (!param.flatten) {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
} else {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
}
#if defined(__CUDACC__)
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif
// backprop
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
// gradient of weight
Tensor<xpu, 2, DType> gwmat = in_grad[fullc::kWeight].get<xpu, 2, DType>(s);
// Legacy approach shown here for comparison:
// out = Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data));
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
// gradient of bias
if (!param.no_bias) {
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
Assign(gbias, req[fullc::kBias], sum_rows(grad));
}
// gradient of data
// Legacy approach shown here for comparison:
// Assign(gdata, req[fullc::kData], dot(grad, wmat));
linalg_gemm(grad, wmat, gdata, false, false, s, req[fullc::kData]);
}
template<typename xpu>
void FullyConnectedCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
uint32_t in_expected = param.no_bias ? 2 : 3;
CHECK_EQ(inputs.size(), in_expected);
CHECK_EQ(outputs.size(), 1U);
int dtype = inputs[0].type_flag_;
switch (dtype) {
case mshadow::kFloat32:
FCForward<xpu, float>(ctx, param, inputs, req, outputs);
break;
case mshadow::kFloat64:
FCForward<xpu, double>(ctx, param, inputs, req, outputs);
break;
case mshadow::kFloat16:
LOG(FATAL) << "float16 fully connected layer is currently"
"only supported by CuDNN version.";
break;
default:
LOG(FATAL) << "Unsupported type " << dtype;
}
}
template<typename xpu>
void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), out_expected);
CHECK_EQ(req.size(), out_expected);
std::vector<TBlob> out_grad{inputs[0]};
std::vector<TBlob> in_data(inputs.begin() + 1, inputs.end());
int dtype = inputs[0].type_flag_;
switch (dtype) {
case mshadow::kFloat32:
FCBackward<xpu, float>(ctx, param, out_grad, in_data, req, outputs);
break;
case mshadow::kFloat64:
FCBackward<xpu, double>(ctx, param, out_grad, in_data, req, outputs);
break;
case mshadow::kFloat16:
LOG(FATAL) << "float16 fully connected layer is currently"
"only supported by CuDNN version.";
break;
default:
LOG(FATAL) << "Unsupported type " << dtype;
}
}
} // namespace op
} // namespace mxnet
namespace std {
template<>
struct hash<mxnet::op::FullyConnectedParam> {
size_t operator()(const mxnet::op::FullyConnectedParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.num_hidden);
ret = dmlc::HashCombine(ret, val.no_bias);
ret = dmlc::HashCombine(ret, val.flatten);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_