@@ -46,35 +46,47 @@ def get_autotune_config():
46
46
47
47
@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
48
48
@triton .jit
49
- def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
49
+ def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , eps ,
50
50
BLOCK_SIZE : tl .constexpr ):
51
51
row_start = tl .program_id (0 )
52
- row_step = tl .num_programs (0 )
53
- col_offsets = tl .arange (0 , BLOCK_SIZE )
54
- mask = col_offsets < n_cols
55
- for row_idx in tl .range (row_start , n_rows , row_step ):
56
- row_start_ptr = input_ptr + row_idx * input_row_stride
52
+ row_idx = row_start
53
+
54
+ #Calculate squared mean by block
55
+ row_start_ptr = input_ptr + row_idx * input_row_stride
56
+ row_sum = 0.0
57
+ for b in tl .range (0 , n_cols , BLOCK_SIZE ):
58
+ col_offsets = b + tl .arange (0 , BLOCK_SIZE )
59
+ input_ptrs = row_start_ptr + col_offsets
60
+ mask = col_offsets < n_cols
61
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
62
+ row_block = row_block * row_block #square every value the block
63
+ row_sum += (tl .sum (row_block , axis = - 1 ) / n_cols
64
+ ) #tl.sum across row, divide by block_size and add it running sum
65
+
66
+ row_norm = row_sum + eps
67
+ row_norm = tl .rsqrt (row_norm )
68
+
69
+ #Blocked normalization
70
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
71
+ for b in tl .range (0 , n_cols , BLOCK_SIZE ):
72
+ col_offsets = b + tl .arange (0 , BLOCK_SIZE )
57
73
input_ptrs = row_start_ptr + col_offsets
58
- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
59
- row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
60
- g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 , cache_modifier = ".cg" )
61
- row_norm = row * row #square each value
62
- row_norm = tl .sum (row_norm , axis = - 1 ) #sum across columns(axis=-1)
63
- row_norm = row_norm / n_cols #divide by n_cols
64
- row_norm = row_norm + epsilon #add epsilon
65
- row_norm = tl .rsqrt (row_norm ) #take rsqrt, this is normalization value
66
- rms_norm = row * row_norm #multiply each x by normalization value
67
- rms_norm = rms_norm * g #element wise multiplication with g
68
-
69
- output_row_start_ptr = output_ptr + row_idx * output_row_stride
74
+ mask = col_offsets < n_cols
75
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of input
76
+ g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of g
77
+ output = row_block * row_norm #element wise multiply with rms_norm
78
+ output = output * g #element wise multiplication with g
79
+
70
80
output_ptrs = output_row_start_ptr + col_offsets
71
- output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
72
- tl .store (output_ptrs , rms_norm , mask = mask )
81
+ tl .store (output_ptrs , output , mask = mask )
73
82
74
83
75
84
def rmsnorm (x , epsilon = 1e-6 ):
76
85
n_rows , n_cols = x .shape
77
- BLOCK_SIZE = triton .next_power_of_2 (n_cols )
86
+ #Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
87
+ #performance can drop significantly for larger n_cols.
88
+ MAX_FUSED_SIZE = 65536 // x .element_size ()
89
+ BLOCK_SIZE = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
78
90
79
91
y = torch .empty_like (x , device = 'cuda' )
80
92
g = torch .ones ((1 , n_cols ), device = 'cuda' )
@@ -87,21 +99,15 @@ def rmsnorm(x, epsilon=1e-6):
87
99
88
100
89
101
def run_rmsnorm (M , N ):
102
+ print (f"Running RMSNorm for shape ({ M } , { N } )" )
90
103
torch .manual_seed (0 )
91
104
x = torch .randn (M , N , device = 'cuda' )
92
105
y_triton = rmsnorm (x )
93
106
94
107
return y_triton
95
108
96
109
97
- @pytest .mark .parametrize ('M, N' , [
98
- (1 , 4 ),
99
- (2 , 10 ),
100
- (8192 , 4096 ),
101
- (4096 , 8192 ),
102
- (1 , 8192 ),
103
- (873 , 1245 ),
104
- ])
110
+ @pytest .mark .parametrize ('M, N' , [(1 , 4 ), (2 , 10 ), (8192 , 4096 ), (4096 , 8192 ), (1 , 8192 ), (873 , 1245 ), (1 , 98304 )])
105
111
def test_rmsnorm (M , N ):
106
112
torch .manual_seed (0 )
107
113
x = torch .randn (M , N , device = 'cuda' )
@@ -110,6 +116,7 @@ def test_rmsnorm(M, N):
110
116
rms_norm = torch .nn .RMSNorm (N , device = 'cuda' )
111
117
y_torch = rms_norm (x )
112
118
119
+ print (f"y_triton={ y_triton } " )
113
120
assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
114
121
115
122
0 commit comments