Skip to content

Commit 918bcbc

Browse files
Rahul Batrawunhuang
Rahul Batra
authored and
wunhuang
committed
RMSNorm blocked implementation
1 parent 086312b commit 918bcbc

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

python/perf-kernels/rmsnorm.py

+35-31
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,47 @@ def get_autotune_config():
4646

4747
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
4848
@triton.jit
49-
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
49+
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
5050
BLOCK_SIZE: tl.constexpr):
5151
row_start = tl.program_id(0)
52-
row_step = tl.num_programs(0)
53-
col_offsets = tl.arange(0, BLOCK_SIZE)
54-
mask = col_offsets < n_cols
55-
for row_idx in tl.range(row_start, n_rows, row_step):
56-
row_start_ptr = input_ptr + row_idx * input_row_stride
52+
row_idx = row_start
53+
54+
#Calculate squared mean by block
55+
row_start_ptr = input_ptr + row_idx * input_row_stride
56+
row_sum = 0.0
57+
for b in tl.range(0, n_cols, BLOCK_SIZE):
58+
col_offsets = b + tl.arange(0, BLOCK_SIZE)
59+
input_ptrs = row_start_ptr + col_offsets
60+
mask = col_offsets < n_cols
61+
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
62+
row_block = row_block * row_block #square every value the block
63+
row_sum += (tl.sum(row_block, axis=-1) / n_cols
64+
) #tl.sum across row, divide by block_size and add it running sum
65+
66+
row_norm = row_sum + eps
67+
row_norm = tl.rsqrt(row_norm)
68+
69+
#Blocked normalization
70+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
71+
for b in tl.range(0, n_cols, BLOCK_SIZE):
72+
col_offsets = b + tl.arange(0, BLOCK_SIZE)
5773
input_ptrs = row_start_ptr + col_offsets
58-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
59-
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
60-
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
61-
row_norm = row * row #square each value
62-
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
63-
row_norm = row_norm / n_cols #divide by n_cols
64-
row_norm = row_norm + epsilon #add epsilon
65-
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
66-
rms_norm = row * row_norm #multiply each x by normalization value
67-
rms_norm = rms_norm * g #element wise multiplication with g
68-
69-
output_row_start_ptr = output_ptr + row_idx * output_row_stride
74+
mask = col_offsets < n_cols
75+
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
76+
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
77+
output = row_block * row_norm #element wise multiply with rms_norm
78+
output = output * g #element wise multiplication with g
79+
7080
output_ptrs = output_row_start_ptr + col_offsets
71-
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
72-
tl.store(output_ptrs, rms_norm, mask=mask)
81+
tl.store(output_ptrs, output, mask=mask)
7382

7483

7584
def triton_rmsnorm(x, g, epsilon=1e-6):
7685
n_rows, n_cols = x.shape
77-
BLOCK_SIZE = triton.next_power_of_2(n_cols)
86+
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
87+
#performance can drop significantly for larger n_cols.
88+
MAX_FUSED_SIZE = 65536 // x.element_size()
89+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
7890

7991
y = torch.empty_like(x, device='cuda')
8092

@@ -84,7 +96,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):
8496

8597
return y
8698

87-
8899
def torch_rmsnorm(x, g):
89100
M, N = x.shape
90101
if hasattr(torch.nn, 'RMSNorm'):
@@ -95,15 +106,7 @@ def torch_rmsnorm(x, g):
95106
rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g
96107
return rms_norm
97108

98-
99-
@pytest.mark.parametrize('M, N', [
100-
(1, 4),
101-
(2, 10),
102-
(8192, 4096),
103-
(4096, 8192),
104-
(1, 8192),
105-
(873, 1245),
106-
])
109+
@pytest.mark.parametrize('M, N', [(1, 4), (2, 10), (8192, 4096), (4096, 8192), (1, 8192), (873, 1245), (1, 98304)])
107110
def test_rmsnorm(M, N):
108111
torch.manual_seed(0)
109112
x = torch.randn(M, N, device='cuda')
@@ -112,6 +115,7 @@ def test_rmsnorm(M, N):
112115

113116
y_torch = torch_rmsnorm(x, g)
114117

118+
print(f"y_triton={y_triton}")
115119
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
116120

117121

0 commit comments

Comments
 (0)