diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 3843c2b6ea8b..959016bf10e3 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -173,7 +173,7 @@ void launch_dequantize(T* output, unsigned thd_cnt = (hidden_dim - 1) / threads + 1; hid_cnt = hid_cnt > 0 ? hid_cnt : 1; - unsigned blocks = output_size / hid_cnt / groups; + unsigned blocks = (output_size + hid_cnt * groups - 1) / (hid_cnt * groups); dim3 block_dims(threads); dim3 grid_dims(groups, blocks);