diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index fbe4d0536789..7d318773f354 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -43,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, if (threadIdx.x == 0) { int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (pos < (rows * width)) out[pos] = sum; + if (pos < width) out[pos] = sum; } }