Skip to content

Commit 381b660

Browse files
author
Rahul Batra
committed
RMSNorm blocked implementation
1 parent 086312b commit 381b660

File tree

1 file changed

+63
-31
lines changed

1 file changed

+63
-31
lines changed

python/perf-kernels/rmsnorm.py

+63-31
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,67 @@ def get_autotune_config():
4646

4747
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
4848
@triton.jit
49-
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
49+
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
5050
BLOCK_SIZE: tl.constexpr):
51-
row_start = tl.program_id(0)
52-
row_step = tl.num_programs(0)
53-
col_offsets = tl.arange(0, BLOCK_SIZE)
51+
row_idx = tl.program_id(0)
52+
53+
#Calculate squared mean by block
54+
row_start_ptr = input_ptr + row_idx * input_row_stride
55+
row_sum = 0.0
56+
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
57+
#tl.device_print("n_cols_blks",n_cols_blks)
58+
for b in tl.range(0, n_cols_blks):
59+
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
60+
input_ptrs = row_start_ptr + col_offsets
61+
row_block = tl.load(input_ptrs, cache_modifier=".cg")
62+
row_block = row_block * row_block #square every value the block
63+
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row
64+
65+
col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
66+
input_ptrs = row_start_ptr + col_offsets
5467
mask = col_offsets < n_cols
55-
for row_idx in tl.range(row_start, n_rows, row_step):
56-
row_start_ptr = input_ptr + row_idx * input_row_stride
68+
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
69+
row_block = row_block * row_block #square every value the block
70+
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row
71+
72+
73+
row_norm = row_sum + eps
74+
row_norm = tl.rsqrt(row_norm)
75+
76+
#Blocked normalization
77+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
78+
#for b in tl.range(0, n_cols, BLOCK_SIZE):
79+
for b in tl.range(0, n_cols_blks):
80+
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
5781
input_ptrs = row_start_ptr + col_offsets
58-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
59-
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
60-
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
61-
row_norm = row * row #square each value
62-
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
63-
row_norm = row_norm / n_cols #divide by n_cols
64-
row_norm = row_norm + epsilon #add epsilon
65-
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
66-
rms_norm = row * row_norm #multiply each x by normalization value
67-
rms_norm = rms_norm * g #element wise multiplication with g
68-
69-
output_row_start_ptr = output_ptr + row_idx * output_row_stride
82+
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input
83+
g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g
84+
output = row_block * row_norm #element wise multiply with rms_norm
85+
output = output * g #element wise multiplication with g
86+
7087
output_ptrs = output_row_start_ptr + col_offsets
71-
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
72-
tl.store(output_ptrs, rms_norm, mask=mask)
88+
tl.store(output_ptrs, output)
89+
90+
col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
91+
input_ptrs = row_start_ptr + col_offsets
92+
mask = col_offsets < n_cols
93+
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
94+
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
95+
output = row_block * row_norm #element wise multiply with rms_norm
96+
output = output * g #element wise multiplication with g
97+
98+
#tl.device_print("output",output)
99+
output_ptrs = output_row_start_ptr + col_offsets
100+
tl.store(output_ptrs, output, mask=mask)
101+
73102

74103

75104
def triton_rmsnorm(x, g, epsilon=1e-6):
76105
n_rows, n_cols = x.shape
77-
BLOCK_SIZE = triton.next_power_of_2(n_cols)
106+
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
107+
#performance can drop significantly for larger n_cols.
108+
MAX_FUSED_SIZE = 65536 // x.element_size()
109+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
78110

79111
y = torch.empty_like(x, device='cuda')
80112

@@ -84,7 +116,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):
84116

85117
return y
86118

87-
88119
def torch_rmsnorm(x, g):
89120
M, N = x.shape
90121
if hasattr(torch.nn, 'RMSNorm'):
@@ -95,15 +126,17 @@ def torch_rmsnorm(x, g):
95126
rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g
96127
return rms_norm
97128

98-
129+
# yapf: disable
99130
@pytest.mark.parametrize('M, N', [
100-
(1, 4),
101-
(2, 10),
102-
(8192, 4096),
103-
(4096, 8192),
104-
(1, 8192),
105-
(873, 1245),
106-
])
131+
(1, 4),
132+
(2, 10),
133+
(8192, 4096),
134+
(4096, 8192),
135+
(1, 8192),
136+
(873, 1245),
137+
(1, 98304)
138+
])
139+
# yapf: enable
107140
def test_rmsnorm(M, N):
108141
torch.manual_seed(0)
109142
x = torch.randn(M, N, device='cuda')
@@ -114,7 +147,6 @@ def test_rmsnorm(M, N):
114147

115148
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
116149

117-
118150
#Benchmark
119151
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
120152

0 commit comments

Comments
 (0)