diff --git a/vllm_ascend/ops/triton/activation/swiglu_quant.py b/vllm_ascend/ops/triton/activation/swiglu_quant.py index 7ec2cbaf36a..73dc27dcd9e 100644 --- a/vllm_ascend/ops/triton/activation/swiglu_quant.py +++ b/vllm_ascend/ops/triton/activation/swiglu_quant.py @@ -1,7 +1,8 @@ import torch from vllm.triton_utils import tl, triton -from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.ops.triton.triton_utils import (extract_slice, + get_vectorcore_num) @triton.jit @@ -41,14 +42,14 @@ def _swiglu_quant_kernel( # swiglu x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS) cur_x = tl.load(x_ptr + x_offsets) - x1 = tl.extract_slice(cur_x, - offsets=(0, ), - sizes=(HALF_COLS, ), - strides=(1, )) - x2 = tl.extract_slice(cur_x, - offsets=(HALF_COLS, ), - sizes=(HALF_COLS, ), - strides=(1, )) + x1 = extract_slice(cur_x, + offsets=(0, ), + sizes=(HALF_COLS, ), + strides=(1, )) + x2 = extract_slice(cur_x, + offsets=(HALF_COLS, ), + sizes=(HALF_COLS, ), + strides=(1, )) out = x1 * tl.sigmoid(x1) * x2 # quant @@ -57,10 +58,10 @@ def _swiglu_quant_kernel( # store scale tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty)) for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE): - tmp_out = tl.extract_slice(out, - offsets=(col_blk_idx, ), - sizes=(COL_BLOCK_SIZE, ), - strides=(1, )) + tmp_out = extract_slice(out, + offsets=(col_blk_idx, ), + sizes=(COL_BLOCK_SIZE, ), + strides=(1, )) tmp_out = (tmp_out.to(tl.float32) / scale).to( x_ptr.dtype.element_ty) tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate") diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index a80003207ca..f911d77c1d0 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -13,6 +13,8 @@ import torch from vllm.triton_utils import tl, triton +from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice + from .utils import prepare_chunk_indices @@ -78,7 +80,7 @@ def solve_tril_16x16_kernel( # 4 Use mask to safely load data b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32) - b_A = tl.insert_slice( + b_A = insert_slice( ful=b_A, sub=b_A_subrec16[None, :, :], # (1, 16, 16) offsets=[blkid, 0, 0], @@ -97,9 +99,9 @@ def solve_tril_16x16_kernel( # for loop to update N_BLOCKS row vector for i in range(1, 16): - nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), - (1, 16 * N_BLOCKS), - (16 * N_BLOCKS, 1)) + nblks_vec16 = -extract_slice(local_ori_A, (i, 0), + (1, 16 * N_BLOCKS), + (16 * N_BLOCKS, 1)) b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) @@ -107,11 +109,11 @@ def solve_tril_16x16_kernel( b_a = b_a + dot_product b_a_new_expanded = b_a[:, None, :] - b_A = tl.insert_slice(ful=b_A, - sub=b_a_new_expanded, - offsets=[0, i, 0], - sizes=[N_BLOCKS, 1, 16], - strides=[1, 1, 1]) + b_A = insert_slice(ful=b_A, + sub=b_a_new_expanded, + offsets=[0, i, 0], + sizes=[N_BLOCKS, 1, 16], + strides=[1, 1, 1]) on_diagonal = (rows == cols) b_A = tl.where(on_diagonal, b_A + 1.0, b_A) @@ -288,9 +290,9 @@ def merge_16x16_to_64x64_inverse_kernel( # build Ai_22_32 (32 * 32) Ai_22_32 = tl.zeros((32, 32), tl.float32) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) offs_m = i_t * 64 + 32 + tl.arange(0, 32) @@ -302,9 +304,9 @@ def merge_16x16_to_64x64_inverse_kernel( # build Ai_11_32 (32 * 32) Ai_11_32 = tl.zeros((32, 32), tl.float32) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 01ff79bf26b..6c77139fecd 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -21,7 +21,9 @@ import triton.language as tl # type: ignore from vllm.utils.torch_utils import direct_register_custom_op -from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.ops.triton.triton_utils import (extract_slice, + get_vectorcore_num, + insert_slice) @triton.jit @@ -83,13 +85,13 @@ def split_qkv_rmsnorm_rope_kernel( sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) - x1 = tl.extract_slice( + x1 = extract_slice( normalized_values, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( normalized_values, offsets=(0, HALF_HEAD_DIM), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), @@ -100,14 +102,14 @@ def split_qkv_rmsnorm_rope_kernel( roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_q = tl.insert_slice( + roped_q = insert_slice( roped_q, roped_q1, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_q = tl.insert_slice( + roped_q = insert_slice( roped_q, roped_q2, offsets=(0, HALF_HEAD_DIM), @@ -153,13 +155,13 @@ def split_qkv_rmsnorm_rope_kernel( sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) - x1 = tl.extract_slice( + x1 = extract_slice( normalized_values, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( normalized_values, offsets=(0, HALF_HEAD_DIM), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), @@ -170,14 +172,14 @@ def split_qkv_rmsnorm_rope_kernel( roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_k = tl.insert_slice( + roped_k = insert_slice( roped_k, roped_k1, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_k = tl.insert_slice( + roped_k = insert_slice( roped_k, roped_k2, offsets=(0, HALF_HEAD_DIM), diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 142815572ea..86b73467659 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -17,7 +17,7 @@ from vllm.triton_utils import tl, triton -from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num def cal_grid_and_block_size(batch_size: int): @@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton( tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) for pos in tl.range(0, BLOCK_SIZE): - draft_token_id1 = tl.get_element(draft_token_id, (pos, )) - target_argmax1 = tl.get_element(target_argmax_id, (pos, )) + draft_token_id1 = get_element(draft_token_id, (pos, )) + target_argmax1 = get_element(target_argmax_id, (pos, )) position = block_idx * BLOCK_SIZE + pos if draft_token_id1 == target_argmax1: bonus_renew_1( @@ -113,10 +113,10 @@ def rejection_greedy_sample_triton( num_draft_tokens = end_idx - start_idx for pos in tl.range(0, BLOCK_SIZE): - num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) + num_tokens1 = get_element(num_draft_tokens, (pos, )) rejected = False - start_idx1 = tl.get_element(start_idx, (pos, )) - is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, )) + start_idx1 = get_element(start_idx, (pos, )) + is_greedy_mask1 = get_element(is_greedy_mask, (pos, )) position = block_idx * BLOCK_SIZE + pos for i in range(num_tokens1): if not rejected: @@ -167,12 +167,12 @@ def rejection_random_sample_kernel( end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): - not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + not_greedy = get_element(not_greedy_mask, (req_i, )) if not_greedy: rejected = False - start_idx = tl.get_element(start_idxs, (req_i, )) + start_idx = get_element(start_idxs, (req_i, )) req_idx = block_idx * BLOCK_SIZE + req_i - num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + num_draft_tokens = get_element(n_num_draft_tokens, (req_i, )) for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + @@ -234,9 +234,9 @@ def expand_kernel( src_val = tl.where(src_val == replace_from, replace_to, src_val) for i in tl.range(0, BLOCK_SIZE): - num_tokens1 = tl.get_element(num_tokens, (i, )) - start_idx1 = tl.get_element(start_idx, (i, )) - src_val1 = tl.get_element(src_val, (i, )) + num_tokens1 = get_element(num_tokens, (i, )) + start_idx1 = get_element(start_idx, (i, )) + src_val1 = get_element(src_val, (i, )) offset1 = tl.arange(0, MAX_NUM_TOKENS) tl.store(output_ptr + start_idx1 + offset1, src_val1, @@ -292,7 +292,7 @@ def sample_recovered_tokens_kernel( other=float("-inf")) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) - max_p = tl.get_element(new_p, (recovered_id, )) + max_p = get_element(new_p, (recovered_id, )) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id @@ -318,7 +318,7 @@ def sample_recovered_tokens_kernel( other=float("-inf")) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) - max_p = tl.get_element(new_p, (recovered_id, )) + max_p = get_element(new_p, (recovered_id, )) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id @@ -407,16 +407,16 @@ def rejection_random_sample_block_verify_kernel( end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): - not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + not_greedy = get_element(not_greedy_mask, (req_i, )) if not_greedy: rejected = False pi = 1.0 uniform_prob = 1.0 last_accepted_token_pos = -1 - start_idx = tl.get_element(start_idxs, (req_i, )) + start_idx = get_element(start_idxs, (req_i, )) req_idx = block_idx * BLOCK_SIZE + req_i - num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + num_draft_tokens = get_element(n_num_draft_tokens, (req_i, )) for pos in range(num_draft_tokens): draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) diff --git a/vllm_ascend/ops/triton/triton_utils.py b/vllm_ascend/ops/triton/triton_utils.py index 6b0ac9645d2..ea8aab64077 100644 --- a/vllm_ascend/ops/triton/triton_utils.py +++ b/vllm_ascend/ops/triton/triton_utils.py @@ -1,10 +1,48 @@ from typing import Any, Dict import torch -from vllm.triton_utils import HAS_TRITON, triton +from vllm.triton_utils import HAS_TRITON, tl, triton _NUM_AICORE = -1 _NUM_VECTORCORE = -1 +_extension_module = None + +if HAS_TRITON: + try: + import triton.language.extra.cann.extension as _extension_module # type: ignore + except ImportError: + _extension_module = None + + +def _resolve_triton_ascend_op(op_name: str): + if not HAS_TRITON: + raise RuntimeError( + f"Triton op '{op_name}' cannot be resolved because HAS_TRITON is False" + ) + + if _extension_module is not None: + extension_op = getattr(_extension_module, op_name, None) + if extension_op is not None: + return extension_op + + tl_op = getattr(tl, op_name, None) + if tl_op is not None: + return tl_op + + raise RuntimeError( + f"Failed to resolve Triton op '{op_name}': " + "neither triton.language.extra.cann.extension nor triton.language provides it." + ) + + +if HAS_TRITON: + insert_slice = _resolve_triton_ascend_op("insert_slice") + extract_slice = _resolve_triton_ascend_op("extract_slice") + get_element = _resolve_triton_ascend_op("get_element") +else: + insert_slice = None + extract_slice = None + get_element = None def init_device_properties_triton():