From b531798b82784217a2e338dc134d9f13967c0977 Mon Sep 17 00:00:00 2001 From: wangx700 Date: Tue, 3 Mar 2026 16:32:40 +0800 Subject: [PATCH 1/3] fix the wrong use of extract_slice and insert_slice. In the TritonAscend main branch, extract_slice and insert_slice are located in the triton.language.extra.cann.extension library, rather than in the triton.language library Signed-off-by: wangx700 --- .../ops/triton/activation/swiglu_quant.py | 11 +++++--- vllm_ascend/ops/triton/fla/solve_tril.py | 25 ++++++++++++------- .../linearnorm/split_qkv_rmsnorm_rope.py | 23 +++++++++++------ 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/ops/triton/activation/swiglu_quant.py b/vllm_ascend/ops/triton/activation/swiglu_quant.py index b0ef78130f6..2d12c95220b 100644 --- a/vllm_ascend/ops/triton/activation/swiglu_quant.py +++ b/vllm_ascend/ops/triton/activation/swiglu_quant.py @@ -3,6 +3,11 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +try: + import triton.language.extra.cann.extension as extension + extract_slice = extension.extract_slice +except ImportError: + extract_slice = tl.extract_slice @triton.jit def _swiglu_quant_kernel( @@ -40,8 +45,8 @@ def _swiglu_quant_kernel( # swiglu x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS) cur_x = tl.load(x_ptr + x_offsets) - x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,)) - x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,)) + x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,)) + x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,)) out = x1 * tl.sigmoid(x1) * x2 # quant @@ -50,7 +55,7 @@ def _swiglu_quant_kernel( # store scale tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty)) for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE): - tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,)) + tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,)) tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty) tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate") diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index 62a943fbc08..42efa14de8f 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -14,6 +14,13 @@ from .utils import prepare_chunk_indices +try: + import triton.language.extra.cann.extension as extension + insert_slice = extension.insert_slice + extract_slice = extension.extract_slice +except ImportError: + insert_slice = tl.insert_slice + extract_slice = tl.extract_slice @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.jit(do_not_specialize=["T"]) @@ -80,7 +87,7 @@ def solve_tril_16x16_kernel( # 4 Use mask to safely load data b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32) - b_A = tl.insert_slice( + b_A = insert_slice( ful=b_A, sub=b_A_subrec16[None, :, :], # (1, 16, 16) offsets=[blkid, 0, 0], @@ -100,7 +107,7 @@ def solve_tril_16x16_kernel( # for loop to update N_BLOCKS row vector for i in range(1, 16): - nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)) + nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)) b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) @@ -108,7 +115,7 @@ def solve_tril_16x16_kernel( b_a = b_a + dot_product b_a_new_expanded = b_a[:, None, :] - b_A = tl.insert_slice( + b_A = insert_slice( ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1] ) @@ -276,9 +283,9 @@ def merge_16x16_to_64x64_inverse_kernel( # build Ai_22_32 (32 * 32) Ai_22_32 = tl.zeros((32, 32), tl.float32) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) - Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) + Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) offs_m = i_t * 64 + 32 + tl.arange(0, 32) @@ -290,9 +297,9 @@ def merge_16x16_to_64x64_inverse_kernel( # build Ai_11_32 (32 * 32) Ai_11_32 = tl.zeros((32, 32), tl.float32) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) - Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) + Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 18135aa7a19..59baa1a0a7b 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -22,6 +22,13 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num +try: + import triton.language.extra.cann.extension as extension + insert_slice = extension.insert_slice + extract_slice = extension.extract_slice +except ImportError: + insert_slice = tl.insert_slice + extract_slice = tl.extract_slice @triton.jit def split_qkv_rmsnorm_rope_kernel( @@ -79,13 +86,13 @@ def split_qkv_rmsnorm_rope_kernel( sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) - x1 = tl.extract_slice( + x1 = extract_slice( normalized_values, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( normalized_values, offsets=(0, HALF_HEAD_DIM), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), @@ -95,14 +102,14 @@ def split_qkv_rmsnorm_rope_kernel( roped_q2 = x2 * cos + x1 * sin roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_q = tl.insert_slice( + roped_q = insert_slice( roped_q, roped_q1, offsets=(0, 0), sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_q = tl.insert_slice( + roped_q = insert_slice( roped_q, roped_q2, offsets=(0, HALF_HEAD_DIM), @@ -145,13 +152,13 @@ def split_qkv_rmsnorm_rope_kernel( sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) - x1 = tl.extract_slice( + x1 = extract_slice( normalized_values, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - x2 = tl.extract_slice( + x2 = extract_slice( normalized_values, offsets=(0, HALF_HEAD_DIM), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), @@ -161,14 +168,14 @@ def split_qkv_rmsnorm_rope_kernel( roped_k2 = x2 * cos + x1 * sin roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) - roped_k = tl.insert_slice( + roped_k = insert_slice( roped_k, roped_k1, offsets=(0, 0), sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - roped_k = tl.insert_slice( + roped_k = insert_slice( roped_k, roped_k2, offsets=(0, HALF_HEAD_DIM), From 560ec28bc97ecacf552a47465dfeab3dc0069277 Mon Sep 17 00:00:00 2001 From: wangx700 Date: Tue, 3 Mar 2026 17:28:39 +0800 Subject: [PATCH 2/3] fix ruff format. Signed-off-by: wangx700 --- vllm_ascend/ops/triton/activation/swiglu_quant.py | 2 ++ vllm_ascend/ops/triton/fla/solve_tril.py | 2 ++ vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/vllm_ascend/ops/triton/activation/swiglu_quant.py b/vllm_ascend/ops/triton/activation/swiglu_quant.py index 2d12c95220b..6f462408baa 100644 --- a/vllm_ascend/ops/triton/activation/swiglu_quant.py +++ b/vllm_ascend/ops/triton/activation/swiglu_quant.py @@ -5,10 +5,12 @@ try: import triton.language.extra.cann.extension as extension + extract_slice = extension.extract_slice except ImportError: extract_slice = tl.extract_slice + @triton.jit def _swiglu_quant_kernel( x_ptr, diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index 42efa14de8f..e7cacd1bd76 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -16,12 +16,14 @@ try: import triton.language.extra.cann.extension as extension + insert_slice = extension.insert_slice extract_slice = extension.extract_slice except ImportError: insert_slice = tl.insert_slice extract_slice = tl.extract_slice + @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.jit(do_not_specialize=["T"]) def solve_tril_16x16_kernel( diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 59baa1a0a7b..0a41ca76879 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -24,12 +24,14 @@ try: import triton.language.extra.cann.extension as extension + insert_slice = extension.insert_slice extract_slice = extension.extract_slice except ImportError: insert_slice = tl.insert_slice extract_slice = tl.extract_slice + @triton.jit def split_qkv_rmsnorm_rope_kernel( input_ptr, From 0feb4ea2231969a23d05d87f860c12578494c709 Mon Sep 17 00:00:00 2001 From: wangx700 Date: Tue, 3 Mar 2026 17:40:05 +0800 Subject: [PATCH 3/3] fix mypy test with "# type: ignore" Signed-off-by: wangx700 --- vllm_ascend/ops/triton/activation/swiglu_quant.py | 2 +- vllm_ascend/ops/triton/fla/solve_tril.py | 2 +- vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ops/triton/activation/swiglu_quant.py b/vllm_ascend/ops/triton/activation/swiglu_quant.py index 6f462408baa..e651ff67170 100644 --- a/vllm_ascend/ops/triton/activation/swiglu_quant.py +++ b/vllm_ascend/ops/triton/activation/swiglu_quant.py @@ -4,7 +4,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num try: - import triton.language.extra.cann.extension as extension + import triton.language.extra.cann.extension as extension # type: ignore extract_slice = extension.extract_slice except ImportError: diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py index e7cacd1bd76..55e3847328b 100644 --- a/vllm_ascend/ops/triton/fla/solve_tril.py +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -15,7 +15,7 @@ from .utils import prepare_chunk_indices try: - import triton.language.extra.cann.extension as extension + import triton.language.extra.cann.extension as extension # type: ignore insert_slice = extension.insert_slice extract_slice = extension.extract_slice diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 0a41ca76879..f8964744dd7 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -23,7 +23,7 @@ from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num try: - import triton.language.extra.cann.extension as extension + import triton.language.extra.cann.extension as extension # type: ignore insert_slice = extension.insert_slice extract_slice = extension.extract_slice