diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/norm/add_rmsnorm_bias.py b/python/sgl_kernel_npu/sgl_kernel_npu/norm/add_rmsnorm_bias.py index 89b4161c2..e63024fbc 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/norm/add_rmsnorm_bias.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/norm/add_rmsnorm_bias.py @@ -144,3 +144,75 @@ def add_rmsnorm_bias( batch_size, ) return output, output2 + + +@triton.jit +def add_gemma_rms_norm_kernel( + hidden_state_ptr, + hidden_state_stride_bs, + weight_ptr, + residual_ptr, + add_output_ptr, + norm_output_ptr, + variance_epsilon, + batch, + dim: tl.constexpr, + BLOCK_M: tl.constexpr, +): + core_id = tl.program_id(0) + core_num = tl.num_programs(0) + batch_per_core = tl.cdiv(batch, core_num) + start_batch = core_id * batch_per_core + end_batch = tl.minimum(start_batch + batch_per_core, batch) + offset_d = tl.arange(0, dim) + + for row_start in tl.range(start_batch, end_batch, BLOCK_M): + offset_row = row_start + tl.arange(0, BLOCK_M) + offset_hidden = offset_row[:, None] * hidden_state_stride_bs + offset_d[None, :] + mask_hidden = offset_row < batch + mask_bs = mask_hidden[:, None] + + x = tl.load(hidden_state_ptr + offset_hidden, mask=mask_bs) + residual = tl.load(residual_ptr + offset_hidden, mask=mask_bs) + add_val = x + residual + tl.store(add_output_ptr + offset_hidden, add_val, mask=mask_bs) + + x_fp32 = add_val.to(tl.float32) + w = tl.load(weight_ptr + offset_d).to(tl.float32) + variance = tl.sum(x_fp32 * x_fp32, axis=-1) / dim + x_fp32 = x_fp32 * tl.rsqrt(variance[:, None] + variance_epsilon) + x_fp32 = x_fp32 * (w + 1.0) + output = x_fp32.to(x.dtype) + tl.store(norm_output_ptr + offset_hidden, output, mask=mask_bs) + + +def add_gemma_rms_norm( + hidden_state, + weight, + residual, + variance_epsilon, +): + batch, dim = hidden_state.shape + if dim > 2048: + raise NotImplementedError("dim > 2048 not supported") + ROW_BLOCK_SIZE = 4 # A safe default balancing parallelism and register pressure. + BLOCK_M = min(ROW_BLOCK_SIZE, batch) + + _, num_vectorcore = get_device_properties() + grid = (num_vectorcore,) + add_output = torch.empty_like(hidden_state) + norm_output = torch.empty_like(hidden_state) + + add_gemma_rms_norm_kernel[grid]( + hidden_state, + hidden_state.stride(0), + weight, + residual, + add_output, + norm_output, + variance_epsilon, + batch, + dim, + BLOCK_M, + ) + return norm_output, add_output diff --git a/tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py b/tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py index 82bf5f41a..8ac99a6fd 100644 --- a/tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py +++ b/tests/python/sgl_kernel_npu/test_add_rmsnorm_bias.py @@ -1,6 +1,6 @@ import numpy as np import torch -from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias +from sgl_kernel_npu.norm.add_rmsnorm_bias import add_gemma_rms_norm, add_rmsnorm_bias def add_rmsnorm_bias_quant_golden( @@ -99,5 +99,58 @@ def test_add_rmsnorm_bias(): ) +def reference_add_gemma_rms_norm(hidden_state, weight, residual, variance_epsilon): + # Step 1: Add + add_output = hidden_state + residual + + # Step 2: RMS Norm (Gemma style: x * (w + 1) / sqrt(mean(x^2) + eps)) + dtype = add_output.dtype + add_output_fp32 = add_output.to(torch.float32) + variance = torch.mean(add_output_fp32**2, dim=-1, keepdim=True) + norm_output_fp32 = add_output_fp32 * torch.rsqrt(variance + variance_epsilon) + norm_output_fp32 = norm_output_fp32 * (weight.to(torch.float32) + 1.0) + norm_output = norm_output_fp32.to(dtype) + + return norm_output, add_output + + +def test_add_gemma_rms_norm(): + torch.manual_seed(0) + device = torch.device("npu") + + test_cases = [ + (8, 512), + (16, 1024), + (32, 2048), + (1, 256), + ] + + variance_epsilon = 1e-6 + + for batch, dim in test_cases: + print(f"Testing batch={batch}, dim={dim}") + + hidden_state = torch.randn(batch, dim, device=device, dtype=torch.float16) + residual = torch.randn(batch, dim, device=device, dtype=torch.float16) + weight = torch.randn(dim, device=device, dtype=torch.float16) + + # Triton output + norm_out_triton, add_out_triton = add_gemma_rms_norm( + hidden_state, weight, residual, variance_epsilon + ) + + # Reference output + norm_out_ref, add_out_ref = reference_add_gemma_rms_norm( + hidden_state, weight, residual, variance_epsilon + ) + + # Compare + assert torch.allclose(add_out_triton, add_out_ref, atol=1e-2, rtol=1e-2) + assert torch.allclose(norm_out_triton, norm_out_ref, atol=1e-2, rtol=1e-2) + + print("All tests passed!") + + if __name__ == "__main__": test_add_rmsnorm_bias() + test_add_gemma_rms_norm()