Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions python/sgl_kernel_npu/sgl_kernel_npu/norm/split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ def split_qkv_rmsnorm_rope_kernel(
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr,
NORMS: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr,
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_step = tl.num_programs(0)
# q
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if NORMS:
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size
Expand All @@ -41,25 +43,27 @@ def split_qkv_rmsnorm_rope_kernel(
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(
tl.bfloat16
input_values = tl.load(
input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0
).reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
if NORMS:
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (
normalized_values * weight_values + bias_values
).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
normalized_values = input_values.to(tl.bfloat16)

# rope
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
Expand Down Expand Up @@ -102,7 +106,8 @@ def split_qkv_rmsnorm_rope_kernel(
output_offset += output_offset_step

# k
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
if NORMS:
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size + q_hidden_size
Expand All @@ -116,20 +121,24 @@ def split_qkv_rmsnorm_rope_kernel(
.to(tl.float32)
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(
tl.bfloat16
if NORMS:
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (
normalized_values * weight_values + bias_values
).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
normalized_values = input_values.to(tl.bfloat16)

# # rope
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
Expand Down Expand Up @@ -192,14 +201,14 @@ def split_qkv_rmsnorm_rope(
input,
sin,
cos,
q_weight,
k_weight,
q_hidden_size,
kv_hidden_size,
head_dim,
eps,
q_bias,
k_bias,
eps=None,
q_weight=None,
k_weight=None,
q_bias=None,
k_bias=None,
):
_, num_vectorcore = get_device_properties()

Expand All @@ -222,6 +231,7 @@ def split_qkv_rmsnorm_rope(
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
BIAS = q_bias is not None
NORMS = eps is not None

kernel = kernels.get(
(
Expand All @@ -232,6 +242,7 @@ def split_qkv_rmsnorm_rope(
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
NORMS,
),
None,
)
Expand All @@ -255,6 +266,7 @@ def split_qkv_rmsnorm_rope(
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
NORMS,
head_dim,
head_dim // 2,
grid=(
Expand All @@ -272,6 +284,7 @@ def split_qkv_rmsnorm_rope(
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
NORMS,
)
] = kernel

Expand Down
18 changes: 16 additions & 2 deletions tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ def test_add_rmsnorm_bias():
residual = torch.randn(3, hidden_size).to(torch.bfloat16).npu()
weight = torch.randn(hidden_size).to(torch.bfloat16).npu()
bias = torch.randn(hidden_size).to(torch.bfloat16).npu()
res1, res2 = add_rmsnorm_bias(input, residual, weight, bias, 1e-6)
res1, res2 = add_rmsnorm_bias(
input,
residual,
weight,
1e-6,
norm_bias=bias,
quant_scale=None,
quant_offset=None,
)
ans1, ans2 = add_rmsnorm_bias_quant_golden(input, residual, weight, bias, 1e-6)

assert (
Expand Down Expand Up @@ -65,7 +73,13 @@ def test_add_rmsnorm_bias():
quant_scale = torch.randn(hidden_size).to(torch.bfloat16).npu()
quant_offset = torch.randn(hidden_size).to(torch.bfloat16).npu()
res1, res2 = add_rmsnorm_bias(
input, residual, weight, bias, 1e-6, quant_scale, quant_offset
input,
residual,
weight,
1e-6,
norm_bias=bias,
quant_scale=quant_scale,
quant_offset=quant_offset,
)
ans1, ans2 = add_rmsnorm_bias_quant_golden(
input, residual, weight, bias, 1e-6, quant_scale, quant_offset
Expand Down
Loading