Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions vllm_ascend/ops/triton/activation/swiglu_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num

try:
import triton.language.extra.cann.extension as extension # type: ignore

extract_slice = extension.extract_slice
except ImportError:
extract_slice = tl.extract_slice
Comment on lines +6 to +11
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current try-except ImportError block assumes that tl.extract_slice will always be available if triton.language.extra.cann.extension cannot be imported. However, the PR description indicates that these functions might not be in the standard triton.language library for the TritonAscend main branch. If tl.extract_slice is indeed missing, this fallback will result in an AttributeError at runtime.

To make this more robust, consider explicitly checking for the existence of tl.extract_slice or handling AttributeError in the fallback, or raising a more specific error if neither source provides the function.

try:
    import triton.language.extra.cann.extension as extension
    extract_slice = extension.extract_slice
except ImportError:
    if hasattr(tl, 'extract_slice'):
        extract_slice = tl.extract_slice
    else:
        raise RuntimeError("Neither triton.language.extra.cann.extension.extract_slice nor tl.extract_slice is available.")



@triton.jit
def _swiglu_quant_kernel(
Expand Down Expand Up @@ -40,8 +47,8 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
out = x1 * tl.sigmoid(x1) * x2

# quant
Expand All @@ -50,7 +57,7 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")

Expand Down
27 changes: 18 additions & 9 deletions vllm_ascend/ops/triton/fla/solve_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@

from .utils import prepare_chunk_indices

try:
import triton.language.extra.cann.extension as extension # type: ignore

insert_slice = extension.insert_slice
extract_slice = extension.extract_slice
except ImportError:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could move the same logic to triton_utils.py

insert_slice = tl.insert_slice
extract_slice = tl.extract_slice

Comment on lines +17 to 25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in swiglu_quant.py, the fallback in the except ImportError block assumes tl.insert_slice and tl.extract_slice are always present. If these functions are not part of the standard triton.language module in some environments, this could lead to an AttributeError.

It's recommended to add a check for the existence of these functions in tl or to catch AttributeError to provide a more graceful failure or a clearer error message.

Suggested change
try:
import triton.language.extra.cann.extension as extension
insert_slice = extension.insert_slice
extract_slice = extension.extract_slice
except ImportError:
insert_slice = tl.insert_slice
extract_slice = tl.extract_slice
try:
import triton.language.extra.cann.extension as extension
insert_slice = extension.insert_slice
extract_slice = extension.extract_slice
except ImportError:
if hasattr(tl, 'insert_slice') and hasattr(tl, 'extract_slice'):
insert_slice = tl.insert_slice
extract_slice = tl.extract_slice
else:
raise RuntimeError("Neither triton.language.extra.cann.extension nor tl provides insert_slice/extract_slice.")


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
Expand Down Expand Up @@ -80,7 +89,7 @@ def solve_tril_16x16_kernel(

# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
Expand All @@ -100,15 +109,15 @@ def solve_tril_16x16_kernel(

# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))

dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
dot_product = tl.sum(dot_tmp, 0)
b_a = b_a + dot_product

b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)

Expand Down Expand Up @@ -276,9 +285,9 @@ def merge_16x16_to_64x64_inverse_kernel(

# build Ai_22_32 (32 * 32)
Ai_22_32 = tl.zeros((32, 32), tl.float32)
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))

# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
Expand All @@ -290,9 +299,9 @@ def merge_16x16_to_64x64_inverse_kernel(

# build Ai_11_32 (32 * 32)
Ai_11_32 = tl.zeros((32, 32), tl.float32)
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))

Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")

Expand Down
25 changes: 17 additions & 8 deletions vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num

try:
import triton.language.extra.cann.extension as extension # type: ignore

insert_slice = extension.insert_slice
extract_slice = extension.extract_slice
except ImportError:
insert_slice = tl.insert_slice
extract_slice = tl.extract_slice

Comment on lines +25 to 33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The conditional import for insert_slice and extract_slice has the same potential issue as identified in the other files. If triton.language.extra.cann.extension fails to import, the fallback to tl.insert_slice and tl.extract_slice might fail with an AttributeError if these functions are not defined in the standard triton.language module for the current environment.

Please implement a more robust fallback mechanism to ensure these functions are available or to provide a clear error if they are not.

try:
    import triton.language.extra.cann.extension as extension
    insert_slice = extension.insert_slice
    extract_slice = extension.extract_slice
except ImportError:
    if hasattr(tl, 'insert_slice') and hasattr(tl, 'extract_slice'):
        insert_slice = tl.insert_slice
        extract_slice = tl.extract_slice
    else:
        raise RuntimeError("Neither triton.language.extra.cann.extension nor tl provides insert_slice/extract_slice.")


@triton.jit
def split_qkv_rmsnorm_rope_kernel(
Expand Down Expand Up @@ -79,13 +88,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
Expand All @@ -95,14 +104,14 @@ def split_qkv_rmsnorm_rope_kernel(
roped_q2 = x2 * cos + x1 * sin

roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q1,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q2,
offsets=(0, HALF_HEAD_DIM),
Expand Down Expand Up @@ -145,13 +154,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
Expand All @@ -161,14 +170,14 @@ def split_qkv_rmsnorm_rope_kernel(
roped_k2 = x2 * cos + x1 * sin

roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k1,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k2,
offsets=(0, HALF_HEAD_DIM),
Expand Down