@@ -47,15 +47,29 @@ def get_autotune_config():
47
47
@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
48
48
@triton .jit
49
49
def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , eps ,
50
- BLOCK_SIZE : tl .constexpr ):
50
+ BLOCK_SIZE : tl .constexpr , use_mask : tl . constexpr ):
51
51
row_start = tl .program_id (0 )
52
52
row_idx = row_start
53
53
54
54
#Calculate squared mean by block
55
55
row_start_ptr = input_ptr + row_idx * input_row_stride
56
56
row_sum = 0.0
57
- for b in tl .range (0 , n_cols , BLOCK_SIZE ):
58
- col_offsets = b + tl .arange (0 , BLOCK_SIZE )
57
+ loop_num = tl .cdiv (n_cols , BLOCK_SIZE )
58
+ if use_mask :
59
+ loop_num -= 1
60
+ #for b in tl.range(0, n_cols, BLOCK_SIZE):
61
+ loop_num_t = loop_num
62
+ for b in tl .range (0 , loop_num_t ):
63
+ col_offsets = b * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
64
+ input_ptrs = row_start_ptr + col_offsets
65
+ mask = col_offsets < n_cols
66
+ row_block = tl .load (input_ptrs , cache_modifier = ".cg" )
67
+ row_block = row_block * row_block #square every value the block
68
+ row_sum += (tl .sum (row_block , axis = - 1 ) / n_cols
69
+ ) #tl.sum across row, divide by block_size and add it running sum
70
+
71
+ if use_mask :
72
+ col_offsets = loop_num * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
59
73
input_ptrs = row_start_ptr + col_offsets
60
74
mask = col_offsets < n_cols
61
75
row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
@@ -67,9 +81,26 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
67
81
row_norm = tl .rsqrt (row_norm )
68
82
69
83
#Blocked normalization
84
+ loop_num_t = loop_num
70
85
output_row_start_ptr = output_ptr + row_idx * output_row_stride
71
- for b in tl .range (0 , n_cols , BLOCK_SIZE ):
72
- col_offsets = b + tl .arange (0 , BLOCK_SIZE )
86
+ #for b in tl.range(0, n_cols, BLOCK_SIZE):
87
+ for b in tl .range (0 , loop_num_t ):
88
+ col_offsets = b * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
89
+ input_ptrs = row_start_ptr + col_offsets
90
+ mask = col_offsets < n_cols
91
+ #row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
92
+ #g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
93
+ row_block = tl .load (input_ptrs , cache_modifier = ".cg" ) #load block of input
94
+ g = tl .load (g_ptr + col_offsets , cache_modifier = ".cg" ) #load block of g
95
+ output = row_block * row_norm #element wise multiply with rms_norm
96
+ output = output * g #element wise multiplication with g
97
+
98
+ output_ptrs = output_row_start_ptr + col_offsets
99
+ #tl.store(output_ptrs, output, mask=mask)
100
+ tl .store (output_ptrs , output )
101
+
102
+ if use_mask :
103
+ col_offsets = loop_num * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
73
104
input_ptrs = row_start_ptr + col_offsets
74
105
mask = col_offsets < n_cols
75
106
row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of input
@@ -81,7 +112,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
81
112
tl .store (output_ptrs , output , mask = mask )
82
113
83
114
84
- def rmsnorm (x , epsilon = 1e-6 ):
115
+ def rmsnorm (x , epsilon = 1e-6 , use_mask = 1 ):
85
116
n_rows , n_cols = x .shape
86
117
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
87
118
#performance can drop significantly for larger n_cols.
@@ -93,7 +124,7 @@ def rmsnorm(x, epsilon=1e-6):
93
124
94
125
num_programs = n_rows
95
126
grid = lambda meta : (num_programs , )
96
- rms_kernel [grid ](y , x , g , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , BLOCK_SIZE )
127
+ rms_kernel [grid ](y , x , g , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , BLOCK_SIZE , use_mask )
97
128
98
129
return y
99
130
0 commit comments