@@ -46,35 +46,47 @@ def get_autotune_config():
46
46
47
47
@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
48
48
@triton .jit
49
- def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
49
+ def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , eps ,
50
50
BLOCK_SIZE : tl .constexpr ):
51
51
row_start = tl .program_id (0 )
52
- row_step = tl .num_programs (0 )
53
- col_offsets = tl .arange (0 , BLOCK_SIZE )
54
- mask = col_offsets < n_cols
55
- for row_idx in tl .range (row_start , n_rows , row_step ):
56
- row_start_ptr = input_ptr + row_idx * input_row_stride
52
+ row_idx = row_start
53
+
54
+ #Calculate squared mean by block
55
+ row_start_ptr = input_ptr + row_idx * input_row_stride
56
+ row_sum = 0.0
57
+ for b in tl .range (0 , n_cols , BLOCK_SIZE ):
58
+ col_offsets = b + tl .arange (0 , BLOCK_SIZE )
59
+ input_ptrs = row_start_ptr + col_offsets
60
+ mask = col_offsets < n_cols
61
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
62
+ row_block = row_block * row_block #square every value the block
63
+ row_sum += (tl .sum (row_block , axis = - 1 ) / n_cols
64
+ ) #tl.sum across row, divide by block_size and add it running sum
65
+
66
+ row_norm = row_sum + eps
67
+ row_norm = tl .rsqrt (row_norm )
68
+
69
+ #Blocked normalization
70
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
71
+ for b in tl .range (0 , n_cols , BLOCK_SIZE ):
72
+ col_offsets = b + tl .arange (0 , BLOCK_SIZE )
57
73
input_ptrs = row_start_ptr + col_offsets
58
- input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
59
- row = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" )
60
- g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 )
61
- row_norm = row * row #square each value
62
- row_norm = tl .sum (row_norm , axis = - 1 ) #sum across columns(axis=-1)
63
- row_norm = row_norm / n_cols #divide by n_cols
64
- row_norm = row_norm + epsilon #add epsilon
65
- row_norm = tl .rsqrt (row_norm ) #take rsqrt, this is normalization value
66
- rms_norm = row * row_norm #multiply each x by normalization value
67
- rms_norm = rms_norm * g #element wise multiplication with g
68
-
69
- output_row_start_ptr = output_ptr + row_idx * output_row_stride
74
+ mask = col_offsets < n_cols
75
+ row_block = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of input
76
+ g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 , cache_modifier = ".cg" ) #load block of g
77
+ output = row_block * row_norm #element wise multiply with rms_norm
78
+ output = output * g #element wise multiplication with g
79
+
70
80
output_ptrs = output_row_start_ptr + col_offsets
71
- output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
72
- tl .store (output_ptrs , rms_norm , mask = mask )
81
+ tl .store (output_ptrs , output , mask = mask )
73
82
74
83
75
84
def triton_rmsnorm (x , g , epsilon = 1e-6 ):
76
85
n_rows , n_cols = x .shape
77
- BLOCK_SIZE = triton .next_power_of_2 (n_cols )
86
+ #Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
87
+ #performance can drop significantly for larger n_cols.
88
+ MAX_FUSED_SIZE = 65536 // x .element_size ()
89
+ BLOCK_SIZE = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
78
90
79
91
y = torch .empty_like (x , device = 'cuda' )
80
92
@@ -84,7 +96,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):
84
96
85
97
return y
86
98
87
-
88
99
def torch_rmsnorm (x , g ):
89
100
M , N = x .shape
90
101
if hasattr (torch .nn , 'RMSNorm' ):
@@ -95,15 +106,7 @@ def torch_rmsnorm(x, g):
95
106
rms_norm = torch .div (x , rms .unsqueeze (1 ).repeat (1 , N )) * g
96
107
return rms_norm
97
108
98
-
99
- @pytest .mark .parametrize ('M, N' , [
100
- (1 , 4 ),
101
- (2 , 10 ),
102
- (8192 , 4096 ),
103
- (4096 , 8192 ),
104
- (1 , 8192 ),
105
- (873 , 1245 ),
106
- ])
109
+ @pytest .mark .parametrize ('M, N' , [(1 , 4 ), (2 , 10 ), (8192 , 4096 ), (4096 , 8192 ), (1 , 8192 ), (873 , 1245 ), (1 , 98304 )])
107
110
def test_rmsnorm (M , N ):
108
111
torch .manual_seed (0 )
109
112
x = torch .randn (M , N , device = 'cuda' )
@@ -112,6 +115,7 @@ def test_rmsnorm(M, N):
112
115
113
116
y_torch = torch_rmsnorm (x , g )
114
117
118
+ print (f"y_triton={ y_triton } " )
115
119
assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
116
120
117
121
0 commit comments