Skip to content

Commit 207f73a

Browse files
author
Rahul Batra
committed
Add use mask for loads
1 parent bbdc10b commit 207f73a

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

python/perf-kernels/rmsnorm.py

+38-7
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,29 @@ def get_autotune_config():
4747
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
4848
@triton.jit
4949
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
50-
BLOCK_SIZE: tl.constexpr):
50+
BLOCK_SIZE: tl.constexpr, use_mask: tl.constexpr):
5151
row_start = tl.program_id(0)
5252
row_idx = row_start
5353

5454
#Calculate squared mean by block
5555
row_start_ptr = input_ptr + row_idx * input_row_stride
5656
row_sum = 0.0
57-
for b in tl.range(0, n_cols, BLOCK_SIZE):
58-
col_offsets = b + tl.arange(0, BLOCK_SIZE)
57+
loop_num = tl.cdiv(n_cols, BLOCK_SIZE)
58+
if use_mask:
59+
loop_num -= 1
60+
#for b in tl.range(0, n_cols, BLOCK_SIZE):
61+
loop_num_t = loop_num
62+
for b in tl.range(0, loop_num_t):
63+
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64+
input_ptrs = row_start_ptr + col_offsets
65+
mask = col_offsets < n_cols
66+
row_block = tl.load(input_ptrs, cache_modifier=".cg")
67+
row_block = row_block * row_block #square every value the block
68+
row_sum += (tl.sum(row_block, axis=-1) / n_cols
69+
) #tl.sum across row, divide by block_size and add it running sum
70+
71+
if use_mask:
72+
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
5973
input_ptrs = row_start_ptr + col_offsets
6074
mask = col_offsets < n_cols
6175
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
@@ -67,9 +81,26 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
6781
row_norm = tl.rsqrt(row_norm)
6882

6983
#Blocked normalization
84+
loop_num_t = loop_num
7085
output_row_start_ptr = output_ptr + row_idx * output_row_stride
71-
for b in tl.range(0, n_cols, BLOCK_SIZE):
72-
col_offsets = b + tl.arange(0, BLOCK_SIZE)
86+
#for b in tl.range(0, n_cols, BLOCK_SIZE):
87+
for b in tl.range(0, loop_num_t):
88+
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
89+
input_ptrs = row_start_ptr + col_offsets
90+
mask = col_offsets < n_cols
91+
#row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
92+
#g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
93+
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input
94+
g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g
95+
output = row_block * row_norm #element wise multiply with rms_norm
96+
output = output * g #element wise multiplication with g
97+
98+
output_ptrs = output_row_start_ptr + col_offsets
99+
#tl.store(output_ptrs, output, mask=mask)
100+
tl.store(output_ptrs, output)
101+
102+
if use_mask:
103+
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
73104
input_ptrs = row_start_ptr + col_offsets
74105
mask = col_offsets < n_cols
75106
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
@@ -81,7 +112,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
81112
tl.store(output_ptrs, output, mask=mask)
82113

83114

84-
def rmsnorm(x, epsilon=1e-6):
115+
def rmsnorm(x, epsilon=1e-6, use_mask=1):
85116
n_rows, n_cols = x.shape
86117
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
87118
#performance can drop significantly for larger n_cols.
@@ -93,7 +124,7 @@ def rmsnorm(x, epsilon=1e-6):
93124

94125
num_programs = n_rows
95126
grid = lambda meta: (num_programs, )
96-
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE)
127+
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, use_mask)
97128

98129
return y
99130

0 commit comments

Comments
 (0)