-
Notifications
You must be signed in to change notification settings - Fork 395
/
Copy pathprelu_layer.cu
128 lines (113 loc) · 4.3 KB
/
prelu_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
// CUDA kernele for forward
template <typename Dtype>
__global__ void PReLUForward(const int n, const int channels, const int dim,
const Dtype* in, Dtype* out, const Dtype* slope_data,
const int div_factor) {
CUDA_KERNEL_LOOP(index, n) {
int c = (index / dim) % channels / div_factor;
out[index] = in[index] > 0 ? in[index] : in[index] * slope_data[c];
}
}
// CUDA kernel for bottom backward
template <typename Dtype>
__global__ void PReLUBackward(const int n, const int channels, const int dim,
const Dtype* in_diff, const Dtype* in_data, Dtype* out_diff,
const Dtype* slope_data, const int div_factor) {
CUDA_KERNEL_LOOP(index, n) {
int c = (index / dim) % channels / div_factor;
out_diff[index] = in_diff[index] * ((in_data[index] > 0)
+ (in_data[index] <= 0) * slope_data[c]);
}
}
// CUDA kernel for element-wise parameter backward
template <typename Dtype>
__global__ void PReLUParamBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
CUDA_KERNEL_LOOP(index, n) {
out_diff[index] = in_diff[index] * in_data[index] * (in_data[index] <= 0);
}
}
template <typename Dtype>
void PReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int count = bottom[0]->count();
const int dim = bottom[0]->count(2);
const int channels = bottom[0]->channels();
const Dtype* slope_data = this->blobs_[0]->gpu_data();
const int div_factor = channel_shared_ ? channels : 1;
// For in-place computation
if (top[0] == bottom[0]) {
caffe_copy(count, bottom_data, bottom_memory_.mutable_gpu_data());
}
// NOLINT_NEXT_LINE(whitespace/operators)
PReLUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, channels, dim, bottom_data, top_data, slope_data, div_factor);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
void PReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const int count = bottom[0]->count();
const int dim = bottom[0]->count(2);
const int channels = bottom[0]->channels();
// For in-place computation
if (top[0] == bottom[0]) {
bottom_data = bottom_memory_.gpu_data();
}
// Propagate to param
// Since to write bottom diff will affect top diff if top and bottom blobs
// are identical (in-place computaion), we first compute param backward to
// keep top_diff unchanged.
if (this->param_propagate_down_[0]) {
Dtype* slope_diff = this->blobs_[0]->mutable_gpu_diff();
int cdim = channels * dim;
Dtype dsum = 0.;
for (int n = 0; n < bottom[0]->num(); ++n) {
// compute element-wise diff
// NOLINT_NEXT_LINE(whitespace/operators)
PReLUParamBackward<Dtype><<<CAFFE_GET_BLOCKS(cdim),
CAFFE_CUDA_NUM_THREADS>>>(
cdim, top_diff + top[0]->offset(n),
bottom_data + bottom[0]->offset(n),
backward_buff_.mutable_gpu_diff());
CUDA_POST_KERNEL_CHECK;
if (channel_shared_) {
Dtype d;
caffe_gpu_dot<Dtype>(channels * dim, backward_buff_.gpu_diff(),
multiplier_.gpu_data(), &d);
dsum += d;
} else {
caffe_gpu_gemv<Dtype>(CblasNoTrans, channels, dim, 1.,
backward_buff_.gpu_diff(), multiplier_.gpu_data(), 1.,
slope_diff);
}
}
if (channel_shared_) {
caffe_gpu_add_scalar(this->blobs_[0]->count(), Dtype(dsum), slope_diff);
}
}
// Propagate to bottom
if (propagate_down[0]) {
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const Dtype* slope_data = this->blobs_[0]->gpu_data();
int div_factor = channel_shared_ ? channels : 1;
// NOLINT_NEXT_LINE(whitespace/operators)
PReLUBackward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, channels, dim, top_diff, bottom_data, bottom_diff, slope_data,
div_factor);
CUDA_POST_KERNEL_CHECK;
}
}
INSTANTIATE_LAYER_GPU_FUNCS(PReLULayer);
} // namespace caffe