Skip to content

Commit

Permalink
Merge pull request apache#81 from tornadomeet/master
Browse files Browse the repository at this point in the history
fix bug of Softmax3DKernel for multiouput
  • Loading branch information
piiswrong committed Dec 4, 2015
2 parents da39052 + 6d01a40 commit 00ca771
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -289,50 +289,55 @@ inline void SoftmaxGrad(Tensor<gpu, 2, DType> &dst,
dst.size(1));
}

template<typename DType>
template<int n_bits, typename DType>
__global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src,
const Tensor<gpu, 2, DType> label) {
const index_t xmax = dst.size(1);
const index_t nmax = dst.size(2);
const unsigned n_size = 1 << n_bits;
const int y = blockIdx.x;
const int n = threadIdx.x;

if (n < dst.size(2)) {
const int k = static_cast<int>(label[y][n]);
for (index_t n_index = n; n_index < nmax; n_index += n_size) {
const int k = static_cast<int>(label[y][n_index]);
for (index_t i = 0; i < xmax; ++i) {
if (i == k) {
dst[y][i][n] = src[y][i][n] - 1.0f;
dst[y][i][n_index] = src[y][i][n_index] - 1.0f;
} else {
dst[y][i][n] = src[y][i][n];
dst[y][i][n_index] = src[y][i][n_index];
}
}
}
}
}

template<typename DType>
template<int n_bits, typename DType>
__global__ void Softmax3DKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src) {
const index_t xmax = dst.size(1);
const index_t nmax = dst.size(2);
const unsigned n_size = 1 << n_bits;
const int y = blockIdx.x;
const int n = threadIdx.x;

if (n < dst.size(2)) {
DType smax = src[y][0][n];
for (index_t n_index = n; n_index < nmax; n_index += n_size) {
DType smax = src[y][0][n_index];
for (index_t i = 1; i < xmax; ++i) {
smax = max(smax, src[y][i][n]);
smax = max(smax, src[y][i][n_index]);
}
DType ssum = 0.0f;
for (index_t i = 0; i < xmax; ++i) {
DType p = expf(src[y][i][n] - smax);
DType p = expf(src[y][i][n_index] - smax);
ssum += p;
dst[y][i][n] = p;
dst[y][i][n_index] = p;
}
for (index_t i = 0; i < xmax; ++i) {
dst[y][i][n] /= ssum;
dst[y][i][n_index] /= ssum;
}
}
}


template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
Expand All @@ -341,7 +346,7 @@ inline void Softmax(Tensor<gpu, 3, DType> &dst,
CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "Softmax");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Softmax3DKernel<DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src);
Softmax3DKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src);
}


Expand All @@ -356,7 +361,7 @@ inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Softmax3DGradKernel<DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src, label);
Softmax3DGradKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src, label);
}

} // namespace cuda
Expand Down

0 comments on commit 00ca771

Please sign in to comment.