-
Notifications
You must be signed in to change notification settings - Fork 395
/
Copy pathconcat_layer.cpp
99 lines (90 loc) · 3.75 KB
/
concat_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void ConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const ConcatParameter& concat_param = this->layer_param_.concat_param();
CHECK(!(concat_param.has_axis() && concat_param.has_concat_dim()))
<< "Either axis or concat_dim should be specified; not both.";
}
template <typename Dtype>
void ConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int num_axes = bottom[0]->num_axes();
const ConcatParameter& concat_param = this->layer_param_.concat_param();
if (concat_param.has_concat_dim()) {
concat_axis_ = static_cast<int>(concat_param.concat_dim());
// Don't allow negative indexing for concat_dim, a uint32 -- almost
// certainly unintended.
CHECK_GE(concat_axis_, 0) << "casting concat_dim from uint32 to int32 "
<< "produced negative result; concat_dim must satisfy "
<< "0 <= concat_dim < " << kMaxBlobAxes;
CHECK_LT(concat_axis_, num_axes) << "concat_dim out of range.";
} else {
concat_axis_ = bottom[0]->CanonicalAxisIndex(concat_param.axis());
}
// Initialize with the first blob.
vector<int> top_shape = bottom[0]->shape();
num_concats_ = bottom[0]->count(0, concat_axis_);
concat_input_size_ = bottom[0]->count(concat_axis_ + 1);
int bottom_count_sum = bottom[0]->count();
for (int i = 1; i < bottom.size(); ++i) {
CHECK_EQ(num_axes, bottom[i]->num_axes())
<< "All inputs must have the same #axes.";
for (int j = 0; j < num_axes; ++j) {
if (j == concat_axis_) { continue; }
CHECK_EQ(top_shape[j], bottom[i]->shape(j))
<< "All inputs must have the same shape, except at concat_axis.";
}
bottom_count_sum += bottom[i]->count();
top_shape[concat_axis_] += bottom[i]->shape(concat_axis_);
}
top[0]->Reshape(top_shape);
CHECK_EQ(bottom_count_sum, top[0]->count());
}
template <typename Dtype>
void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
Dtype* top_data = top[0]->mutable_cpu_data();
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_,
bottom_data + n * bottom_concat_axis * concat_input_size_,
top_data + (n * top_concat_axis + offset_concat_axis)
* concat_input_size_);
}
offset_concat_axis += bottom_concat_axis;
}
}
template <typename Dtype>
void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
for (int i = 0; i < bottom.size(); ++i) {
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
if (propagate_down[i]) {
Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
(n * top_concat_axis + offset_concat_axis) * concat_input_size_,
bottom_diff + n * bottom_concat_axis * concat_input_size_);
}
}
offset_concat_axis += bottom_concat_axis;
}
}
#ifdef CPU_ONLY
STUB_GPU(ConcatLayer);
#endif
INSTANTIATE_CLASS(ConcatLayer);
REGISTER_LAYER_CLASS(Concat);
} // namespace caffe