Skip to content

Commit

Permalink
Merge pull request apache#66 from piiswrong/master
Browse files Browse the repository at this point in the history
fixed crash
  • Loading branch information
tqchen committed Oct 24, 2015
2 parents ded43f1 + b11c77b commit 28ffc0a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ inline void SoftmaxGrad(Tensor<gpu, 2, DType> &dst,
}

template<typename DType>
__global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
__global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src,
const Tensor<gpu, 2, DType> label) {
const index_t xmax = dst.size(1);
const int y = blockIdx.x;
const int n = threadIdx.x;
Expand All @@ -310,8 +310,8 @@ __global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> &dst,
}

template<typename DType>
__global__ void Softmax3DKernel(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
__global__ void Softmax3DKernel(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> src) {
const index_t xmax = dst.size(1);
const int y = blockIdx.x;
const int n = threadIdx.x;
Expand All @@ -337,7 +337,7 @@ template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(dst.size(0), dst.size(2));
dim3 dimGrid(dst.size(0));
CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "Softmax");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Expand All @@ -350,7 +350,7 @@ inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(dst.size(0), dst.size(2));
dim3 dimGrid(dst.size(0));
CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch";
CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch";
CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch";
Expand Down

0 comments on commit 28ffc0a

Please sign in to comment.