Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Jul 15, 2019
1 parent e21b046 commit 4c0a179
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ __device__ inline T warp_reduce(T value, OP redfun) {

template <typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = float(value);
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
Expand Down Expand Up @@ -288,7 +288,8 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index
__syncthreads();
}
if (my_id < warp_size) {
AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return ::max(x, y); });
AType my_value = warp_reduce(scratch[threadIdx.x],
[](AType x, AType y) { return ::max(x, y); });
scratch[threadIdx.x] = my_value;
}
__syncthreads();
Expand All @@ -311,7 +312,8 @@ __global__ void softmax_compute_kernel2(const DType *in, OType *out, const index
__syncthreads();
}
if (my_id < warp_size) {
AType my_value = warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y;});
AType my_value = warp_reduce(scratch[threadIdx.x],
[](AType x, AType y) { return x + y;});
scratch[threadIdx.x] = my_value;
}
__syncthreads();
Expand Down

0 comments on commit 4c0a179

Please sign in to comment.