Skip to content

Commit

Permalink
Merge pull request apache#87 from tornadomeet/master
Browse files Browse the repository at this point in the history
[RFC] add ignore-label for image segmentation
  • Loading branch information
tqchen committed Dec 20, 2015
2 parents 120acae + 4e32830 commit 47521c6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
46 changes: 44 additions & 2 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ __global__ void SoftmaxKernel(DstPlan dst, SrcPlan src, index_t xmax) {
}
}
}

template<typename DType>
inline void Softmax(Tensor<gpu, 2, DType> &dst,
const Tensor<gpu, 2, DType> &src) {
Expand Down Expand Up @@ -311,6 +312,34 @@ __global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> dst,
}
}

template<int n_bits, typename DType>
__global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src,
const Tensor<gpu, 2, DType> label,
DType ignore_label) {
const index_t xmax = dst.size(1);
const index_t nmax = dst.size(2);
const unsigned n_size = 1 << n_bits;
const int y = blockIdx.x;
const int n = threadIdx.x;
for (index_t n_index = n; n_index < nmax; n_index += n_size) {
int k = static_cast<int>(label[y][n_index]);
if (k == static_cast<int>(ignore_label)) {
for (index_t i = 0; i < xmax; ++i) {
dst[y][i][n_index] = 0.0f;
}
} else {
for (index_t i = 0; i < xmax; ++i) {
if (i == k) {
dst[y][i][n_index] = src[y][i][n_index] - 1.0f;
} else {
dst[y][i][n_index] = src[y][i][n_index];
}
}
}
}
}

template<int n_bits, typename DType>
__global__ void Softmax3DKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src) {
Expand All @@ -337,7 +366,6 @@ __global__ void Softmax3DKernel(Tensor<gpu, 3, DType> dst,
}
}


template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
Expand All @@ -349,7 +377,6 @@ inline void Softmax(Tensor<gpu, 3, DType> &dst,
Softmax3DKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src);
}


template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
Expand All @@ -364,6 +391,21 @@ inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
Softmax3DGradKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src, label);
}

template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label,
const DType &ignore_label) {
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(dst.size(0));
CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch";
CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch";
CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Softmax3DGradKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src, label, ignore_label);
}

} // namespace cuda
} // namespace mshadow
#endif // MSHADOW_CUDA_TENSOR_GPU_INL_CUH_
25 changes: 25 additions & 0 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,31 @@ inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
}
}

template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label,
const DType &ignore_label) {
for (index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
if (k == static_cast<int>(ignore_label)) {
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = 0.0f;
}
} else {
for (index_t x = 0; x < dst.size(1); ++x) {
if (x == k) {
dst[y][k][n] = src[y][k][n] - 1.0f;
} else {
dst[y][x][n] = src[y][x][n];
}
}
}
}
}
}

template<typename DType>
inline void Softmax(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &energy) {
Expand Down
8 changes: 8 additions & 0 deletions mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst,
cuda::SoftmaxGrad(dst, src, label);
}

template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label,
const DType &ignore_label) {
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}

} // namespace mshadow
#endif // __CUDACC__
#endif // MSHADOW_TENSOR_GPU_INL_H_

0 comments on commit 47521c6

Please sign in to comment.