-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Bugfix] fix the wrong use of extract_slice and insert_slice #6956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,15 @@ | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from .utils import prepare_chunk_indices | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| import triton.language.extra.cann.extension as extension # type: ignore | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| insert_slice = extension.insert_slice | ||||||||||||||||||||||||||||||||||||
| extract_slice = extension.extract_slice | ||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could move the same logic to triton_utils.py |
||||||||||||||||||||||||||||||||||||
| insert_slice = tl.insert_slice | ||||||||||||||||||||||||||||||||||||
| extract_slice = tl.extract_slice | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
Comment on lines
+17
to
25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in It's recommended to add a check for the existence of these functions in
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) | ||||||||||||||||||||||||||||||||||||
| @triton.jit(do_not_specialize=["T"]) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -80,7 +89,7 @@ def solve_tril_16x16_kernel( | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # 4 Use mask to safely load data | ||||||||||||||||||||||||||||||||||||
| b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32) | ||||||||||||||||||||||||||||||||||||
| b_A = tl.insert_slice( | ||||||||||||||||||||||||||||||||||||
| b_A = insert_slice( | ||||||||||||||||||||||||||||||||||||
| ful=b_A, | ||||||||||||||||||||||||||||||||||||
| sub=b_A_subrec16[None, :, :], # (1, 16, 16) | ||||||||||||||||||||||||||||||||||||
| offsets=[blkid, 0, 0], | ||||||||||||||||||||||||||||||||||||
|
|
@@ -100,15 +109,15 @@ def solve_tril_16x16_kernel( | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # for loop to update N_BLOCKS row vector | ||||||||||||||||||||||||||||||||||||
| for i in range(1, 16): | ||||||||||||||||||||||||||||||||||||
| nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)) | ||||||||||||||||||||||||||||||||||||
| nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1)) | ||||||||||||||||||||||||||||||||||||
| b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) | ||||||||||||||||||||||||||||||||||||
| dot_product = tl.sum(dot_tmp, 0) | ||||||||||||||||||||||||||||||||||||
| b_a = b_a + dot_product | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| b_a_new_expanded = b_a[:, None, :] | ||||||||||||||||||||||||||||||||||||
| b_A = tl.insert_slice( | ||||||||||||||||||||||||||||||||||||
| b_A = insert_slice( | ||||||||||||||||||||||||||||||||||||
| ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -276,9 +285,9 @@ def merge_16x16_to_64x64_inverse_kernel( | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # build Ai_22_32 (32 * 32) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = tl.zeros((32, 32), tl.float32) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) | ||||||||||||||||||||||||||||||||||||
| offs_m = i_t * 64 + 32 + tl.arange(0, 32) | ||||||||||||||||||||||||||||||||||||
|
|
@@ -290,9 +299,9 @@ def merge_16x16_to_64x64_inverse_kernel( | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # build Ai_11_32 (32 * 32) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = tl.zeros((32, 32), tl.float32) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
| Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,15 @@ | |
|
|
||
| from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num | ||
|
|
||
| try: | ||
| import triton.language.extra.cann.extension as extension # type: ignore | ||
|
|
||
| insert_slice = extension.insert_slice | ||
| extract_slice = extension.extract_slice | ||
| except ImportError: | ||
| insert_slice = tl.insert_slice | ||
| extract_slice = tl.extract_slice | ||
|
|
||
|
Comment on lines
+25
to
33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The conditional import for Please implement a more robust fallback mechanism to ensure these functions are available or to provide a clear error if they are not. try:
import triton.language.extra.cann.extension as extension
insert_slice = extension.insert_slice
extract_slice = extension.extract_slice
except ImportError:
if hasattr(tl, 'insert_slice') and hasattr(tl, 'extract_slice'):
insert_slice = tl.insert_slice
extract_slice = tl.extract_slice
else:
raise RuntimeError("Neither triton.language.extra.cann.extension nor tl provides insert_slice/extract_slice.") |
||
|
|
||
| @triton.jit | ||
| def split_qkv_rmsnorm_rope_kernel( | ||
|
|
@@ -79,13 +88,13 @@ def split_qkv_rmsnorm_rope_kernel( | |
| sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) | ||
| cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) | ||
| sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) | ||
| x1 = tl.extract_slice( | ||
| x1 = extract_slice( | ||
| normalized_values, | ||
| offsets=(0, 0), | ||
| sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
| strides=(1, 1), | ||
| ) | ||
| x2 = tl.extract_slice( | ||
| x2 = extract_slice( | ||
| normalized_values, | ||
| offsets=(0, HALF_HEAD_DIM), | ||
| sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
|
|
@@ -95,14 +104,14 @@ def split_qkv_rmsnorm_rope_kernel( | |
| roped_q2 = x2 * cos + x1 * sin | ||
|
|
||
| roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) | ||
| roped_q = tl.insert_slice( | ||
| roped_q = insert_slice( | ||
| roped_q, | ||
| roped_q1, | ||
| offsets=(0, 0), | ||
| sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
| strides=(1, 1), | ||
| ) | ||
| roped_q = tl.insert_slice( | ||
| roped_q = insert_slice( | ||
| roped_q, | ||
| roped_q2, | ||
| offsets=(0, HALF_HEAD_DIM), | ||
|
|
@@ -145,13 +154,13 @@ def split_qkv_rmsnorm_rope_kernel( | |
| sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM) | ||
| cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM) | ||
| sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM) | ||
| x1 = tl.extract_slice( | ||
| x1 = extract_slice( | ||
| normalized_values, | ||
| offsets=(0, 0), | ||
| sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
| strides=(1, 1), | ||
| ) | ||
| x2 = tl.extract_slice( | ||
| x2 = extract_slice( | ||
| normalized_values, | ||
| offsets=(0, HALF_HEAD_DIM), | ||
| sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
|
|
@@ -161,14 +170,14 @@ def split_qkv_rmsnorm_rope_kernel( | |
| roped_k2 = x2 * cos + x1 * sin | ||
|
|
||
| roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) | ||
| roped_k = tl.insert_slice( | ||
| roped_k = insert_slice( | ||
| roped_k, | ||
| roped_k1, | ||
| offsets=(0, 0), | ||
| sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), | ||
| strides=(1, 1), | ||
| ) | ||
| roped_k = tl.insert_slice( | ||
| roped_k = insert_slice( | ||
| roped_k, | ||
| roped_k2, | ||
| offsets=(0, HALF_HEAD_DIM), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current
try-except ImportErrorblock assumes thattl.extract_slicewill always be available iftriton.language.extra.cann.extensioncannot be imported. However, the PR description indicates that these functions might not be in the standardtriton.languagelibrary for the TritonAscend main branch. Iftl.extract_sliceis indeed missing, this fallback will result in anAttributeErrorat runtime.To make this more robust, consider explicitly checking for the existence of
tl.extract_sliceor handlingAttributeErrorin the fallback, or raising a more specific error if neither source provides the function.