forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiMarginCriterion.cu
122 lines (104 loc) · 3.46 KB
/
MultiMarginCriterion.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
#include <THC/THCTensor.hpp>
#include <THC/THCStorage.hpp>
#define MULTIMARGIN_THREADS 128
template <int P, typename Dtype, typename Acctype>
__global__ void cunn_MultiMarginCriterion_updateOutput_kernel(Dtype *output, Dtype *input, THCIndex_t *target, Dtype *weights, int nframe, int dim, bool sizeAverage, Dtype margin)
{
__shared__ Acctype buffer[MULTIMARGIN_THREADS];
int k = blockIdx.x;
Dtype *input_k = input + k*dim;
Dtype *output_k = output + k;
int target_k = ((int)target[k]);
Dtype input_target_k = input_k[target_k];
int i_start = threadIdx.x;
int i_end = dim;
int i_step = blockDim.x;
buffer[threadIdx.x] = 0;
for (int i = i_start; i < i_end; i += i_step)
{
Dtype z = margin - input_target_k + input_k[i];
if (i == target_k)
continue;
if (z > 0) {
Dtype h = (P==1) ? z : z*z;
if(weights)
h *= weights[target_k];
buffer[threadIdx.x] += h;
}
}
__syncthreads();
// reduce
if (threadIdx.x == 0)
{
Acctype sum = 0;
for (int i=0; i < blockDim.x; i++)
sum += buffer[i];
*output_k = ScalarConvert<Acctype, Dtype>::to(sum/dim);
if(sizeAverage)
*output_k /= nframe;
}
}
template <int P, typename Dtype, typename Acctype>
__global__ void cunn_MultiMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
Dtype *gradOutput,
Dtype *input,
THCIndex_t *target,
Dtype *weights,
int nframe,
int dim,
bool sizeAverage,
Dtype margin,
int reduce)
{
__shared__ Acctype buffer[MULTIMARGIN_THREADS];
int k = blockIdx.x;
Dtype *input_k = input + k*dim;
Dtype *gradInput_k = gradInput + k*dim;
int target_k = ((int)target[k]);
Dtype input_target_k = input_k[target_k];
Dtype *gradOutput_k = gradOutput;
if (!reduce) {
gradOutput_k += k;
}
Acctype g = (sizeAverage && reduce ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim));
int i_start = threadIdx.x;
int i_end = dim;
int i_step = blockDim.x;
buffer[threadIdx.x] = 0;
for (int i=i_start; i<i_end; i+=i_step)
{
Dtype z = margin - input_target_k + input_k[i];
if (i == target_k)
continue;
if (z > 0)
{
Dtype h = ScalarConvert<Acctype, Dtype>::to((P == 1) ? g : 2*g*z);
if(weights)
h *= weights[target_k];
buffer[threadIdx.x] -= h;
gradInput_k[i] = h;
}
else
gradInput_k[i] = ScalarConvert<int, Dtype>::to(0);
}
__syncthreads();
// reduce
if (threadIdx.x == 0)
{
Acctype gradInput_target_k = 0;
for (int i=0; i<blockDim.x; i++)
gradInput_target_k += buffer[i];
gradInput_k[target_k] = ScalarConvert<Acctype, Dtype>::to(gradInput_target_k);
}
for (int i=i_start; i<i_end; i+= i_step)
{
gradInput_k[i] *= * gradOutput_k;
}
}
#include <THCUNN/generic/MultiMarginCriterion.cu>
#include <THC/THCGenerateFloatTypes.h>
#undef MULTIMARGIN_THREADS