Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions vllm_ascend/ops/triton/activation/swiglu_quant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from vllm.triton_utils import tl, triton

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import (extract_slice,
get_vectorcore_num)


@triton.jit
Expand Down Expand Up @@ -41,14 +42,14 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x,
offsets=(0, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x2 = tl.extract_slice(cur_x,
offsets=(HALF_COLS, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x1 = extract_slice(cur_x,
offsets=(0, ),
sizes=(HALF_COLS, ),
strides=(1, ))
x2 = extract_slice(cur_x,
offsets=(HALF_COLS, ),
sizes=(HALF_COLS, ),
strides=(1, ))
out = x1 * tl.sigmoid(x1) * x2

# quant
Expand All @@ -57,10 +58,10 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out,
offsets=(col_blk_idx, ),
sizes=(COL_BLOCK_SIZE, ),
strides=(1, ))
tmp_out = extract_slice(out,
offsets=(col_blk_idx, ),
sizes=(COL_BLOCK_SIZE, ),
strides=(1, ))
tmp_out = (tmp_out.to(tl.float32) / scale).to(
x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
Expand Down
32 changes: 17 additions & 15 deletions vllm_ascend/ops/triton/fla/solve_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
from vllm.triton_utils import tl, triton

from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice

from .utils import prepare_chunk_indices


Expand Down Expand Up @@ -78,7 +80,7 @@ def solve_tril_16x16_kernel(
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
other=0.0).to(tl.float32)
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
Expand All @@ -97,21 +99,21 @@ def solve_tril_16x16_kernel(

# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
(1, 16 * N_BLOCKS),
(16 * N_BLOCKS, 1))
nblks_vec16 = -extract_slice(local_ori_A, (i, 0),
(1, 16 * N_BLOCKS),
(16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))

dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
dot_product = tl.sum(dot_tmp, 0)
b_a = b_a + dot_product

b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(ful=b_A,
sub=b_a_new_expanded,
offsets=[0, i, 0],
sizes=[N_BLOCKS, 1, 16],
strides=[1, 1, 1])
b_A = insert_slice(ful=b_A,
sub=b_a_new_expanded,
offsets=[0, i, 0],
sizes=[N_BLOCKS, 1, 16],
strides=[1, 1, 1])

on_diagonal = (rows == cols)
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
Expand Down Expand Up @@ -288,9 +290,9 @@ def merge_16x16_to_64x64_inverse_kernel(

# build Ai_22_32 (32 * 32)
Ai_22_32 = tl.zeros((32, 32), tl.float32)
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))

# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
Expand All @@ -302,9 +304,9 @@ def merge_16x16_to_64x64_inverse_kernel(

# build Ai_11_32 (32 * 32)
Ai_11_32 = tl.zeros((32, 32), tl.float32)
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))

Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")

Expand Down
20 changes: 11 additions & 9 deletions vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import (extract_slice,
get_vectorcore_num,
insert_slice)


@triton.jit
Expand Down Expand Up @@ -83,13 +85,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
Expand All @@ -100,14 +102,14 @@ def split_qkv_rmsnorm_rope_kernel(

roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q1,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q2,
offsets=(0, HALF_HEAD_DIM),
Expand Down Expand Up @@ -153,13 +155,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
Expand All @@ -170,14 +172,14 @@ def split_qkv_rmsnorm_rope_kernel(

roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM),
dtype=tl.bfloat16)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k1,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k2,
offsets=(0, HALF_HEAD_DIM),
Expand Down
34 changes: 17 additions & 17 deletions vllm_ascend/ops/triton/reject_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm.triton_utils import tl, triton

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num


def cal_grid_and_block_size(batch_size: int):
Expand Down Expand Up @@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)

for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
draft_token_id1 = get_element(draft_token_id, (pos, ))
target_argmax1 = get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
Expand Down Expand Up @@ -113,10 +113,10 @@ def rejection_greedy_sample_triton(
num_draft_tokens = end_idx - start_idx

for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
num_tokens1 = get_element(num_draft_tokens, (pos, ))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos, ))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
start_idx1 = get_element(start_idx, (pos, ))
is_greedy_mask1 = get_element(is_greedy_mask, (pos, ))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
Expand Down Expand Up @@ -167,12 +167,12 @@ def rejection_random_sample_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
not_greedy = get_element(not_greedy_mask, (req_i, ))
if not_greedy:
rejected = False
start_idx = tl.get_element(start_idxs, (req_i, ))
start_idx = get_element(start_idxs, (req_i, ))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i, ))
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx +
Expand Down Expand Up @@ -234,9 +234,9 @@ def expand_kernel(
src_val = tl.where(src_val == replace_from, replace_to, src_val)

for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
num_tokens1 = get_element(num_tokens, (i, ))
start_idx1 = get_element(start_idx, (i, ))
src_val1 = get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
Expand Down Expand Up @@ -292,7 +292,7 @@ def sample_recovered_tokens_kernel(
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
max_p = get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
Expand All @@ -318,7 +318,7 @@ def sample_recovered_tokens_kernel(
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
max_p = get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
Expand Down Expand Up @@ -407,16 +407,16 @@ def rejection_random_sample_block_verify_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
not_greedy = get_element(not_greedy_mask, (req_i, ))
if not_greedy:

rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i, ))
start_idx = get_element(start_idxs, (req_i, ))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i, ))

for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
Expand Down
40 changes: 39 additions & 1 deletion vllm_ascend/ops/triton/triton_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
from typing import Any, Dict

import torch
from vllm.triton_utils import HAS_TRITON, triton
from vllm.triton_utils import HAS_TRITON, tl, triton

_NUM_AICORE = -1
_NUM_VECTORCORE = -1
_extension_module = None

if HAS_TRITON:
try:
import triton.language.extra.cann.extension as _extension_module # type: ignore
except ImportError:
_extension_module = None


def _resolve_triton_ascend_op(op_name: str):
if not HAS_TRITON:
raise RuntimeError(
f"Triton op '{op_name}' cannot be resolved because HAS_TRITON is False"
)

if _extension_module is not None:
extension_op = getattr(_extension_module, op_name, None)
if extension_op is not None:
return extension_op

tl_op = getattr(tl, op_name, None)
if tl_op is not None:
return tl_op

raise RuntimeError(
f"Failed to resolve Triton op '{op_name}': "
"neither triton.language.extra.cann.extension nor triton.language provides it."
)


if HAS_TRITON:
insert_slice = _resolve_triton_ascend_op("insert_slice")
extract_slice = _resolve_triton_ascend_op("extract_slice")
get_element = _resolve_triton_ascend_op("get_element")
else:
insert_slice = None
extract_slice = None
get_element = None


def init_device_properties_triton():
Expand Down
Loading