Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions python/sgl_kernel_npu/sgl_kernel_npu/norm/add_rmsnorm_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,75 @@ def add_rmsnorm_bias(
batch_size,
)
return output, output2


@triton.jit
def add_gemma_rms_norm_kernel(
hidden_state_ptr,
hidden_state_stride_bs,
weight_ptr,
residual_ptr,
add_output_ptr,
norm_output_ptr,
variance_epsilon,
batch,
dim: tl.constexpr,
BLOCK_M: tl.constexpr,
):
core_id = tl.program_id(0)
core_num = tl.num_programs(0)
batch_per_core = tl.cdiv(batch, core_num)
start_batch = core_id * batch_per_core
end_batch = tl.minimum(start_batch + batch_per_core, batch)
offset_d = tl.arange(0, dim)

for row_start in tl.range(start_batch, end_batch, BLOCK_M):
offset_row = row_start + tl.arange(0, BLOCK_M)
offset_hidden = offset_row[:, None] * hidden_state_stride_bs + offset_d[None, :]
mask_hidden = offset_row < batch
mask_bs = mask_hidden[:, None]

x = tl.load(hidden_state_ptr + offset_hidden, mask=mask_bs)
residual = tl.load(residual_ptr + offset_hidden, mask=mask_bs)
add_val = x + residual
tl.store(add_output_ptr + offset_hidden, add_val, mask=mask_bs)

x_fp32 = add_val.to(tl.float32)
w = tl.load(weight_ptr + offset_d).to(tl.float32)
variance = tl.sum(x_fp32 * x_fp32, axis=-1) / dim
x_fp32 = x_fp32 * tl.rsqrt(variance[:, None] + variance_epsilon)
x_fp32 = x_fp32 * (w + 1.0)
output = x_fp32.to(x.dtype)
tl.store(norm_output_ptr + offset_hidden, output, mask=mask_bs)


def add_gemma_rms_norm(
hidden_state,
weight,
residual,
variance_epsilon,
):
batch, dim = hidden_state.shape
if dim > 2048:
raise NotImplementedError("dim > 2048 not supported")
ROW_BLOCK_SIZE = 4 # A safe default balancing parallelism and register pressure.
BLOCK_M = min(ROW_BLOCK_SIZE, batch)

_, num_vectorcore = get_device_properties()
grid = (num_vectorcore,)
add_output = torch.empty_like(hidden_state)
norm_output = torch.empty_like(hidden_state)

add_gemma_rms_norm_kernel[grid](
hidden_state,
hidden_state.stride(0),
weight,
residual,
add_output,
norm_output,
variance_epsilon,
batch,
dim,
BLOCK_M,
)
return norm_output, add_output
55 changes: 54 additions & 1 deletion tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias
from sgl_kernel_npu.norm.add_rmsnorm_bias import add_gemma_rms_norm, add_rmsnorm_bias


def add_rmsnorm_bias_quant_golden(
Expand Down Expand Up @@ -99,5 +99,58 @@ def test_add_rmsnorm_bias():
)


def reference_add_gemma_rms_norm(hidden_state, weight, residual, variance_epsilon):
# Step 1: Add
add_output = hidden_state + residual

# Step 2: RMS Norm (Gemma style: x * (w + 1) / sqrt(mean(x^2) + eps))
dtype = add_output.dtype
add_output_fp32 = add_output.to(torch.float32)
variance = torch.mean(add_output_fp32**2, dim=-1, keepdim=True)
norm_output_fp32 = add_output_fp32 * torch.rsqrt(variance + variance_epsilon)
norm_output_fp32 = norm_output_fp32 * (weight.to(torch.float32) + 1.0)
norm_output = norm_output_fp32.to(dtype)

return norm_output, add_output


def test_add_gemma_rms_norm():
torch.manual_seed(0)
device = torch.device("npu")

test_cases = [
(8, 512),
(16, 1024),
(32, 2048),
(1, 256),
]

variance_epsilon = 1e-6

for batch, dim in test_cases:
print(f"Testing batch={batch}, dim={dim}")

hidden_state = torch.randn(batch, dim, device=device, dtype=torch.float16)
residual = torch.randn(batch, dim, device=device, dtype=torch.float16)
weight = torch.randn(dim, device=device, dtype=torch.float16)

# Triton output
norm_out_triton, add_out_triton = add_gemma_rms_norm(
hidden_state, weight, residual, variance_epsilon
)

# Reference output
norm_out_ref, add_out_ref = reference_add_gemma_rms_norm(
hidden_state, weight, residual, variance_epsilon
)

# Compare
assert torch.allclose(add_out_triton, add_out_ref, atol=1e-2, rtol=1e-2)
assert torch.allclose(norm_out_triton, norm_out_ref, atol=1e-2, rtol=1e-2)

print("All tests passed!")


if __name__ == "__main__":
test_add_rmsnorm_bias()
test_add_gemma_rms_norm()