diff --git a/lightllm/common/basemodel/triton_kernel/gather_token_id.py b/lightllm/common/basemodel/triton_kernel/gather_token_id.py index c3d95d300..147944747 100644 --- a/lightllm/common/basemodel/triton_kernel/gather_token_id.py +++ b/lightllm/common/basemodel/triton_kernel/gather_token_id.py @@ -16,6 +16,7 @@ def _fwd_kernel_scatter( num_size, HAS_OUT_IS_NONE: tl.constexpr, BLOCK: tl.constexpr, + OLD_VERSION_TRITON: tl.constexpr, ): block_index = tl.program_id(0) block_range = block_index * BLOCK + tl.arange(0, BLOCK) @@ -27,6 +28,8 @@ def _fwd_kernel_scatter( if not HAS_OUT_IS_NONE: cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False) + if OLD_VERSION_TRITON: + cur_has_out = cur_has_out != 0 tl.store( req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, cur_next_token_id, @@ -76,6 +79,7 @@ def scatter_token( num_size=batch_size, HAS_OUT_IS_NONE=b_has_out is None, BLOCK=BLOCK, + OLD_VERSION_TRITON=triton.__version__ < "3.2.0", num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py index 638ad92d3..6f14c6d5a 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py @@ -125,6 +125,7 @@ def _token_id_counter_update_kernel( batch_size, HAS_MASK: tl.constexpr, BLOCK: tl.constexpr, + OLD_VERSION_TRITON: tl.constexpr, ): block_start_index = tl.program_id(0) * BLOCK @@ -136,6 +137,8 @@ def _token_id_counter_update_kernel( if HAS_MASK: mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False) + if OLD_VERSION_TRITON: + mask = mask != 0 tl.atomic_add( req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, 1, @@ -170,6 +173,7 @@ def update_req_to_token_id_counter( batch_size=batch_size, HAS_MASK=has_mask, BLOCK=BLOCK, + OLD_VERSION_TRITON=triton.__version__ < "3.2.0", num_warps=1, ) return